mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-01-26 00:28:31 +00:00
One-v-Rest Classification technique applied to linear classifiers (#54)
* One-v-Rest Classification technique applied to linear classifiers * Fix for Apriori * Fixes for One-v-Rest * One-v-Rest test cases
This commit is contained in:
parent
63c63dfba2
commit
01bb82a2a7
@ -16,11 +16,6 @@ class DecisionTree implements Classifier
|
||||
const CONTINUOS = 1;
|
||||
const NOMINAL = 2;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $samples = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
|
@ -5,13 +5,13 @@ declare(strict_types=1);
|
||||
namespace Phpml\Classification\Linear;
|
||||
|
||||
use Phpml\Helper\Predictable;
|
||||
use Phpml\Helper\Trainable;
|
||||
use Phpml\Helper\OneVsRest;
|
||||
use Phpml\Classification\WeightedClassifier;
|
||||
use Phpml\Classification\DecisionTree;
|
||||
|
||||
class DecisionStump extends WeightedClassifier
|
||||
{
|
||||
use Trainable, Predictable;
|
||||
use Predictable, OneVsRest;
|
||||
|
||||
const AUTO_SELECT = -1;
|
||||
|
||||
@ -20,6 +20,10 @@ class DecisionStump extends WeightedClassifier
|
||||
*/
|
||||
protected $givenColumnIndex;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $binaryLabels;
|
||||
|
||||
/**
|
||||
* Sample weights : If used the optimization on the decision value
|
||||
@ -57,10 +61,22 @@ class DecisionStump extends WeightedClassifier
|
||||
*/
|
||||
protected $columnTypes;
|
||||
|
||||
/**
|
||||
* @var int
|
||||
*/
|
||||
protected $featureCount;
|
||||
|
||||
/**
|
||||
* @var float
|
||||
*/
|
||||
protected $numSplitCount = 10.0;
|
||||
protected $numSplitCount = 100.0;
|
||||
|
||||
/**
|
||||
* Distribution of samples in the leaves
|
||||
*
|
||||
* @var array
|
||||
*/
|
||||
protected $prob;
|
||||
|
||||
/**
|
||||
* A DecisionStump classifier is a one-level deep DecisionTree. It is generally
|
||||
@ -81,20 +97,15 @@ class DecisionStump extends WeightedClassifier
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
public function train(array $samples, array $targets)
|
||||
protected function trainBinary(array $samples, array $targets)
|
||||
{
|
||||
$this->samples = array_merge($this->samples, $samples);
|
||||
$this->targets = array_merge($this->targets, $targets);
|
||||
|
||||
// DecisionStump is capable of classifying between two classes only
|
||||
$labels = array_count_values($this->targets);
|
||||
$this->labels = array_keys($labels);
|
||||
if (count($this->labels) != 2) {
|
||||
throw new \Exception("DecisionStump can classify between two classes only:" . implode(',', $this->labels));
|
||||
}
|
||||
$this->binaryLabels = array_keys(array_count_values($this->targets));
|
||||
$this->featureCount = count($this->samples[0]);
|
||||
|
||||
// If a column index is given, it should be among the existing columns
|
||||
if ($this->givenColumnIndex > count($samples[0]) - 1) {
|
||||
if ($this->givenColumnIndex > count($this->samples[0]) - 1) {
|
||||
$this->givenColumnIndex = self::AUTO_SELECT;
|
||||
}
|
||||
|
||||
@ -106,7 +117,7 @@ class DecisionStump extends WeightedClassifier
|
||||
throw new \Exception("Number of sample weights does not match with number of samples");
|
||||
}
|
||||
} else {
|
||||
$this->weights = array_fill(0, count($samples), 1);
|
||||
$this->weights = array_fill(0, count($this->samples), 1);
|
||||
}
|
||||
|
||||
// Determine type of each column as either "continuous" or "nominal"
|
||||
@ -114,14 +125,15 @@ class DecisionStump extends WeightedClassifier
|
||||
|
||||
// Try to find the best split in the columns of the dataset
|
||||
// by calculating error rate for each split point in each column
|
||||
$columns = range(0, count($samples[0]) - 1);
|
||||
$columns = range(0, count($this->samples[0]) - 1);
|
||||
if ($this->givenColumnIndex != self::AUTO_SELECT) {
|
||||
$columns = [$this->givenColumnIndex];
|
||||
}
|
||||
|
||||
$bestSplit = [
|
||||
'value' => 0, 'operator' => '',
|
||||
'column' => 0, 'trainingErrorRate' => 1.0];
|
||||
'prob' => [], 'column' => 0,
|
||||
'trainingErrorRate' => 1.0];
|
||||
foreach ($columns as $col) {
|
||||
if ($this->columnTypes[$col] == DecisionTree::CONTINUOS) {
|
||||
$split = $this->getBestNumericalSplit($col);
|
||||
@ -164,6 +176,10 @@ class DecisionStump extends WeightedClassifier
|
||||
protected function getBestNumericalSplit(int $col)
|
||||
{
|
||||
$values = array_column($this->samples, $col);
|
||||
// Trying all possible points may be accomplished in two general ways:
|
||||
// 1- Try all values in the $samples array ($values)
|
||||
// 2- Artificially split the range of values into several parts and try them
|
||||
// We choose the second one because it is faster in larger datasets
|
||||
$minValue = min($values);
|
||||
$maxValue = max($values);
|
||||
$stepSize = ($maxValue - $minValue) / $this->numSplitCount;
|
||||
@ -174,19 +190,21 @@ class DecisionStump extends WeightedClassifier
|
||||
// Before trying all possible split points, let's first try
|
||||
// the average value for the cut point
|
||||
$threshold = array_sum($values) / (float) count($values);
|
||||
$errorRate = $this->calculateErrorRate($threshold, $operator, $values);
|
||||
list($errorRate, $prob) = $this->calculateErrorRate($threshold, $operator, $values);
|
||||
if ($split == null || $errorRate < $split['trainingErrorRate']) {
|
||||
$split = ['value' => $threshold, 'operator' => $operator,
|
||||
'column' => $col, 'trainingErrorRate' => $errorRate];
|
||||
'prob' => $prob, 'column' => $col,
|
||||
'trainingErrorRate' => $errorRate];
|
||||
}
|
||||
|
||||
// Try other possible points one by one
|
||||
for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) {
|
||||
$threshold = (float)$step;
|
||||
$errorRate = $this->calculateErrorRate($threshold, $operator, $values);
|
||||
list($errorRate, $prob) = $this->calculateErrorRate($threshold, $operator, $values);
|
||||
if ($errorRate < $split['trainingErrorRate']) {
|
||||
$split = ['value' => $threshold, 'operator' => $operator,
|
||||
'column' => $col, 'trainingErrorRate' => $errorRate];
|
||||
'prob' => $prob, 'column' => $col,
|
||||
'trainingErrorRate' => $errorRate];
|
||||
}
|
||||
}// for
|
||||
}
|
||||
@ -210,11 +228,12 @@ class DecisionStump extends WeightedClassifier
|
||||
|
||||
foreach (['=', '!='] as $operator) {
|
||||
foreach ($distinctVals as $val) {
|
||||
$errorRate = $this->calculateErrorRate($val, $operator, $values);
|
||||
list($errorRate, $prob) = $this->calculateErrorRate($val, $operator, $values);
|
||||
|
||||
if ($split == null || $split['trainingErrorRate'] < $errorRate) {
|
||||
$split = ['value' => $val, 'operator' => $operator,
|
||||
'column' => $col, 'trainingErrorRate' => $errorRate];
|
||||
'prob' => $prob, 'column' => $col,
|
||||
'trainingErrorRate' => $errorRate];
|
||||
}
|
||||
}// for
|
||||
}
|
||||
@ -238,9 +257,9 @@ class DecisionStump extends WeightedClassifier
|
||||
case '>=': return $leftValue >= $rightValue;
|
||||
case '<': return $leftValue < $rightValue;
|
||||
case '<=': return $leftValue <= $rightValue;
|
||||
case '=': return $leftValue == $rightValue;
|
||||
case '=': return $leftValue === $rightValue;
|
||||
case '!=':
|
||||
case '<>': return $leftValue != $rightValue;
|
||||
case '<>': return $leftValue !== $rightValue;
|
||||
}
|
||||
|
||||
return false;
|
||||
@ -253,42 +272,90 @@ class DecisionStump extends WeightedClassifier
|
||||
* @param float $threshold
|
||||
* @param string $operator
|
||||
* @param array $values
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
protected function calculateErrorRate(float $threshold, string $operator, array $values)
|
||||
{
|
||||
$total = (float) array_sum($this->weights);
|
||||
$wrong = 0.0;
|
||||
$leftLabel = $this->labels[0];
|
||||
$rightLabel= $this->labels[1];
|
||||
$prob = [];
|
||||
$leftLabel = $this->binaryLabels[0];
|
||||
$rightLabel= $this->binaryLabels[1];
|
||||
|
||||
foreach ($values as $index => $value) {
|
||||
if ($this->evaluate($threshold, $operator, $value)) {
|
||||
if ($this->evaluate($value, $operator, $threshold)) {
|
||||
$predicted = $leftLabel;
|
||||
} else {
|
||||
$predicted = $rightLabel;
|
||||
}
|
||||
|
||||
if ($predicted != $this->targets[$index]) {
|
||||
$target = $this->targets[$index];
|
||||
if (strval($predicted) != strval($this->targets[$index])) {
|
||||
$wrong += $this->weights[$index];
|
||||
}
|
||||
|
||||
if (! isset($prob[$predicted][$target])) {
|
||||
$prob[$predicted][$target] = 0;
|
||||
}
|
||||
$prob[$predicted][$target]++;
|
||||
}
|
||||
|
||||
// Calculate probabilities: Proportion of labels in each leaf
|
||||
$dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0));
|
||||
foreach ($prob as $leaf => $counts) {
|
||||
$leafTotal = (float)array_sum($prob[$leaf]);
|
||||
foreach ($counts as $label => $count) {
|
||||
if (strval($leaf) == strval($label)) {
|
||||
$dist[$leaf] = $count / $leafTotal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return $wrong / $total;
|
||||
return [$wrong / (float) array_sum($this->weights), $dist];
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the probability of the sample of belonging to the given label
|
||||
*
|
||||
* Probability of a sample is calculated as the proportion of the label
|
||||
* within the labels of the training samples in the decision node
|
||||
*
|
||||
* @param array $sample
|
||||
* @param mixed $label
|
||||
*
|
||||
* @return float
|
||||
*/
|
||||
protected function predictProbability(array $sample, $label)
|
||||
{
|
||||
$predicted = $this->predictSampleBinary($sample);
|
||||
if (strval($predicted) == strval($label)) {
|
||||
return $this->prob[$label];
|
||||
}
|
||||
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $sample
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
protected function predictSample(array $sample)
|
||||
protected function predictSampleBinary(array $sample)
|
||||
{
|
||||
if ($this->evaluate($this->value, $this->operator, $sample[$this->column])) {
|
||||
return $this->labels[0];
|
||||
}
|
||||
return $this->labels[1];
|
||||
if ($this->evaluate($sample[$this->column], $this->operator, $this->value)) {
|
||||
return $this->binaryLabels[0];
|
||||
}
|
||||
|
||||
return $this->binaryLabels[1];
|
||||
}
|
||||
|
||||
/**
|
||||
* @return string
|
||||
*/
|
||||
public function __toString()
|
||||
{
|
||||
return "$this->column $this->operator $this->value";
|
||||
return "IF $this->column $this->operator $this->value " .
|
||||
"THEN " . $this->binaryLabels[0] . " ".
|
||||
"ELSE " . $this->binaryLabels[1];
|
||||
}
|
||||
}
|
||||
|
@ -5,12 +5,13 @@ declare(strict_types=1);
|
||||
namespace Phpml\Classification\Linear;
|
||||
|
||||
use Phpml\Helper\Predictable;
|
||||
use Phpml\Helper\OneVsRest;
|
||||
use Phpml\Classification\Classifier;
|
||||
use Phpml\Preprocessing\Normalizer;
|
||||
|
||||
class Perceptron implements Classifier
|
||||
{
|
||||
use Predictable;
|
||||
use Predictable, OneVsRest;
|
||||
|
||||
/**
|
||||
* The function whose result will be used to calculate the network error
|
||||
@ -114,7 +115,7 @@ class Perceptron implements Classifier
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
public function train(array $samples, array $targets)
|
||||
public function trainBinary(array $samples, array $targets)
|
||||
{
|
||||
$this->labels = array_keys(array_count_values($targets));
|
||||
if (count($this->labels) > 2) {
|
||||
@ -128,7 +129,7 @@ class Perceptron implements Classifier
|
||||
// Set all target values to either -1 or 1
|
||||
$this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
|
||||
foreach ($targets as $target) {
|
||||
$this->targets[] = $target == $this->labels[1] ? 1 : -1;
|
||||
$this->targets[] = strval($target) == strval($this->labels[1]) ? 1 : -1;
|
||||
}
|
||||
|
||||
// Set samples and feature count vars
|
||||
@ -213,6 +214,25 @@ class Perceptron implements Classifier
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the sample should be normalized and if so, returns the
|
||||
* normalized sample
|
||||
*
|
||||
* @param array $sample
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
protected function checkNormalizedSample(array $sample)
|
||||
{
|
||||
if ($this->normalizer) {
|
||||
$samples = [$sample];
|
||||
$this->normalizer->transform($samples);
|
||||
$sample = $samples[0];
|
||||
}
|
||||
|
||||
return $sample;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates net output of the network as a float value for the given input
|
||||
*
|
||||
@ -244,17 +264,34 @@ class Perceptron implements Classifier
|
||||
return $this->output($sample) > 0 ? 1 : -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the probability of the sample of belonging to the given label.
|
||||
*
|
||||
* The probability is simply taken as the distance of the sample
|
||||
* to the decision plane.
|
||||
*
|
||||
* @param array $sample
|
||||
* @param mixed $label
|
||||
*/
|
||||
protected function predictProbability(array $sample, $label)
|
||||
{
|
||||
$predicted = $this->predictSampleBinary($sample);
|
||||
|
||||
if (strval($predicted) == strval($label)) {
|
||||
$sample = $this->checkNormalizedSample($sample);
|
||||
return abs($this->output($sample));
|
||||
}
|
||||
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $sample
|
||||
* @return mixed
|
||||
*/
|
||||
protected function predictSample(array $sample)
|
||||
protected function predictSampleBinary(array $sample)
|
||||
{
|
||||
if ($this->normalizer) {
|
||||
$samples = [$sample];
|
||||
$this->normalizer->transform($samples);
|
||||
$sample = $samples[0];
|
||||
}
|
||||
$sample = $this->checkNormalizedSample($sample);
|
||||
|
||||
$predictedClass = $this->outputClass($sample);
|
||||
|
||||
|
126
src/Phpml/Helper/OneVsRest.php
Normal file
126
src/Phpml/Helper/OneVsRest.php
Normal file
@ -0,0 +1,126 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Helper;
|
||||
|
||||
trait OneVsRest
|
||||
{
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $samples = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $targets = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $classifiers;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $labels;
|
||||
|
||||
/**
|
||||
* Train a binary classifier in the OvR style
|
||||
*
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
public function train(array $samples, array $targets)
|
||||
{
|
||||
// Clone the current classifier, so that
|
||||
// we don't mess up its variables while training
|
||||
// multiple instances of this classifier
|
||||
$classifier = clone $this;
|
||||
$this->classifiers = [];
|
||||
|
||||
// If there are only two targets, then there is no need to perform OvR
|
||||
$this->labels = array_keys(array_count_values($targets));
|
||||
if (count($this->labels) == 2) {
|
||||
$classifier->trainBinary($samples, $targets);
|
||||
$this->classifiers[] = $classifier;
|
||||
} else {
|
||||
// Train a separate classifier for each label and memorize them
|
||||
$this->samples = $samples;
|
||||
$this->targets = $targets;
|
||||
foreach ($this->labels as $label) {
|
||||
$predictor = clone $classifier;
|
||||
$targets = $this->binarizeTargets($label);
|
||||
$predictor->trainBinary($samples, $targets);
|
||||
$this->classifiers[$label] = $predictor;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Groups all targets into two groups: Targets equal to
|
||||
* the given label and the others
|
||||
*
|
||||
* @param mixed $label
|
||||
*/
|
||||
private function binarizeTargets($label)
|
||||
{
|
||||
$targets = [];
|
||||
|
||||
foreach ($this->targets as $target) {
|
||||
$targets[] = $target == $label ? $label : "not_$label";
|
||||
}
|
||||
|
||||
return $targets;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @param array $sample
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
protected function predictSample(array $sample)
|
||||
{
|
||||
if (count($this->labels) == 2) {
|
||||
return $this->classifiers[0]->predictSampleBinary($sample);
|
||||
}
|
||||
|
||||
$probs = [];
|
||||
|
||||
foreach ($this->classifiers as $label => $predictor) {
|
||||
$probs[$label] = $predictor->predictProbability($sample, $label);
|
||||
}
|
||||
|
||||
arsort($probs, SORT_NUMERIC);
|
||||
return key($probs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Each classifier should implement this method instead of train(samples, targets)
|
||||
*
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
abstract protected function trainBinary(array $samples, array $targets);
|
||||
|
||||
/**
|
||||
* Each classifier that make use of OvR approach should be able to
|
||||
* return a probability for a sample to belong to the given label.
|
||||
*
|
||||
* @param array $sample
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
abstract protected function predictProbability(array $sample, string $label);
|
||||
|
||||
/**
|
||||
* Each classifier should implement this method instead of predictSample()
|
||||
*
|
||||
* @param array $sample
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
abstract protected function predictSampleBinary(array $sample);
|
||||
}
|
60
src/Phpml/Math/Statistic/Gaussian.php
Normal file
60
src/Phpml/Math/Statistic/Gaussian.php
Normal file
@ -0,0 +1,60 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Math\Statistic;
|
||||
|
||||
class Gaussian
|
||||
{
|
||||
/**
|
||||
* @var float
|
||||
*/
|
||||
protected $mean;
|
||||
|
||||
/**
|
||||
* @var float
|
||||
*/
|
||||
protected $std;
|
||||
|
||||
/**
|
||||
* @param float $mean
|
||||
* @param float $std
|
||||
*/
|
||||
public function __construct(float $mean, float $std)
|
||||
{
|
||||
$this->mean = $mean;
|
||||
$this->std = $std;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns probability density of the given <i>$value</i>
|
||||
*
|
||||
* @param float $value
|
||||
*
|
||||
* @return type
|
||||
*/
|
||||
public function pdf(float $value)
|
||||
{
|
||||
// Calculate the probability density by use of normal/Gaussian distribution
|
||||
// Ref: https://en.wikipedia.org/wiki/Normal_distribution
|
||||
$std2 = $this->std ** 2;
|
||||
$mean = $this->mean;
|
||||
return exp(- (($value - $mean) ** 2) / (2 * $std2)) / sqrt(2 * $std2 * pi());
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns probability density value of the given <i>$value</i> based on
|
||||
* given standard deviation and the mean
|
||||
*
|
||||
* @param float $mean
|
||||
* @param float $std
|
||||
* @param float $value
|
||||
*
|
||||
* @return float
|
||||
*/
|
||||
public static function distributionPdf(float $mean, float $std, float $value)
|
||||
{
|
||||
$normal = new self($mean, $std);
|
||||
return $normal->pdf($value);
|
||||
}
|
||||
}
|
@ -30,6 +30,21 @@ class AdalineTest extends TestCase
|
||||
$this->assertEquals(1, $classifier->predict([0.1, 0.99]));
|
||||
$this->assertEquals(1, $classifier->predict([1.1, 0.8]));
|
||||
|
||||
// By use of One-v-Rest, Adaline can perform multi-class classification
|
||||
// The samples should be separable by lines perpendicular to the dimensions
|
||||
$samples = [
|
||||
[0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D
|
||||
[5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right
|
||||
[3, 10],[3, 10],[3, 8], [3, 9] // Third group : cluster at the top-middle
|
||||
];
|
||||
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
|
||||
|
||||
$classifier = new Adaline();
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(0, $classifier->predict([0.5, 0.5]));
|
||||
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
|
||||
$this->assertEquals(2, $classifier->predict([3.0, 9.5]));
|
||||
|
||||
return $classifier;
|
||||
}
|
||||
|
||||
|
@ -12,8 +12,9 @@ class DecisionStumpTest extends TestCase
|
||||
{
|
||||
public function testPredictSingleSample()
|
||||
{
|
||||
// Samples should be separable with a line perpendicular to any dimension
|
||||
// given in the dataset
|
||||
// Samples should be separable with a line perpendicular
|
||||
// to any dimension given in the dataset
|
||||
//
|
||||
// First: horizontal test
|
||||
$samples = [[0, 0], [1, 0], [0, 1], [1, 1]];
|
||||
$targets = [0, 0, 1, 1];
|
||||
@ -34,6 +35,21 @@ class DecisionStumpTest extends TestCase
|
||||
$this->assertEquals(1, $classifier->predict([1.0, 0.99]));
|
||||
$this->assertEquals(1, $classifier->predict([1.1, 0.1]));
|
||||
|
||||
// By use of One-v-Rest, DecisionStump can perform multi-class classification
|
||||
// The samples should be separable by lines perpendicular to the dimensions
|
||||
$samples = [
|
||||
[0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D
|
||||
[5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right
|
||||
[3, 10],[3, 10],[3, 8], [3, 9] // Third group : cluster at the top-middle
|
||||
];
|
||||
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
|
||||
|
||||
$classifier = new DecisionStump();
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(0, $classifier->predict([0.5, 0.5]));
|
||||
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
|
||||
$this->assertEquals(2, $classifier->predict([3.5, 9.5]));
|
||||
|
||||
return $classifier;
|
||||
}
|
||||
|
||||
|
@ -30,6 +30,21 @@ class PerceptronTest extends TestCase
|
||||
$this->assertEquals(1, $classifier->predict([0.1, 0.99]));
|
||||
$this->assertEquals(1, $classifier->predict([1.1, 0.8]));
|
||||
|
||||
// By use of One-v-Rest, Perceptron can perform multi-class classification
|
||||
// The samples should be separable by lines perpendicular to the dimensions
|
||||
$samples = [
|
||||
[0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D
|
||||
[5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right
|
||||
[3, 10],[3, 10],[3, 8], [3, 9] // Third group : cluster at the top-middle
|
||||
];
|
||||
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
|
||||
|
||||
$classifier = new Perceptron();
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(0, $classifier->predict([0.5, 0.5]));
|
||||
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
|
||||
$this->assertEquals(2, $classifier->predict([3.0, 9.5]));
|
||||
|
||||
return $classifier;
|
||||
}
|
||||
|
||||
|
28
tests/Phpml/Math/Statistic/GaussianTest.php
Normal file
28
tests/Phpml/Math/Statistic/GaussianTest.php
Normal file
@ -0,0 +1,28 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace test\Phpml\Math\StandardDeviation;
|
||||
|
||||
use Phpml\Math\Statistic\Gaussian;
|
||||
use PHPUnit\Framework\TestCase;
|
||||
|
||||
class GaussianTest extends TestCase
|
||||
{
|
||||
public function testPdf()
|
||||
{
|
||||
$std = 1.0;
|
||||
$mean= 0.0;
|
||||
$g = new Gaussian($mean, $std);
|
||||
|
||||
// Allowable error
|
||||
$delta = 0.001;
|
||||
$x = [0, 0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0];
|
||||
$pdf = [0.3989, 0.3969, 0.3520, 0.2419, 0.1295, 0.0539, 0.0175, 0.0044];
|
||||
foreach ($x as $i => $v) {
|
||||
$this->assertEquals($pdf[$i], $g->pdf($v), '', $delta);
|
||||
|
||||
$this->assertEquals($pdf[$i], Gaussian::distributionPdf($mean, $std, $v), '', $delta);
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user