diff --git a/src/Phpml/FeatureExtraction/TfIdfTransformer.php b/src/Phpml/FeatureExtraction/TfIdfTransformer.php index fade5b4..783bd46 100644 --- a/src/Phpml/FeatureExtraction/TfIdfTransformer.php +++ b/src/Phpml/FeatureExtraction/TfIdfTransformer.php @@ -15,10 +15,8 @@ class TfIdfTransformer implements Transformer /** * @param array $samples - * - * @return array */ - public function transform(array $samples): array + public function transform(array &$samples) { $this->countTokensFrequency($samples); @@ -32,8 +30,6 @@ class TfIdfTransformer implements Transformer $feature = $feature * $this->idf[$index]; } } - - return $samples; } /** diff --git a/src/Phpml/FeatureExtraction/TokenCountVectorizer.php b/src/Phpml/FeatureExtraction/TokenCountVectorizer.php index 5273f6b..f823778 100644 --- a/src/Phpml/FeatureExtraction/TokenCountVectorizer.php +++ b/src/Phpml/FeatureExtraction/TokenCountVectorizer.php @@ -48,10 +48,8 @@ class TokenCountVectorizer implements Transformer /** * @param array $samples - * - * @return array */ - public function transform(array $samples): array + public function transform(array &$samples) { $this->buildVocabulary($samples); @@ -60,8 +58,6 @@ class TokenCountVectorizer implements Transformer } $samples = $this->checkDocumentFrequency($samples); - - return $samples; } /** diff --git a/src/Phpml/Pipeline.php b/src/Phpml/Pipeline.php index f0017fe..3f98f29 100644 --- a/src/Phpml/Pipeline.php +++ b/src/Phpml/Pipeline.php @@ -18,7 +18,7 @@ class Pipeline implements Estimator /** * @param array|Transformer[] $transformers - * @param Estimator $estimator + * @param Estimator $estimator */ public function __construct(array $transformers = [], Estimator $estimator) { @@ -76,11 +76,11 @@ class Pipeline implements Estimator /** * @param array $samples + * * @return mixed */ public function predict(array $samples) { return $this->estimator->predict($samples); } - } diff --git a/src/Phpml/Transformer.php b/src/Phpml/Transformer.php index bdc809b..47c2ce3 100644 --- a/src/Phpml/Transformer.php +++ b/src/Phpml/Transformer.php @@ -8,8 +8,6 @@ interface Transformer { /** * @param array $samples - * - * @return array */ - public function transform(array $samples): array; + public function transform(array &$samples); } diff --git a/tests/Phpml/FeatureExtraction/TfIdfTransformerTest.php b/tests/Phpml/FeatureExtraction/TfIdfTransformerTest.php index 59d96c0..ca5db36 100644 --- a/tests/Phpml/FeatureExtraction/TfIdfTransformerTest.php +++ b/tests/Phpml/FeatureExtraction/TfIdfTransformerTest.php @@ -23,7 +23,8 @@ class TfIdfTransformerTest extends \PHPUnit_Framework_TestCase ]; $transformer = new TfIdfTransformer(); + $transformer->transform($samples); - $this->assertEquals($tfIdfSamples, $transformer->transform($samples), '', 0.001); + $this->assertEquals($tfIdfSamples, $samples, '', 0.001); } } diff --git a/tests/Phpml/FeatureExtraction/TokenCountVectorizerTest.php b/tests/Phpml/FeatureExtraction/TokenCountVectorizerTest.php index 5166575..80b7723 100644 --- a/tests/Phpml/FeatureExtraction/TokenCountVectorizerTest.php +++ b/tests/Phpml/FeatureExtraction/TokenCountVectorizerTest.php @@ -37,8 +37,9 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase ]; $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer()); + $vectorizer->transform($samples); - $this->assertEquals($tokensCounts, $vectorizer->transform($samples)); + $this->assertEquals($tokensCounts, $samples); $this->assertEquals($vocabulary, $vectorizer->getVocabulary()); } @@ -68,8 +69,9 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase ]; $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5); + $vectorizer->transform($samples); - $this->assertEquals($tokensCounts, $vectorizer->transform($samples)); + $this->assertEquals($tokensCounts, $samples); $this->assertEquals($vocabulary, $vectorizer->getVocabulary()); // word at least once in all samples @@ -86,7 +88,8 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase ]; $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1); + $vectorizer->transform($samples); - $this->assertEquals($tokensCounts, $vectorizer->transform($samples)); + $this->assertEquals($tokensCounts, $samples); } } diff --git a/tests/Phpml/PipelineTest.php b/tests/Phpml/PipelineTest.php index 270fb89..b00a60a 100644 --- a/tests/Phpml/PipelineTest.php +++ b/tests/Phpml/PipelineTest.php @@ -10,11 +10,10 @@ use Phpml\Pipeline; class PipelineTest extends \PHPUnit_Framework_TestCase { - public function testPipelineConstruction() { $transformers = [ - new TfIdfTransformer() + new TfIdfTransformer(), ]; $estimator = new SVC(); @@ -23,5 +22,4 @@ class PipelineTest extends \PHPUnit_Framework_TestCase $this->assertEquals($transformers, $pipeline->getTransformers()); $this->assertEquals($estimator, $pipeline->getEstimator()); } - }