diff --git a/src/Phpml/Classifier/Classifier.php b/src/Phpml/Classifier/Classifier.php index 6fad67a..90250a9 100644 --- a/src/Phpml/Classifier/Classifier.php +++ b/src/Phpml/Classifier/Classifier.php @@ -13,9 +13,9 @@ interface Classifier public function train(array $samples, array $labels); /** - * @param array $sample + * @param array $samples * * @return mixed */ - public function predict(array $sample); + public function predict(array $samples); } diff --git a/src/Phpml/Classifier/KNearestNeighbors.php b/src/Phpml/Classifier/KNearestNeighbors.php index 5a998d3..53c66b0 100644 --- a/src/Phpml/Classifier/KNearestNeighbors.php +++ b/src/Phpml/Classifier/KNearestNeighbors.php @@ -43,12 +43,31 @@ class KNearestNeighbors implements Classifier $this->labels = $labels; } + /** + * @param array $samples + * + * @return mixed + */ + public function predict(array $samples) + { + if(!is_array($samples[0])) { + $predicted = $this->predictSample($samples); + } else { + $predicted = []; + foreach ($samples as $index => $sample) { + $predicted[$index] = $this->predictSample($sample); + } + } + + return $predicted; + } + /** * @param array $sample * * @return mixed */ - public function predict(array $sample) + private function predictSample(array $sample) { $distances = $this->kNeighborsDistances($sample); diff --git a/src/Phpml/Classifier/NaiveBayes.php b/src/Phpml/Classifier/NaiveBayes.php index c1cc902..7324d79 100644 --- a/src/Phpml/Classifier/NaiveBayes.php +++ b/src/Phpml/Classifier/NaiveBayes.php @@ -15,11 +15,11 @@ class NaiveBayes implements Classifier } /** - * @param array $sample + * @param array $samples * * @return mixed */ - public function predict(array $sample) + public function predict(array $samples) { } } diff --git a/tests/Phpml/Classifier/KNearestNeighborsTest.php b/tests/Phpml/Classifier/KNearestNeighborsTest.php index 2786ade..06ae42f 100644 --- a/tests/Phpml/Classifier/KNearestNeighborsTest.php +++ b/tests/Phpml/Classifier/KNearestNeighborsTest.php @@ -26,4 +26,19 @@ class KNearestNeighborsTest extends \PHPUnit_Framework_TestCase $this->assertEquals('a', $classifier->predict([1, 5])); $this->assertEquals('a', $classifier->predict([3, 10])); } + + public function testPredictArrayOfSamples() + { + $trainSamples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; + $trainLabels = ['a', 'a', 'a', 'b', 'b', 'b']; + + $testSamples = [[3, 2], [5, 1], [4, 3], [4, -5], [2, 3], [1, 2], [1, 5], [3, 10]]; + $testLabels = ['b', 'b', 'b', 'b', 'a', 'a', 'a', 'a',]; + + $classifier = new KNearestNeighbors(); + $classifier->train($trainSamples, $trainLabels); + $predicted = $classifier->predict($testSamples); + + $this->assertEquals($testLabels, $predicted); + } }