diff --git a/src/Phpml/Classification/NaiveBayes.php b/src/Phpml/Classification/NaiveBayes.php index 38c857d..8f09257 100644 --- a/src/Phpml/Classification/NaiveBayes.php +++ b/src/Phpml/Classification/NaiveBayes.php @@ -66,8 +66,7 @@ class NaiveBayes implements Classifier $this->sampleCount = count($this->samples); $this->featureCount = count($this->samples[0]); - $labelCounts = array_count_values($this->targets); - $this->labels = array_keys($labelCounts); + $this->labels = array_map('strval', array_flip(array_flip($this->targets))); foreach ($this->labels as $label) { $samples = $this->getSamplesByLabel($label); $this->p[$label] = count($samples) / $this->sampleCount; diff --git a/tests/Phpml/Classification/NaiveBayesTest.php b/tests/Phpml/Classification/NaiveBayesTest.php index 8312e9c..7db8645 100644 --- a/tests/Phpml/Classification/NaiveBayesTest.php +++ b/tests/Phpml/Classification/NaiveBayesTest.php @@ -68,4 +68,63 @@ class NaiveBayesTest extends TestCase $this->assertEquals($classifier, $restoredClassifier); $this->assertEquals($predicted, $restoredClassifier->predict($testSamples)); } + + public function testPredictSimpleNumericLabels(): void + { + $samples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]]; + $labels = ['1996', '1997', '1998']; + + $classifier = new NaiveBayes(); + $classifier->train($samples, $labels); + + $this->assertEquals('1996', $classifier->predict([3, 1, 1])); + $this->assertEquals('1997', $classifier->predict([1, 4, 1])); + $this->assertEquals('1998', $classifier->predict([1, 1, 6])); + } + + public function testPredictArrayOfSamplesNumericalLabels(): void + { + $trainSamples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]]; + $trainLabels = ['1996', '1997', '1998']; + + $testSamples = [[3, 1, 1], [5, 1, 1], [4, 3, 8], [1, 1, 2], [2, 3, 2], [1, 2, 1], [9, 5, 1], [3, 1, 2]]; + $testLabels = ['1996', '1996', '1998', '1998', '1997', '1997', '1996', '1996']; + + $classifier = new NaiveBayes(); + $classifier->train($trainSamples, $trainLabels); + $predicted = $classifier->predict($testSamples); + + $this->assertEquals($testLabels, $predicted); + + // Feed an extra set of training data. + $samples = [[1, 1, 6]]; + $labels = ['1999']; + $classifier->train($samples, $labels); + + $testSamples = [[1, 1, 6], [5, 1, 1]]; + $testLabels = ['1999', '1996']; + $this->assertEquals($testLabels, $classifier->predict($testSamples)); + } + + public function testSaveAndRestoreNumericLabels(): void + { + $trainSamples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]]; + $trainLabels = ['1996', '1997', '1998']; + + $testSamples = [[3, 1, 1], [5, 1, 1], [4, 3, 8]]; + $testLabels = ['1996', '1996', '1998']; + + $classifier = new NaiveBayes(); + $classifier->train($trainSamples, $trainLabels); + $predicted = $classifier->predict($testSamples); + + $filename = 'naive-bayes-test-'.random_int(100, 999).'-'.uniqid(); + $filepath = tempnam(sys_get_temp_dir(), $filename); + $modelManager = new ModelManager(); + $modelManager->saveToFile($classifier, $filepath); + + $restoredClassifier = $modelManager->restoreFromFile($filepath); + $this->assertEquals($classifier, $restoredClassifier); + $this->assertEquals($predicted, $restoredClassifier->predict($testSamples)); + } }