From d30c212f3bedb802d4d1acaf9eecf2fe18aaf097 Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Wed, 7 Nov 2018 09:39:51 +0100 Subject: [PATCH] Check if feature exist when predict target in NaiveBayes (#327) * Check if feature exist when predict target in NaiveBayes * Fix typo --- src/Classification/NaiveBayes.php | 5 +++++ tests/Classification/NaiveBayesTest.php | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/Classification/NaiveBayes.php b/src/Classification/NaiveBayes.php index f14ada7..45075a4 100644 --- a/src/Classification/NaiveBayes.php +++ b/src/Classification/NaiveBayes.php @@ -4,6 +4,7 @@ declare(strict_types=1); namespace Phpml\Classification; +use Phpml\Exception\InvalidArgumentException; use Phpml\Helper\Predictable; use Phpml\Helper\Trainable; use Phpml\Math\Statistic\Mean; @@ -137,6 +138,10 @@ class NaiveBayes implements Classifier */ private function sampleProbability(array $sample, int $feature, string $label): float { + if (!isset($sample[$feature])) { + throw new InvalidArgumentException('Missing feature. All samples must have equal number of features'); + } + $value = $sample[$feature]; if ($this->dataType[$label][$feature] == self::NOMINAL) { if (!isset($this->discreteProb[$label][$feature][$value]) || diff --git a/tests/Classification/NaiveBayesTest.php b/tests/Classification/NaiveBayesTest.php index 4e27261..076a70d 100644 --- a/tests/Classification/NaiveBayesTest.php +++ b/tests/Classification/NaiveBayesTest.php @@ -5,6 +5,7 @@ declare(strict_types=1); namespace Phpml\Tests\Classification; use Phpml\Classification\NaiveBayes; +use Phpml\Exception\InvalidArgumentException; use Phpml\ModelManager; use PHPUnit\Framework\TestCase; @@ -125,4 +126,19 @@ class NaiveBayesTest extends TestCase self::assertEquals($classifier, $restoredClassifier); self::assertEquals($predicted, $restoredClassifier->predict($testSamples)); } + + public function testInconsistentFeaturesInSamples(): void + { + $trainSamples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]]; + $trainLabels = ['1996', '1997', '1998']; + + $testSamples = [[3, 1, 1], [5, 1], [4, 3, 8]]; + + $classifier = new NaiveBayes(); + $classifier->train($trainSamples, $trainLabels); + + $this->expectException(InvalidArgumentException::class); + + $classifier->predict($testSamples); + } }