Fix pipeline transformers

This commit is contained in:
Arkadiusz Kondas 2018-02-14 19:05:48 +01:00
parent 998879b6fc
commit b4b190de7f
2 changed files with 15 additions and 1 deletions

View File

@ -54,7 +54,7 @@ class Pipeline implements Estimator
public function train(array $samples, array $targets): void
{
foreach ($this->transformers as $transformer) {
$transformer->fit($samples);
$transformer->fit($samples, $targets);
$transformer->transform($samples);
}

View File

@ -7,6 +7,8 @@ namespace Phpml\Tests;
use Phpml\Classification\SVC;
use Phpml\FeatureExtraction\TfIdfTransformer;
use Phpml\FeatureExtraction\TokenCountVectorizer;
use Phpml\FeatureSelection\ScoringFunction\ANOVAFValue;
use Phpml\FeatureSelection\SelectKBest;
use Phpml\ModelManager;
use Phpml\Pipeline;
use Phpml\Preprocessing\Imputer;
@ -106,6 +108,18 @@ class PipelineTest extends TestCase
$this->assertEquals($expected, $predicted);
}
public function testPipelineTransformersWithTargets() : void
{
$samples = [[1, 2, 1], [1, 3, 4], [5, 2, 1], [1, 3, 3], [1, 3, 4], [0, 3, 5]];
$targets = ['a', 'a', 'a', 'b', 'b', 'b'];
$pipeline = new Pipeline([$selector = new SelectKBest(2)], new SVC());
$pipeline->train($samples, $targets);
self::assertEquals([1.47058823, 4.0, 3.0], $selector->scores(), '', 0.00000001);
self::assertEquals(['b'], $pipeline->predict([[1, 3, 5]]));
}
public function testSaveAndRestore(): void
{
$pipeline = new Pipeline([