implement fit fot TokenCountVectorizer

This commit is contained in:
Arkadiusz Kondas 2016-06-17 00:33:48 +02:00
parent be7423350f
commit 424519cd83
2 changed files with 32 additions and 40 deletions

View File

@ -24,11 +24,6 @@ class TokenCountVectorizer implements Transformer
*/
private $vocabulary;
/**
* @var array
*/
private $tokens;
/**
* @var array
*/
@ -51,7 +46,7 @@ class TokenCountVectorizer implements Transformer
*/
public function fit(array $samples)
{
// TODO: Implement fit() method.
$this->buildVocabulary($samples);
}
/**
@ -59,13 +54,11 @@ class TokenCountVectorizer implements Transformer
*/
public function transform(array &$samples)
{
$this->buildVocabulary($samples);
foreach ($samples as $index => $sample) {
$samples[$index] = $this->transformSample($index);
foreach ($samples as &$sample) {
$this->transformSample($sample);
}
$samples = $this->checkDocumentFrequency($samples);
$this->checkDocumentFrequency($samples);
}
/**
@ -86,22 +79,20 @@ class TokenCountVectorizer implements Transformer
foreach ($tokens as $token) {
$this->addTokenToVocabulary($token);
}
$this->tokens[$index] = $tokens;
}
}
/**
* @param int $index
*
* @return array
* @param string $sample
*/
private function transformSample(int $index)
private function transformSample(string &$sample)
{
$counts = [];
$tokens = $this->tokens[$index];
$tokens = $this->tokenizer->tokenize($sample);
foreach ($tokens as $token) {
$index = $this->getTokenIndex($token);
if(false !== $index) {
$this->updateFrequency($token);
if (!isset($counts[$index])) {
$counts[$index] = 0;
@ -109,6 +100,7 @@ class TokenCountVectorizer implements Transformer
++$counts[$index];
}
}
foreach ($this->vocabulary as $index) {
if (!isset($counts[$index])) {
@ -116,17 +108,17 @@ class TokenCountVectorizer implements Transformer
}
}
return $counts;
$sample = $counts;
}
/**
* @param string $token
*
* @return int
* @return int|bool
*/
private function getTokenIndex(string $token): int
private function getTokenIndex(string $token)
{
return $this->vocabulary[$token];
return isset($this->vocabulary[$token]) ? $this->vocabulary[$token] : false;
}
/**
@ -156,31 +148,25 @@ class TokenCountVectorizer implements Transformer
*
* @return array
*/
private function checkDocumentFrequency(array $samples)
private function checkDocumentFrequency(array &$samples)
{
if ($this->minDF > 0) {
$beyondMinimum = $this->getBeyondMinimumIndexes(count($samples));
foreach ($samples as $index => $sample) {
$samples[$index] = $this->resetBeyondMinimum($sample, $beyondMinimum);
foreach ($samples as &$sample) {
$this->resetBeyondMinimum($sample, $beyondMinimum);
}
}
return $samples;
}
/**
* @param array $sample
* @param array $beyondMinimum
*
* @return array
*/
private function resetBeyondMinimum(array $sample, array $beyondMinimum)
private function resetBeyondMinimum(array &$sample, array $beyondMinimum)
{
foreach ($beyondMinimum as $index) {
$sample[$index] = 0;
}
return $sample;
}
/**

View File

@ -37,10 +37,12 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer());
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples);
$vectorizer->fit($samples);
$this->assertEquals($vocabulary, $vectorizer->getVocabulary());
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples);
}
public function testMinimumDocumentTokenCountFrequency()
@ -69,11 +71,14 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5);
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples);
$vectorizer->fit($samples);
$this->assertEquals($vocabulary, $vectorizer->getVocabulary());
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples);
// word at least once in all samples
$samples = [
'Lorem ipsum dolor sit amet',
@ -88,6 +93,7 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1);
$vectorizer->fit($samples);
$vectorizer->transform($samples);
$this->assertEquals($tokensCounts, $samples);