From b4b190de7fd624892bade2bbd9c9e39f57da4fad Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Wed, 14 Feb 2018 19:05:48 +0100 Subject: [PATCH] Fix pipeline transformers --- src/Pipeline.php | 2 +- tests/PipelineTest.php | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/Pipeline.php b/src/Pipeline.php index 480a980..d57da87 100644 --- a/src/Pipeline.php +++ b/src/Pipeline.php @@ -54,7 +54,7 @@ class Pipeline implements Estimator public function train(array $samples, array $targets): void { foreach ($this->transformers as $transformer) { - $transformer->fit($samples); + $transformer->fit($samples, $targets); $transformer->transform($samples); } diff --git a/tests/PipelineTest.php b/tests/PipelineTest.php index 86ff2a9..6f1562c 100644 --- a/tests/PipelineTest.php +++ b/tests/PipelineTest.php @@ -7,6 +7,8 @@ namespace Phpml\Tests; use Phpml\Classification\SVC; use Phpml\FeatureExtraction\TfIdfTransformer; use Phpml\FeatureExtraction\TokenCountVectorizer; +use Phpml\FeatureSelection\ScoringFunction\ANOVAFValue; +use Phpml\FeatureSelection\SelectKBest; use Phpml\ModelManager; use Phpml\Pipeline; use Phpml\Preprocessing\Imputer; @@ -106,6 +108,18 @@ class PipelineTest extends TestCase $this->assertEquals($expected, $predicted); } + public function testPipelineTransformersWithTargets() : void + { + $samples = [[1, 2, 1], [1, 3, 4], [5, 2, 1], [1, 3, 3], [1, 3, 4], [0, 3, 5]]; + $targets = ['a', 'a', 'a', 'b', 'b', 'b']; + + $pipeline = new Pipeline([$selector = new SelectKBest(2)], new SVC()); + $pipeline->train($samples, $targets); + + self::assertEquals([1.47058823, 4.0, 3.0], $selector->scores(), '', 0.00000001); + self::assertEquals(['b'], $pipeline->predict([[1, 3, 5]])); + } + public function testSaveAndRestore(): void { $pipeline = new Pipeline([