From 26f2cbabc4c40b950cc17989ecaa2fbe097eed2e Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Thu, 16 Jun 2016 10:26:29 +0200 Subject: [PATCH] fix Pipeline transformation --- src/Phpml/Pipeline.php | 2 +- tests/Phpml/PipelineTest.php | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/Phpml/Pipeline.php b/src/Phpml/Pipeline.php index 3f98f29..2c2d984 100644 --- a/src/Phpml/Pipeline.php +++ b/src/Phpml/Pipeline.php @@ -68,7 +68,7 @@ class Pipeline implements Estimator public function train(array $samples, array $targets) { foreach ($this->transformers as $transformer) { - $samples = $transformer->transform($samples); + $transformer->transform($samples); } $this->estimator->train($samples, $targets); diff --git a/tests/Phpml/PipelineTest.php b/tests/Phpml/PipelineTest.php index b00a60a..b1d5bf7 100644 --- a/tests/Phpml/PipelineTest.php +++ b/tests/Phpml/PipelineTest.php @@ -7,6 +7,9 @@ namespace tests; use Phpml\Classification\SVC; use Phpml\FeatureExtraction\TfIdfTransformer; use Phpml\Pipeline; +use Phpml\Preprocessing\Imputer; +use Phpml\Preprocessing\Normalizer; +use Phpml\Preprocessing\Imputer\Strategy\MostFrequentStrategy; class PipelineTest extends \PHPUnit_Framework_TestCase { @@ -22,4 +25,33 @@ class PipelineTest extends \PHPUnit_Framework_TestCase $this->assertEquals($transformers, $pipeline->getTransformers()); $this->assertEquals($estimator, $pipeline->getEstimator()); } + + public function testPipelineWorkflow() + { + $transformers = [ + new Imputer(null, new MostFrequentStrategy()), + new Normalizer(), + ]; + $estimator = new SVC(); + + $samples = [ + [1, -1, 2], + [2, 0, null], + [null, 1, -1], + ]; + + $targets = [ + 4, + 1, + 4 + ]; + + $pipeline = new Pipeline($transformers, $estimator); + $pipeline->train($samples, $targets); + + $predicted = $pipeline->predict([[0, 0, 0]]); + + $this->assertEquals(4, $predicted[0]); + } + }