mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-22 04:55:10 +00:00
Check if feature exist when predict target in NaiveBayes (#327)
* Check if feature exist when predict target in NaiveBayes * Fix typo
This commit is contained in:
parent
18c36b971f
commit
d30c212f3b
@ -4,6 +4,7 @@ declare(strict_types=1);
|
|||||||
|
|
||||||
namespace Phpml\Classification;
|
namespace Phpml\Classification;
|
||||||
|
|
||||||
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
use Phpml\Helper\Predictable;
|
use Phpml\Helper\Predictable;
|
||||||
use Phpml\Helper\Trainable;
|
use Phpml\Helper\Trainable;
|
||||||
use Phpml\Math\Statistic\Mean;
|
use Phpml\Math\Statistic\Mean;
|
||||||
@ -137,6 +138,10 @@ class NaiveBayes implements Classifier
|
|||||||
*/
|
*/
|
||||||
private function sampleProbability(array $sample, int $feature, string $label): float
|
private function sampleProbability(array $sample, int $feature, string $label): float
|
||||||
{
|
{
|
||||||
|
if (!isset($sample[$feature])) {
|
||||||
|
throw new InvalidArgumentException('Missing feature. All samples must have equal number of features');
|
||||||
|
}
|
||||||
|
|
||||||
$value = $sample[$feature];
|
$value = $sample[$feature];
|
||||||
if ($this->dataType[$label][$feature] == self::NOMINAL) {
|
if ($this->dataType[$label][$feature] == self::NOMINAL) {
|
||||||
if (!isset($this->discreteProb[$label][$feature][$value]) ||
|
if (!isset($this->discreteProb[$label][$feature][$value]) ||
|
||||||
|
@ -5,6 +5,7 @@ declare(strict_types=1);
|
|||||||
namespace Phpml\Tests\Classification;
|
namespace Phpml\Tests\Classification;
|
||||||
|
|
||||||
use Phpml\Classification\NaiveBayes;
|
use Phpml\Classification\NaiveBayes;
|
||||||
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
use Phpml\ModelManager;
|
use Phpml\ModelManager;
|
||||||
use PHPUnit\Framework\TestCase;
|
use PHPUnit\Framework\TestCase;
|
||||||
|
|
||||||
@ -125,4 +126,19 @@ class NaiveBayesTest extends TestCase
|
|||||||
self::assertEquals($classifier, $restoredClassifier);
|
self::assertEquals($classifier, $restoredClassifier);
|
||||||
self::assertEquals($predicted, $restoredClassifier->predict($testSamples));
|
self::assertEquals($predicted, $restoredClassifier->predict($testSamples));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function testInconsistentFeaturesInSamples(): void
|
||||||
|
{
|
||||||
|
$trainSamples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]];
|
||||||
|
$trainLabels = ['1996', '1997', '1998'];
|
||||||
|
|
||||||
|
$testSamples = [[3, 1, 1], [5, 1], [4, 3, 8]];
|
||||||
|
|
||||||
|
$classifier = new NaiveBayes();
|
||||||
|
$classifier->train($trainSamples, $trainLabels);
|
||||||
|
|
||||||
|
$this->expectException(InvalidArgumentException::class);
|
||||||
|
|
||||||
|
$classifier->predict($testSamples);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user