classifier predict array of samples or one sample

This commit is contained in:
Arkadiusz Kondas 2016-04-08 22:25:15 +02:00
parent f1c81638d6
commit e7d2780150
4 changed files with 39 additions and 5 deletions

View File

@ -13,9 +13,9 @@ interface Classifier
public function train(array $samples, array $labels);
/**
* @param array $sample
* @param array $samples
*
* @return mixed
*/
public function predict(array $sample);
public function predict(array $samples);
}

View File

@ -43,12 +43,31 @@ class KNearestNeighbors implements Classifier
$this->labels = $labels;
}
/**
* @param array $samples
*
* @return mixed
*/
public function predict(array $samples)
{
if(!is_array($samples[0])) {
$predicted = $this->predictSample($samples);
} else {
$predicted = [];
foreach ($samples as $index => $sample) {
$predicted[$index] = $this->predictSample($sample);
}
}
return $predicted;
}
/**
* @param array $sample
*
* @return mixed
*/
public function predict(array $sample)
private function predictSample(array $sample)
{
$distances = $this->kNeighborsDistances($sample);

View File

@ -15,11 +15,11 @@ class NaiveBayes implements Classifier
}
/**
* @param array $sample
* @param array $samples
*
* @return mixed
*/
public function predict(array $sample)
public function predict(array $samples)
{
}
}

View File

@ -26,4 +26,19 @@ class KNearestNeighborsTest extends \PHPUnit_Framework_TestCase
$this->assertEquals('a', $classifier->predict([1, 5]));
$this->assertEquals('a', $classifier->predict([3, 10]));
}
public function testPredictArrayOfSamples()
{
$trainSamples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$trainLabels = ['a', 'a', 'a', 'b', 'b', 'b'];
$testSamples = [[3, 2], [5, 1], [4, 3], [4, -5], [2, 3], [1, 2], [1, 5], [3, 10]];
$testLabels = ['b', 'b', 'b', 'b', 'a', 'a', 'a', 'a',];
$classifier = new KNearestNeighbors();
$classifier->train($trainSamples, $trainLabels);
$predicted = $classifier->predict($testSamples);
$this->assertEquals($testLabels, $predicted);
}
}