diff --git a/src/Phpml/Pipeline.php b/src/Phpml/Pipeline.php index 2c2d984..f230db3 100644 --- a/src/Phpml/Pipeline.php +++ b/src/Phpml/Pipeline.php @@ -67,10 +67,7 @@ class Pipeline implements Estimator */ public function train(array $samples, array $targets) { - foreach ($this->transformers as $transformer) { - $transformer->transform($samples); - } - + $this->transformSamples($samples); $this->estimator->train($samples, $targets); } @@ -81,6 +78,18 @@ class Pipeline implements Estimator */ public function predict(array $samples) { + $this->transformSamples($samples); + return $this->estimator->predict($samples); } + + /** + * @param array $samples + */ + private function transformSamples(array &$samples) + { + foreach ($this->transformers as $transformer) { + $transformer->transform($samples); + } + } } diff --git a/src/Phpml/Preprocessing/Preprocessor.php b/src/Phpml/Preprocessing/Preprocessor.php index ff5530e..ae70941 100644 --- a/src/Phpml/Preprocessing/Preprocessor.php +++ b/src/Phpml/Preprocessing/Preprocessor.php @@ -8,5 +8,4 @@ use Phpml\Transformer; interface Preprocessor extends Transformer { - } diff --git a/tests/Phpml/PipelineTest.php b/tests/Phpml/PipelineTest.php index b1d5bf7..4e5815b 100644 --- a/tests/Phpml/PipelineTest.php +++ b/tests/Phpml/PipelineTest.php @@ -43,7 +43,7 @@ class PipelineTest extends \PHPUnit_Framework_TestCase $targets = [ 4, 1, - 4 + 4, ]; $pipeline = new Pipeline($transformers, $estimator); @@ -53,5 +53,4 @@ class PipelineTest extends \PHPUnit_Framework_TestCase $this->assertEquals(4, $predicted[0]); } - }