AdaBoost improvements (#53)

* AdaBoost improvements

* AdaBoost improvements & test case resolved

* Some coding style fixes
This commit is contained in:
Mustafa Karabulut 2017-02-28 23:45:18 +03:00 committed by Arkadiusz Kondas
parent e8c6005aec
commit c028a73985
7 changed files with 385 additions and 99 deletions

View File

@ -110,12 +110,13 @@ class DecisionTree implements Classifier
}
}
protected function getColumnTypes(array $samples)
public static function getColumnTypes(array $samples)
{
$types = [];
for ($i=0; $i<$this->featureCount; $i++) {
$featureCount = count($samples[0]);
for ($i=0; $i < $featureCount; $i++) {
$values = array_column($samples, $i);
$isCategorical = $this->isCategoricalColumn($values);
$isCategorical = self::isCategoricalColumn($values);
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS;
}
return $types;
@ -327,13 +328,13 @@ class DecisionTree implements Classifier
* @param array $columnValues
* @return bool
*/
protected function isCategoricalColumn(array $columnValues)
protected static function isCategoricalColumn(array $columnValues)
{
$count = count($columnValues);
// There are two main indicators that *may* show whether a
// column is composed of discrete set of values:
// 1- Column may contain string values and not float values
// 1- Column may contain string values and non-float values
// 2- Number of unique values in the column is only a small fraction of
// all values in that column (Lower than or equal to %20 of all values)
$numericValues = array_filter($columnValues, 'is_numeric');

View File

@ -5,6 +5,9 @@ declare(strict_types=1);
namespace Phpml\Classification\Ensemble;
use Phpml\Classification\Linear\DecisionStump;
use Phpml\Classification\WeightedClassifier;
use Phpml\Math\Statistic\Mean;
use Phpml\Math\Statistic\StandardDeviation;
use Phpml\Classification\Classifier;
use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable;
@ -44,7 +47,7 @@ class AdaBoost implements Classifier
protected $weights = [];
/**
* Base classifiers
* List of selected 'weak' classifiers
*
* @var array
*/
@ -57,17 +60,39 @@ class AdaBoost implements Classifier
*/
protected $alpha = [];
/**
* @var string
*/
protected $baseClassifier = DecisionStump::class;
/**
* @var array
*/
protected $classifierOptions = [];
/**
* ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
* improve classification performance of 'weak' classifiers such as
* DecisionStump (default base classifier of AdaBoost).
*
*/
public function __construct(int $maxIterations = 30)
public function __construct(int $maxIterations = 50)
{
$this->maxIterations = $maxIterations;
}
/**
* Sets the base classifier that will be used for boosting (default = DecisionStump)
*
* @param string $baseClassifier
* @param array $classifierOptions
*/
public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = [])
{
$this->baseClassifier = $baseClassifier;
$this->classifierOptions = $classifierOptions;
}
/**
* @param array $samples
* @param array $targets
@ -77,7 +102,7 @@ class AdaBoost implements Classifier
// Initialize usual variables
$this->labels = array_keys(array_count_values($targets));
if (count($this->labels) != 2) {
throw new \Exception("AdaBoost is a binary classifier and can only classify between two classes");
throw new \Exception("AdaBoost is a binary classifier and can classify between two classes only");
}
// Set all target values to either -1 or 1
@ -98,9 +123,12 @@ class AdaBoost implements Classifier
// Execute the algorithm for a maximum number of iterations
$currIter = 0;
while ($this->maxIterations > $currIter++) {
// Determine the best 'weak' classifier based on current weights
// and update alpha & weight values at each iteration
list($classifier, $errorRate) = $this->getBestClassifier();
$classifier = $this->getBestClassifier();
$errorRate = $this->evaluateClassifier($classifier);
// Update alpha & weight values at each iteration
$alpha = $this->calculateAlpha($errorRate);
$this->updateWeights($classifier, $alpha);
@ -117,24 +145,71 @@ class AdaBoost implements Classifier
*/
protected function getBestClassifier()
{
// This method works only for "DecisionStump" classifier, for now.
// As a future task, it will be generalized enough to work with other
// classifiers as well
$minErrorRate = 1.0;
$bestClassifier = null;
for ($i=0; $i < $this->featureCount; $i++) {
$stump = new DecisionStump($i);
$stump->setSampleWeights($this->weights);
$stump->train($this->samples, $this->targets);
$ref = new \ReflectionClass($this->baseClassifier);
if ($this->classifierOptions) {
$classifier = $ref->newInstanceArgs($this->classifierOptions);
} else {
$classifier = $ref->newInstance();
}
$errorRate = $stump->getTrainingErrorRate();
if ($errorRate < $minErrorRate) {
$bestClassifier = $stump;
$minErrorRate = $errorRate;
if (is_subclass_of($classifier, WeightedClassifier::class)) {
$classifier->setSampleWeights($this->weights);
$classifier->train($this->samples, $this->targets);
} else {
list($samples, $targets) = $this->resample();
$classifier->train($samples, $targets);
}
return $classifier;
}
/**
* Resamples the dataset in accordance with the weights and
* returns the new dataset
*
* @return array
*/
protected function resample()
{
$weights = $this->weights;
$std = StandardDeviation::population($weights);
$mean= Mean::arithmetic($weights);
$min = min($weights);
$minZ= (int)round(($min - $mean) / $std);
$samples = [];
$targets = [];
foreach ($weights as $index => $weight) {
$z = (int)round(($weight - $mean) / $std) - $minZ + 1;
for ($i=0; $i < $z; $i++) {
if (rand(0, 1) == 0) {
continue;
}
$samples[] = $this->samples[$index];
$targets[] = $this->targets[$index];
}
}
return [$bestClassifier, $minErrorRate];
return [$samples, $targets];
}
/**
* Evaluates the classifier and returns the classification error rate
*
* @param Classifier $classifier
*/
protected function evaluateClassifier(Classifier $classifier)
{
$total = (float) array_sum($this->weights);
$wrong = 0;
foreach ($this->samples as $index => $sample) {
$predicted = $classifier->predict($sample);
if ($predicted != $this->targets[$index]) {
$wrong += $this->weights[$index];
}
}
return $wrong / $total;
}
/**
@ -154,10 +229,10 @@ class AdaBoost implements Classifier
/**
* Updates the sample weights
*
* @param DecisionStump $classifier
* @param Classifier $classifier
* @param float $alpha
*/
protected function updateWeights(DecisionStump $classifier, float $alpha)
protected function updateWeights(Classifier $classifier, float $alpha)
{
$sumOfWeights = array_sum($this->weights);
$weightsT1 = [];

View File

@ -76,10 +76,16 @@ class Adaline extends Perceptron
// Batch learning is executed:
$currIter = 0;
while ($this->maxIterations > $currIter++) {
$weights = $this->weights;
$outputs = array_map([$this, 'output'], $this->samples);
$updates = array_map([$this, 'gradient'], $this->targets, $outputs);
$this->updateWeights($updates);
if ($this->earlyStop($weights)) {
break;
}
}
}

View File

@ -6,18 +6,19 @@ namespace Phpml\Classification\Linear;
use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable;
use Phpml\Classification\Classifier;
use Phpml\Classification\WeightedClassifier;
use Phpml\Classification\DecisionTree;
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
class DecisionStump extends DecisionTree
class DecisionStump extends WeightedClassifier
{
use Trainable, Predictable;
const AUTO_SELECT = -1;
/**
* @var int
*/
protected $columnIndex;
protected $givenColumnIndex;
/**
@ -36,6 +37,31 @@ class DecisionStump extends DecisionTree
*/
protected $trainingErrorRate;
/**
* @var int
*/
protected $column;
/**
* @var mixed
*/
protected $value;
/**
* @var string
*/
protected $operator;
/**
* @var array
*/
protected $columnTypes;
/**
* @var float
*/
protected $numSplitCount = 10.0;
/**
* A DecisionStump classifier is a one-level deep DecisionTree. It is generally
* used with ensemble algorithms as in the weak classifier role. <br>
@ -46,11 +72,9 @@ class DecisionStump extends DecisionTree
*
* @param int $columnIndex
*/
public function __construct(int $columnIndex = -1)
public function __construct(int $columnIndex = self::AUTO_SELECT)
{
$this->columnIndex = $columnIndex;
parent::__construct(1);
$this->givenColumnIndex = $columnIndex;
}
/**
@ -59,95 +83,167 @@ class DecisionStump extends DecisionTree
*/
public function train(array $samples, array $targets)
{
if ($this->columnIndex > count($samples[0]) - 1) {
$this->columnIndex = -1;
$this->samples = array_merge($this->samples, $samples);
$this->targets = array_merge($this->targets, $targets);
// DecisionStump is capable of classifying between two classes only
$labels = array_count_values($this->targets);
$this->labels = array_keys($labels);
if (count($this->labels) != 2) {
throw new \Exception("DecisionStump can classify between two classes only:" . implode(',', $this->labels));
}
if ($this->columnIndex >= 0) {
$this->setSelectedFeatures([$this->columnIndex]);
// If a column index is given, it should be among the existing columns
if ($this->givenColumnIndex > count($samples[0]) - 1) {
$this->givenColumnIndex = self::AUTO_SELECT;
}
// Check the size of the weights given.
// If none given, then assign 1 as a weight to each sample
if ($this->weights) {
$numWeights = count($this->weights);
if ($numWeights != count($samples)) {
if ($numWeights != count($this->samples)) {
throw new \Exception("Number of sample weights does not match with number of samples");
}
} else {
$this->weights = array_fill(0, count($samples), 1);
}
parent::train($samples, $targets);
// Determine type of each column as either "continuous" or "nominal"
$this->columnTypes = DecisionTree::getColumnTypes($this->samples);
$this->columnIndex = $this->tree->columnIndex;
// Try to find the best split in the columns of the dataset
// by calculating error rate for each split point in each column
$columns = range(0, count($samples[0]) - 1);
if ($this->givenColumnIndex != self::AUTO_SELECT) {
$columns = [$this->givenColumnIndex];
}
// For numerical values, try to optimize the value by finding a different threshold value
if ($this->columnTypes[$this->columnIndex] == self::CONTINUOS) {
$this->optimizeDecision($samples, $targets);
$bestSplit = [
'value' => 0, 'operator' => '',
'column' => 0, 'trainingErrorRate' => 1.0];
foreach ($columns as $col) {
if ($this->columnTypes[$col] == DecisionTree::CONTINUOS) {
$split = $this->getBestNumericalSplit($col);
} else {
$split = $this->getBestNominalSplit($col);
}
if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
$bestSplit = $split;
}
}
// Assign determined best values to the stump
foreach ($bestSplit as $name => $value) {
$this->{$name} = $value;
}
}
/**
* Used to set sample weights.
* While finding best split point for a numerical valued column,
* DecisionStump looks for equally distanced values between minimum and maximum
* values in the column. Given <i>$count</i> value determines how many split
* points to be probed. The more split counts, the better performance but
* worse processing time (Default value is 10.0)
*
* @param array $weights
* @param float $count
*/
public function setSampleWeights(array $weights)
public function setNumericalSplitCount(float $count)
{
$this->weights = $weights;
$this->numSplitCount = $count;
}
/**
* Returns the training error rate, the proportion of wrong predictions
* over the total number of samples
* Determines best split point for the given column
*
* @return float
*/
public function getTrainingErrorRate()
{
return $this->trainingErrorRate;
}
/**
* Tries to optimize the threshold by probing a range of different values
* between the minimum and maximum values in the selected column
* @param int $col
*
* @param array $samples
* @param array $targets
* @return array
*/
protected function optimizeDecision(array $samples, array $targets)
protected function getBestNumericalSplit(int $col)
{
$values = array_column($samples, $this->columnIndex);
$values = array_column($this->samples, $col);
$minValue = min($values);
$maxValue = max($values);
$stepSize = ($maxValue - $minValue) / 100.0;
$stepSize = ($maxValue - $minValue) / $this->numSplitCount;
$leftLabel = $this->tree->leftLeaf->classValue;
$rightLabel= $this->tree->rightLeaf->classValue;
$bestOperator = $this->tree->operator;
$bestThreshold = $this->tree->numericValue;
$bestErrorRate = $this->calculateErrorRate(
$bestThreshold, $bestOperator, $values, $targets, $leftLabel, $rightLabel);
$split = null;
foreach (['<=', '>'] as $operator) {
// Before trying all possible split points, let's first try
// the average value for the cut point
$threshold = array_sum($values) / (float) count($values);
$errorRate = $this->calculateErrorRate($threshold, $operator, $values);
if ($split == null || $errorRate < $split['trainingErrorRate']) {
$split = ['value' => $threshold, 'operator' => $operator,
'column' => $col, 'trainingErrorRate' => $errorRate];
}
// Try other possible points one by one
for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) {
$threshold = (float)$step;
$errorRate = $this->calculateErrorRate(
$threshold, $operator, $values, $targets, $leftLabel, $rightLabel);
if ($errorRate < $bestErrorRate) {
$bestErrorRate = $errorRate;
$bestThreshold = $threshold;
$bestOperator = $operator;
$errorRate = $this->calculateErrorRate($threshold, $operator, $values);
if ($errorRate < $split['trainingErrorRate']) {
$split = ['value' => $threshold, 'operator' => $operator,
'column' => $col, 'trainingErrorRate' => $errorRate];
}
}// for
}
// Update the tree node value
$this->tree->numericValue = $bestThreshold;
$this->tree->operator = $bestOperator;
$this->tree->value = "$bestOperator $bestThreshold";
$this->trainingErrorRate = $bestErrorRate;
return $split;
}
/**
*
* @param int $col
*
* @return array
*/
protected function getBestNominalSplit(int $col)
{
$values = array_column($this->samples, $col);
$valueCounts = array_count_values($values);
$distinctVals= array_keys($valueCounts);
$split = null;
foreach (['=', '!='] as $operator) {
foreach ($distinctVals as $val) {
$errorRate = $this->calculateErrorRate($val, $operator, $values);
if ($split == null || $split['trainingErrorRate'] < $errorRate) {
$split = ['value' => $val, 'operator' => $operator,
'column' => $col, 'trainingErrorRate' => $errorRate];
}
}// for
}
return $split;
}
/**
*
* @param type $leftValue
* @param type $operator
* @param type $rightValue
*
* @return boolean
*/
protected function evaluate($leftValue, $operator, $rightValue)
{
switch ($operator) {
case '>': return $leftValue > $rightValue;
case '>=': return $leftValue >= $rightValue;
case '<': return $leftValue < $rightValue;
case '<=': return $leftValue <= $rightValue;
case '=': return $leftValue == $rightValue;
case '!=':
case '<>': return $leftValue != $rightValue;
}
return false;
}
/**
@ -157,23 +253,42 @@ class DecisionStump extends DecisionTree
* @param float $threshold
* @param string $operator
* @param array $values
* @param array $targets
* @param mixed $leftLabel
* @param mixed $rightLabel
*/
protected function calculateErrorRate(float $threshold, string $operator, array $values, array $targets, $leftLabel, $rightLabel)
protected function calculateErrorRate(float $threshold, string $operator, array $values)
{
$total = (float) array_sum($this->weights);
$wrong = 0;
$wrong = 0.0;
$leftLabel = $this->labels[0];
$rightLabel= $this->labels[1];
foreach ($values as $index => $value) {
eval("\$predicted = \$value $operator \$threshold ? \$leftLabel : \$rightLabel;");
if ($this->evaluate($threshold, $operator, $value)) {
$predicted = $leftLabel;
} else {
$predicted = $rightLabel;
}
if ($predicted != $targets[$index]) {
if ($predicted != $this->targets[$index]) {
$wrong += $this->weights[$index];
}
}
return $wrong / $total;
}
/**
* @param array $sample
* @return mixed
*/
protected function predictSample(array $sample)
{
if ($this->evaluate($this->value, $this->operator, $sample[$this->column])) {
return $this->labels[0];
}
return $this->labels[1];
}
public function __toString()
{
return "$this->column $this->operator $this->value";
}
}

View File

@ -61,6 +61,14 @@ class Perceptron implements Classifier
*/
protected $normalizer;
/**
* Minimum amount of change in the weights between iterations
* that needs to be obtained to continue the training
*
* @var float
*/
protected $threshold = 1e-5;
/**
* Initalize a perceptron classifier with given learning rate and maximum
* number of iterations used while training the perceptron <br>
@ -89,6 +97,20 @@ class Perceptron implements Classifier
$this->maxIterations = $maxIterations;
}
/**
* Sets minimum value for the change in the weights
* between iterations to continue the iterations.<br>
*
* If the weight change is less than given value then the
* algorithm will stop training
*
* @param float $threshold
*/
public function setChangeThreshold(float $threshold = 1e-5)
{
$this->threshold = $threshold;
}
/**
* @param array $samples
* @param array $targets
@ -97,7 +119,7 @@ class Perceptron implements Classifier
{
$this->labels = array_keys(array_count_values($targets));
if (count($this->labels) > 2) {
throw new \Exception("Perceptron is for only binary (two-class) classification");
throw new \Exception("Perceptron is for binary (two-class) classification only");
}
if ($this->normalizer) {
@ -130,11 +152,20 @@ class Perceptron implements Classifier
protected function runTraining()
{
$currIter = 0;
$bestWeights = null;
$bestScore = count($this->samples);
$bestWeightIter = 0;
while ($this->maxIterations > $currIter++) {
$weights = $this->weights;
$misClassified = 0;
foreach ($this->samples as $index => $sample) {
$target = $this->targets[$index];
$prediction = $this->{static::$errorFunction}($sample);
$update = $target - $prediction;
if ($target != $prediction) {
$misClassified++;
}
// Update bias
$this->weights[0] += $update * $this->learningRate; // Bias
// Update other weights
@ -142,7 +173,45 @@ class Perceptron implements Classifier
$this->weights[$i] += $update * $sample[$i - 1] * $this->learningRate;
}
}
// Save the best weights in the "pocket" so that
// any future weights worse than this will be disregarded
if ($bestWeights == null || $misClassified <= $bestScore) {
$bestWeights = $weights;
$bestScore = $misClassified;
$bestWeightIter = $currIter;
}
// Check for early stop
if ($this->earlyStop($weights)) {
break;
}
}
// The weights in the pocket are better than or equal to the last state
// so, we use these weights
$this->weights = $bestWeights;
}
/**
* @param array $oldWeights
*
* @return boolean
*/
protected function earlyStop($oldWeights)
{
// Check for early stop: No change larger than 1e-5
$diff = array_map(
function ($w1, $w2) {
return abs($w1 - $w2) > 1e-5 ? 1 : 0;
},
$oldWeights, $this->weights);
if (array_sum($diff) == 0) {
return true;
}
return false;
}
/**

View File

@ -0,0 +1,20 @@
<?php declare(strict_types=1);
namespace Phpml\Classification;
use Phpml\Classification\Classifier;
abstract class WeightedClassifier implements Classifier
{
protected $weights = null;
/**
* Sets the array including a weight for each sample
*
* @param array $weights
*/
public function setSampleWeights(array $weights)
{
$this->weights = $weights;
}
}

View File

@ -13,20 +13,20 @@ class PerceptronTest extends TestCase
public function testPredictSingleSample()
{
// AND problem
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.9, 0.8]];
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 1];
$classifier = new Perceptron(0.001, 5000);
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.2]));
$this->assertEquals(0, $classifier->predict([0.1, 0.99]));
$this->assertEquals(0, $classifier->predict([0, 1]));
$this->assertEquals(1, $classifier->predict([1.1, 0.8]));
// OR problem
$samples = [[0, 0], [0.1, 0.2], [1, 0], [0, 1], [1, 1]];
$targets = [0, 0, 1, 1, 1];
$classifier = new Perceptron(0.001, 5000);
$samples = [[0.1, 0.1], [0.4, 0.], [0., 0.3], [1, 0], [0, 1], [1, 1]];
$targets = [0, 0, 0, 1, 1, 1];
$classifier = new Perceptron(0.001, 5000, false);
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0, 0]));
$this->assertEquals(0, $classifier->predict([0., 0.]));
$this->assertEquals(1, $classifier->predict([0.1, 0.99]));
$this->assertEquals(1, $classifier->predict([1.1, 0.8]));