implement fit fot TokenCountVectorizer

This commit is contained in:
Arkadiusz Kondas 2016-06-17 00:33:48 +02:00
parent be7423350f
commit 424519cd83
2 changed files with 32 additions and 40 deletions

View File

@ -24,11 +24,6 @@ class TokenCountVectorizer implements Transformer
*/ */
private $vocabulary; private $vocabulary;
/**
* @var array
*/
private $tokens;
/** /**
* @var array * @var array
*/ */
@ -51,7 +46,7 @@ class TokenCountVectorizer implements Transformer
*/ */
public function fit(array $samples) public function fit(array $samples)
{ {
// TODO: Implement fit() method. $this->buildVocabulary($samples);
} }
/** /**
@ -59,13 +54,11 @@ class TokenCountVectorizer implements Transformer
*/ */
public function transform(array &$samples) public function transform(array &$samples)
{ {
$this->buildVocabulary($samples); foreach ($samples as &$sample) {
$this->transformSample($sample);
foreach ($samples as $index => $sample) {
$samples[$index] = $this->transformSample($index);
} }
$samples = $this->checkDocumentFrequency($samples); $this->checkDocumentFrequency($samples);
} }
/** /**
@ -86,22 +79,20 @@ class TokenCountVectorizer implements Transformer
foreach ($tokens as $token) { foreach ($tokens as $token) {
$this->addTokenToVocabulary($token); $this->addTokenToVocabulary($token);
} }
$this->tokens[$index] = $tokens;
} }
} }
/** /**
* @param int $index * @param string $sample
*
* @return array
*/ */
private function transformSample(int $index) private function transformSample(string &$sample)
{ {
$counts = []; $counts = [];
$tokens = $this->tokens[$index]; $tokens = $this->tokenizer->tokenize($sample);
foreach ($tokens as $token) { foreach ($tokens as $token) {
$index = $this->getTokenIndex($token); $index = $this->getTokenIndex($token);
if(false !== $index) {
$this->updateFrequency($token); $this->updateFrequency($token);
if (!isset($counts[$index])) { if (!isset($counts[$index])) {
$counts[$index] = 0; $counts[$index] = 0;
@ -109,6 +100,7 @@ class TokenCountVectorizer implements Transformer
++$counts[$index]; ++$counts[$index];
} }
}
foreach ($this->vocabulary as $index) { foreach ($this->vocabulary as $index) {
if (!isset($counts[$index])) { if (!isset($counts[$index])) {
@ -116,17 +108,17 @@ class TokenCountVectorizer implements Transformer
} }
} }
return $counts; $sample = $counts;
} }
/** /**
* @param string $token * @param string $token
* *
* @return int * @return int|bool
*/ */
private function getTokenIndex(string $token): int private function getTokenIndex(string $token)
{ {
return $this->vocabulary[$token]; return isset($this->vocabulary[$token]) ? $this->vocabulary[$token] : false;
} }
/** /**
@ -156,31 +148,25 @@ class TokenCountVectorizer implements Transformer
* *
* @return array * @return array
*/ */
private function checkDocumentFrequency(array $samples) private function checkDocumentFrequency(array &$samples)
{ {
if ($this->minDF > 0) { if ($this->minDF > 0) {
$beyondMinimum = $this->getBeyondMinimumIndexes(count($samples)); $beyondMinimum = $this->getBeyondMinimumIndexes(count($samples));
foreach ($samples as $index => $sample) { foreach ($samples as &$sample) {
$samples[$index] = $this->resetBeyondMinimum($sample, $beyondMinimum); $this->resetBeyondMinimum($sample, $beyondMinimum);
} }
} }
return $samples;
} }
/** /**
* @param array $sample * @param array $sample
* @param array $beyondMinimum * @param array $beyondMinimum
*
* @return array
*/ */
private function resetBeyondMinimum(array $sample, array $beyondMinimum) private function resetBeyondMinimum(array &$sample, array $beyondMinimum)
{ {
foreach ($beyondMinimum as $index) { foreach ($beyondMinimum as $index) {
$sample[$index] = 0; $sample[$index] = 0;
} }
return $sample;
} }
/** /**

View File

@ -37,10 +37,12 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
]; ];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer()); $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer());
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples); $vectorizer->fit($samples);
$this->assertEquals($vocabulary, $vectorizer->getVocabulary()); $this->assertEquals($vocabulary, $vectorizer->getVocabulary());
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples);
} }
public function testMinimumDocumentTokenCountFrequency() public function testMinimumDocumentTokenCountFrequency()
@ -69,11 +71,14 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
]; ];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5); $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5);
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples); $vectorizer->fit($samples);
$this->assertEquals($vocabulary, $vectorizer->getVocabulary()); $this->assertEquals($vocabulary, $vectorizer->getVocabulary());
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples);
// word at least once in all samples // word at least once in all samples
$samples = [ $samples = [
'Lorem ipsum dolor sit amet', 'Lorem ipsum dolor sit amet',
@ -88,6 +93,7 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
]; ];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1); $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1);
$vectorizer->fit($samples);
$vectorizer->transform($samples); $vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples); $this->assertEquals($tokensCounts, $samples);