2016-02-09 06:45:07 +00:00
|
|
|
<?php
|
2016-04-04 20:49:54 +00:00
|
|
|
|
2016-11-20 21:53:17 +00:00
|
|
|
declare(strict_types=1);
|
2016-02-09 06:45:07 +00:00
|
|
|
|
2016-04-30 21:45:21 +00:00
|
|
|
namespace Phpml\Classification;
|
2016-02-09 06:45:07 +00:00
|
|
|
|
2016-05-07 21:04:58 +00:00
|
|
|
use Phpml\Helper\Predictable;
|
|
|
|
use Phpml\Helper\Trainable;
|
2017-01-17 15:21:58 +00:00
|
|
|
use Phpml\Math\Statistic\Mean;
|
|
|
|
use Phpml\Math\Statistic\StandardDeviation;
|
2016-04-16 19:24:40 +00:00
|
|
|
|
2016-04-04 20:25:27 +00:00
|
|
|
class NaiveBayes implements Classifier
|
2016-02-09 06:45:07 +00:00
|
|
|
{
|
2016-04-16 19:24:40 +00:00
|
|
|
use Trainable, Predictable;
|
2017-01-17 15:26:43 +00:00
|
|
|
|
2017-01-17 15:21:58 +00:00
|
|
|
const CONTINUOS = 1;
|
|
|
|
const NOMINAL = 2;
|
|
|
|
const EPSILON = 1e-10;
|
2017-01-17 15:26:43 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $std = [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $mean= [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $discreteProb = [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $dataType = [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $p = [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var int
|
|
|
|
*/
|
2017-01-17 15:21:58 +00:00
|
|
|
private $sampleCount = 0;
|
2017-01-17 15:26:43 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var int
|
|
|
|
*/
|
2017-01-17 15:21:58 +00:00
|
|
|
private $featureCount = 0;
|
2017-01-17 15:26:43 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $labels = [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @param array $samples
|
|
|
|
* @param array $targets
|
|
|
|
*/
|
2017-01-17 15:21:58 +00:00
|
|
|
public function train(array $samples, array $targets)
|
|
|
|
{
|
|
|
|
$this->samples = $samples;
|
|
|
|
$this->targets = $targets;
|
|
|
|
$this->sampleCount = count($samples);
|
|
|
|
$this->featureCount = count($samples[0]);
|
2017-01-17 15:26:43 +00:00
|
|
|
|
2017-01-17 15:21:58 +00:00
|
|
|
$this->labels = $targets;
|
|
|
|
array_unique($this->labels);
|
|
|
|
foreach ($this->labels as $label) {
|
|
|
|
$samples = $this->getSamplesByLabel($label);
|
|
|
|
$this->p[$label] = count($samples) / $this->sampleCount;
|
|
|
|
$this->calculateStatistics($label, $samples);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Calculates vital statistics for each label & feature. Stores these
|
|
|
|
* values in private array in order to avoid repeated calculation
|
|
|
|
* @param string $label
|
|
|
|
* @param array $samples
|
|
|
|
*/
|
|
|
|
private function calculateStatistics($label, $samples)
|
|
|
|
{
|
|
|
|
$this->std[$label] = array_fill(0, $this->featureCount, 0);
|
|
|
|
$this->mean[$label]= array_fill(0, $this->featureCount, 0);
|
|
|
|
$this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
|
|
|
|
$this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
|
|
|
|
for ($i=0; $i<$this->featureCount; $i++) {
|
|
|
|
// Get the values of nth column in the samples array
|
|
|
|
// Mean::arithmetic is called twice, can be optimized
|
|
|
|
$values = array_column($samples, $i);
|
|
|
|
$numValues = count($values);
|
|
|
|
// if the values contain non-numeric data,
|
|
|
|
// then it should be treated as nominal/categorical/discrete column
|
|
|
|
if ($values !== array_filter($values, 'is_numeric')) {
|
|
|
|
$this->dataType[$label][$i] = self::NOMINAL;
|
|
|
|
$this->discreteProb[$label][$i] = array_count_values($values);
|
|
|
|
$db = &$this->discreteProb[$label][$i];
|
|
|
|
$db = array_map(function ($el) use ($numValues) {
|
|
|
|
return $el / $numValues;
|
|
|
|
}, $db);
|
|
|
|
} else {
|
|
|
|
$this->mean[$label][$i] = Mean::arithmetic($values);
|
2017-01-17 15:26:43 +00:00
|
|
|
// Add epsilon in order to avoid zero stdev
|
2017-01-17 15:21:58 +00:00
|
|
|
$this->std[$label][$i] = 1e-10 + StandardDeviation::population($values, false);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Calculates the probability P(label|sample_n)
|
2017-01-17 15:26:43 +00:00
|
|
|
*
|
2017-01-17 15:21:58 +00:00
|
|
|
* @param array $sample
|
|
|
|
* @param int $feature
|
|
|
|
* @param string $label
|
2017-01-17 15:26:43 +00:00
|
|
|
* @return float
|
2017-01-17 15:21:58 +00:00
|
|
|
*/
|
|
|
|
private function sampleProbability($sample, $feature, $label)
|
|
|
|
{
|
|
|
|
$value = $sample[$feature];
|
|
|
|
if ($this->dataType[$label][$feature] == self::NOMINAL) {
|
|
|
|
if (! isset($this->discreteProb[$label][$feature][$value]) ||
|
|
|
|
$this->discreteProb[$label][$feature][$value] == 0) {
|
|
|
|
return self::EPSILON;
|
|
|
|
}
|
|
|
|
return $this->discreteProb[$label][$feature][$value];
|
|
|
|
}
|
|
|
|
$std = $this->std[$label][$feature] ;
|
|
|
|
$mean= $this->mean[$label][$feature];
|
|
|
|
// Calculate the probability density by use of normal/Gaussian distribution
|
|
|
|
// Ref: https://en.wikipedia.org/wiki/Normal_distribution
|
2017-01-17 15:26:43 +00:00
|
|
|
//
|
|
|
|
// In order to avoid numerical errors because of small or zero values,
|
|
|
|
// some libraries adopt taking log of calculations such as
|
|
|
|
// scikit-learn did.
|
|
|
|
// (See : https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/naive_bayes.py)
|
|
|
|
$pdf = -0.5 * log(2.0 * pi() * $std * $std);
|
|
|
|
$pdf -= 0.5 * pow($value - $mean, 2) / ($std * $std);
|
|
|
|
return $pdf;
|
2017-01-17 15:21:58 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Return samples belonging to specific label
|
|
|
|
* @param string $label
|
|
|
|
* @return array
|
|
|
|
*/
|
|
|
|
private function getSamplesByLabel($label)
|
|
|
|
{
|
|
|
|
$samples = array();
|
|
|
|
for ($i=0; $i<$this->sampleCount; $i++) {
|
|
|
|
if ($this->targets[$i] == $label) {
|
|
|
|
$samples[] = $this->samples[$i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return $samples;
|
|
|
|
}
|
2016-04-14 20:56:54 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @param array $sample
|
|
|
|
* @return mixed
|
|
|
|
*/
|
2016-04-19 20:54:15 +00:00
|
|
|
protected function predictSample(array $sample)
|
2016-04-14 20:56:54 +00:00
|
|
|
{
|
2017-01-17 15:21:58 +00:00
|
|
|
$isArray = is_array($sample[0]);
|
|
|
|
$samples = $sample;
|
|
|
|
if (!$isArray) {
|
|
|
|
$samples = array($sample);
|
|
|
|
}
|
|
|
|
$samplePredictions = array();
|
|
|
|
foreach ($samples as $sample) {
|
|
|
|
// Use NaiveBayes assumption for each label using:
|
|
|
|
// P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
|
|
|
|
// Then compare probability for each class to determine which label is most likely
|
|
|
|
$predictions = array();
|
|
|
|
foreach ($this->labels as $label) {
|
|
|
|
$p = $this->p[$label];
|
|
|
|
for ($i=0; $i<$this->featureCount; $i++) {
|
|
|
|
$Plf = $this->sampleProbability($sample, $i, $label);
|
|
|
|
$p += $Plf;
|
2016-04-14 20:56:54 +00:00
|
|
|
}
|
2017-01-17 15:21:58 +00:00
|
|
|
$predictions[$label] = $p;
|
2016-04-14 20:56:54 +00:00
|
|
|
}
|
2017-01-17 15:21:58 +00:00
|
|
|
arsort($predictions, SORT_NUMERIC);
|
|
|
|
reset($predictions);
|
|
|
|
$samplePredictions[] = key($predictions);
|
2016-04-14 20:56:54 +00:00
|
|
|
}
|
2017-01-17 15:21:58 +00:00
|
|
|
if (! $isArray) {
|
|
|
|
return $samplePredictions[0];
|
|
|
|
}
|
|
|
|
return $samplePredictions;
|
2016-04-04 20:25:27 +00:00
|
|
|
}
|
2016-02-09 06:45:07 +00:00
|
|
|
}
|