From c028a73985434058171f0c95ca47b7375cf89634 Mon Sep 17 00:00:00 2001 From: Mustafa Karabulut Date: Tue, 28 Feb 2017 23:45:18 +0300 Subject: [PATCH] AdaBoost improvements (#53) * AdaBoost improvements * AdaBoost improvements & test case resolved * Some coding style fixes --- src/Phpml/Classification/DecisionTree.php | 11 +- .../Classification/Ensemble/AdaBoost.php | 117 +++++++-- src/Phpml/Classification/Linear/Adaline.php | 6 + .../Classification/Linear/DecisionStump.php | 247 +++++++++++++----- .../Classification/Linear/Perceptron.php | 71 ++++- .../Classification/WeightedClassifier.php | 20 ++ .../Classification/Linear/PerceptronTest.php | 12 +- 7 files changed, 385 insertions(+), 99 deletions(-) create mode 100644 src/Phpml/Classification/WeightedClassifier.php diff --git a/src/Phpml/Classification/DecisionTree.php b/src/Phpml/Classification/DecisionTree.php index 231d766..b2b4db3 100644 --- a/src/Phpml/Classification/DecisionTree.php +++ b/src/Phpml/Classification/DecisionTree.php @@ -110,12 +110,13 @@ class DecisionTree implements Classifier } } - protected function getColumnTypes(array $samples) + public static function getColumnTypes(array $samples) { $types = []; - for ($i=0; $i<$this->featureCount; $i++) { + $featureCount = count($samples[0]); + for ($i=0; $i < $featureCount; $i++) { $values = array_column($samples, $i); - $isCategorical = $this->isCategoricalColumn($values); + $isCategorical = self::isCategoricalColumn($values); $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS; } return $types; @@ -327,13 +328,13 @@ class DecisionTree implements Classifier * @param array $columnValues * @return bool */ - protected function isCategoricalColumn(array $columnValues) + protected static function isCategoricalColumn(array $columnValues) { $count = count($columnValues); // There are two main indicators that *may* show whether a // column is composed of discrete set of values: - // 1- Column may contain string values and not float values + // 1- Column may contain string values and non-float values // 2- Number of unique values in the column is only a small fraction of // all values in that column (Lower than or equal to %20 of all values) $numericValues = array_filter($columnValues, 'is_numeric'); diff --git a/src/Phpml/Classification/Ensemble/AdaBoost.php b/src/Phpml/Classification/Ensemble/AdaBoost.php index 70440a6..3d1e418 100644 --- a/src/Phpml/Classification/Ensemble/AdaBoost.php +++ b/src/Phpml/Classification/Ensemble/AdaBoost.php @@ -5,6 +5,9 @@ declare(strict_types=1); namespace Phpml\Classification\Ensemble; use Phpml\Classification\Linear\DecisionStump; +use Phpml\Classification\WeightedClassifier; +use Phpml\Math\Statistic\Mean; +use Phpml\Math\Statistic\StandardDeviation; use Phpml\Classification\Classifier; use Phpml\Helper\Predictable; use Phpml\Helper\Trainable; @@ -44,7 +47,7 @@ class AdaBoost implements Classifier protected $weights = []; /** - * Base classifiers + * List of selected 'weak' classifiers * * @var array */ @@ -57,17 +60,39 @@ class AdaBoost implements Classifier */ protected $alpha = []; + /** + * @var string + */ + protected $baseClassifier = DecisionStump::class; + + /** + * @var array + */ + protected $classifierOptions = []; + /** * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to * improve classification performance of 'weak' classifiers such as * DecisionStump (default base classifier of AdaBoost). * */ - public function __construct(int $maxIterations = 30) + public function __construct(int $maxIterations = 50) { $this->maxIterations = $maxIterations; } + /** + * Sets the base classifier that will be used for boosting (default = DecisionStump) + * + * @param string $baseClassifier + * @param array $classifierOptions + */ + public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []) + { + $this->baseClassifier = $baseClassifier; + $this->classifierOptions = $classifierOptions; + } + /** * @param array $samples * @param array $targets @@ -77,7 +102,7 @@ class AdaBoost implements Classifier // Initialize usual variables $this->labels = array_keys(array_count_values($targets)); if (count($this->labels) != 2) { - throw new \Exception("AdaBoost is a binary classifier and can only classify between two classes"); + throw new \Exception("AdaBoost is a binary classifier and can classify between two classes only"); } // Set all target values to either -1 or 1 @@ -98,9 +123,12 @@ class AdaBoost implements Classifier // Execute the algorithm for a maximum number of iterations $currIter = 0; while ($this->maxIterations > $currIter++) { + // Determine the best 'weak' classifier based on current weights - // and update alpha & weight values at each iteration - list($classifier, $errorRate) = $this->getBestClassifier(); + $classifier = $this->getBestClassifier(); + $errorRate = $this->evaluateClassifier($classifier); + + // Update alpha & weight values at each iteration $alpha = $this->calculateAlpha($errorRate); $this->updateWeights($classifier, $alpha); @@ -117,24 +145,71 @@ class AdaBoost implements Classifier */ protected function getBestClassifier() { - // This method works only for "DecisionStump" classifier, for now. - // As a future task, it will be generalized enough to work with other - // classifiers as well - $minErrorRate = 1.0; - $bestClassifier = null; - for ($i=0; $i < $this->featureCount; $i++) { - $stump = new DecisionStump($i); - $stump->setSampleWeights($this->weights); - $stump->train($this->samples, $this->targets); + $ref = new \ReflectionClass($this->baseClassifier); + if ($this->classifierOptions) { + $classifier = $ref->newInstanceArgs($this->classifierOptions); + } else { + $classifier = $ref->newInstance(); + } - $errorRate = $stump->getTrainingErrorRate(); - if ($errorRate < $minErrorRate) { - $bestClassifier = $stump; - $minErrorRate = $errorRate; + if (is_subclass_of($classifier, WeightedClassifier::class)) { + $classifier->setSampleWeights($this->weights); + $classifier->train($this->samples, $this->targets); + } else { + list($samples, $targets) = $this->resample(); + $classifier->train($samples, $targets); + } + + return $classifier; + } + + /** + * Resamples the dataset in accordance with the weights and + * returns the new dataset + * + * @return array + */ + protected function resample() + { + $weights = $this->weights; + $std = StandardDeviation::population($weights); + $mean= Mean::arithmetic($weights); + $min = min($weights); + $minZ= (int)round(($min - $mean) / $std); + + $samples = []; + $targets = []; + foreach ($weights as $index => $weight) { + $z = (int)round(($weight - $mean) / $std) - $minZ + 1; + for ($i=0; $i < $z; $i++) { + if (rand(0, 1) == 0) { + continue; + } + $samples[] = $this->samples[$index]; + $targets[] = $this->targets[$index]; } } - return [$bestClassifier, $minErrorRate]; + return [$samples, $targets]; + } + + /** + * Evaluates the classifier and returns the classification error rate + * + * @param Classifier $classifier + */ + protected function evaluateClassifier(Classifier $classifier) + { + $total = (float) array_sum($this->weights); + $wrong = 0; + foreach ($this->samples as $index => $sample) { + $predicted = $classifier->predict($sample); + if ($predicted != $this->targets[$index]) { + $wrong += $this->weights[$index]; + } + } + + return $wrong / $total; } /** @@ -154,10 +229,10 @@ class AdaBoost implements Classifier /** * Updates the sample weights * - * @param DecisionStump $classifier + * @param Classifier $classifier * @param float $alpha */ - protected function updateWeights(DecisionStump $classifier, float $alpha) + protected function updateWeights(Classifier $classifier, float $alpha) { $sumOfWeights = array_sum($this->weights); $weightsT1 = []; diff --git a/src/Phpml/Classification/Linear/Adaline.php b/src/Phpml/Classification/Linear/Adaline.php index aeff95e..13674f1 100644 --- a/src/Phpml/Classification/Linear/Adaline.php +++ b/src/Phpml/Classification/Linear/Adaline.php @@ -76,10 +76,16 @@ class Adaline extends Perceptron // Batch learning is executed: $currIter = 0; while ($this->maxIterations > $currIter++) { + $weights = $this->weights; + $outputs = array_map([$this, 'output'], $this->samples); $updates = array_map([$this, 'gradient'], $this->targets, $outputs); $this->updateWeights($updates); + + if ($this->earlyStop($weights)) { + break; + } } } diff --git a/src/Phpml/Classification/Linear/DecisionStump.php b/src/Phpml/Classification/Linear/DecisionStump.php index 1220d48..1605a20 100644 --- a/src/Phpml/Classification/Linear/DecisionStump.php +++ b/src/Phpml/Classification/Linear/DecisionStump.php @@ -6,18 +6,19 @@ namespace Phpml\Classification\Linear; use Phpml\Helper\Predictable; use Phpml\Helper\Trainable; -use Phpml\Classification\Classifier; +use Phpml\Classification\WeightedClassifier; use Phpml\Classification\DecisionTree; -use Phpml\Classification\DecisionTree\DecisionTreeLeaf; -class DecisionStump extends DecisionTree +class DecisionStump extends WeightedClassifier { use Trainable, Predictable; + const AUTO_SELECT = -1; + /** * @var int */ - protected $columnIndex; + protected $givenColumnIndex; /** @@ -36,6 +37,31 @@ class DecisionStump extends DecisionTree */ protected $trainingErrorRate; + /** + * @var int + */ + protected $column; + + /** + * @var mixed + */ + protected $value; + + /** + * @var string + */ + protected $operator; + + /** + * @var array + */ + protected $columnTypes; + + /** + * @var float + */ + protected $numSplitCount = 10.0; + /** * A DecisionStump classifier is a one-level deep DecisionTree. It is generally * used with ensemble algorithms as in the weak classifier role.
@@ -46,11 +72,9 @@ class DecisionStump extends DecisionTree * * @param int $columnIndex */ - public function __construct(int $columnIndex = -1) + public function __construct(int $columnIndex = self::AUTO_SELECT) { - $this->columnIndex = $columnIndex; - - parent::__construct(1); + $this->givenColumnIndex = $columnIndex; } /** @@ -59,95 +83,167 @@ class DecisionStump extends DecisionTree */ public function train(array $samples, array $targets) { - if ($this->columnIndex > count($samples[0]) - 1) { - $this->columnIndex = -1; + $this->samples = array_merge($this->samples, $samples); + $this->targets = array_merge($this->targets, $targets); + + // DecisionStump is capable of classifying between two classes only + $labels = array_count_values($this->targets); + $this->labels = array_keys($labels); + if (count($this->labels) != 2) { + throw new \Exception("DecisionStump can classify between two classes only:" . implode(',', $this->labels)); } - if ($this->columnIndex >= 0) { - $this->setSelectedFeatures([$this->columnIndex]); + // If a column index is given, it should be among the existing columns + if ($this->givenColumnIndex > count($samples[0]) - 1) { + $this->givenColumnIndex = self::AUTO_SELECT; } + // Check the size of the weights given. + // If none given, then assign 1 as a weight to each sample if ($this->weights) { $numWeights = count($this->weights); - if ($numWeights != count($samples)) { + if ($numWeights != count($this->samples)) { throw new \Exception("Number of sample weights does not match with number of samples"); } } else { $this->weights = array_fill(0, count($samples), 1); } - parent::train($samples, $targets); + // Determine type of each column as either "continuous" or "nominal" + $this->columnTypes = DecisionTree::getColumnTypes($this->samples); - $this->columnIndex = $this->tree->columnIndex; + // Try to find the best split in the columns of the dataset + // by calculating error rate for each split point in each column + $columns = range(0, count($samples[0]) - 1); + if ($this->givenColumnIndex != self::AUTO_SELECT) { + $columns = [$this->givenColumnIndex]; + } - // For numerical values, try to optimize the value by finding a different threshold value - if ($this->columnTypes[$this->columnIndex] == self::CONTINUOS) { - $this->optimizeDecision($samples, $targets); + $bestSplit = [ + 'value' => 0, 'operator' => '', + 'column' => 0, 'trainingErrorRate' => 1.0]; + foreach ($columns as $col) { + if ($this->columnTypes[$col] == DecisionTree::CONTINUOS) { + $split = $this->getBestNumericalSplit($col); + } else { + $split = $this->getBestNominalSplit($col); + } + + if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) { + $bestSplit = $split; + } + } + + // Assign determined best values to the stump + foreach ($bestSplit as $name => $value) { + $this->{$name} = $value; } } /** - * Used to set sample weights. + * While finding best split point for a numerical valued column, + * DecisionStump looks for equally distanced values between minimum and maximum + * values in the column. Given $count value determines how many split + * points to be probed. The more split counts, the better performance but + * worse processing time (Default value is 10.0) * - * @param array $weights + * @param float $count */ - public function setSampleWeights(array $weights) + public function setNumericalSplitCount(float $count) { - $this->weights = $weights; + $this->numSplitCount = $count; } /** - * Returns the training error rate, the proportion of wrong predictions - * over the total number of samples + * Determines best split point for the given column * - * @return float - */ - public function getTrainingErrorRate() - { - return $this->trainingErrorRate; - } - - /** - * Tries to optimize the threshold by probing a range of different values - * between the minimum and maximum values in the selected column + * @param int $col * - * @param array $samples - * @param array $targets + * @return array */ - protected function optimizeDecision(array $samples, array $targets) + protected function getBestNumericalSplit(int $col) { - $values = array_column($samples, $this->columnIndex); + $values = array_column($this->samples, $col); $minValue = min($values); $maxValue = max($values); - $stepSize = ($maxValue - $minValue) / 100.0; + $stepSize = ($maxValue - $minValue) / $this->numSplitCount; - $leftLabel = $this->tree->leftLeaf->classValue; - $rightLabel= $this->tree->rightLeaf->classValue; - - $bestOperator = $this->tree->operator; - $bestThreshold = $this->tree->numericValue; - $bestErrorRate = $this->calculateErrorRate( - $bestThreshold, $bestOperator, $values, $targets, $leftLabel, $rightLabel); + $split = null; foreach (['<=', '>'] as $operator) { + // Before trying all possible split points, let's first try + // the average value for the cut point + $threshold = array_sum($values) / (float) count($values); + $errorRate = $this->calculateErrorRate($threshold, $operator, $values); + if ($split == null || $errorRate < $split['trainingErrorRate']) { + $split = ['value' => $threshold, 'operator' => $operator, + 'column' => $col, 'trainingErrorRate' => $errorRate]; + } + + // Try other possible points one by one for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) { $threshold = (float)$step; - $errorRate = $this->calculateErrorRate( - $threshold, $operator, $values, $targets, $leftLabel, $rightLabel); - - if ($errorRate < $bestErrorRate) { - $bestErrorRate = $errorRate; - $bestThreshold = $threshold; - $bestOperator = $operator; + $errorRate = $this->calculateErrorRate($threshold, $operator, $values); + if ($errorRate < $split['trainingErrorRate']) { + $split = ['value' => $threshold, 'operator' => $operator, + 'column' => $col, 'trainingErrorRate' => $errorRate]; } }// for } - // Update the tree node value - $this->tree->numericValue = $bestThreshold; - $this->tree->operator = $bestOperator; - $this->tree->value = "$bestOperator $bestThreshold"; - $this->trainingErrorRate = $bestErrorRate; + return $split; + } + + /** + * + * @param int $col + * + * @return array + */ + protected function getBestNominalSplit(int $col) + { + $values = array_column($this->samples, $col); + $valueCounts = array_count_values($values); + $distinctVals= array_keys($valueCounts); + + $split = null; + + foreach (['=', '!='] as $operator) { + foreach ($distinctVals as $val) { + $errorRate = $this->calculateErrorRate($val, $operator, $values); + + if ($split == null || $split['trainingErrorRate'] < $errorRate) { + $split = ['value' => $val, 'operator' => $operator, + 'column' => $col, 'trainingErrorRate' => $errorRate]; + } + }// for + } + + return $split; + } + + + /** + * + * @param type $leftValue + * @param type $operator + * @param type $rightValue + * + * @return boolean + */ + protected function evaluate($leftValue, $operator, $rightValue) + { + switch ($operator) { + case '>': return $leftValue > $rightValue; + case '>=': return $leftValue >= $rightValue; + case '<': return $leftValue < $rightValue; + case '<=': return $leftValue <= $rightValue; + case '=': return $leftValue == $rightValue; + case '!=': + case '<>': return $leftValue != $rightValue; + } + + return false; } /** @@ -157,23 +253,42 @@ class DecisionStump extends DecisionTree * @param float $threshold * @param string $operator * @param array $values - * @param array $targets - * @param mixed $leftLabel - * @param mixed $rightLabel */ - protected function calculateErrorRate(float $threshold, string $operator, array $values, array $targets, $leftLabel, $rightLabel) + protected function calculateErrorRate(float $threshold, string $operator, array $values) { $total = (float) array_sum($this->weights); - $wrong = 0; - + $wrong = 0.0; + $leftLabel = $this->labels[0]; + $rightLabel= $this->labels[1]; foreach ($values as $index => $value) { - eval("\$predicted = \$value $operator \$threshold ? \$leftLabel : \$rightLabel;"); + if ($this->evaluate($threshold, $operator, $value)) { + $predicted = $leftLabel; + } else { + $predicted = $rightLabel; + } - if ($predicted != $targets[$index]) { + if ($predicted != $this->targets[$index]) { $wrong += $this->weights[$index]; } } return $wrong / $total; } + + /** + * @param array $sample + * @return mixed + */ + protected function predictSample(array $sample) + { + if ($this->evaluate($this->value, $this->operator, $sample[$this->column])) { + return $this->labels[0]; + } + return $this->labels[1]; + } + + public function __toString() + { + return "$this->column $this->operator $this->value"; + } } diff --git a/src/Phpml/Classification/Linear/Perceptron.php b/src/Phpml/Classification/Linear/Perceptron.php index 78a204a..bc31da1 100644 --- a/src/Phpml/Classification/Linear/Perceptron.php +++ b/src/Phpml/Classification/Linear/Perceptron.php @@ -61,6 +61,14 @@ class Perceptron implements Classifier */ protected $normalizer; + /** + * Minimum amount of change in the weights between iterations + * that needs to be obtained to continue the training + * + * @var float + */ + protected $threshold = 1e-5; + /** * Initalize a perceptron classifier with given learning rate and maximum * number of iterations used while training the perceptron
@@ -89,6 +97,20 @@ class Perceptron implements Classifier $this->maxIterations = $maxIterations; } + /** + * Sets minimum value for the change in the weights + * between iterations to continue the iterations.
+ * + * If the weight change is less than given value then the + * algorithm will stop training + * + * @param float $threshold + */ + public function setChangeThreshold(float $threshold = 1e-5) + { + $this->threshold = $threshold; + } + /** * @param array $samples * @param array $targets @@ -97,7 +119,7 @@ class Perceptron implements Classifier { $this->labels = array_keys(array_count_values($targets)); if (count($this->labels) > 2) { - throw new \Exception("Perceptron is for only binary (two-class) classification"); + throw new \Exception("Perceptron is for binary (two-class) classification only"); } if ($this->normalizer) { @@ -130,11 +152,20 @@ class Perceptron implements Classifier protected function runTraining() { $currIter = 0; + $bestWeights = null; + $bestScore = count($this->samples); + $bestWeightIter = 0; + while ($this->maxIterations > $currIter++) { + $weights = $this->weights; + $misClassified = 0; foreach ($this->samples as $index => $sample) { $target = $this->targets[$index]; $prediction = $this->{static::$errorFunction}($sample); $update = $target - $prediction; + if ($target != $prediction) { + $misClassified++; + } // Update bias $this->weights[0] += $update * $this->learningRate; // Bias // Update other weights @@ -142,7 +173,45 @@ class Perceptron implements Classifier $this->weights[$i] += $update * $sample[$i - 1] * $this->learningRate; } } + + // Save the best weights in the "pocket" so that + // any future weights worse than this will be disregarded + if ($bestWeights == null || $misClassified <= $bestScore) { + $bestWeights = $weights; + $bestScore = $misClassified; + $bestWeightIter = $currIter; + } + + // Check for early stop + if ($this->earlyStop($weights)) { + break; + } } + + // The weights in the pocket are better than or equal to the last state + // so, we use these weights + $this->weights = $bestWeights; + } + + /** + * @param array $oldWeights + * + * @return boolean + */ + protected function earlyStop($oldWeights) + { + // Check for early stop: No change larger than 1e-5 + $diff = array_map( + function ($w1, $w2) { + return abs($w1 - $w2) > 1e-5 ? 1 : 0; + }, + $oldWeights, $this->weights); + + if (array_sum($diff) == 0) { + return true; + } + + return false; } /** diff --git a/src/Phpml/Classification/WeightedClassifier.php b/src/Phpml/Classification/WeightedClassifier.php new file mode 100644 index 0000000..36a294e --- /dev/null +++ b/src/Phpml/Classification/WeightedClassifier.php @@ -0,0 +1,20 @@ +weights = $weights; + } +} diff --git a/tests/Phpml/Classification/Linear/PerceptronTest.php b/tests/Phpml/Classification/Linear/PerceptronTest.php index bf1b384..64954f7 100644 --- a/tests/Phpml/Classification/Linear/PerceptronTest.php +++ b/tests/Phpml/Classification/Linear/PerceptronTest.php @@ -13,20 +13,20 @@ class PerceptronTest extends TestCase public function testPredictSingleSample() { // AND problem - $samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.9, 0.8]]; + $samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.6, 0.6]]; $targets = [0, 0, 0, 1, 1]; $classifier = new Perceptron(0.001, 5000); $classifier->train($samples, $targets); $this->assertEquals(0, $classifier->predict([0.1, 0.2])); - $this->assertEquals(0, $classifier->predict([0.1, 0.99])); + $this->assertEquals(0, $classifier->predict([0, 1])); $this->assertEquals(1, $classifier->predict([1.1, 0.8])); // OR problem - $samples = [[0, 0], [0.1, 0.2], [1, 0], [0, 1], [1, 1]]; - $targets = [0, 0, 1, 1, 1]; - $classifier = new Perceptron(0.001, 5000); + $samples = [[0.1, 0.1], [0.4, 0.], [0., 0.3], [1, 0], [0, 1], [1, 1]]; + $targets = [0, 0, 0, 1, 1, 1]; + $classifier = new Perceptron(0.001, 5000, false); $classifier->train($samples, $targets); - $this->assertEquals(0, $classifier->predict([0, 0])); + $this->assertEquals(0, $classifier->predict([0., 0.])); $this->assertEquals(1, $classifier->predict([0.1, 0.99])); $this->assertEquals(1, $classifier->predict([1.1, 0.8]));