AdaBoost improvements (#53)

* AdaBoost improvements

* AdaBoost improvements & test case resolved

* Some coding style fixes
This commit is contained in:
Mustafa Karabulut 2017-02-28 23:45:18 +03:00 committed by Arkadiusz Kondas
parent e8c6005aec
commit c028a73985
7 changed files with 385 additions and 99 deletions

View File

@ -110,12 +110,13 @@ class DecisionTree implements Classifier
} }
} }
protected function getColumnTypes(array $samples) public static function getColumnTypes(array $samples)
{ {
$types = []; $types = [];
for ($i=0; $i<$this->featureCount; $i++) { $featureCount = count($samples[0]);
for ($i=0; $i < $featureCount; $i++) {
$values = array_column($samples, $i); $values = array_column($samples, $i);
$isCategorical = $this->isCategoricalColumn($values); $isCategorical = self::isCategoricalColumn($values);
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS; $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS;
} }
return $types; return $types;
@ -327,13 +328,13 @@ class DecisionTree implements Classifier
* @param array $columnValues * @param array $columnValues
* @return bool * @return bool
*/ */
protected function isCategoricalColumn(array $columnValues) protected static function isCategoricalColumn(array $columnValues)
{ {
$count = count($columnValues); $count = count($columnValues);
// There are two main indicators that *may* show whether a // There are two main indicators that *may* show whether a
// column is composed of discrete set of values: // column is composed of discrete set of values:
// 1- Column may contain string values and not float values // 1- Column may contain string values and non-float values
// 2- Number of unique values in the column is only a small fraction of // 2- Number of unique values in the column is only a small fraction of
// all values in that column (Lower than or equal to %20 of all values) // all values in that column (Lower than or equal to %20 of all values)
$numericValues = array_filter($columnValues, 'is_numeric'); $numericValues = array_filter($columnValues, 'is_numeric');

View File

@ -5,6 +5,9 @@ declare(strict_types=1);
namespace Phpml\Classification\Ensemble; namespace Phpml\Classification\Ensemble;
use Phpml\Classification\Linear\DecisionStump; use Phpml\Classification\Linear\DecisionStump;
use Phpml\Classification\WeightedClassifier;
use Phpml\Math\Statistic\Mean;
use Phpml\Math\Statistic\StandardDeviation;
use Phpml\Classification\Classifier; use Phpml\Classification\Classifier;
use Phpml\Helper\Predictable; use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable; use Phpml\Helper\Trainable;
@ -44,7 +47,7 @@ class AdaBoost implements Classifier
protected $weights = []; protected $weights = [];
/** /**
* Base classifiers * List of selected 'weak' classifiers
* *
* @var array * @var array
*/ */
@ -57,17 +60,39 @@ class AdaBoost implements Classifier
*/ */
protected $alpha = []; protected $alpha = [];
/**
* @var string
*/
protected $baseClassifier = DecisionStump::class;
/**
* @var array
*/
protected $classifierOptions = [];
/** /**
* ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
* improve classification performance of 'weak' classifiers such as * improve classification performance of 'weak' classifiers such as
* DecisionStump (default base classifier of AdaBoost). * DecisionStump (default base classifier of AdaBoost).
* *
*/ */
public function __construct(int $maxIterations = 30) public function __construct(int $maxIterations = 50)
{ {
$this->maxIterations = $maxIterations; $this->maxIterations = $maxIterations;
} }
/**
* Sets the base classifier that will be used for boosting (default = DecisionStump)
*
* @param string $baseClassifier
* @param array $classifierOptions
*/
public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = [])
{
$this->baseClassifier = $baseClassifier;
$this->classifierOptions = $classifierOptions;
}
/** /**
* @param array $samples * @param array $samples
* @param array $targets * @param array $targets
@ -77,7 +102,7 @@ class AdaBoost implements Classifier
// Initialize usual variables // Initialize usual variables
$this->labels = array_keys(array_count_values($targets)); $this->labels = array_keys(array_count_values($targets));
if (count($this->labels) != 2) { if (count($this->labels) != 2) {
throw new \Exception("AdaBoost is a binary classifier and can only classify between two classes"); throw new \Exception("AdaBoost is a binary classifier and can classify between two classes only");
} }
// Set all target values to either -1 or 1 // Set all target values to either -1 or 1
@ -98,9 +123,12 @@ class AdaBoost implements Classifier
// Execute the algorithm for a maximum number of iterations // Execute the algorithm for a maximum number of iterations
$currIter = 0; $currIter = 0;
while ($this->maxIterations > $currIter++) { while ($this->maxIterations > $currIter++) {
// Determine the best 'weak' classifier based on current weights // Determine the best 'weak' classifier based on current weights
// and update alpha & weight values at each iteration $classifier = $this->getBestClassifier();
list($classifier, $errorRate) = $this->getBestClassifier(); $errorRate = $this->evaluateClassifier($classifier);
// Update alpha & weight values at each iteration
$alpha = $this->calculateAlpha($errorRate); $alpha = $this->calculateAlpha($errorRate);
$this->updateWeights($classifier, $alpha); $this->updateWeights($classifier, $alpha);
@ -117,24 +145,71 @@ class AdaBoost implements Classifier
*/ */
protected function getBestClassifier() protected function getBestClassifier()
{ {
// This method works only for "DecisionStump" classifier, for now. $ref = new \ReflectionClass($this->baseClassifier);
// As a future task, it will be generalized enough to work with other if ($this->classifierOptions) {
// classifiers as well $classifier = $ref->newInstanceArgs($this->classifierOptions);
$minErrorRate = 1.0; } else {
$bestClassifier = null; $classifier = $ref->newInstance();
for ($i=0; $i < $this->featureCount; $i++) { }
$stump = new DecisionStump($i);
$stump->setSampleWeights($this->weights);
$stump->train($this->samples, $this->targets);
$errorRate = $stump->getTrainingErrorRate(); if (is_subclass_of($classifier, WeightedClassifier::class)) {
if ($errorRate < $minErrorRate) { $classifier->setSampleWeights($this->weights);
$bestClassifier = $stump; $classifier->train($this->samples, $this->targets);
$minErrorRate = $errorRate; } else {
list($samples, $targets) = $this->resample();
$classifier->train($samples, $targets);
}
return $classifier;
}
/**
* Resamples the dataset in accordance with the weights and
* returns the new dataset
*
* @return array
*/
protected function resample()
{
$weights = $this->weights;
$std = StandardDeviation::population($weights);
$mean= Mean::arithmetic($weights);
$min = min($weights);
$minZ= (int)round(($min - $mean) / $std);
$samples = [];
$targets = [];
foreach ($weights as $index => $weight) {
$z = (int)round(($weight - $mean) / $std) - $minZ + 1;
for ($i=0; $i < $z; $i++) {
if (rand(0, 1) == 0) {
continue;
}
$samples[] = $this->samples[$index];
$targets[] = $this->targets[$index];
} }
} }
return [$bestClassifier, $minErrorRate]; return [$samples, $targets];
}
/**
* Evaluates the classifier and returns the classification error rate
*
* @param Classifier $classifier
*/
protected function evaluateClassifier(Classifier $classifier)
{
$total = (float) array_sum($this->weights);
$wrong = 0;
foreach ($this->samples as $index => $sample) {
$predicted = $classifier->predict($sample);
if ($predicted != $this->targets[$index]) {
$wrong += $this->weights[$index];
}
}
return $wrong / $total;
} }
/** /**
@ -154,10 +229,10 @@ class AdaBoost implements Classifier
/** /**
* Updates the sample weights * Updates the sample weights
* *
* @param DecisionStump $classifier * @param Classifier $classifier
* @param float $alpha * @param float $alpha
*/ */
protected function updateWeights(DecisionStump $classifier, float $alpha) protected function updateWeights(Classifier $classifier, float $alpha)
{ {
$sumOfWeights = array_sum($this->weights); $sumOfWeights = array_sum($this->weights);
$weightsT1 = []; $weightsT1 = [];

View File

@ -76,10 +76,16 @@ class Adaline extends Perceptron
// Batch learning is executed: // Batch learning is executed:
$currIter = 0; $currIter = 0;
while ($this->maxIterations > $currIter++) { while ($this->maxIterations > $currIter++) {
$weights = $this->weights;
$outputs = array_map([$this, 'output'], $this->samples); $outputs = array_map([$this, 'output'], $this->samples);
$updates = array_map([$this, 'gradient'], $this->targets, $outputs); $updates = array_map([$this, 'gradient'], $this->targets, $outputs);
$this->updateWeights($updates); $this->updateWeights($updates);
if ($this->earlyStop($weights)) {
break;
}
} }
} }

View File

@ -6,18 +6,19 @@ namespace Phpml\Classification\Linear;
use Phpml\Helper\Predictable; use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable; use Phpml\Helper\Trainable;
use Phpml\Classification\Classifier; use Phpml\Classification\WeightedClassifier;
use Phpml\Classification\DecisionTree; use Phpml\Classification\DecisionTree;
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
class DecisionStump extends DecisionTree class DecisionStump extends WeightedClassifier
{ {
use Trainable, Predictable; use Trainable, Predictable;
const AUTO_SELECT = -1;
/** /**
* @var int * @var int
*/ */
protected $columnIndex; protected $givenColumnIndex;
/** /**
@ -36,6 +37,31 @@ class DecisionStump extends DecisionTree
*/ */
protected $trainingErrorRate; protected $trainingErrorRate;
/**
* @var int
*/
protected $column;
/**
* @var mixed
*/
protected $value;
/**
* @var string
*/
protected $operator;
/**
* @var array
*/
protected $columnTypes;
/**
* @var float
*/
protected $numSplitCount = 10.0;
/** /**
* A DecisionStump classifier is a one-level deep DecisionTree. It is generally * A DecisionStump classifier is a one-level deep DecisionTree. It is generally
* used with ensemble algorithms as in the weak classifier role. <br> * used with ensemble algorithms as in the weak classifier role. <br>
@ -46,11 +72,9 @@ class DecisionStump extends DecisionTree
* *
* @param int $columnIndex * @param int $columnIndex
*/ */
public function __construct(int $columnIndex = -1) public function __construct(int $columnIndex = self::AUTO_SELECT)
{ {
$this->columnIndex = $columnIndex; $this->givenColumnIndex = $columnIndex;
parent::__construct(1);
} }
/** /**
@ -59,95 +83,167 @@ class DecisionStump extends DecisionTree
*/ */
public function train(array $samples, array $targets) public function train(array $samples, array $targets)
{ {
if ($this->columnIndex > count($samples[0]) - 1) { $this->samples = array_merge($this->samples, $samples);
$this->columnIndex = -1; $this->targets = array_merge($this->targets, $targets);
// DecisionStump is capable of classifying between two classes only
$labels = array_count_values($this->targets);
$this->labels = array_keys($labels);
if (count($this->labels) != 2) {
throw new \Exception("DecisionStump can classify between two classes only:" . implode(',', $this->labels));
} }
if ($this->columnIndex >= 0) { // If a column index is given, it should be among the existing columns
$this->setSelectedFeatures([$this->columnIndex]); if ($this->givenColumnIndex > count($samples[0]) - 1) {
$this->givenColumnIndex = self::AUTO_SELECT;
} }
// Check the size of the weights given.
// If none given, then assign 1 as a weight to each sample
if ($this->weights) { if ($this->weights) {
$numWeights = count($this->weights); $numWeights = count($this->weights);
if ($numWeights != count($samples)) { if ($numWeights != count($this->samples)) {
throw new \Exception("Number of sample weights does not match with number of samples"); throw new \Exception("Number of sample weights does not match with number of samples");
} }
} else { } else {
$this->weights = array_fill(0, count($samples), 1); $this->weights = array_fill(0, count($samples), 1);
} }
parent::train($samples, $targets); // Determine type of each column as either "continuous" or "nominal"
$this->columnTypes = DecisionTree::getColumnTypes($this->samples);
$this->columnIndex = $this->tree->columnIndex; // Try to find the best split in the columns of the dataset
// by calculating error rate for each split point in each column
$columns = range(0, count($samples[0]) - 1);
if ($this->givenColumnIndex != self::AUTO_SELECT) {
$columns = [$this->givenColumnIndex];
}
// For numerical values, try to optimize the value by finding a different threshold value $bestSplit = [
if ($this->columnTypes[$this->columnIndex] == self::CONTINUOS) { 'value' => 0, 'operator' => '',
$this->optimizeDecision($samples, $targets); 'column' => 0, 'trainingErrorRate' => 1.0];
foreach ($columns as $col) {
if ($this->columnTypes[$col] == DecisionTree::CONTINUOS) {
$split = $this->getBestNumericalSplit($col);
} else {
$split = $this->getBestNominalSplit($col);
}
if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
$bestSplit = $split;
}
}
// Assign determined best values to the stump
foreach ($bestSplit as $name => $value) {
$this->{$name} = $value;
} }
} }
/** /**
* Used to set sample weights. * While finding best split point for a numerical valued column,
* DecisionStump looks for equally distanced values between minimum and maximum
* values in the column. Given <i>$count</i> value determines how many split
* points to be probed. The more split counts, the better performance but
* worse processing time (Default value is 10.0)
* *
* @param array $weights * @param float $count
*/ */
public function setSampleWeights(array $weights) public function setNumericalSplitCount(float $count)
{ {
$this->weights = $weights; $this->numSplitCount = $count;
} }
/** /**
* Returns the training error rate, the proportion of wrong predictions * Determines best split point for the given column
* over the total number of samples
* *
* @return float * @param int $col
*/
public function getTrainingErrorRate()
{
return $this->trainingErrorRate;
}
/**
* Tries to optimize the threshold by probing a range of different values
* between the minimum and maximum values in the selected column
* *
* @param array $samples * @return array
* @param array $targets
*/ */
protected function optimizeDecision(array $samples, array $targets) protected function getBestNumericalSplit(int $col)
{ {
$values = array_column($samples, $this->columnIndex); $values = array_column($this->samples, $col);
$minValue = min($values); $minValue = min($values);
$maxValue = max($values); $maxValue = max($values);
$stepSize = ($maxValue - $minValue) / 100.0; $stepSize = ($maxValue - $minValue) / $this->numSplitCount;
$leftLabel = $this->tree->leftLeaf->classValue; $split = null;
$rightLabel= $this->tree->rightLeaf->classValue;
$bestOperator = $this->tree->operator;
$bestThreshold = $this->tree->numericValue;
$bestErrorRate = $this->calculateErrorRate(
$bestThreshold, $bestOperator, $values, $targets, $leftLabel, $rightLabel);
foreach (['<=', '>'] as $operator) { foreach (['<=', '>'] as $operator) {
// Before trying all possible split points, let's first try
// the average value for the cut point
$threshold = array_sum($values) / (float) count($values);
$errorRate = $this->calculateErrorRate($threshold, $operator, $values);
if ($split == null || $errorRate < $split['trainingErrorRate']) {
$split = ['value' => $threshold, 'operator' => $operator,
'column' => $col, 'trainingErrorRate' => $errorRate];
}
// Try other possible points one by one
for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) { for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) {
$threshold = (float)$step; $threshold = (float)$step;
$errorRate = $this->calculateErrorRate( $errorRate = $this->calculateErrorRate($threshold, $operator, $values);
$threshold, $operator, $values, $targets, $leftLabel, $rightLabel); if ($errorRate < $split['trainingErrorRate']) {
$split = ['value' => $threshold, 'operator' => $operator,
if ($errorRate < $bestErrorRate) { 'column' => $col, 'trainingErrorRate' => $errorRate];
$bestErrorRate = $errorRate;
$bestThreshold = $threshold;
$bestOperator = $operator;
} }
}// for }// for
} }
// Update the tree node value return $split;
$this->tree->numericValue = $bestThreshold; }
$this->tree->operator = $bestOperator;
$this->tree->value = "$bestOperator $bestThreshold"; /**
$this->trainingErrorRate = $bestErrorRate; *
* @param int $col
*
* @return array
*/
protected function getBestNominalSplit(int $col)
{
$values = array_column($this->samples, $col);
$valueCounts = array_count_values($values);
$distinctVals= array_keys($valueCounts);
$split = null;
foreach (['=', '!='] as $operator) {
foreach ($distinctVals as $val) {
$errorRate = $this->calculateErrorRate($val, $operator, $values);
if ($split == null || $split['trainingErrorRate'] < $errorRate) {
$split = ['value' => $val, 'operator' => $operator,
'column' => $col, 'trainingErrorRate' => $errorRate];
}
}// for
}
return $split;
}
/**
*
* @param type $leftValue
* @param type $operator
* @param type $rightValue
*
* @return boolean
*/
protected function evaluate($leftValue, $operator, $rightValue)
{
switch ($operator) {
case '>': return $leftValue > $rightValue;
case '>=': return $leftValue >= $rightValue;
case '<': return $leftValue < $rightValue;
case '<=': return $leftValue <= $rightValue;
case '=': return $leftValue == $rightValue;
case '!=':
case '<>': return $leftValue != $rightValue;
}
return false;
} }
/** /**
@ -157,23 +253,42 @@ class DecisionStump extends DecisionTree
* @param float $threshold * @param float $threshold
* @param string $operator * @param string $operator
* @param array $values * @param array $values
* @param array $targets
* @param mixed $leftLabel
* @param mixed $rightLabel
*/ */
protected function calculateErrorRate(float $threshold, string $operator, array $values, array $targets, $leftLabel, $rightLabel) protected function calculateErrorRate(float $threshold, string $operator, array $values)
{ {
$total = (float) array_sum($this->weights); $total = (float) array_sum($this->weights);
$wrong = 0; $wrong = 0.0;
$leftLabel = $this->labels[0];
$rightLabel= $this->labels[1];
foreach ($values as $index => $value) { foreach ($values as $index => $value) {
eval("\$predicted = \$value $operator \$threshold ? \$leftLabel : \$rightLabel;"); if ($this->evaluate($threshold, $operator, $value)) {
$predicted = $leftLabel;
} else {
$predicted = $rightLabel;
}
if ($predicted != $targets[$index]) { if ($predicted != $this->targets[$index]) {
$wrong += $this->weights[$index]; $wrong += $this->weights[$index];
} }
} }
return $wrong / $total; return $wrong / $total;
} }
/**
* @param array $sample
* @return mixed
*/
protected function predictSample(array $sample)
{
if ($this->evaluate($this->value, $this->operator, $sample[$this->column])) {
return $this->labels[0];
}
return $this->labels[1];
}
public function __toString()
{
return "$this->column $this->operator $this->value";
}
} }

View File

@ -61,6 +61,14 @@ class Perceptron implements Classifier
*/ */
protected $normalizer; protected $normalizer;
/**
* Minimum amount of change in the weights between iterations
* that needs to be obtained to continue the training
*
* @var float
*/
protected $threshold = 1e-5;
/** /**
* Initalize a perceptron classifier with given learning rate and maximum * Initalize a perceptron classifier with given learning rate and maximum
* number of iterations used while training the perceptron <br> * number of iterations used while training the perceptron <br>
@ -89,6 +97,20 @@ class Perceptron implements Classifier
$this->maxIterations = $maxIterations; $this->maxIterations = $maxIterations;
} }
/**
* Sets minimum value for the change in the weights
* between iterations to continue the iterations.<br>
*
* If the weight change is less than given value then the
* algorithm will stop training
*
* @param float $threshold
*/
public function setChangeThreshold(float $threshold = 1e-5)
{
$this->threshold = $threshold;
}
/** /**
* @param array $samples * @param array $samples
* @param array $targets * @param array $targets
@ -97,7 +119,7 @@ class Perceptron implements Classifier
{ {
$this->labels = array_keys(array_count_values($targets)); $this->labels = array_keys(array_count_values($targets));
if (count($this->labels) > 2) { if (count($this->labels) > 2) {
throw new \Exception("Perceptron is for only binary (two-class) classification"); throw new \Exception("Perceptron is for binary (two-class) classification only");
} }
if ($this->normalizer) { if ($this->normalizer) {
@ -130,11 +152,20 @@ class Perceptron implements Classifier
protected function runTraining() protected function runTraining()
{ {
$currIter = 0; $currIter = 0;
$bestWeights = null;
$bestScore = count($this->samples);
$bestWeightIter = 0;
while ($this->maxIterations > $currIter++) { while ($this->maxIterations > $currIter++) {
$weights = $this->weights;
$misClassified = 0;
foreach ($this->samples as $index => $sample) { foreach ($this->samples as $index => $sample) {
$target = $this->targets[$index]; $target = $this->targets[$index];
$prediction = $this->{static::$errorFunction}($sample); $prediction = $this->{static::$errorFunction}($sample);
$update = $target - $prediction; $update = $target - $prediction;
if ($target != $prediction) {
$misClassified++;
}
// Update bias // Update bias
$this->weights[0] += $update * $this->learningRate; // Bias $this->weights[0] += $update * $this->learningRate; // Bias
// Update other weights // Update other weights
@ -142,7 +173,45 @@ class Perceptron implements Classifier
$this->weights[$i] += $update * $sample[$i - 1] * $this->learningRate; $this->weights[$i] += $update * $sample[$i - 1] * $this->learningRate;
} }
} }
// Save the best weights in the "pocket" so that
// any future weights worse than this will be disregarded
if ($bestWeights == null || $misClassified <= $bestScore) {
$bestWeights = $weights;
$bestScore = $misClassified;
$bestWeightIter = $currIter;
}
// Check for early stop
if ($this->earlyStop($weights)) {
break;
}
} }
// The weights in the pocket are better than or equal to the last state
// so, we use these weights
$this->weights = $bestWeights;
}
/**
* @param array $oldWeights
*
* @return boolean
*/
protected function earlyStop($oldWeights)
{
// Check for early stop: No change larger than 1e-5
$diff = array_map(
function ($w1, $w2) {
return abs($w1 - $w2) > 1e-5 ? 1 : 0;
},
$oldWeights, $this->weights);
if (array_sum($diff) == 0) {
return true;
}
return false;
} }
/** /**

View File

@ -0,0 +1,20 @@
<?php declare(strict_types=1);
namespace Phpml\Classification;
use Phpml\Classification\Classifier;
abstract class WeightedClassifier implements Classifier
{
protected $weights = null;
/**
* Sets the array including a weight for each sample
*
* @param array $weights
*/
public function setSampleWeights(array $weights)
{
$this->weights = $weights;
}
}

View File

@ -13,20 +13,20 @@ class PerceptronTest extends TestCase
public function testPredictSingleSample() public function testPredictSingleSample()
{ {
// AND problem // AND problem
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.9, 0.8]]; $samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 1]; $targets = [0, 0, 0, 1, 1];
$classifier = new Perceptron(0.001, 5000); $classifier = new Perceptron(0.001, 5000);
$classifier->train($samples, $targets); $classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.2])); $this->assertEquals(0, $classifier->predict([0.1, 0.2]));
$this->assertEquals(0, $classifier->predict([0.1, 0.99])); $this->assertEquals(0, $classifier->predict([0, 1]));
$this->assertEquals(1, $classifier->predict([1.1, 0.8])); $this->assertEquals(1, $classifier->predict([1.1, 0.8]));
// OR problem // OR problem
$samples = [[0, 0], [0.1, 0.2], [1, 0], [0, 1], [1, 1]]; $samples = [[0.1, 0.1], [0.4, 0.], [0., 0.3], [1, 0], [0, 1], [1, 1]];
$targets = [0, 0, 1, 1, 1]; $targets = [0, 0, 0, 1, 1, 1];
$classifier = new Perceptron(0.001, 5000); $classifier = new Perceptron(0.001, 5000, false);
$classifier->train($samples, $targets); $classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0, 0])); $this->assertEquals(0, $classifier->predict([0., 0.]));
$this->assertEquals(1, $classifier->predict([0.1, 0.99])); $this->assertEquals(1, $classifier->predict([0.1, 0.99]));
$this->assertEquals(1, $classifier->predict([1.1, 0.8])); $this->assertEquals(1, $classifier->predict([1.1, 0.8]));