Apply cs fixes for NaiveBayes

This commit is contained in:
Arkadiusz Kondas 2017-01-17 16:26:43 +01:00
parent e603d60841
commit d19ddb8507

View File

@ -12,24 +12,62 @@ use Phpml\Math\Statistic\StandardDeviation;
class NaiveBayes implements Classifier class NaiveBayes implements Classifier
{ {
use Trainable, Predictable; use Trainable, Predictable;
const CONTINUOS = 1; const CONTINUOS = 1;
const NOMINAL = 2; const NOMINAL = 2;
const EPSILON = 1e-10; const EPSILON = 1e-10;
private $std = array();
private $mean= array(); /**
private $discreteProb = array(); * @var array
private $dataType = array(); */
private $p = array(); private $std = [];
/**
* @var array
*/
private $mean= [];
/**
* @var array
*/
private $discreteProb = [];
/**
* @var array
*/
private $dataType = [];
/**
* @var array
*/
private $p = [];
/**
* @var int
*/
private $sampleCount = 0; private $sampleCount = 0;
/**
* @var int
*/
private $featureCount = 0; private $featureCount = 0;
private $labels = array();
/**
* @var array
*/
private $labels = [];
/**
* @param array $samples
* @param array $targets
*/
public function train(array $samples, array $targets) public function train(array $samples, array $targets)
{ {
$this->samples = $samples; $this->samples = $samples;
$this->targets = $targets; $this->targets = $targets;
$this->sampleCount = count($samples); $this->sampleCount = count($samples);
$this->featureCount = count($samples[0]); $this->featureCount = count($samples[0]);
// Get distinct targets
$this->labels = $targets; $this->labels = $targets;
array_unique($this->labels); array_unique($this->labels);
foreach ($this->labels as $label) { foreach ($this->labels as $label) {
@ -67,7 +105,7 @@ class NaiveBayes implements Classifier
}, $db); }, $db);
} else { } else {
$this->mean[$label][$i] = Mean::arithmetic($values); $this->mean[$label][$i] = Mean::arithmetic($values);
// Add epsilon in order to avoid zero stdev // Add epsilon in order to avoid zero stdev
$this->std[$label][$i] = 1e-10 + StandardDeviation::population($values, false); $this->std[$label][$i] = 1e-10 + StandardDeviation::population($values, false);
} }
} }
@ -75,10 +113,11 @@ class NaiveBayes implements Classifier
/** /**
* Calculates the probability P(label|sample_n) * Calculates the probability P(label|sample_n)
* *
* @param array $sample * @param array $sample
* @param int $feature * @param int $feature
* @param string $label * @param string $label
* @return float
*/ */
private function sampleProbability($sample, $feature, $label) private function sampleProbability($sample, $feature, $label)
{ {
@ -94,14 +133,14 @@ class NaiveBayes implements Classifier
$mean= $this->mean[$label][$feature]; $mean= $this->mean[$label][$feature];
// Calculate the probability density by use of normal/Gaussian distribution // Calculate the probability density by use of normal/Gaussian distribution
// Ref: https://en.wikipedia.org/wiki/Normal_distribution // Ref: https://en.wikipedia.org/wiki/Normal_distribution
// //
// In order to avoid numerical errors because of small or zero values, // In order to avoid numerical errors because of small or zero values,
// some libraries adopt taking log of calculations such as // some libraries adopt taking log of calculations such as
// scikit-learn did. // scikit-learn did.
// (See : https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/naive_bayes.py) // (See : https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/naive_bayes.py)
$pdf = -0.5 * log(2.0 * pi() * $std * $std); $pdf = -0.5 * log(2.0 * pi() * $std * $std);
$pdf -= 0.5 * pow($value - $mean, 2) / ($std * $std); $pdf -= 0.5 * pow($value - $mean, 2) / ($std * $std);
return $pdf; return $pdf;
} }
/** /**