change token count vectorizer to return full token counts

This commit is contained in:
Arkadiusz Kondas 2016-06-14 09:58:11 +02:00
parent 23eff0044a
commit 2f51716388
3 changed files with 85 additions and 33 deletions

View File

@ -23,6 +23,11 @@ class TokenCountVectorizer implements Vectorizer
*/ */
private $vocabulary; private $vocabulary;
/**
* @var array
*/
private $tokens;
/** /**
* @var array * @var array
*/ */
@ -47,8 +52,10 @@ class TokenCountVectorizer implements Vectorizer
*/ */
public function transform(array $samples): array public function transform(array $samples): array
{ {
$this->buildVocabulary($samples);
foreach ($samples as $index => $sample) { foreach ($samples as $index => $sample) {
$samples[$index] = $this->transformSample($sample); $samples[$index] = $this->transformSample($index);
} }
$samples = $this->checkDocumentFrequency($samples); $samples = $this->checkDocumentFrequency($samples);
@ -65,14 +72,29 @@ class TokenCountVectorizer implements Vectorizer
} }
/** /**
* @param string $sample * @param array $samples
*/
private function buildVocabulary(array &$samples)
{
foreach ($samples as $index => $sample) {
$tokens = $this->tokenizer->tokenize($sample);
foreach ($tokens as $token) {
$this->addTokenToVocabulary($token);
}
$this->tokens[$index] = $tokens;
}
}
/**
* @param int $index
* *
* @return array * @return array
*/ */
private function transformSample(string $sample) private function transformSample(int $index)
{ {
$counts = []; $counts = [];
$tokens = $this->tokenizer->tokenize($sample); $tokens = $this->tokens[$index];
foreach ($tokens as $token) { foreach ($tokens as $token) {
$index = $this->getTokenIndex($token); $index = $this->getTokenIndex($token);
$this->updateFrequency($token); $this->updateFrequency($token);
@ -83,21 +105,33 @@ class TokenCountVectorizer implements Vectorizer
++$counts[$index]; ++$counts[$index];
} }
foreach ($this->vocabulary as $index) {
if (!isset($counts[$index])) {
$counts[$index] = 0;
}
}
return $counts; return $counts;
} }
/** /**
* @param string $token * @param string $token
* *
* @return mixed * @return int
*/ */
private function getTokenIndex(string $token) private function getTokenIndex(string $token): int
{
return $this->vocabulary[$token];
}
/**
* @param string $token
*/
private function addTokenToVocabulary(string $token)
{ {
if (!isset($this->vocabulary[$token])) { if (!isset($this->vocabulary[$token])) {
$this->vocabulary[$token] = count($this->vocabulary); $this->vocabulary[$token] = count($this->vocabulary);
} }
return $this->vocabulary[$token];
} }
/** /**
@ -122,7 +156,7 @@ class TokenCountVectorizer implements Vectorizer
if ($this->minDF > 0) { if ($this->minDF > 0) {
$beyondMinimum = $this->getBeyondMinimumIndexes(count($samples)); $beyondMinimum = $this->getBeyondMinimumIndexes(count($samples));
foreach ($samples as $index => $sample) { foreach ($samples as $index => $sample) {
$samples[$index] = $this->unsetBeyondMinimum($sample, $beyondMinimum); $samples[$index] = $this->resetBeyondMinimum($sample, $beyondMinimum);
} }
} }
@ -135,10 +169,10 @@ class TokenCountVectorizer implements Vectorizer
* *
* @return array * @return array
*/ */
private function unsetBeyondMinimum(array $sample, array $beyondMinimum) private function resetBeyondMinimum(array $sample, array $beyondMinimum)
{ {
foreach ($beyondMinimum as $index) { foreach ($beyondMinimum as $index) {
unset($sample[$index]); $sample[$index] = 0;
} }
return $sample; return $sample;

View File

@ -17,16 +17,28 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
'Mauris diam eros fringilla diam', 'Mauris diam eros fringilla diam',
]; ];
$vocabulary = ['Lorem', 'ipsum', 'dolor', 'sit', 'amet', 'Mauris', 'placerat', 'diam', 'eros', 'fringilla']; $vocabulary = [
$vector = [ 0 => 'Lorem',
[0 => 1, 1 => 1, 2 => 2, 3 => 1, 4 => 1], 1 => 'ipsum',
[5 => 1, 6 => 1, 1 => 1, 2 => 1], 2 => 'dolor',
[5 => 1, 7 => 2, 8 => 1, 9 => 1], 3 => 'sit',
4 => 'amet',
5 => 'Mauris',
6 => 'placerat',
7 => 'diam',
8 => 'eros',
9 => 'fringilla',
];
$tokensCounts = [
[0 => 1, 1 => 1, 2 => 2, 3 => 1, 4 => 1, 5 => 0, 6 => 0, 7 => 0, 8 => 0, 9 => 0],
[0 => 0, 1 => 1, 2 => 1, 3 => 0, 4 => 0, 5 => 1, 6 => 1, 7 => 0, 8 => 0, 9 => 0],
[0 => 0, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 1, 6 => 0, 7 => 2, 8 => 1, 9 => 1],
]; ];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer()); $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer());
$this->assertEquals($vector, $vectorizer->transform($samples)); $this->assertEquals($tokensCounts, $vectorizer->transform($samples));
$this->assertEquals($vocabulary, $vectorizer->getVocabulary()); $this->assertEquals($vocabulary, $vectorizer->getVocabulary());
} }
@ -40,34 +52,41 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
'ipsum sit amet', 'ipsum sit amet',
]; ];
$vocabulary = ['Lorem', 'ipsum', 'dolor', 'sit', 'amet']; $vocabulary = [
$vector = [ 0 => 'Lorem',
[0 => 1, 1 => 1, 3 => 1, 4 => 1], 1 => 'ipsum',
[0 => 1, 1 => 1, 3 => 1, 4 => 1], 2 => 'dolor',
[1 => 1, 3 => 1, 4 => 1], 3 => 'sit',
[1 => 1, 3 => 1, 4 => 1], 4 => 'amet',
];
$tokensCounts = [
[0 => 1, 1 => 1, 2 => 0, 3 => 1, 4 => 1],
[0 => 1, 1 => 1, 2 => 0, 3 => 1, 4 => 1],
[0 => 0, 1 => 1, 2 => 0, 3 => 1, 4 => 1],
[0 => 0, 1 => 1, 2 => 0, 3 => 1, 4 => 1],
]; ];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5); $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5);
$this->assertEquals($vector, $vectorizer->transform($samples)); $this->assertEquals($tokensCounts, $vectorizer->transform($samples));
$this->assertEquals($vocabulary, $vectorizer->getVocabulary()); $this->assertEquals($vocabulary, $vectorizer->getVocabulary());
// word at least in all samples // word at least once in all samples
$samples = [ $samples = [
'Lorem ipsum dolor sit amet', 'Lorem ipsum dolor sit amet',
'Morbi quis lacinia arcu. Sed eu sagittis Lorem', 'Morbi quis sagittis Lorem',
'Suspendisse gravida consequat eros Lorem', 'eros Lorem',
]; ];
$vector = [ $tokensCounts = [
[0 => 1], [0 => 1, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 0, 6 => 0, 7 => 0, 8 => 0],
[0 => 1], [0 => 1, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 0, 6 => 0, 7 => 0, 8 => 0],
[0 => 1], [0 => 1, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 0, 6 => 0, 7 => 0, 8 => 0],
]; ];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1); $vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1);
$this->assertEquals($vector, $vectorizer->transform($samples)); $this->assertEquals($tokensCounts, $vectorizer->transform($samples));
} }
} }

View File

@ -52,5 +52,4 @@ class AccuracyTest extends \PHPUnit_Framework_TestCase
$this->assertEquals(0.959, $accuracy, '', 0.01); $this->assertEquals(0.959, $accuracy, '', 0.01);
} }
} }