mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-01-11 09:35:08 +00:00
Partial training base (#78)
* Cost values for multiclass OneVsRest uses * Partial training interface * Reduce linear classifiers memory usage * Testing partial training and isolated training * Partial trainer naming switched to incremental estimator Other changes according to review's feedback. * Clean optimization data once optimize is finished * Abstract resetBinary
This commit is contained in:
parent
c0463ae087
commit
e1854d44a2
@ -53,8 +53,11 @@ class Adaline extends Perceptron
|
||||
/**
|
||||
* Adapts the weights with respect to given samples and targets
|
||||
* by use of gradient descent learning rule
|
||||
*
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
protected function runTraining()
|
||||
protected function runTraining(array $samples, array $targets)
|
||||
{
|
||||
// The cost function is the sum of squares
|
||||
$callback = function ($weights, $sample, $target) {
|
||||
@ -69,6 +72,6 @@ class Adaline extends Perceptron
|
||||
|
||||
$isBatch = $this->trainingType == self::BATCH_TRAINING;
|
||||
|
||||
return parent::runGradientDescent($callback, $isBatch);
|
||||
return parent::runGradientDescent($samples, $targets, $callback, $isBatch);
|
||||
}
|
||||
}
|
||||
|
@ -89,15 +89,13 @@ class DecisionStump extends WeightedClassifier
|
||||
* @param array $targets
|
||||
* @throws \Exception
|
||||
*/
|
||||
protected function trainBinary(array $samples, array $targets)
|
||||
protected function trainBinary(array $samples, array $targets, array $labels)
|
||||
{
|
||||
$this->samples = array_merge($this->samples, $samples);
|
||||
$this->targets = array_merge($this->targets, $targets);
|
||||
$this->binaryLabels = array_keys(array_count_values($this->targets));
|
||||
$this->featureCount = count($this->samples[0]);
|
||||
$this->binaryLabels = $labels;
|
||||
$this->featureCount = count($samples[0]);
|
||||
|
||||
// If a column index is given, it should be among the existing columns
|
||||
if ($this->givenColumnIndex > count($this->samples[0]) - 1) {
|
||||
if ($this->givenColumnIndex > count($samples[0]) - 1) {
|
||||
$this->givenColumnIndex = self::AUTO_SELECT;
|
||||
}
|
||||
|
||||
@ -105,19 +103,19 @@ class DecisionStump extends WeightedClassifier
|
||||
// If none given, then assign 1 as a weight to each sample
|
||||
if ($this->weights) {
|
||||
$numWeights = count($this->weights);
|
||||
if ($numWeights != count($this->samples)) {
|
||||
if ($numWeights != count($samples)) {
|
||||
throw new \Exception("Number of sample weights does not match with number of samples");
|
||||
}
|
||||
} else {
|
||||
$this->weights = array_fill(0, count($this->samples), 1);
|
||||
$this->weights = array_fill(0, count($samples), 1);
|
||||
}
|
||||
|
||||
// Determine type of each column as either "continuous" or "nominal"
|
||||
$this->columnTypes = DecisionTree::getColumnTypes($this->samples);
|
||||
$this->columnTypes = DecisionTree::getColumnTypes($samples);
|
||||
|
||||
// Try to find the best split in the columns of the dataset
|
||||
// by calculating error rate for each split point in each column
|
||||
$columns = range(0, count($this->samples[0]) - 1);
|
||||
$columns = range(0, count($samples[0]) - 1);
|
||||
if ($this->givenColumnIndex != self::AUTO_SELECT) {
|
||||
$columns = [$this->givenColumnIndex];
|
||||
}
|
||||
@ -128,9 +126,9 @@ class DecisionStump extends WeightedClassifier
|
||||
'trainingErrorRate' => 1.0];
|
||||
foreach ($columns as $col) {
|
||||
if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
|
||||
$split = $this->getBestNumericalSplit($col);
|
||||
$split = $this->getBestNumericalSplit($samples, $targets, $col);
|
||||
} else {
|
||||
$split = $this->getBestNominalSplit($col);
|
||||
$split = $this->getBestNominalSplit($samples, $targets, $col);
|
||||
}
|
||||
|
||||
if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
|
||||
@ -161,13 +159,15 @@ class DecisionStump extends WeightedClassifier
|
||||
/**
|
||||
* Determines best split point for the given column
|
||||
*
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
* @param int $col
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
protected function getBestNumericalSplit(int $col)
|
||||
protected function getBestNumericalSplit(array $samples, array $targets, int $col)
|
||||
{
|
||||
$values = array_column($this->samples, $col);
|
||||
$values = array_column($samples, $col);
|
||||
// Trying all possible points may be accomplished in two general ways:
|
||||
// 1- Try all values in the $samples array ($values)
|
||||
// 2- Artificially split the range of values into several parts and try them
|
||||
@ -182,7 +182,7 @@ class DecisionStump extends WeightedClassifier
|
||||
// Before trying all possible split points, let's first try
|
||||
// the average value for the cut point
|
||||
$threshold = array_sum($values) / (float) count($values);
|
||||
list($errorRate, $prob) = $this->calculateErrorRate($threshold, $operator, $values);
|
||||
list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values);
|
||||
if ($split == null || $errorRate < $split['trainingErrorRate']) {
|
||||
$split = ['value' => $threshold, 'operator' => $operator,
|
||||
'prob' => $prob, 'column' => $col,
|
||||
@ -192,7 +192,7 @@ class DecisionStump extends WeightedClassifier
|
||||
// Try other possible points one by one
|
||||
for ($step = $minValue; $step <= $maxValue; $step+= $stepSize) {
|
||||
$threshold = (float)$step;
|
||||
list($errorRate, $prob) = $this->calculateErrorRate($threshold, $operator, $values);
|
||||
list($errorRate, $prob) = $this->calculateErrorRate($targets, $threshold, $operator, $values);
|
||||
if ($errorRate < $split['trainingErrorRate']) {
|
||||
$split = ['value' => $threshold, 'operator' => $operator,
|
||||
'prob' => $prob, 'column' => $col,
|
||||
@ -205,13 +205,15 @@ class DecisionStump extends WeightedClassifier
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
* @param int $col
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
protected function getBestNominalSplit(int $col) : array
|
||||
protected function getBestNominalSplit(array $samples, array $targets, int $col) : array
|
||||
{
|
||||
$values = array_column($this->samples, $col);
|
||||
$values = array_column($samples, $col);
|
||||
$valueCounts = array_count_values($values);
|
||||
$distinctVals= array_keys($valueCounts);
|
||||
|
||||
@ -219,7 +221,7 @@ class DecisionStump extends WeightedClassifier
|
||||
|
||||
foreach (['=', '!='] as $operator) {
|
||||
foreach ($distinctVals as $val) {
|
||||
list($errorRate, $prob) = $this->calculateErrorRate($val, $operator, $values);
|
||||
list($errorRate, $prob) = $this->calculateErrorRate($targets, $val, $operator, $values);
|
||||
|
||||
if ($split == null || $split['trainingErrorRate'] < $errorRate) {
|
||||
$split = ['value' => $val, 'operator' => $operator,
|
||||
@ -260,13 +262,14 @@ class DecisionStump extends WeightedClassifier
|
||||
* Calculates the ratio of wrong predictions based on the new threshold
|
||||
* value given as the parameter
|
||||
*
|
||||
* @param array $targets
|
||||
* @param float $threshold
|
||||
* @param string $operator
|
||||
* @param array $values
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
protected function calculateErrorRate(float $threshold, string $operator, array $values) : array
|
||||
protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values) : array
|
||||
{
|
||||
$wrong = 0.0;
|
||||
$prob = [];
|
||||
@ -280,8 +283,8 @@ class DecisionStump extends WeightedClassifier
|
||||
$predicted = $rightLabel;
|
||||
}
|
||||
|
||||
$target = $this->targets[$index];
|
||||
if (strval($predicted) != strval($this->targets[$index])) {
|
||||
$target = $targets[$index];
|
||||
if (strval($predicted) != strval($targets[$index])) {
|
||||
$wrong += $this->weights[$index];
|
||||
}
|
||||
|
||||
@ -340,6 +343,13 @@ class DecisionStump extends WeightedClassifier
|
||||
return $this->binaryLabels[1];
|
||||
}
|
||||
|
||||
/**
|
||||
* @return void
|
||||
*/
|
||||
protected function resetBinary()
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* @return string
|
||||
*/
|
||||
|
@ -123,20 +123,23 @@ class LogisticRegression extends Adaline
|
||||
/**
|
||||
* Adapts the weights with respect to given samples and targets
|
||||
* by use of selected solver
|
||||
*
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
protected function runTraining()
|
||||
protected function runTraining(array $samples, array $targets)
|
||||
{
|
||||
$callback = $this->getCostFunction();
|
||||
|
||||
switch ($this->trainingType) {
|
||||
case self::BATCH_TRAINING:
|
||||
return $this->runGradientDescent($callback, true);
|
||||
return $this->runGradientDescent($samples, $targets, $callback, true);
|
||||
|
||||
case self::ONLINE_TRAINING:
|
||||
return $this->runGradientDescent($callback, false);
|
||||
return $this->runGradientDescent($samples, $targets, $callback, false);
|
||||
|
||||
case self::CONJUGATE_GRAD_TRAINING:
|
||||
return $this->runConjugateGradient($callback);
|
||||
return $this->runConjugateGradient($samples, $targets, $callback);
|
||||
}
|
||||
}
|
||||
|
||||
@ -144,13 +147,15 @@ class LogisticRegression extends Adaline
|
||||
* Executes Conjugate Gradient method to optimize the
|
||||
* weights of the LogReg model
|
||||
*/
|
||||
protected function runConjugateGradient(\Closure $gradientFunc)
|
||||
protected function runConjugateGradient(array $samples, array $targets, \Closure $gradientFunc)
|
||||
{
|
||||
$optimizer = (new ConjugateGradient($this->featureCount))
|
||||
->setMaxIterations($this->maxIterations);
|
||||
if (empty($this->optimizer)) {
|
||||
$this->optimizer = (new ConjugateGradient($this->featureCount))
|
||||
->setMaxIterations($this->maxIterations);
|
||||
}
|
||||
|
||||
$this->weights = $optimizer->runOptimization($this->samples, $this->targets, $gradientFunc);
|
||||
$this->costValues = $optimizer->getCostValues();
|
||||
$this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
|
||||
$this->costValues = $this->optimizer->getCostValues();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -10,20 +10,17 @@ use Phpml\Helper\Optimizer\StochasticGD;
|
||||
use Phpml\Helper\Optimizer\GD;
|
||||
use Phpml\Classification\Classifier;
|
||||
use Phpml\Preprocessing\Normalizer;
|
||||
use Phpml\IncrementalEstimator;
|
||||
use Phpml\Helper\PartiallyTrainable;
|
||||
|
||||
class Perceptron implements Classifier
|
||||
class Perceptron implements Classifier, IncrementalEstimator
|
||||
{
|
||||
use Predictable, OneVsRest;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $samples = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
* @var \Phpml\Helper\Optimizer\Optimizer
|
||||
*/
|
||||
protected $targets = [];
|
||||
protected $optimizer;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
@ -93,32 +90,47 @@ class Perceptron implements Classifier
|
||||
$this->maxIterations = $maxIterations;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
* @param array $labels
|
||||
*/
|
||||
public function partialTrain(array $samples, array $targets, array $labels = array())
|
||||
{
|
||||
return $this->trainByLabel($samples, $targets, $labels);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
* @param array $labels
|
||||
*/
|
||||
public function trainBinary(array $samples, array $targets)
|
||||
public function trainBinary(array $samples, array $targets, array $labels)
|
||||
{
|
||||
$this->labels = array_keys(array_count_values($targets));
|
||||
if (count($this->labels) > 2) {
|
||||
throw new \Exception("Perceptron is for binary (two-class) classification only");
|
||||
}
|
||||
|
||||
if ($this->normalizer) {
|
||||
$this->normalizer->transform($samples);
|
||||
}
|
||||
|
||||
// Set all target values to either -1 or 1
|
||||
$this->labels = [1 => $this->labels[0], -1 => $this->labels[1]];
|
||||
foreach ($targets as $target) {
|
||||
$this->targets[] = strval($target) == strval($this->labels[1]) ? 1 : -1;
|
||||
$this->labels = [1 => $labels[0], -1 => $labels[1]];
|
||||
foreach ($targets as $key => $target) {
|
||||
$targets[$key] = strval($target) == strval($this->labels[1]) ? 1 : -1;
|
||||
}
|
||||
|
||||
// Set samples and feature count vars
|
||||
$this->samples = array_merge($this->samples, $samples);
|
||||
$this->featureCount = count($this->samples[0]);
|
||||
$this->featureCount = count($samples[0]);
|
||||
|
||||
$this->runTraining();
|
||||
$this->runTraining($samples, $targets);
|
||||
}
|
||||
|
||||
protected function resetBinary()
|
||||
{
|
||||
$this->labels = [];
|
||||
$this->optimizer = null;
|
||||
$this->featureCount = 0;
|
||||
$this->weights = null;
|
||||
$this->costValues = [];
|
||||
}
|
||||
|
||||
/**
|
||||
@ -151,8 +163,11 @@ class Perceptron implements Classifier
|
||||
/**
|
||||
* Trains the perceptron model with Stochastic Gradient Descent optimization
|
||||
* to get the correct set of weights
|
||||
*
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
protected function runTraining()
|
||||
protected function runTraining(array $samples, array $targets)
|
||||
{
|
||||
// The cost function is the sum of squares
|
||||
$callback = function ($weights, $sample, $target) {
|
||||
@ -165,25 +180,30 @@ class Perceptron implements Classifier
|
||||
return [$error, $gradient];
|
||||
};
|
||||
|
||||
$this->runGradientDescent($callback);
|
||||
$this->runGradientDescent($samples, $targets, $callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes Stochastic Gradient Descent algorithm for
|
||||
* Executes a Gradient Descent algorithm for
|
||||
* the given cost function
|
||||
*
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
protected function runGradientDescent(\Closure $gradientFunc, bool $isBatch = false)
|
||||
protected function runGradientDescent(array $samples, array $targets, \Closure $gradientFunc, bool $isBatch = false)
|
||||
{
|
||||
$class = $isBatch ? GD::class : StochasticGD::class;
|
||||
|
||||
$optimizer = (new $class($this->featureCount))
|
||||
->setLearningRate($this->learningRate)
|
||||
->setMaxIterations($this->maxIterations)
|
||||
->setChangeThreshold(1e-6)
|
||||
->setEarlyStop($this->enableEarlyStop);
|
||||
if (empty($this->optimizer)) {
|
||||
$this->optimizer = (new $class($this->featureCount))
|
||||
->setLearningRate($this->learningRate)
|
||||
->setMaxIterations($this->maxIterations)
|
||||
->setChangeThreshold(1e-6)
|
||||
->setEarlyStop($this->enableEarlyStop);
|
||||
}
|
||||
|
||||
$this->weights = $optimizer->runOptimization($this->samples, $this->targets, $gradientFunc);
|
||||
$this->costValues = $optimizer->getCostValues();
|
||||
$this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
|
||||
$this->costValues = $this->optimizer->getCostValues();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -6,30 +6,23 @@ namespace Phpml\Helper;
|
||||
|
||||
trait OneVsRest
|
||||
{
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $samples = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $targets = [];
|
||||
protected $classifiers = [];
|
||||
|
||||
/**
|
||||
* All provided training targets' labels.
|
||||
*
|
||||
* @var array
|
||||
*/
|
||||
protected $allLabels = [];
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $classifiers;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $labels;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $costValues;
|
||||
protected $costValues = [];
|
||||
|
||||
/**
|
||||
* Train a binary classifier in the OvR style
|
||||
@ -39,51 +32,111 @@ trait OneVsRest
|
||||
*/
|
||||
public function train(array $samples, array $targets)
|
||||
{
|
||||
// Clears previous stuff.
|
||||
$this->reset();
|
||||
|
||||
return $this->trainBylabel($samples, $targets);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
* @param array $allLabels All training set labels
|
||||
* @return void
|
||||
*/
|
||||
protected function trainByLabel(array $samples, array $targets, array $allLabels = array())
|
||||
{
|
||||
|
||||
// Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run.
|
||||
if (!empty($allLabels)) {
|
||||
$this->allLabels = $allLabels;
|
||||
} else {
|
||||
$this->allLabels = array_keys(array_count_values($targets));
|
||||
}
|
||||
sort($this->allLabels, SORT_STRING);
|
||||
|
||||
// If there are only two targets, then there is no need to perform OvR
|
||||
if (count($this->allLabels) == 2) {
|
||||
|
||||
// Init classifier if required.
|
||||
if (empty($this->classifiers)) {
|
||||
$this->classifiers[0] = $this->getClassifierCopy();
|
||||
}
|
||||
|
||||
$this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels);
|
||||
} else {
|
||||
// Train a separate classifier for each label and memorize them
|
||||
|
||||
foreach ($this->allLabels as $label) {
|
||||
|
||||
// Init classifier if required.
|
||||
if (empty($this->classifiers[$label])) {
|
||||
$this->classifiers[$label] = $this->getClassifierCopy();
|
||||
}
|
||||
|
||||
list($binarizedTargets, $classifierLabels) = $this->binarizeTargets($targets, $label);
|
||||
$this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels);
|
||||
}
|
||||
}
|
||||
|
||||
// If the underlying classifier is capable of giving the cost values
|
||||
// during the training, then assign it to the relevant variable
|
||||
// Adding just the first classifier cost values to avoid complex average calculations.
|
||||
$classifierref = reset($this->classifiers);
|
||||
if (method_exists($classifierref, 'getCostValues')) {
|
||||
$this->costValues = $classifierref->getCostValues();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers.
|
||||
*/
|
||||
public function reset()
|
||||
{
|
||||
$this->classifiers = [];
|
||||
$this->allLabels = [];
|
||||
$this->costValues = [];
|
||||
|
||||
$this->resetBinary();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an instance of the current class after cleaning up OneVsRest stuff.
|
||||
*
|
||||
* @return \Phpml\Estimator
|
||||
*/
|
||||
protected function getClassifierCopy()
|
||||
{
|
||||
|
||||
// Clone the current classifier, so that
|
||||
// we don't mess up its variables while training
|
||||
// multiple instances of this classifier
|
||||
$classifier = clone $this;
|
||||
$this->classifiers = [];
|
||||
|
||||
// If there are only two targets, then there is no need to perform OvR
|
||||
$this->labels = array_keys(array_count_values($targets));
|
||||
if (count($this->labels) == 2) {
|
||||
$classifier->trainBinary($samples, $targets);
|
||||
$this->classifiers[] = $classifier;
|
||||
} else {
|
||||
// Train a separate classifier for each label and memorize them
|
||||
$this->samples = $samples;
|
||||
$this->targets = $targets;
|
||||
foreach ($this->labels as $label) {
|
||||
$predictor = clone $classifier;
|
||||
$targets = $this->binarizeTargets($label);
|
||||
$predictor->trainBinary($samples, $targets);
|
||||
$this->classifiers[$label] = $predictor;
|
||||
}
|
||||
}
|
||||
|
||||
// If the underlying classifier is capable of giving the cost values
|
||||
// during the training, then assign it to the relevant variable
|
||||
if (method_exists($this->classifiers[0], 'getCostValues')) {
|
||||
$this->costValues = $this->classifiers[0]->getCostValues();
|
||||
}
|
||||
$classifier->reset();
|
||||
return $classifier;
|
||||
}
|
||||
|
||||
/**
|
||||
* Groups all targets into two groups: Targets equal to
|
||||
* the given label and the others
|
||||
*
|
||||
* $targets is not passed by reference nor contains objects so this method
|
||||
* changes will not affect the caller $targets array.
|
||||
*
|
||||
* @param array $targets
|
||||
* @param mixed $label
|
||||
* @return array Binarized targets and target's labels
|
||||
*/
|
||||
private function binarizeTargets($label)
|
||||
private function binarizeTargets($targets, $label)
|
||||
{
|
||||
$targets = [];
|
||||
|
||||
foreach ($this->targets as $target) {
|
||||
$targets[] = $target == $label ? $label : "not_$label";
|
||||
$notLabel = "not_$label";
|
||||
foreach ($targets as $key => $target) {
|
||||
$targets[$key] = $target == $label ? $label : $notLabel;
|
||||
}
|
||||
|
||||
return $targets;
|
||||
$labels = array($label, $notLabel);
|
||||
return array($targets, $labels);
|
||||
}
|
||||
|
||||
|
||||
@ -94,7 +147,7 @@ trait OneVsRest
|
||||
*/
|
||||
protected function predictSample(array $sample)
|
||||
{
|
||||
if (count($this->labels) == 2) {
|
||||
if (count($this->allLabels) == 2) {
|
||||
return $this->classifiers[0]->predictSampleBinary($sample);
|
||||
}
|
||||
|
||||
@ -113,8 +166,16 @@ trait OneVsRest
|
||||
*
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
* @param array $labels
|
||||
*/
|
||||
abstract protected function trainBinary(array $samples, array $targets);
|
||||
abstract protected function trainBinary(array $samples, array $targets, array $labels);
|
||||
|
||||
/**
|
||||
* To be overwritten by OneVsRest classifiers.
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
abstract protected function resetBinary();
|
||||
|
||||
/**
|
||||
* Each classifier that make use of OvR approach should be able to
|
||||
|
@ -57,6 +57,8 @@ class ConjugateGradient extends GD
|
||||
}
|
||||
}
|
||||
|
||||
$this->clear();
|
||||
|
||||
return $this->theta;
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ class GD extends StochasticGD
|
||||
*
|
||||
* @var int
|
||||
*/
|
||||
protected $sampleCount;
|
||||
protected $sampleCount = null;
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
@ -49,6 +49,8 @@ class GD extends StochasticGD
|
||||
}
|
||||
}
|
||||
|
||||
$this->clear();
|
||||
|
||||
return $this->theta;
|
||||
}
|
||||
|
||||
@ -105,4 +107,15 @@ class GD extends StochasticGD
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the optimizer internal vars after the optimization process.
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
protected function clear()
|
||||
{
|
||||
$this->sampleCount = null;
|
||||
parent::clear();
|
||||
}
|
||||
}
|
||||
|
@ -16,14 +16,14 @@ class StochasticGD extends Optimizer
|
||||
*
|
||||
* @var array
|
||||
*/
|
||||
protected $samples;
|
||||
protected $samples = [];
|
||||
|
||||
/**
|
||||
* y (targets)
|
||||
*
|
||||
* @var array
|
||||
*/
|
||||
protected $targets;
|
||||
protected $targets = [];
|
||||
|
||||
/**
|
||||
* Callback function to get the gradient and cost value
|
||||
@ -31,7 +31,7 @@ class StochasticGD extends Optimizer
|
||||
*
|
||||
* @var \Closure
|
||||
*/
|
||||
protected $gradientCb;
|
||||
protected $gradientCb = null;
|
||||
|
||||
/**
|
||||
* Maximum number of iterations used to train the model
|
||||
@ -192,6 +192,8 @@ class StochasticGD extends Optimizer
|
||||
}
|
||||
}
|
||||
|
||||
$this->clear();
|
||||
|
||||
// Solution in the pocket is better than or equal to the last state
|
||||
// so, we use this solution
|
||||
return $this->theta = $bestTheta;
|
||||
@ -268,4 +270,16 @@ class StochasticGD extends Optimizer
|
||||
{
|
||||
return $this->costValues;
|
||||
}
|
||||
|
||||
/**
|
||||
* Clears the optimizer internal vars after the optimization process.
|
||||
*
|
||||
* @return void
|
||||
*/
|
||||
protected function clear()
|
||||
{
|
||||
$this->samples = [];
|
||||
$this->targets = [];
|
||||
$this->gradientCb = null;
|
||||
}
|
||||
}
|
||||
|
16
src/Phpml/IncrementalEstimator.php
Normal file
16
src/Phpml/IncrementalEstimator.php
Normal file
@ -0,0 +1,16 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace Phpml;
|
||||
|
||||
interface IncrementalEstimator
|
||||
{
|
||||
|
||||
/**
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
* @param array $labels
|
||||
*/
|
||||
public function partialTrain(array $samples, array $targets, array $labels = array());
|
||||
}
|
@ -45,7 +45,23 @@ class AdalineTest extends TestCase
|
||||
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
|
||||
$this->assertEquals(2, $classifier->predict([3.0, 9.5]));
|
||||
|
||||
return $classifier;
|
||||
// Extra partial training should lead to the same results.
|
||||
$classifier->partialTrain([[0, 1], [1, 0]], [0, 0], [0, 1, 2]);
|
||||
$this->assertEquals(0, $classifier->predict([0.5, 0.5]));
|
||||
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
|
||||
$this->assertEquals(2, $classifier->predict([3.0, 9.5]));
|
||||
|
||||
// Train should clear previous data.
|
||||
$samples = [
|
||||
[0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D
|
||||
[5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right
|
||||
[3, 10],[3, 10],[3, 8], [3, 9] // Third group : cluster at the top-middle
|
||||
];
|
||||
$targets = [2, 2, 2, 2, 0, 0, 0, 0, 1, 1, 1, 1];
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(2, $classifier->predict([0.5, 0.5]));
|
||||
$this->assertEquals(0, $classifier->predict([6.0, 5.0]));
|
||||
$this->assertEquals(1, $classifier->predict([3.0, 9.5]));
|
||||
}
|
||||
|
||||
public function testSaveAndRestore()
|
||||
|
@ -48,7 +48,23 @@ class PerceptronTest extends TestCase
|
||||
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
|
||||
$this->assertEquals(2, $classifier->predict([3.0, 9.5]));
|
||||
|
||||
return $classifier;
|
||||
// Extra partial training should lead to the same results.
|
||||
$classifier->partialTrain([[0, 1], [1, 0]], [0, 0], [0, 1, 2]);
|
||||
$this->assertEquals(0, $classifier->predict([0.5, 0.5]));
|
||||
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
|
||||
$this->assertEquals(2, $classifier->predict([3.0, 9.5]));
|
||||
|
||||
// Train should clear previous data.
|
||||
$samples = [
|
||||
[0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D
|
||||
[5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right
|
||||
[3, 10],[3, 10],[3, 8], [3, 9] // Third group : cluster at the top-middle
|
||||
];
|
||||
$targets = [2, 2, 2, 2, 0, 0, 0, 0, 1, 1, 1, 1];
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(2, $classifier->predict([0.5, 0.5]));
|
||||
$this->assertEquals(0, $classifier->predict([6.0, 5.0]));
|
||||
$this->assertEquals(1, $classifier->predict([3.0, 9.5]));
|
||||
}
|
||||
|
||||
public function testSaveAndRestore()
|
||||
|
Loading…
Reference in New Issue
Block a user