diff --git a/src/Phpml/FeatureExtraction/TokenCountVectorizer.php b/src/Phpml/FeatureExtraction/TokenCountVectorizer.php index 14fc69c..cde5278 100644 --- a/src/Phpml/FeatureExtraction/TokenCountVectorizer.php +++ b/src/Phpml/FeatureExtraction/TokenCountVectorizer.php @@ -23,6 +23,11 @@ class TokenCountVectorizer implements Vectorizer */ private $vocabulary; + /** + * @var array + */ + private $tokens; + /** * @var array */ @@ -47,8 +52,10 @@ class TokenCountVectorizer implements Vectorizer */ public function transform(array $samples): array { + $this->buildVocabulary($samples); + foreach ($samples as $index => $sample) { - $samples[$index] = $this->transformSample($sample); + $samples[$index] = $this->transformSample($index); } $samples = $this->checkDocumentFrequency($samples); @@ -65,14 +72,29 @@ class TokenCountVectorizer implements Vectorizer } /** - * @param string $sample + * @param array $samples + */ + private function buildVocabulary(array &$samples) + { + foreach ($samples as $index => $sample) { + $tokens = $this->tokenizer->tokenize($sample); + foreach ($tokens as $token) { + $this->addTokenToVocabulary($token); + } + $this->tokens[$index] = $tokens; + } + } + + /** + * @param int $index * * @return array */ - private function transformSample(string $sample) + private function transformSample(int $index) { $counts = []; - $tokens = $this->tokenizer->tokenize($sample); + $tokens = $this->tokens[$index]; + foreach ($tokens as $token) { $index = $this->getTokenIndex($token); $this->updateFrequency($token); @@ -83,21 +105,33 @@ class TokenCountVectorizer implements Vectorizer ++$counts[$index]; } + foreach ($this->vocabulary as $index) { + if (!isset($counts[$index])) { + $counts[$index] = 0; + } + } + return $counts; } /** * @param string $token * - * @return mixed + * @return int */ - private function getTokenIndex(string $token) + private function getTokenIndex(string $token): int + { + return $this->vocabulary[$token]; + } + + /** + * @param string $token + */ + private function addTokenToVocabulary(string $token) { if (!isset($this->vocabulary[$token])) { $this->vocabulary[$token] = count($this->vocabulary); } - - return $this->vocabulary[$token]; } /** @@ -122,7 +156,7 @@ class TokenCountVectorizer implements Vectorizer if ($this->minDF > 0) { $beyondMinimum = $this->getBeyondMinimumIndexes(count($samples)); foreach ($samples as $index => $sample) { - $samples[$index] = $this->unsetBeyondMinimum($sample, $beyondMinimum); + $samples[$index] = $this->resetBeyondMinimum($sample, $beyondMinimum); } } @@ -135,10 +169,10 @@ class TokenCountVectorizer implements Vectorizer * * @return array */ - private function unsetBeyondMinimum(array $sample, array $beyondMinimum) + private function resetBeyondMinimum(array $sample, array $beyondMinimum) { foreach ($beyondMinimum as $index) { - unset($sample[$index]); + $sample[$index] = 0; } return $sample; diff --git a/tests/Phpml/FeatureExtraction/TokenCountVectorizerTest.php b/tests/Phpml/FeatureExtraction/TokenCountVectorizerTest.php index 64ac569..5166575 100644 --- a/tests/Phpml/FeatureExtraction/TokenCountVectorizerTest.php +++ b/tests/Phpml/FeatureExtraction/TokenCountVectorizerTest.php @@ -17,16 +17,28 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase 'Mauris diam eros fringilla diam', ]; - $vocabulary = ['Lorem', 'ipsum', 'dolor', 'sit', 'amet', 'Mauris', 'placerat', 'diam', 'eros', 'fringilla']; - $vector = [ - [0 => 1, 1 => 1, 2 => 2, 3 => 1, 4 => 1], - [5 => 1, 6 => 1, 1 => 1, 2 => 1], - [5 => 1, 7 => 2, 8 => 1, 9 => 1], + $vocabulary = [ + 0 => 'Lorem', + 1 => 'ipsum', + 2 => 'dolor', + 3 => 'sit', + 4 => 'amet', + 5 => 'Mauris', + 6 => 'placerat', + 7 => 'diam', + 8 => 'eros', + 9 => 'fringilla', + ]; + + $tokensCounts = [ + [0 => 1, 1 => 1, 2 => 2, 3 => 1, 4 => 1, 5 => 0, 6 => 0, 7 => 0, 8 => 0, 9 => 0], + [0 => 0, 1 => 1, 2 => 1, 3 => 0, 4 => 0, 5 => 1, 6 => 1, 7 => 0, 8 => 0, 9 => 0], + [0 => 0, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 1, 6 => 0, 7 => 2, 8 => 1, 9 => 1], ]; $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer()); - $this->assertEquals($vector, $vectorizer->transform($samples)); + $this->assertEquals($tokensCounts, $vectorizer->transform($samples)); $this->assertEquals($vocabulary, $vectorizer->getVocabulary()); } @@ -40,34 +52,41 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase 'ipsum sit amet', ]; - $vocabulary = ['Lorem', 'ipsum', 'dolor', 'sit', 'amet']; - $vector = [ - [0 => 1, 1 => 1, 3 => 1, 4 => 1], - [0 => 1, 1 => 1, 3 => 1, 4 => 1], - [1 => 1, 3 => 1, 4 => 1], - [1 => 1, 3 => 1, 4 => 1], + $vocabulary = [ + 0 => 'Lorem', + 1 => 'ipsum', + 2 => 'dolor', + 3 => 'sit', + 4 => 'amet', + ]; + + $tokensCounts = [ + [0 => 1, 1 => 1, 2 => 0, 3 => 1, 4 => 1], + [0 => 1, 1 => 1, 2 => 0, 3 => 1, 4 => 1], + [0 => 0, 1 => 1, 2 => 0, 3 => 1, 4 => 1], + [0 => 0, 1 => 1, 2 => 0, 3 => 1, 4 => 1], ]; $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5); - $this->assertEquals($vector, $vectorizer->transform($samples)); + $this->assertEquals($tokensCounts, $vectorizer->transform($samples)); $this->assertEquals($vocabulary, $vectorizer->getVocabulary()); - // word at least in all samples + // word at least once in all samples $samples = [ 'Lorem ipsum dolor sit amet', - 'Morbi quis lacinia arcu. Sed eu sagittis Lorem', - 'Suspendisse gravida consequat eros Lorem', + 'Morbi quis sagittis Lorem', + 'eros Lorem', ]; - $vector = [ - [0 => 1], - [0 => 1], - [0 => 1], + $tokensCounts = [ + [0 => 1, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 0, 6 => 0, 7 => 0, 8 => 0], + [0 => 1, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 0, 6 => 0, 7 => 0, 8 => 0], + [0 => 1, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 0, 6 => 0, 7 => 0, 8 => 0], ]; $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1); - $this->assertEquals($vector, $vectorizer->transform($samples)); + $this->assertEquals($tokensCounts, $vectorizer->transform($samples)); } } diff --git a/tests/Phpml/Metric/AccuracyTest.php b/tests/Phpml/Metric/AccuracyTest.php index 1e71fcf..6f28d94 100644 --- a/tests/Phpml/Metric/AccuracyTest.php +++ b/tests/Phpml/Metric/AccuracyTest.php @@ -52,5 +52,4 @@ class AccuracyTest extends \PHPUnit_Framework_TestCase $this->assertEquals(0.959, $accuracy, '', 0.01); } - }