diff --git a/src/Phpml/Classifier/NaiveBayes.php b/src/Phpml/Classifier/NaiveBayes.php index 7324d79..05cb120 100644 --- a/src/Phpml/Classifier/NaiveBayes.php +++ b/src/Phpml/Classifier/NaiveBayes.php @@ -6,12 +6,24 @@ namespace Phpml\Classifier; class NaiveBayes implements Classifier { + /** + * @var array + */ + private $samples; + + /** + * @var array + */ + private $labels; + /** * @param array $samples * @param array $labels */ public function train(array $samples, array $labels) { + $this->samples = $samples; + $this->labels = $labels; } /** @@ -21,5 +33,38 @@ class NaiveBayes implements Classifier */ public function predict(array $samples) { + if (!is_array($samples[0])) { + $predicted = $this->predictSample($samples); + } else { + $predicted = []; + foreach ($samples as $index => $sample) { + $predicted[$index] = $this->predictSample($sample); + } + } + + return $predicted; + } + + /** + * @param array $sample + * + * @return mixed + */ + private function predictSample(array $sample) + { + $predictions = []; + foreach ($this->labels as $index => $label) { + $predictions[$label] = 0; + foreach ($sample as $token => $count) { + if (array_key_exists($token, $this->samples[$index])) { + $predictions[$label] += $count * $this->samples[$index][$token]; + } + } + } + + arsort($predictions, SORT_NUMERIC); + reset($predictions); + + return key($predictions); } } diff --git a/tests/Phpml/Classifier/NaiveBayesTest.php b/tests/Phpml/Classifier/NaiveBayesTest.php new file mode 100644 index 0000000..ce52bbc --- /dev/null +++ b/tests/Phpml/Classifier/NaiveBayesTest.php @@ -0,0 +1,38 @@ +train($samples, $labels); + + $this->assertEquals('a', $classifier->predict([3, 1, 1])); + $this->assertEquals('b', $classifier->predict([1, 4, 1])); + $this->assertEquals('c', $classifier->predict([1, 1, 6])); + } + + public function testPredictArrayOfSamples() + { + $trainSamples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]]; + $trainLabels = ['a', 'b', 'c']; + + $testSamples = [[3, 1, 1], [5, 1, 1], [4, 3, 8], [1, 1, 2], [2, 3, 2], [1, 2, 1], [9, 5, 1], [3, 1, 2]]; + $testLabels = ['a', 'a', 'c', 'c', 'b', 'b', 'a', 'a']; + + $classifier = new NaiveBayes(); + $classifier->train($trainSamples, $trainLabels); + $predicted = $classifier->predict($testSamples); + + $this->assertEquals($testLabels, $predicted); + } +}