change token count vectorizer to return full token counts

This commit is contained in:
Arkadiusz Kondas 2016-06-14 09:58:11 +02:00
parent 23eff0044a
commit 2f51716388
3 changed files with 85 additions and 33 deletions

View File

@ -23,6 +23,11 @@ class TokenCountVectorizer implements Vectorizer
*/
private $vocabulary;
/**
* @var array
*/
private $tokens;
/**
* @var array
*/
@ -47,8 +52,10 @@ class TokenCountVectorizer implements Vectorizer
*/
public function transform(array $samples): array
{
$this->buildVocabulary($samples);
foreach ($samples as $index => $sample) {
$samples[$index] = $this->transformSample($sample);
$samples[$index] = $this->transformSample($index);
}
$samples = $this->checkDocumentFrequency($samples);
@ -65,14 +72,29 @@ class TokenCountVectorizer implements Vectorizer
}
/**
* @param string $sample
* @param array $samples
*/
private function buildVocabulary(array &$samples)
{
foreach ($samples as $index => $sample) {
$tokens = $this->tokenizer->tokenize($sample);
foreach ($tokens as $token) {
$this->addTokenToVocabulary($token);
}
$this->tokens[$index] = $tokens;
}
}
/**
* @param int $index
*
* @return array
*/
private function transformSample(string $sample)
private function transformSample(int $index)
{
$counts = [];
$tokens = $this->tokenizer->tokenize($sample);
$tokens = $this->tokens[$index];
foreach ($tokens as $token) {
$index = $this->getTokenIndex($token);
$this->updateFrequency($token);
@ -83,21 +105,33 @@ class TokenCountVectorizer implements Vectorizer
++$counts[$index];
}
foreach ($this->vocabulary as $index) {
if (!isset($counts[$index])) {
$counts[$index] = 0;
}
}
return $counts;
}
/**
* @param string $token
*
* @return mixed
* @return int
*/
private function getTokenIndex(string $token)
private function getTokenIndex(string $token): int
{
return $this->vocabulary[$token];
}
/**
* @param string $token
*/
private function addTokenToVocabulary(string $token)
{
if (!isset($this->vocabulary[$token])) {
$this->vocabulary[$token] = count($this->vocabulary);
}
return $this->vocabulary[$token];
}
/**
@ -122,7 +156,7 @@ class TokenCountVectorizer implements Vectorizer
if ($this->minDF > 0) {
$beyondMinimum = $this->getBeyondMinimumIndexes(count($samples));
foreach ($samples as $index => $sample) {
$samples[$index] = $this->unsetBeyondMinimum($sample, $beyondMinimum);
$samples[$index] = $this->resetBeyondMinimum($sample, $beyondMinimum);
}
}
@ -135,10 +169,10 @@ class TokenCountVectorizer implements Vectorizer
*
* @return array
*/
private function unsetBeyondMinimum(array $sample, array $beyondMinimum)
private function resetBeyondMinimum(array $sample, array $beyondMinimum)
{
foreach ($beyondMinimum as $index) {
unset($sample[$index]);
$sample[$index] = 0;
}
return $sample;

View File

@ -17,16 +17,28 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
'Mauris diam eros fringilla diam',
];
$vocabulary = ['Lorem', 'ipsum', 'dolor', 'sit', 'amet', 'Mauris', 'placerat', 'diam', 'eros', 'fringilla'];
$vector = [
[0 => 1, 1 => 1, 2 => 2, 3 => 1, 4 => 1],
[5 => 1, 6 => 1, 1 => 1, 2 => 1],
[5 => 1, 7 => 2, 8 => 1, 9 => 1],
$vocabulary = [
0 => 'Lorem',
1 => 'ipsum',
2 => 'dolor',
3 => 'sit',
4 => 'amet',
5 => 'Mauris',
6 => 'placerat',
7 => 'diam',
8 => 'eros',
9 => 'fringilla',
];
$tokensCounts = [
[0 => 1, 1 => 1, 2 => 2, 3 => 1, 4 => 1, 5 => 0, 6 => 0, 7 => 0, 8 => 0, 9 => 0],
[0 => 0, 1 => 1, 2 => 1, 3 => 0, 4 => 0, 5 => 1, 6 => 1, 7 => 0, 8 => 0, 9 => 0],
[0 => 0, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 1, 6 => 0, 7 => 2, 8 => 1, 9 => 1],
];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer());
$this->assertEquals($vector, $vectorizer->transform($samples));
$this->assertEquals($tokensCounts, $vectorizer->transform($samples));
$this->assertEquals($vocabulary, $vectorizer->getVocabulary());
}
@ -40,34 +52,41 @@ class TokenCountVectorizerTest extends \PHPUnit_Framework_TestCase
'ipsum sit amet',
];
$vocabulary = ['Lorem', 'ipsum', 'dolor', 'sit', 'amet'];
$vector = [
[0 => 1, 1 => 1, 3 => 1, 4 => 1],
[0 => 1, 1 => 1, 3 => 1, 4 => 1],
[1 => 1, 3 => 1, 4 => 1],
[1 => 1, 3 => 1, 4 => 1],
$vocabulary = [
0 => 'Lorem',
1 => 'ipsum',
2 => 'dolor',
3 => 'sit',
4 => 'amet',
];
$tokensCounts = [
[0 => 1, 1 => 1, 2 => 0, 3 => 1, 4 => 1],
[0 => 1, 1 => 1, 2 => 0, 3 => 1, 4 => 1],
[0 => 0, 1 => 1, 2 => 0, 3 => 1, 4 => 1],
[0 => 0, 1 => 1, 2 => 0, 3 => 1, 4 => 1],
];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 0.5);
$this->assertEquals($vector, $vectorizer->transform($samples));
$this->assertEquals($tokensCounts, $vectorizer->transform($samples));
$this->assertEquals($vocabulary, $vectorizer->getVocabulary());
// word at least in all samples
// word at least once in all samples
$samples = [
'Lorem ipsum dolor sit amet',
'Morbi quis lacinia arcu. Sed eu sagittis Lorem',
'Suspendisse gravida consequat eros Lorem',
'Morbi quis sagittis Lorem',
'eros Lorem',
];
$vector = [
[0 => 1],
[0 => 1],
[0 => 1],
$tokensCounts = [
[0 => 1, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 0, 6 => 0, 7 => 0, 8 => 0],
[0 => 1, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 0, 6 => 0, 7 => 0, 8 => 0],
[0 => 1, 1 => 0, 2 => 0, 3 => 0, 4 => 0, 5 => 0, 6 => 0, 7 => 0, 8 => 0],
];
$vectorizer = new TokenCountVectorizer(new WhitespaceTokenizer(), 1);
$this->assertEquals($vector, $vectorizer->transform($samples));
$this->assertEquals($tokensCounts, $vectorizer->transform($samples));
}
}

View File

@ -52,5 +52,4 @@ class AccuracyTest extends \PHPUnit_Framework_TestCase
$this->assertEquals(0.959, $accuracy, '', 0.01);
}
}