change transformer behavior to reference

This commit is contained in:
Arkadiusz Kondas 2016-06-16 10:01:40 +02:00
parent 15519ba122
commit 7c5e79d2c6
7 changed files with 14 additions and 22 deletions

View File

@ -15,10 +15,8 @@ class TfIdfTransformer implements Transformer
/**
* @param array $samples
*
* @return array
*/
public function transform(array $samples): array
public function transform(array &$samples)
{
$this->countTokensFrequency($samples);
@ -32,8 +30,6 @@ class TfIdfTransformer implements Transformer
$feature = $feature * $this->idf[$index];
}
}
return $samples;
}
/**

View File

@ -48,10 +48,8 @@ class TokenCountVectorizer implements Transformer
/**
* @param array $samples
*
* @return array
*/
public function transform(array $samples): array
public function transform(array &$samples)
{
$this->buildVocabulary($samples);
@ -60,8 +58,6 @@ class TokenCountVectorizer implements Transformer
}
$samples = $this->checkDocumentFrequency($samples);
return $samples;
}
/**

View File

@ -18,7 +18,7 @@ class Pipeline implements Estimator
/**
* @param array|Transformer[] $transformers
* @param Estimator $estimator
* @param Estimator $estimator
*/
public function __construct(array $transformers = [], Estimator $estimator)
{
@ -76,11 +76,11 @@ class Pipeline implements Estimator
/**
* @param array $samples
*
* @return mixed
*/
public function predict(array $samples)
{
return $this->estimator->predict($samples);
}
}

View File

@ -8,8 +8,6 @@ interface Transformer
{
/**
* @param array $samples
*
* @return array
*/
public function transform(array $samples): array;
public function transform(array &$samples);
}

View File

@ -23,7 +23,8 @@ class TfIdfTransformerTest extends \PHPUnit_Framework_TestCase
];
$transformer = new TfIdfTransformer();
$transformer->transform($samples);
$this->assertEquals($tfIdfSamples, $transformer->transform($samples), '', 0.001);
$this->assertEquals($tfIdfSamples, $samples, '', 0.001);
}
}

View File

@ -37,8 +37,9 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer());
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $vectorizer->transform($samples));
$this->assertEquals($tokensCounts, $samples);
$this->assertEquals($vocabulary, $vectorizer->getVocabulary());
}
@ -68,8 +69,9 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5);
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $vectorizer->transform($samples));
$this->assertEquals($tokensCounts, $samples);
$this->assertEquals($vocabulary, $vectorizer->getVocabulary());
// word at least once in all samples
@ -86,7 +88,8 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1);
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $vectorizer->transform($samples));
$this->assertEquals($tokensCounts, $samples);
}
}

View File

@ -10,11 +10,10 @@ use Phpml\Pipeline;
class PipelineTest extends \PHPUnit_Framework_TestCase
{
public function testPipelineConstruction()
{
$transformers = [
new TfIdfTransformer()
new TfIdfTransformer(),
];
$estimator = new SVC();
@ -23,5 +22,4 @@ class PipelineTest extends \PHPUnit_Framework_TestCase
$this->assertEquals($transformers, $pipeline->getTransformers());
$this->assertEquals($estimator, $pipeline->getEstimator());
}
}