From 2d3b44f1a048f8625c2ea6896105efe906539ee5 Mon Sep 17 00:00:00 2001 From: Maxime COLIN Date: Wed, 24 May 2017 09:06:54 +0200 Subject: [PATCH] Fix samples transformation in Pipeline training (#94) --- src/Phpml/Pipeline.php | 17 +++++----------- tests/Phpml/PipelineTest.php | 39 ++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/src/Phpml/Pipeline.php b/src/Phpml/Pipeline.php index a6b3d56..ca2914a 100644 --- a/src/Phpml/Pipeline.php +++ b/src/Phpml/Pipeline.php @@ -67,8 +67,11 @@ class Pipeline implements Estimator */ public function train(array $samples, array $targets) { - $this->fitTransformers($samples); - $this->transformSamples($samples); + foreach ($this->transformers as $transformer) { + $transformer->fit($samples); + $transformer->transform($samples); + } + $this->estimator->train($samples, $targets); } @@ -84,16 +87,6 @@ class Pipeline implements Estimator return $this->estimator->predict($samples); } - /** - * @param array $samples - */ - private function fitTransformers(array &$samples) - { - foreach ($this->transformers as $transformer) { - $transformer->fit($samples); - } - } - /** * @param array $samples */ diff --git a/tests/Phpml/PipelineTest.php b/tests/Phpml/PipelineTest.php index 4454dee..92a6223 100644 --- a/tests/Phpml/PipelineTest.php +++ b/tests/Phpml/PipelineTest.php @@ -6,11 +6,13 @@ namespace tests; use Phpml\Classification\SVC; use Phpml\FeatureExtraction\TfIdfTransformer; +use Phpml\FeatureExtraction\TokenCountVectorizer; use Phpml\Pipeline; use Phpml\Preprocessing\Imputer; use Phpml\Preprocessing\Normalizer; use Phpml\Preprocessing\Imputer\Strategy\MostFrequentStrategy; use Phpml\Regression\SVR; +use Phpml\Tokenization\WordTokenizer; use PHPUnit\Framework\TestCase; class PipelineTest extends TestCase @@ -65,4 +67,41 @@ class PipelineTest extends TestCase $this->assertEquals(4, $predicted[0]); } + + public function testPipelineTransformers() + { + $transformers = [ + new TokenCountVectorizer(new WordTokenizer()), + new TfIdfTransformer() + ]; + + $estimator = new SVC(); + + $samples = [ + 'Hello Paul', + 'Hello Martin', + 'Goodbye Tom', + 'Hello John', + 'Goodbye Alex', + 'Bye Tony', + ]; + + $targets = [ + 'greetings', + 'greetings', + 'farewell', + 'greetings', + 'farewell', + 'farewell', + ]; + + $pipeline = new Pipeline($transformers, $estimator); + $pipeline->train($samples, $targets); + + $expected = ['greetings', 'farewell']; + + $predicted = $pipeline->predict(['Hello Max', 'Goodbye Mark']); + + $this->assertEquals($expected, $predicted); + } }