mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-02-03 04:28:33 +00:00
AdaBoost algorithm along with some improvements (#51)
This commit is contained in:
parent
cf222bcce4
commit
4daa0a222a
@ -24,7 +24,7 @@ class DecisionTree implements Classifier
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $columnTypes;
|
||||
protected $columnTypes;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
@ -39,12 +39,12 @@ class DecisionTree implements Classifier
|
||||
/**
|
||||
* @var DecisionTreeLeaf
|
||||
*/
|
||||
private $tree = null;
|
||||
protected $tree = null;
|
||||
|
||||
/**
|
||||
* @var int
|
||||
*/
|
||||
private $maxDepth;
|
||||
protected $maxDepth;
|
||||
|
||||
/**
|
||||
* @var int
|
||||
@ -79,6 +79,7 @@ class DecisionTree implements Classifier
|
||||
{
|
||||
$this->maxDepth = $maxDepth;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
@ -209,6 +210,17 @@ class DecisionTree implements Classifier
|
||||
$split->columnIndex = $i;
|
||||
$split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS;
|
||||
$split->records = $records;
|
||||
|
||||
// If a numeric column is to be selected, then
|
||||
// the original numeric value and the selected operator
|
||||
// will also be saved into the leaf for future access
|
||||
if ($this->columnTypes[$i] == self::CONTINUOS) {
|
||||
$matches = [];
|
||||
preg_match("/^([<>=]{1,2})\s*(.*)/", strval($split->value), $matches);
|
||||
$split->operator = $matches[1];
|
||||
$split->numericValue = floatval($matches[2]);
|
||||
}
|
||||
|
||||
$bestSplit = $split;
|
||||
$bestGiniVal = $gini;
|
||||
}
|
||||
@ -318,15 +330,21 @@ class DecisionTree implements Classifier
|
||||
protected function isCategoricalColumn(array $columnValues)
|
||||
{
|
||||
$count = count($columnValues);
|
||||
|
||||
// There are two main indicators that *may* show whether a
|
||||
// column is composed of discrete set of values:
|
||||
// 1- Column may contain string values
|
||||
// 1- Column may contain string values and not float values
|
||||
// 2- Number of unique values in the column is only a small fraction of
|
||||
// all values in that column (Lower than or equal to %20 of all values)
|
||||
$numericValues = array_filter($columnValues, 'is_numeric');
|
||||
$floatValues = array_filter($columnValues, 'is_float');
|
||||
if ($floatValues) {
|
||||
return false;
|
||||
}
|
||||
if (count($numericValues) != $count) {
|
||||
return true;
|
||||
}
|
||||
|
||||
$distinctValues = array_count_values($columnValues);
|
||||
if (count($distinctValues) <= $count / 5) {
|
||||
return true;
|
||||
@ -357,9 +375,9 @@ class DecisionTree implements Classifier
|
||||
}
|
||||
|
||||
/**
|
||||
* Used to set predefined features to consider while deciding which column to use for a split,
|
||||
* Used to set predefined features to consider while deciding which column to use for a split
|
||||
*
|
||||
* @param array $features
|
||||
* @param array $selectedFeatures
|
||||
*/
|
||||
protected function setSelectedFeatures(array $selectedFeatures)
|
||||
{
|
||||
|
@ -11,6 +11,16 @@ class DecisionTreeLeaf
|
||||
*/
|
||||
public $value;
|
||||
|
||||
/**
|
||||
* @var float
|
||||
*/
|
||||
public $numericValue;
|
||||
|
||||
/**
|
||||
* @var string
|
||||
*/
|
||||
public $operator;
|
||||
|
||||
/**
|
||||
* @var int
|
||||
*/
|
||||
@ -66,13 +76,15 @@ class DecisionTreeLeaf
|
||||
public function evaluate($record)
|
||||
{
|
||||
$recordField = $record[$this->columnIndex];
|
||||
if ($this->isContinuous && preg_match("/^([<>=]{1,2})\s*(.*)/", strval($this->value), $matches)) {
|
||||
$op = $matches[1];
|
||||
$value= floatval($matches[2]);
|
||||
|
||||
if ($this->isContinuous) {
|
||||
$op = $this->operator;
|
||||
$value= $this->numericValue;
|
||||
$recordField = strval($recordField);
|
||||
eval("\$result = $recordField $op $value;");
|
||||
return $result;
|
||||
}
|
||||
|
||||
return $recordField == $this->value;
|
||||
}
|
||||
|
||||
|
190
src/Phpml/Classification/Ensemble/AdaBoost.php
Normal file
190
src/Phpml/Classification/Ensemble/AdaBoost.php
Normal file
@ -0,0 +1,190 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Classification\Ensemble;
|
||||
|
||||
use Phpml\Classification\Linear\DecisionStump;
|
||||
use Phpml\Classification\Classifier;
|
||||
use Phpml\Helper\Predictable;
|
||||
use Phpml\Helper\Trainable;
|
||||
|
||||
class AdaBoost implements Classifier
|
||||
{
|
||||
use Predictable, Trainable;
|
||||
|
||||
/**
|
||||
* Actual labels given in the targets array
|
||||
* @var array
|
||||
*/
|
||||
protected $labels = [];
|
||||
|
||||
/**
|
||||
* @var int
|
||||
*/
|
||||
protected $sampleCount;
|
||||
|
||||
/**
|
||||
* @var int
|
||||
*/
|
||||
protected $featureCount;
|
||||
|
||||
/**
|
||||
* Number of maximum iterations to be done
|
||||
*
|
||||
* @var int
|
||||
*/
|
||||
protected $maxIterations;
|
||||
|
||||
/**
|
||||
* Sample weights
|
||||
*
|
||||
* @var array
|
||||
*/
|
||||
protected $weights = [];
|
||||
|
||||
/**
|
||||
* Base classifiers
|
||||
*
|
||||
* @var array
|
||||
*/
|
||||
protected $classifiers = [];
|
||||
|
||||
/**
|
||||
* Base classifier weights
|
||||
*
|
||||
* @var array
|
||||
*/
|
||||
protected $alpha = [];
|
||||
|
||||
/**
|
||||
* ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
|
||||
* improve classification performance of 'weak' classifiers such as
|
||||
* DecisionStump (default base classifier of AdaBoost).
|
||||
*
|
||||
*/
|
||||
public function __construct(int $maxIterations = 30)
|
||||
{
|
||||
$this->maxIterations = $maxIterations;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
public function train(array $samples, array $targets)
|
||||
{
|
||||
// Initialize usual variables
|
||||
$this->labels = array_keys(array_count_values($targets));
|
||||
if (count($this->labels) != 2) {
|
||||
throw new \Exception("AdaBoost is a binary classifier and can only classify between two classes");
|
||||
}
|
||||
|
||||
// Set all target values to either -1 or 1
|
||||
$this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
|
||||
foreach ($targets as $target) {
|
||||
$this->targets[] = $target == $this->labels[1] ? 1 : -1;
|
||||
}
|
||||
|
||||
$this->samples = array_merge($this->samples, $samples);
|
||||
$this->featureCount = count($samples[0]);
|
||||
$this->sampleCount = count($this->samples);
|
||||
|
||||
// Initialize AdaBoost parameters
|
||||
$this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount);
|
||||
$this->classifiers = [];
|
||||
$this->alpha = [];
|
||||
|
||||
// Execute the algorithm for a maximum number of iterations
|
||||
$currIter = 0;
|
||||
while ($this->maxIterations > $currIter++) {
|
||||
// Determine the best 'weak' classifier based on current weights
|
||||
// and update alpha & weight values at each iteration
|
||||
list($classifier, $errorRate) = $this->getBestClassifier();
|
||||
$alpha = $this->calculateAlpha($errorRate);
|
||||
$this->updateWeights($classifier, $alpha);
|
||||
|
||||
$this->classifiers[] = $classifier;
|
||||
$this->alpha[] = $alpha;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the classifier with the lowest error rate with the
|
||||
* consideration of current sample weights
|
||||
*
|
||||
* @return Classifier
|
||||
*/
|
||||
protected function getBestClassifier()
|
||||
{
|
||||
// This method works only for "DecisionStump" classifier, for now.
|
||||
// As a future task, it will be generalized enough to work with other
|
||||
// classifiers as well
|
||||
$minErrorRate = 1.0;
|
||||
$bestClassifier = null;
|
||||
for ($i=0; $i < $this->featureCount; $i++) {
|
||||
$stump = new DecisionStump($i);
|
||||
$stump->setSampleWeights($this->weights);
|
||||
$stump->train($this->samples, $this->targets);
|
||||
|
||||
$errorRate = $stump->getTrainingErrorRate();
|
||||
if ($errorRate < $minErrorRate) {
|
||||
$bestClassifier = $stump;
|
||||
$minErrorRate = $errorRate;
|
||||
}
|
||||
}
|
||||
|
||||
return [$bestClassifier, $minErrorRate];
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates alpha of a classifier
|
||||
*
|
||||
* @param float $errorRate
|
||||
* @return float
|
||||
*/
|
||||
protected function calculateAlpha(float $errorRate)
|
||||
{
|
||||
if ($errorRate == 0) {
|
||||
$errorRate = 1e-10;
|
||||
}
|
||||
return 0.5 * log((1 - $errorRate) / $errorRate);
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the sample weights
|
||||
*
|
||||
* @param DecisionStump $classifier
|
||||
* @param float $alpha
|
||||
*/
|
||||
protected function updateWeights(DecisionStump $classifier, float $alpha)
|
||||
{
|
||||
$sumOfWeights = array_sum($this->weights);
|
||||
$weightsT1 = [];
|
||||
foreach ($this->weights as $index => $weight) {
|
||||
$desired = $this->targets[$index];
|
||||
$output = $classifier->predict($this->samples[$index]);
|
||||
|
||||
$weight *= exp(-$alpha * $desired * $output) / $sumOfWeights;
|
||||
|
||||
$weightsT1[] = $weight;
|
||||
}
|
||||
|
||||
$this->weights = $weightsT1;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $sample
|
||||
* @return mixed
|
||||
*/
|
||||
public function predictSample(array $sample)
|
||||
{
|
||||
$sum = 0;
|
||||
foreach ($this->alpha as $index => $alpha) {
|
||||
$h = $this->classifiers[$index]->predict($sample);
|
||||
$sum += $h * $alpha;
|
||||
}
|
||||
|
||||
return $this->labels[ $sum > 0 ? 1 : -1];
|
||||
}
|
||||
}
|
@ -8,7 +8,6 @@ use Phpml\Helper\Predictable;
|
||||
use Phpml\Helper\Trainable;
|
||||
use Phpml\Classification\Classifier;
|
||||
use Phpml\Classification\Linear\Perceptron;
|
||||
use Phpml\Preprocessing\Normalizer;
|
||||
|
||||
class Adaline extends Perceptron
|
||||
{
|
||||
@ -38,11 +37,6 @@ class Adaline extends Perceptron
|
||||
*/
|
||||
protected $trainingType;
|
||||
|
||||
/**
|
||||
* @var Normalizer
|
||||
*/
|
||||
private $normalizer;
|
||||
|
||||
/**
|
||||
* Initalize an Adaline (ADAptive LInear NEuron) classifier with given learning rate and maximum
|
||||
* number of iterations used while training the classifier <br>
|
||||
@ -58,29 +52,13 @@ class Adaline extends Perceptron
|
||||
public function __construct(float $learningRate = 0.001, int $maxIterations = 1000,
|
||||
bool $normalizeInputs = true, int $trainingType = self::BATCH_TRAINING)
|
||||
{
|
||||
if ($normalizeInputs) {
|
||||
$this->normalizer = new Normalizer(Normalizer::NORM_STD);
|
||||
}
|
||||
|
||||
if (! in_array($trainingType, [self::BATCH_TRAINING, self::ONLINE_TRAINING])) {
|
||||
throw new \Exception("Adaline can only be trained with batch and online/stochastic gradient descent algorithm");
|
||||
}
|
||||
|
||||
$this->trainingType = $trainingType;
|
||||
|
||||
parent::__construct($learningRate, $maxIterations);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
public function train(array $samples, array $targets)
|
||||
{
|
||||
if ($this->normalizer) {
|
||||
$this->normalizer->transform($samples);
|
||||
}
|
||||
|
||||
parent::train($samples, $targets);
|
||||
parent::__construct($learningRate, $maxIterations, $normalizeInputs);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -100,22 +78,8 @@ class Adaline extends Perceptron
|
||||
while ($this->maxIterations > $currIter++) {
|
||||
$outputs = array_map([$this, 'output'], $this->samples);
|
||||
$updates = array_map([$this, 'gradient'], $this->targets, $outputs);
|
||||
$sum = array_sum($updates);
|
||||
|
||||
// Updates all weights at once
|
||||
for ($i=0; $i <= $this->featureCount; $i++) {
|
||||
if ($i == 0) {
|
||||
$this->weights[0] += $this->learningRate * $sum;
|
||||
} else {
|
||||
$col = array_column($this->samples, $i - 1);
|
||||
$error = 0;
|
||||
foreach ($col as $index => $val) {
|
||||
$error += $val * $updates[$index];
|
||||
}
|
||||
|
||||
$this->weights[$i] += $this->learningRate * $error;
|
||||
}
|
||||
}
|
||||
$this->updateWeights($updates);
|
||||
}
|
||||
}
|
||||
|
||||
@ -132,17 +96,27 @@ class Adaline extends Perceptron
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $sample
|
||||
* @return mixed
|
||||
* Updates the weights of the network given the direction of the
|
||||
* gradient for each sample
|
||||
*
|
||||
* @param array $updates
|
||||
*/
|
||||
public function predictSample(array $sample)
|
||||
protected function updateWeights(array $updates)
|
||||
{
|
||||
if ($this->normalizer) {
|
||||
$samples = [$sample];
|
||||
$this->normalizer->transform($samples);
|
||||
$sample = $samples[0];
|
||||
}
|
||||
// Updates all weights at once
|
||||
for ($i=0; $i <= $this->featureCount; $i++) {
|
||||
if ($i == 0) {
|
||||
$this->weights[0] += $this->learningRate * array_sum($updates);
|
||||
} else {
|
||||
$col = array_column($this->samples, $i - 1);
|
||||
|
||||
return parent::predictSample($sample);
|
||||
$error = 0;
|
||||
foreach ($col as $index => $val) {
|
||||
$error += $val * $updates[$index];
|
||||
}
|
||||
|
||||
$this->weights[$i] += $this->learningRate * $error;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ use Phpml\Helper\Predictable;
|
||||
use Phpml\Helper\Trainable;
|
||||
use Phpml\Classification\Classifier;
|
||||
use Phpml\Classification\DecisionTree;
|
||||
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
|
||||
|
||||
class DecisionStump extends DecisionTree
|
||||
{
|
||||
@ -19,6 +20,22 @@ class DecisionStump extends DecisionTree
|
||||
protected $columnIndex;
|
||||
|
||||
|
||||
/**
|
||||
* Sample weights : If used the optimization on the decision value
|
||||
* will take these weights into account. If not given, all samples
|
||||
* will be weighed with the same value of 1
|
||||
*
|
||||
* @var array
|
||||
*/
|
||||
protected $weights = null;
|
||||
|
||||
/**
|
||||
* Lowest error rate obtained while training/optimizing the model
|
||||
*
|
||||
* @var float
|
||||
*/
|
||||
protected $trainingErrorRate;
|
||||
|
||||
/**
|
||||
* A DecisionStump classifier is a one-level deep DecisionTree. It is generally
|
||||
* used with ensemble algorithms as in the weak classifier role. <br>
|
||||
@ -42,8 +59,7 @@ class DecisionStump extends DecisionTree
|
||||
*/
|
||||
public function train(array $samples, array $targets)
|
||||
{
|
||||
// Check if a column index was given
|
||||
if ($this->columnIndex >= 0 && $this->columnIndex > count($samples[0]) - 1) {
|
||||
if ($this->columnIndex > count($samples[0]) - 1) {
|
||||
$this->columnIndex = -1;
|
||||
}
|
||||
|
||||
@ -51,6 +67,113 @@ class DecisionStump extends DecisionTree
|
||||
$this->setSelectedFeatures([$this->columnIndex]);
|
||||
}
|
||||
|
||||
if ($this->weights) {
|
||||
$numWeights = count($this->weights);
|
||||
if ($numWeights != count($samples)) {
|
||||
throw new \Exception("Number of sample weights does not match with number of samples");
|
||||
}
|
||||
} else {
|
||||
$this->weights = array_fill(0, count($samples), 1);
|
||||
}
|
||||
|
||||
parent::train($samples, $targets);
|
||||
|
||||
$this->columnIndex = $this->tree->columnIndex;
|
||||
|
||||
// For numerical values, try to optimize the value by finding a different threshold value
|
||||
if ($this->columnTypes[$this->columnIndex] == self::CONTINUOS) {
|
||||
$this->optimizeDecision($samples, $targets);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Used to set sample weights.
|
||||
*
|
||||
* @param array $weights
|
||||
*/
|
||||
public function setSampleWeights(array $weights)
|
||||
{
|
||||
$this->weights = $weights;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the training error rate, the proportion of wrong predictions
|
||||
* over the total number of samples
|
||||
*
|
||||
* @return float
|
||||
*/
|
||||
public function getTrainingErrorRate()
|
||||
{
|
||||
return $this->trainingErrorRate;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to optimize the threshold by probing a range of different values
|
||||
* between the minimum and maximum values in the selected column
|
||||
*
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
protected function optimizeDecision(array $samples, array $targets)
|
||||
{
|
||||
$values = array_column($samples, $this->columnIndex);
|
||||
$minValue = min($values);
|
||||
$maxValue = max($values);
|
||||
$stepSize = ($maxValue - $minValue) / 100.0;
|
||||
|
||||
$leftLabel = $this->tree->leftLeaf->classValue;
|
||||
$rightLabel= $this->tree->rightLeaf->classValue;
|
||||
|
||||
$bestOperator = $this->tree->operator;
|
||||
$bestThreshold = $this->tree->numericValue;
|
||||
$bestErrorRate = $this->calculateErrorRate(
|
||||
$bestThreshold, $bestOperator, $values, $targets, $leftLabel, $rightLabel);
|
||||
|
||||
foreach (['<=', '>'] as $operator) {
|
||||
for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) {
|
||||
$threshold = (float)$step;
|
||||
$errorRate = $this->calculateErrorRate(
|
||||
$threshold, $operator, $values, $targets, $leftLabel, $rightLabel);
|
||||
|
||||
if ($errorRate < $bestErrorRate) {
|
||||
$bestErrorRate = $errorRate;
|
||||
$bestThreshold = $threshold;
|
||||
$bestOperator = $operator;
|
||||
}
|
||||
}// for
|
||||
}
|
||||
|
||||
// Update the tree node value
|
||||
$this->tree->numericValue = $bestThreshold;
|
||||
$this->tree->operator = $bestOperator;
|
||||
$this->tree->value = "$bestOperator $bestThreshold";
|
||||
$this->trainingErrorRate = $bestErrorRate;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the ratio of wrong predictions based on the new threshold
|
||||
* value given as the parameter
|
||||
*
|
||||
* @param float $threshold
|
||||
* @param string $operator
|
||||
* @param array $values
|
||||
* @param array $targets
|
||||
* @param mixed $leftLabel
|
||||
* @param mixed $rightLabel
|
||||
*/
|
||||
protected function calculateErrorRate(float $threshold, string $operator, array $values, array $targets, $leftLabel, $rightLabel)
|
||||
{
|
||||
$total = (float) array_sum($this->weights);
|
||||
$wrong = 0;
|
||||
|
||||
foreach ($values as $index => $value) {
|
||||
eval("\$predicted = \$value $operator \$threshold ? \$leftLabel : \$rightLabel;");
|
||||
|
||||
if ($predicted != $targets[$index]) {
|
||||
$wrong += $this->weights[$index];
|
||||
}
|
||||
}
|
||||
|
||||
return $wrong / $total;
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ namespace Phpml\Classification\Linear;
|
||||
use Phpml\Helper\Predictable;
|
||||
use Phpml\Helper\Trainable;
|
||||
use Phpml\Classification\Classifier;
|
||||
use Phpml\Preprocessing\Normalizer;
|
||||
|
||||
class Perceptron implements Classifier
|
||||
{
|
||||
@ -55,6 +56,11 @@ class Perceptron implements Classifier
|
||||
*/
|
||||
protected $maxIterations;
|
||||
|
||||
/**
|
||||
* @var Normalizer
|
||||
*/
|
||||
protected $normalizer;
|
||||
|
||||
/**
|
||||
* Initalize a perceptron classifier with given learning rate and maximum
|
||||
* number of iterations used while training the perceptron <br>
|
||||
@ -64,7 +70,8 @@ class Perceptron implements Classifier
|
||||
* @param int $learningRate
|
||||
* @param int $maxIterations
|
||||
*/
|
||||
public function __construct(float $learningRate = 0.001, int $maxIterations = 1000)
|
||||
public function __construct(float $learningRate = 0.001, int $maxIterations = 1000,
|
||||
bool $normalizeInputs = true)
|
||||
{
|
||||
if ($learningRate <= 0.0 || $learningRate > 1.0) {
|
||||
throw new \Exception("Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)");
|
||||
@ -74,6 +81,10 @@ class Perceptron implements Classifier
|
||||
throw new \Exception("Maximum number of iterations should be an integer greater than 0");
|
||||
}
|
||||
|
||||
if ($normalizeInputs) {
|
||||
$this->normalizer = new Normalizer(Normalizer::NORM_STD);
|
||||
}
|
||||
|
||||
$this->learningRate = $learningRate;
|
||||
$this->maxIterations = $maxIterations;
|
||||
}
|
||||
@ -89,6 +100,10 @@ class Perceptron implements Classifier
|
||||
throw new \Exception("Perceptron is for only binary (two-class) classification");
|
||||
}
|
||||
|
||||
if ($this->normalizer) {
|
||||
$this->normalizer->transform($samples);
|
||||
}
|
||||
|
||||
// Set all target values to either -1 or 1
|
||||
$this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
|
||||
foreach ($targets as $target) {
|
||||
@ -167,6 +182,12 @@ class Perceptron implements Classifier
|
||||
*/
|
||||
protected function predictSample(array $sample)
|
||||
{
|
||||
if ($this->normalizer) {
|
||||
$samples = [$sample];
|
||||
$this->normalizer->transform($samples);
|
||||
$sample = $samples[0];
|
||||
}
|
||||
|
||||
$predictedClass = $this->outputClass($sample);
|
||||
|
||||
return $this->labels[ $predictedClass ];
|
||||
|
64
tests/Phpml/Classification/Ensemble/AdaBoostTest.php
Normal file
64
tests/Phpml/Classification/Ensemble/AdaBoostTest.php
Normal file
@ -0,0 +1,64 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace tests\Classification\Linear;
|
||||
|
||||
use Phpml\Classification\Ensemble\AdaBoost;
|
||||
use Phpml\ModelManager;
|
||||
use PHPUnit\Framework\TestCase;
|
||||
|
||||
class AdaBoostTest extends TestCase
|
||||
{
|
||||
public function testPredictSingleSample()
|
||||
{
|
||||
// AND problem
|
||||
$samples = [[0.1, 0.3], [1, 0], [0, 1], [1, 1], [0.9, 0.8], [1.1, 1.1]];
|
||||
$targets = [0, 0, 0, 1, 1, 1];
|
||||
$classifier = new AdaBoost();
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(0, $classifier->predict([0.1, 0.2]));
|
||||
$this->assertEquals(0, $classifier->predict([0.1, 0.99]));
|
||||
$this->assertEquals(1, $classifier->predict([1.1, 0.8]));
|
||||
|
||||
// OR problem
|
||||
$samples = [[0, 0], [0.1, 0.2], [0.2, 0.1], [1, 0], [0, 1], [1, 1]];
|
||||
$targets = [0, 0, 0, 1, 1, 1];
|
||||
$classifier = new AdaBoost();
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(0, $classifier->predict([0.1, 0.2]));
|
||||
$this->assertEquals(1, $classifier->predict([0.1, 0.99]));
|
||||
$this->assertEquals(1, $classifier->predict([1.1, 0.8]));
|
||||
|
||||
// XOR problem
|
||||
$samples = [[0.1, 0.2], [1., 1.], [0.9, 0.8], [0., 1.], [1., 0.], [0.2, 0.8]];
|
||||
$targets = [0, 0, 0, 1, 1, 1];
|
||||
$classifier = new AdaBoost(5);
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
||||
$this->assertEquals(1, $classifier->predict([0, 0.999]));
|
||||
$this->assertEquals(0, $classifier->predict([1.1, 0.8]));
|
||||
|
||||
return $classifier;
|
||||
}
|
||||
|
||||
public function testSaveAndRestore()
|
||||
{
|
||||
// Instantinate new Percetron trained for OR problem
|
||||
$samples = [[0, 0], [1, 0], [0, 1], [1, 1]];
|
||||
$targets = [0, 1, 1, 1];
|
||||
$classifier = new AdaBoost();
|
||||
$classifier->train($samples, $targets);
|
||||
$testSamples = [[0, 1], [1, 1], [0.2, 0.1]];
|
||||
$predicted = $classifier->predict($testSamples);
|
||||
|
||||
$filename = 'adaboost-test-'.rand(100, 999).'-'.uniqid();
|
||||
$filepath = tempnam(sys_get_temp_dir(), $filename);
|
||||
$modelManager = new ModelManager();
|
||||
$modelManager->saveToFile($classifier, $filepath);
|
||||
|
||||
$restoredClassifier = $modelManager->restoreFromFile($filepath);
|
||||
$this->assertEquals($classifier, $restoredClassifier);
|
||||
$this->assertEquals($predicted, $restoredClassifier->predict($testSamples));
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user