Check if feature exist when predict target in NaiveBayes (#327)

* Check if feature exist when predict target in NaiveBayes

* Fix typo
This commit is contained in:
Arkadiusz Kondas 2018-11-07 09:39:51 +01:00 committed by GitHub
parent 18c36b971f
commit d30c212f3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 0 deletions

View File

@ -4,6 +4,7 @@ declare(strict_types=1);
namespace Phpml\Classification; namespace Phpml\Classification;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Helper\Predictable; use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable; use Phpml\Helper\Trainable;
use Phpml\Math\Statistic\Mean; use Phpml\Math\Statistic\Mean;
@ -137,6 +138,10 @@ class NaiveBayes implements Classifier
*/ */
private function sampleProbability(array $sample, int $feature, string $label): float private function sampleProbability(array $sample, int $feature, string $label): float
{ {
if (!isset($sample[$feature])) {
throw new InvalidArgumentException('Missing feature. All samples must have equal number of features');
}
$value = $sample[$feature]; $value = $sample[$feature];
if ($this->dataType[$label][$feature] == self::NOMINAL) { if ($this->dataType[$label][$feature] == self::NOMINAL) {
if (!isset($this->discreteProb[$label][$feature][$value]) || if (!isset($this->discreteProb[$label][$feature][$value]) ||

View File

@ -5,6 +5,7 @@ declare(strict_types=1);
namespace Phpml\Tests\Classification; namespace Phpml\Tests\Classification;
use Phpml\Classification\NaiveBayes; use Phpml\Classification\NaiveBayes;
use Phpml\Exception\InvalidArgumentException;
use Phpml\ModelManager; use Phpml\ModelManager;
use PHPUnit\Framework\TestCase; use PHPUnit\Framework\TestCase;
@ -125,4 +126,19 @@ class NaiveBayesTest extends TestCase
self::assertEquals($classifier, $restoredClassifier); self::assertEquals($classifier, $restoredClassifier);
self::assertEquals($predicted, $restoredClassifier->predict($testSamples)); self::assertEquals($predicted, $restoredClassifier->predict($testSamples));
} }
public function testInconsistentFeaturesInSamples(): void
{
$trainSamples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]];
$trainLabels = ['1996', '1997', '1998'];
$testSamples = [[3, 1, 1], [5, 1], [4, 3, 8]];
$classifier = new NaiveBayes();
$classifier->train($trainSamples, $trainLabels);
$this->expectException(InvalidArgumentException::class);
$classifier->predict($testSamples);
}
} }