mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-02-03 04:28:33 +00:00
AdaBoost improvements (#53)
* AdaBoost improvements * AdaBoost improvements & test case resolved * Some coding style fixes
This commit is contained in:
parent
e8c6005aec
commit
c028a73985
@ -110,12 +110,13 @@ class DecisionTree implements Classifier
|
||||
}
|
||||
}
|
||||
|
||||
protected function getColumnTypes(array $samples)
|
||||
public static function getColumnTypes(array $samples)
|
||||
{
|
||||
$types = [];
|
||||
for ($i=0; $i<$this->featureCount; $i++) {
|
||||
$featureCount = count($samples[0]);
|
||||
for ($i=0; $i < $featureCount; $i++) {
|
||||
$values = array_column($samples, $i);
|
||||
$isCategorical = $this->isCategoricalColumn($values);
|
||||
$isCategorical = self::isCategoricalColumn($values);
|
||||
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS;
|
||||
}
|
||||
return $types;
|
||||
@ -327,13 +328,13 @@ class DecisionTree implements Classifier
|
||||
* @param array $columnValues
|
||||
* @return bool
|
||||
*/
|
||||
protected function isCategoricalColumn(array $columnValues)
|
||||
protected static function isCategoricalColumn(array $columnValues)
|
||||
{
|
||||
$count = count($columnValues);
|
||||
|
||||
// There are two main indicators that *may* show whether a
|
||||
// column is composed of discrete set of values:
|
||||
// 1- Column may contain string values and not float values
|
||||
// 1- Column may contain string values and non-float values
|
||||
// 2- Number of unique values in the column is only a small fraction of
|
||||
// all values in that column (Lower than or equal to %20 of all values)
|
||||
$numericValues = array_filter($columnValues, 'is_numeric');
|
||||
|
@ -5,6 +5,9 @@ declare(strict_types=1);
|
||||
namespace Phpml\Classification\Ensemble;
|
||||
|
||||
use Phpml\Classification\Linear\DecisionStump;
|
||||
use Phpml\Classification\WeightedClassifier;
|
||||
use Phpml\Math\Statistic\Mean;
|
||||
use Phpml\Math\Statistic\StandardDeviation;
|
||||
use Phpml\Classification\Classifier;
|
||||
use Phpml\Helper\Predictable;
|
||||
use Phpml\Helper\Trainable;
|
||||
@ -44,7 +47,7 @@ class AdaBoost implements Classifier
|
||||
protected $weights = [];
|
||||
|
||||
/**
|
||||
* Base classifiers
|
||||
* List of selected 'weak' classifiers
|
||||
*
|
||||
* @var array
|
||||
*/
|
||||
@ -57,17 +60,39 @@ class AdaBoost implements Classifier
|
||||
*/
|
||||
protected $alpha = [];
|
||||
|
||||
/**
|
||||
* @var string
|
||||
*/
|
||||
protected $baseClassifier = DecisionStump::class;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $classifierOptions = [];
|
||||
|
||||
/**
|
||||
* ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
|
||||
* improve classification performance of 'weak' classifiers such as
|
||||
* DecisionStump (default base classifier of AdaBoost).
|
||||
*
|
||||
*/
|
||||
public function __construct(int $maxIterations = 30)
|
||||
public function __construct(int $maxIterations = 50)
|
||||
{
|
||||
$this->maxIterations = $maxIterations;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the base classifier that will be used for boosting (default = DecisionStump)
|
||||
*
|
||||
* @param string $baseClassifier
|
||||
* @param array $classifierOptions
|
||||
*/
|
||||
public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = [])
|
||||
{
|
||||
$this->baseClassifier = $baseClassifier;
|
||||
$this->classifierOptions = $classifierOptions;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
@ -77,7 +102,7 @@ class AdaBoost implements Classifier
|
||||
// Initialize usual variables
|
||||
$this->labels = array_keys(array_count_values($targets));
|
||||
if (count($this->labels) != 2) {
|
||||
throw new \Exception("AdaBoost is a binary classifier and can only classify between two classes");
|
||||
throw new \Exception("AdaBoost is a binary classifier and can classify between two classes only");
|
||||
}
|
||||
|
||||
// Set all target values to either -1 or 1
|
||||
@ -98,9 +123,12 @@ class AdaBoost implements Classifier
|
||||
// Execute the algorithm for a maximum number of iterations
|
||||
$currIter = 0;
|
||||
while ($this->maxIterations > $currIter++) {
|
||||
|
||||
// Determine the best 'weak' classifier based on current weights
|
||||
// and update alpha & weight values at each iteration
|
||||
list($classifier, $errorRate) = $this->getBestClassifier();
|
||||
$classifier = $this->getBestClassifier();
|
||||
$errorRate = $this->evaluateClassifier($classifier);
|
||||
|
||||
// Update alpha & weight values at each iteration
|
||||
$alpha = $this->calculateAlpha($errorRate);
|
||||
$this->updateWeights($classifier, $alpha);
|
||||
|
||||
@ -117,24 +145,71 @@ class AdaBoost implements Classifier
|
||||
*/
|
||||
protected function getBestClassifier()
|
||||
{
|
||||
// This method works only for "DecisionStump" classifier, for now.
|
||||
// As a future task, it will be generalized enough to work with other
|
||||
// classifiers as well
|
||||
$minErrorRate = 1.0;
|
||||
$bestClassifier = null;
|
||||
for ($i=0; $i < $this->featureCount; $i++) {
|
||||
$stump = new DecisionStump($i);
|
||||
$stump->setSampleWeights($this->weights);
|
||||
$stump->train($this->samples, $this->targets);
|
||||
$ref = new \ReflectionClass($this->baseClassifier);
|
||||
if ($this->classifierOptions) {
|
||||
$classifier = $ref->newInstanceArgs($this->classifierOptions);
|
||||
} else {
|
||||
$classifier = $ref->newInstance();
|
||||
}
|
||||
|
||||
$errorRate = $stump->getTrainingErrorRate();
|
||||
if ($errorRate < $minErrorRate) {
|
||||
$bestClassifier = $stump;
|
||||
$minErrorRate = $errorRate;
|
||||
if (is_subclass_of($classifier, WeightedClassifier::class)) {
|
||||
$classifier->setSampleWeights($this->weights);
|
||||
$classifier->train($this->samples, $this->targets);
|
||||
} else {
|
||||
list($samples, $targets) = $this->resample();
|
||||
$classifier->train($samples, $targets);
|
||||
}
|
||||
|
||||
return $classifier;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resamples the dataset in accordance with the weights and
|
||||
* returns the new dataset
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
protected function resample()
|
||||
{
|
||||
$weights = $this->weights;
|
||||
$std = StandardDeviation::population($weights);
|
||||
$mean= Mean::arithmetic($weights);
|
||||
$min = min($weights);
|
||||
$minZ= (int)round(($min - $mean) / $std);
|
||||
|
||||
$samples = [];
|
||||
$targets = [];
|
||||
foreach ($weights as $index => $weight) {
|
||||
$z = (int)round(($weight - $mean) / $std) - $minZ + 1;
|
||||
for ($i=0; $i < $z; $i++) {
|
||||
if (rand(0, 1) == 0) {
|
||||
continue;
|
||||
}
|
||||
$samples[] = $this->samples[$index];
|
||||
$targets[] = $this->targets[$index];
|
||||
}
|
||||
}
|
||||
|
||||
return [$bestClassifier, $minErrorRate];
|
||||
return [$samples, $targets];
|
||||
}
|
||||
|
||||
/**
|
||||
* Evaluates the classifier and returns the classification error rate
|
||||
*
|
||||
* @param Classifier $classifier
|
||||
*/
|
||||
protected function evaluateClassifier(Classifier $classifier)
|
||||
{
|
||||
$total = (float) array_sum($this->weights);
|
||||
$wrong = 0;
|
||||
foreach ($this->samples as $index => $sample) {
|
||||
$predicted = $classifier->predict($sample);
|
||||
if ($predicted != $this->targets[$index]) {
|
||||
$wrong += $this->weights[$index];
|
||||
}
|
||||
}
|
||||
|
||||
return $wrong / $total;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -154,10 +229,10 @@ class AdaBoost implements Classifier
|
||||
/**
|
||||
* Updates the sample weights
|
||||
*
|
||||
* @param DecisionStump $classifier
|
||||
* @param Classifier $classifier
|
||||
* @param float $alpha
|
||||
*/
|
||||
protected function updateWeights(DecisionStump $classifier, float $alpha)
|
||||
protected function updateWeights(Classifier $classifier, float $alpha)
|
||||
{
|
||||
$sumOfWeights = array_sum($this->weights);
|
||||
$weightsT1 = [];
|
||||
|
@ -76,10 +76,16 @@ class Adaline extends Perceptron
|
||||
// Batch learning is executed:
|
||||
$currIter = 0;
|
||||
while ($this->maxIterations > $currIter++) {
|
||||
$weights = $this->weights;
|
||||
|
||||
$outputs = array_map([$this, 'output'], $this->samples);
|
||||
$updates = array_map([$this, 'gradient'], $this->targets, $outputs);
|
||||
|
||||
$this->updateWeights($updates);
|
||||
|
||||
if ($this->earlyStop($weights)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,18 +6,19 @@ namespace Phpml\Classification\Linear;
|
||||
|
||||
use Phpml\Helper\Predictable;
|
||||
use Phpml\Helper\Trainable;
|
||||
use Phpml\Classification\Classifier;
|
||||
use Phpml\Classification\WeightedClassifier;
|
||||
use Phpml\Classification\DecisionTree;
|
||||
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
|
||||
|
||||
class DecisionStump extends DecisionTree
|
||||
class DecisionStump extends WeightedClassifier
|
||||
{
|
||||
use Trainable, Predictable;
|
||||
|
||||
const AUTO_SELECT = -1;
|
||||
|
||||
/**
|
||||
* @var int
|
||||
*/
|
||||
protected $columnIndex;
|
||||
protected $givenColumnIndex;
|
||||
|
||||
|
||||
/**
|
||||
@ -36,6 +37,31 @@ class DecisionStump extends DecisionTree
|
||||
*/
|
||||
protected $trainingErrorRate;
|
||||
|
||||
/**
|
||||
* @var int
|
||||
*/
|
||||
protected $column;
|
||||
|
||||
/**
|
||||
* @var mixed
|
||||
*/
|
||||
protected $value;
|
||||
|
||||
/**
|
||||
* @var string
|
||||
*/
|
||||
protected $operator;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $columnTypes;
|
||||
|
||||
/**
|
||||
* @var float
|
||||
*/
|
||||
protected $numSplitCount = 10.0;
|
||||
|
||||
/**
|
||||
* A DecisionStump classifier is a one-level deep DecisionTree. It is generally
|
||||
* used with ensemble algorithms as in the weak classifier role. <br>
|
||||
@ -46,11 +72,9 @@ class DecisionStump extends DecisionTree
|
||||
*
|
||||
* @param int $columnIndex
|
||||
*/
|
||||
public function __construct(int $columnIndex = -1)
|
||||
public function __construct(int $columnIndex = self::AUTO_SELECT)
|
||||
{
|
||||
$this->columnIndex = $columnIndex;
|
||||
|
||||
parent::__construct(1);
|
||||
$this->givenColumnIndex = $columnIndex;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -59,95 +83,167 @@ class DecisionStump extends DecisionTree
|
||||
*/
|
||||
public function train(array $samples, array $targets)
|
||||
{
|
||||
if ($this->columnIndex > count($samples[0]) - 1) {
|
||||
$this->columnIndex = -1;
|
||||
$this->samples = array_merge($this->samples, $samples);
|
||||
$this->targets = array_merge($this->targets, $targets);
|
||||
|
||||
// DecisionStump is capable of classifying between two classes only
|
||||
$labels = array_count_values($this->targets);
|
||||
$this->labels = array_keys($labels);
|
||||
if (count($this->labels) != 2) {
|
||||
throw new \Exception("DecisionStump can classify between two classes only:" . implode(',', $this->labels));
|
||||
}
|
||||
|
||||
if ($this->columnIndex >= 0) {
|
||||
$this->setSelectedFeatures([$this->columnIndex]);
|
||||
// If a column index is given, it should be among the existing columns
|
||||
if ($this->givenColumnIndex > count($samples[0]) - 1) {
|
||||
$this->givenColumnIndex = self::AUTO_SELECT;
|
||||
}
|
||||
|
||||
// Check the size of the weights given.
|
||||
// If none given, then assign 1 as a weight to each sample
|
||||
if ($this->weights) {
|
||||
$numWeights = count($this->weights);
|
||||
if ($numWeights != count($samples)) {
|
||||
if ($numWeights != count($this->samples)) {
|
||||
throw new \Exception("Number of sample weights does not match with number of samples");
|
||||
}
|
||||
} else {
|
||||
$this->weights = array_fill(0, count($samples), 1);
|
||||
}
|
||||
|
||||
parent::train($samples, $targets);
|
||||
// Determine type of each column as either "continuous" or "nominal"
|
||||
$this->columnTypes = DecisionTree::getColumnTypes($this->samples);
|
||||
|
||||
$this->columnIndex = $this->tree->columnIndex;
|
||||
// Try to find the best split in the columns of the dataset
|
||||
// by calculating error rate for each split point in each column
|
||||
$columns = range(0, count($samples[0]) - 1);
|
||||
if ($this->givenColumnIndex != self::AUTO_SELECT) {
|
||||
$columns = [$this->givenColumnIndex];
|
||||
}
|
||||
|
||||
// For numerical values, try to optimize the value by finding a different threshold value
|
||||
if ($this->columnTypes[$this->columnIndex] == self::CONTINUOS) {
|
||||
$this->optimizeDecision($samples, $targets);
|
||||
$bestSplit = [
|
||||
'value' => 0, 'operator' => '',
|
||||
'column' => 0, 'trainingErrorRate' => 1.0];
|
||||
foreach ($columns as $col) {
|
||||
if ($this->columnTypes[$col] == DecisionTree::CONTINUOS) {
|
||||
$split = $this->getBestNumericalSplit($col);
|
||||
} else {
|
||||
$split = $this->getBestNominalSplit($col);
|
||||
}
|
||||
|
||||
if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
|
||||
$bestSplit = $split;
|
||||
}
|
||||
}
|
||||
|
||||
// Assign determined best values to the stump
|
||||
foreach ($bestSplit as $name => $value) {
|
||||
$this->{$name} = $value;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Used to set sample weights.
|
||||
* While finding best split point for a numerical valued column,
|
||||
* DecisionStump looks for equally distanced values between minimum and maximum
|
||||
* values in the column. Given <i>$count</i> value determines how many split
|
||||
* points to be probed. The more split counts, the better performance but
|
||||
* worse processing time (Default value is 10.0)
|
||||
*
|
||||
* @param array $weights
|
||||
* @param float $count
|
||||
*/
|
||||
public function setSampleWeights(array $weights)
|
||||
public function setNumericalSplitCount(float $count)
|
||||
{
|
||||
$this->weights = $weights;
|
||||
$this->numSplitCount = $count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the training error rate, the proportion of wrong predictions
|
||||
* over the total number of samples
|
||||
* Determines best split point for the given column
|
||||
*
|
||||
* @return float
|
||||
*/
|
||||
public function getTrainingErrorRate()
|
||||
{
|
||||
return $this->trainingErrorRate;
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to optimize the threshold by probing a range of different values
|
||||
* between the minimum and maximum values in the selected column
|
||||
* @param int $col
|
||||
*
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
* @return array
|
||||
*/
|
||||
protected function optimizeDecision(array $samples, array $targets)
|
||||
protected function getBestNumericalSplit(int $col)
|
||||
{
|
||||
$values = array_column($samples, $this->columnIndex);
|
||||
$values = array_column($this->samples, $col);
|
||||
$minValue = min($values);
|
||||
$maxValue = max($values);
|
||||
$stepSize = ($maxValue - $minValue) / 100.0;
|
||||
$stepSize = ($maxValue - $minValue) / $this->numSplitCount;
|
||||
|
||||
$leftLabel = $this->tree->leftLeaf->classValue;
|
||||
$rightLabel= $this->tree->rightLeaf->classValue;
|
||||
|
||||
$bestOperator = $this->tree->operator;
|
||||
$bestThreshold = $this->tree->numericValue;
|
||||
$bestErrorRate = $this->calculateErrorRate(
|
||||
$bestThreshold, $bestOperator, $values, $targets, $leftLabel, $rightLabel);
|
||||
$split = null;
|
||||
|
||||
foreach (['<=', '>'] as $operator) {
|
||||
// Before trying all possible split points, let's first try
|
||||
// the average value for the cut point
|
||||
$threshold = array_sum($values) / (float) count($values);
|
||||
$errorRate = $this->calculateErrorRate($threshold, $operator, $values);
|
||||
if ($split == null || $errorRate < $split['trainingErrorRate']) {
|
||||
$split = ['value' => $threshold, 'operator' => $operator,
|
||||
'column' => $col, 'trainingErrorRate' => $errorRate];
|
||||
}
|
||||
|
||||
// Try other possible points one by one
|
||||
for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) {
|
||||
$threshold = (float)$step;
|
||||
$errorRate = $this->calculateErrorRate(
|
||||
$threshold, $operator, $values, $targets, $leftLabel, $rightLabel);
|
||||
|
||||
if ($errorRate < $bestErrorRate) {
|
||||
$bestErrorRate = $errorRate;
|
||||
$bestThreshold = $threshold;
|
||||
$bestOperator = $operator;
|
||||
$errorRate = $this->calculateErrorRate($threshold, $operator, $values);
|
||||
if ($errorRate < $split['trainingErrorRate']) {
|
||||
$split = ['value' => $threshold, 'operator' => $operator,
|
||||
'column' => $col, 'trainingErrorRate' => $errorRate];
|
||||
}
|
||||
}// for
|
||||
}
|
||||
|
||||
// Update the tree node value
|
||||
$this->tree->numericValue = $bestThreshold;
|
||||
$this->tree->operator = $bestOperator;
|
||||
$this->tree->value = "$bestOperator $bestThreshold";
|
||||
$this->trainingErrorRate = $bestErrorRate;
|
||||
return $split;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param int $col
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
protected function getBestNominalSplit(int $col)
|
||||
{
|
||||
$values = array_column($this->samples, $col);
|
||||
$valueCounts = array_count_values($values);
|
||||
$distinctVals= array_keys($valueCounts);
|
||||
|
||||
$split = null;
|
||||
|
||||
foreach (['=', '!='] as $operator) {
|
||||
foreach ($distinctVals as $val) {
|
||||
$errorRate = $this->calculateErrorRate($val, $operator, $values);
|
||||
|
||||
if ($split == null || $split['trainingErrorRate'] < $errorRate) {
|
||||
$split = ['value' => $val, 'operator' => $operator,
|
||||
'column' => $col, 'trainingErrorRate' => $errorRate];
|
||||
}
|
||||
}// for
|
||||
}
|
||||
|
||||
return $split;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param type $leftValue
|
||||
* @param type $operator
|
||||
* @param type $rightValue
|
||||
*
|
||||
* @return boolean
|
||||
*/
|
||||
protected function evaluate($leftValue, $operator, $rightValue)
|
||||
{
|
||||
switch ($operator) {
|
||||
case '>': return $leftValue > $rightValue;
|
||||
case '>=': return $leftValue >= $rightValue;
|
||||
case '<': return $leftValue < $rightValue;
|
||||
case '<=': return $leftValue <= $rightValue;
|
||||
case '=': return $leftValue == $rightValue;
|
||||
case '!=':
|
||||
case '<>': return $leftValue != $rightValue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -157,23 +253,42 @@ class DecisionStump extends DecisionTree
|
||||
* @param float $threshold
|
||||
* @param string $operator
|
||||
* @param array $values
|
||||
* @param array $targets
|
||||
* @param mixed $leftLabel
|
||||
* @param mixed $rightLabel
|
||||
*/
|
||||
protected function calculateErrorRate(float $threshold, string $operator, array $values, array $targets, $leftLabel, $rightLabel)
|
||||
protected function calculateErrorRate(float $threshold, string $operator, array $values)
|
||||
{
|
||||
$total = (float) array_sum($this->weights);
|
||||
$wrong = 0;
|
||||
|
||||
$wrong = 0.0;
|
||||
$leftLabel = $this->labels[0];
|
||||
$rightLabel= $this->labels[1];
|
||||
foreach ($values as $index => $value) {
|
||||
eval("\$predicted = \$value $operator \$threshold ? \$leftLabel : \$rightLabel;");
|
||||
if ($this->evaluate($threshold, $operator, $value)) {
|
||||
$predicted = $leftLabel;
|
||||
} else {
|
||||
$predicted = $rightLabel;
|
||||
}
|
||||
|
||||
if ($predicted != $targets[$index]) {
|
||||
if ($predicted != $this->targets[$index]) {
|
||||
$wrong += $this->weights[$index];
|
||||
}
|
||||
}
|
||||
|
||||
return $wrong / $total;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $sample
|
||||
* @return mixed
|
||||
*/
|
||||
protected function predictSample(array $sample)
|
||||
{
|
||||
if ($this->evaluate($this->value, $this->operator, $sample[$this->column])) {
|
||||
return $this->labels[0];
|
||||
}
|
||||
return $this->labels[1];
|
||||
}
|
||||
|
||||
public function __toString()
|
||||
{
|
||||
return "$this->column $this->operator $this->value";
|
||||
}
|
||||
}
|
||||
|
@ -61,6 +61,14 @@ class Perceptron implements Classifier
|
||||
*/
|
||||
protected $normalizer;
|
||||
|
||||
/**
|
||||
* Minimum amount of change in the weights between iterations
|
||||
* that needs to be obtained to continue the training
|
||||
*
|
||||
* @var float
|
||||
*/
|
||||
protected $threshold = 1e-5;
|
||||
|
||||
/**
|
||||
* Initalize a perceptron classifier with given learning rate and maximum
|
||||
* number of iterations used while training the perceptron <br>
|
||||
@ -89,6 +97,20 @@ class Perceptron implements Classifier
|
||||
$this->maxIterations = $maxIterations;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets minimum value for the change in the weights
|
||||
* between iterations to continue the iterations.<br>
|
||||
*
|
||||
* If the weight change is less than given value then the
|
||||
* algorithm will stop training
|
||||
*
|
||||
* @param float $threshold
|
||||
*/
|
||||
public function setChangeThreshold(float $threshold = 1e-5)
|
||||
{
|
||||
$this->threshold = $threshold;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
@ -97,7 +119,7 @@ class Perceptron implements Classifier
|
||||
{
|
||||
$this->labels = array_keys(array_count_values($targets));
|
||||
if (count($this->labels) > 2) {
|
||||
throw new \Exception("Perceptron is for only binary (two-class) classification");
|
||||
throw new \Exception("Perceptron is for binary (two-class) classification only");
|
||||
}
|
||||
|
||||
if ($this->normalizer) {
|
||||
@ -130,11 +152,20 @@ class Perceptron implements Classifier
|
||||
protected function runTraining()
|
||||
{
|
||||
$currIter = 0;
|
||||
$bestWeights = null;
|
||||
$bestScore = count($this->samples);
|
||||
$bestWeightIter = 0;
|
||||
|
||||
while ($this->maxIterations > $currIter++) {
|
||||
$weights = $this->weights;
|
||||
$misClassified = 0;
|
||||
foreach ($this->samples as $index => $sample) {
|
||||
$target = $this->targets[$index];
|
||||
$prediction = $this->{static::$errorFunction}($sample);
|
||||
$update = $target - $prediction;
|
||||
if ($target != $prediction) {
|
||||
$misClassified++;
|
||||
}
|
||||
// Update bias
|
||||
$this->weights[0] += $update * $this->learningRate; // Bias
|
||||
// Update other weights
|
||||
@ -142,7 +173,45 @@ class Perceptron implements Classifier
|
||||
$this->weights[$i] += $update * $sample[$i - 1] * $this->learningRate;
|
||||
}
|
||||
}
|
||||
|
||||
// Save the best weights in the "pocket" so that
|
||||
// any future weights worse than this will be disregarded
|
||||
if ($bestWeights == null || $misClassified <= $bestScore) {
|
||||
$bestWeights = $weights;
|
||||
$bestScore = $misClassified;
|
||||
$bestWeightIter = $currIter;
|
||||
}
|
||||
|
||||
// Check for early stop
|
||||
if ($this->earlyStop($weights)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// The weights in the pocket are better than or equal to the last state
|
||||
// so, we use these weights
|
||||
$this->weights = $bestWeights;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $oldWeights
|
||||
*
|
||||
* @return boolean
|
||||
*/
|
||||
protected function earlyStop($oldWeights)
|
||||
{
|
||||
// Check for early stop: No change larger than 1e-5
|
||||
$diff = array_map(
|
||||
function ($w1, $w2) {
|
||||
return abs($w1 - $w2) > 1e-5 ? 1 : 0;
|
||||
},
|
||||
$oldWeights, $this->weights);
|
||||
|
||||
if (array_sum($diff) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
|
20
src/Phpml/Classification/WeightedClassifier.php
Normal file
20
src/Phpml/Classification/WeightedClassifier.php
Normal file
@ -0,0 +1,20 @@
|
||||
<?php declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Classification;
|
||||
|
||||
use Phpml\Classification\Classifier;
|
||||
|
||||
abstract class WeightedClassifier implements Classifier
|
||||
{
|
||||
protected $weights = null;
|
||||
|
||||
/**
|
||||
* Sets the array including a weight for each sample
|
||||
*
|
||||
* @param array $weights
|
||||
*/
|
||||
public function setSampleWeights(array $weights)
|
||||
{
|
||||
$this->weights = $weights;
|
||||
}
|
||||
}
|
@ -13,20 +13,20 @@ class PerceptronTest extends TestCase
|
||||
public function testPredictSingleSample()
|
||||
{
|
||||
// AND problem
|
||||
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.9, 0.8]];
|
||||
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.6, 0.6]];
|
||||
$targets = [0, 0, 0, 1, 1];
|
||||
$classifier = new Perceptron(0.001, 5000);
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(0, $classifier->predict([0.1, 0.2]));
|
||||
$this->assertEquals(0, $classifier->predict([0.1, 0.99]));
|
||||
$this->assertEquals(0, $classifier->predict([0, 1]));
|
||||
$this->assertEquals(1, $classifier->predict([1.1, 0.8]));
|
||||
|
||||
// OR problem
|
||||
$samples = [[0, 0], [0.1, 0.2], [1, 0], [0, 1], [1, 1]];
|
||||
$targets = [0, 0, 1, 1, 1];
|
||||
$classifier = new Perceptron(0.001, 5000);
|
||||
$samples = [[0.1, 0.1], [0.4, 0.], [0., 0.3], [1, 0], [0, 1], [1, 1]];
|
||||
$targets = [0, 0, 0, 1, 1, 1];
|
||||
$classifier = new Perceptron(0.001, 5000, false);
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(0, $classifier->predict([0, 0]));
|
||||
$this->assertEquals(0, $classifier->predict([0., 0.]));
|
||||
$this->assertEquals(1, $classifier->predict([0.1, 0.99]));
|
||||
$this->assertEquals(1, $classifier->predict([1.1, 0.8]));
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user