change transformer behavior to reference

This commit is contained in:
Arkadiusz Kondas 2016-06-16 10:01:40 +02:00
parent 15519ba122
commit 7c5e79d2c6
7 changed files with 14 additions and 22 deletions

View File

@ -15,10 +15,8 @@ class TfIdfTransformer implements Transformer
/** /**
* @param array $samples * @param array $samples
*
* @return array
*/ */
public function transform(array $samples): array public function transform(array &$samples)
{ {
$this->countTokensFrequency($samples); $this->countTokensFrequency($samples);
@ -32,8 +30,6 @@ class TfIdfTransformer implements Transformer
$feature = $feature * $this->idf[$index]; $feature = $feature * $this->idf[$index];
} }
} }
return $samples;
} }
/** /**

View File

@ -48,10 +48,8 @@ class TokenCountVectorizer implements Transformer
/** /**
* @param array $samples * @param array $samples
*
* @return array
*/ */
public function transform(array $samples): array public function transform(array &$samples)
{ {
$this->buildVocabulary($samples); $this->buildVocabulary($samples);
@ -60,8 +58,6 @@ class TokenCountVectorizer implements Transformer
} }
$samples = $this->checkDocumentFrequency($samples); $samples = $this->checkDocumentFrequency($samples);
return $samples;
} }
/** /**

View File

@ -18,7 +18,7 @@ class Pipeline implements Estimator
/** /**
* @param array|Transformer[] $transformers * @param array|Transformer[] $transformers
* @param Estimator $estimator * @param Estimator $estimator
*/ */
public function __construct(array $transformers = [], Estimator $estimator) public function __construct(array $transformers = [], Estimator $estimator)
{ {
@ -76,11 +76,11 @@ class Pipeline implements Estimator
/** /**
* @param array $samples * @param array $samples
*
* @return mixed * @return mixed
*/ */
public function predict(array $samples) public function predict(array $samples)
{ {
return $this->estimator->predict($samples); return $this->estimator->predict($samples);
} }
} }

View File

@ -8,8 +8,6 @@ interface Transformer
{ {
/** /**
* @param array $samples * @param array $samples
*
* @return array
*/ */
public function transform(array $samples): array; public function transform(array &$samples);
} }

View File

@ -23,7 +23,8 @@ class TfIdfTransformerTest extends \PHPUnit_Framework_TestCase
]; ];
$transformer = new TfIdfTransformer(); $transformer = new TfIdfTransformer();
$transformer->transform($samples);
$this->assertEquals($tfIdfSamples, $transformer->transform($samples), '', 0.001); $this->assertEquals($tfIdfSamples, $samples, '', 0.001);
} }
} }

View File

@ -37,8 +37,9 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
]; ];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer()); $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer());
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $vectorizer->transform($samples)); $this->assertEquals($tokensCounts, $samples);
$this->assertEquals($vocabulary, $vectorizer->getVocabulary()); $this->assertEquals($vocabulary, $vectorizer->getVocabulary());
} }
@ -68,8 +69,9 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
]; ];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5); $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5);
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $vectorizer->transform($samples)); $this->assertEquals($tokensCounts, $samples);
$this->assertEquals($vocabulary, $vectorizer->getVocabulary()); $this->assertEquals($vocabulary, $vectorizer->getVocabulary());
// word at least once in all samples // word at least once in all samples
@ -86,7 +88,8 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
]; ];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1); $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1);
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $vectorizer->transform($samples)); $this->assertEquals($tokensCounts, $samples);
} }
} }

View File

@ -10,11 +10,10 @@ use Phpml\Pipeline;
class PipelineTest extends \PHPUnit_Framework_TestCase class PipelineTest extends \PHPUnit_Framework_TestCase
{ {
public function testPipelineConstruction() public function testPipelineConstruction()
{ {
$transformers = [ $transformers = [
new TfIdfTransformer() new TfIdfTransformer(),
]; ];
$estimator = new SVC(); $estimator = new SVC();
@ -23,5 +22,4 @@ class PipelineTest extends \PHPUnit_Framework_TestCase
$this->assertEquals($transformers, $pipeline->getTransformers()); $this->assertEquals($transformers, $pipeline->getTransformers());
$this->assertEquals($estimator, $pipeline->getEstimator()); $this->assertEquals($estimator, $pipeline->getEstimator());
} }
} }