2017-02-16 22:23:55 +00:00
|
|
|
<?php
|
|
|
|
|
|
|
|
declare(strict_types=1);
|
|
|
|
|
|
|
|
namespace Phpml\Classification\Linear;
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
use Closure;
|
|
|
|
use Exception;
|
2017-11-06 07:56:37 +00:00
|
|
|
use Phpml\Classification\Classifier;
|
2017-03-05 08:43:19 +00:00
|
|
|
use Phpml\Helper\OneVsRest;
|
2017-03-27 21:46:53 +00:00
|
|
|
use Phpml\Helper\Optimizer\GD;
|
2017-11-06 07:56:37 +00:00
|
|
|
use Phpml\Helper\Optimizer\StochasticGD;
|
|
|
|
use Phpml\Helper\Predictable;
|
2017-04-19 20:26:31 +00:00
|
|
|
use Phpml\IncrementalEstimator;
|
2017-11-06 07:56:37 +00:00
|
|
|
use Phpml\Preprocessing\Normalizer;
|
2017-02-16 22:23:55 +00:00
|
|
|
|
2017-04-19 20:26:31 +00:00
|
|
|
class Perceptron implements Classifier, IncrementalEstimator
|
2017-02-16 22:23:55 +00:00
|
|
|
{
|
2017-03-05 08:43:19 +00:00
|
|
|
use Predictable, OneVsRest;
|
2017-02-16 22:23:55 +00:00
|
|
|
|
|
|
|
/**
|
2018-01-06 12:09:33 +00:00
|
|
|
* @var \Phpml\Helper\Optimizer\Optimizer|GD|StochasticGD|null
|
2017-02-16 22:23:55 +00:00
|
|
|
*/
|
2017-04-19 20:26:31 +00:00
|
|
|
protected $optimizer;
|
2017-02-16 22:23:55 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
protected $labels = [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var int
|
|
|
|
*/
|
|
|
|
protected $featureCount = 0;
|
|
|
|
|
|
|
|
/**
|
2018-01-06 20:25:47 +00:00
|
|
|
* @var array
|
2017-02-16 22:23:55 +00:00
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
protected $weights = [];
|
2017-02-16 22:23:55 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var float
|
|
|
|
*/
|
|
|
|
protected $learningRate;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var int
|
|
|
|
*/
|
|
|
|
protected $maxIterations;
|
|
|
|
|
2017-02-21 09:38:18 +00:00
|
|
|
/**
|
|
|
|
* @var Normalizer
|
|
|
|
*/
|
|
|
|
protected $normalizer;
|
|
|
|
|
2017-02-28 20:45:18 +00:00
|
|
|
/**
|
2017-03-27 21:46:53 +00:00
|
|
|
* @var bool
|
|
|
|
*/
|
|
|
|
protected $enableEarlyStop = true;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
2017-02-28 20:45:18 +00:00
|
|
|
*/
|
2017-03-27 21:46:53 +00:00
|
|
|
protected $costValues = [];
|
2017-02-28 20:45:18 +00:00
|
|
|
|
2017-02-16 22:23:55 +00:00
|
|
|
/**
|
|
|
|
* Initalize a perceptron classifier with given learning rate and maximum
|
2017-05-17 07:03:25 +00:00
|
|
|
* number of iterations used while training the perceptron
|
2017-02-16 22:23:55 +00:00
|
|
|
*
|
2018-01-06 12:09:33 +00:00
|
|
|
* @param float $learningRate Value between 0.0(exclusive) and 1.0(inclusive)
|
|
|
|
* @param int $maxIterations Must be at least 1
|
2017-05-17 07:03:25 +00:00
|
|
|
*
|
|
|
|
* @throws \Exception
|
2017-02-16 22:23:55 +00:00
|
|
|
*/
|
2017-05-17 07:03:25 +00:00
|
|
|
public function __construct(float $learningRate = 0.001, int $maxIterations = 1000, bool $normalizeInputs = true)
|
2017-02-16 22:23:55 +00:00
|
|
|
{
|
|
|
|
if ($learningRate <= 0.0 || $learningRate > 1.0) {
|
2017-11-22 21:16:10 +00:00
|
|
|
throw new Exception('Learning rate should be a float value between 0.0(exclusive) and 1.0(inclusive)');
|
2017-02-16 22:23:55 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if ($maxIterations <= 0) {
|
2017-11-22 21:16:10 +00:00
|
|
|
throw new Exception('Maximum number of iterations must be an integer greater than 0');
|
2017-02-16 22:23:55 +00:00
|
|
|
}
|
|
|
|
|
2017-02-21 09:38:18 +00:00
|
|
|
if ($normalizeInputs) {
|
|
|
|
$this->normalizer = new Normalizer(Normalizer::NORM_STD);
|
|
|
|
}
|
|
|
|
|
2017-02-16 22:23:55 +00:00
|
|
|
$this->learningRate = $learningRate;
|
|
|
|
$this->maxIterations = $maxIterations;
|
|
|
|
}
|
|
|
|
|
2017-11-14 20:21:23 +00:00
|
|
|
public function partialTrain(array $samples, array $targets, array $labels = []): void
|
2017-04-19 20:26:31 +00:00
|
|
|
{
|
2017-05-17 07:03:25 +00:00
|
|
|
$this->trainByLabel($samples, $targets, $labels);
|
2017-04-19 20:26:31 +00:00
|
|
|
}
|
|
|
|
|
2017-11-14 20:21:23 +00:00
|
|
|
public function trainBinary(array $samples, array $targets, array $labels): void
|
2017-02-16 22:23:55 +00:00
|
|
|
{
|
2018-02-16 06:25:24 +00:00
|
|
|
if ($this->normalizer !== null) {
|
2017-02-21 09:38:18 +00:00
|
|
|
$this->normalizer->transform($samples);
|
|
|
|
}
|
|
|
|
|
2017-02-16 22:23:55 +00:00
|
|
|
// Set all target values to either -1 or 1
|
2017-11-22 21:16:10 +00:00
|
|
|
$this->labels = [
|
|
|
|
1 => $labels[0],
|
|
|
|
-1 => $labels[1],
|
|
|
|
];
|
2017-04-19 20:26:31 +00:00
|
|
|
foreach ($targets as $key => $target) {
|
2017-09-02 19:41:06 +00:00
|
|
|
$targets[$key] = (string) $target == (string) $this->labels[1] ? 1 : -1;
|
2017-02-16 22:23:55 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
// Set samples and feature count vars
|
2017-04-19 20:26:31 +00:00
|
|
|
$this->featureCount = count($samples[0]);
|
|
|
|
|
|
|
|
$this->runTraining($samples, $targets);
|
|
|
|
}
|
2017-02-16 22:23:55 +00:00
|
|
|
|
|
|
|
/**
|
2017-03-27 21:46:53 +00:00
|
|
|
* Normally enabling early stopping for the optimization procedure may
|
|
|
|
* help saving processing time while in some cases it may result in
|
|
|
|
* premature convergence.<br>
|
|
|
|
*
|
|
|
|
* If "false" is given, the optimization procedure will always be executed
|
|
|
|
* for $maxIterations times
|
|
|
|
*
|
2017-05-17 07:03:25 +00:00
|
|
|
* @return $this
|
2017-03-27 21:46:53 +00:00
|
|
|
*/
|
|
|
|
public function setEarlyStop(bool $enable = true)
|
|
|
|
{
|
|
|
|
$this->enableEarlyStop = $enable;
|
|
|
|
|
|
|
|
return $this;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Returns the cost values obtained during the training.
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
public function getCostValues(): array
|
2017-03-27 21:46:53 +00:00
|
|
|
{
|
|
|
|
return $this->costValues;
|
|
|
|
}
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
protected function resetBinary(): void
|
|
|
|
{
|
|
|
|
$this->labels = [];
|
|
|
|
$this->optimizer = null;
|
|
|
|
$this->featureCount = 0;
|
2018-01-06 20:25:47 +00:00
|
|
|
$this->weights = [];
|
2017-11-22 21:16:10 +00:00
|
|
|
$this->costValues = [];
|
|
|
|
}
|
|
|
|
|
2017-03-27 21:46:53 +00:00
|
|
|
/**
|
|
|
|
* Trains the perceptron model with Stochastic Gradient Descent optimization
|
|
|
|
* to get the correct set of weights
|
2017-02-16 22:23:55 +00:00
|
|
|
*/
|
2017-04-19 20:26:31 +00:00
|
|
|
protected function runTraining(array $samples, array $targets)
|
2017-02-16 22:23:55 +00:00
|
|
|
{
|
2017-03-27 21:46:53 +00:00
|
|
|
// The cost function is the sum of squares
|
|
|
|
$callback = function ($weights, $sample, $target) {
|
|
|
|
$this->weights = $weights;
|
2017-02-28 20:45:18 +00:00
|
|
|
|
2017-03-27 21:46:53 +00:00
|
|
|
$prediction = $this->outputClass($sample);
|
|
|
|
$gradient = $prediction - $target;
|
2017-08-17 06:50:37 +00:00
|
|
|
$error = $gradient ** 2;
|
2017-02-28 20:45:18 +00:00
|
|
|
|
2017-03-27 21:46:53 +00:00
|
|
|
return [$error, $gradient];
|
|
|
|
};
|
2017-02-28 20:45:18 +00:00
|
|
|
|
2017-04-19 20:26:31 +00:00
|
|
|
$this->runGradientDescent($samples, $targets, $callback);
|
2017-02-28 20:45:18 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
2017-04-19 20:26:31 +00:00
|
|
|
* Executes a Gradient Descent algorithm for
|
2017-03-27 21:46:53 +00:00
|
|
|
* the given cost function
|
2017-02-28 20:45:18 +00:00
|
|
|
*/
|
2018-01-06 20:25:47 +00:00
|
|
|
protected function runGradientDescent(array $samples, array $targets, Closure $gradientFunc, bool $isBatch = false)
|
2017-02-28 20:45:18 +00:00
|
|
|
{
|
2017-05-17 07:03:25 +00:00
|
|
|
$class = $isBatch ? GD::class : StochasticGD::class;
|
2017-03-27 21:46:53 +00:00
|
|
|
|
2018-01-06 12:09:33 +00:00
|
|
|
if ($this->optimizer === null) {
|
2017-04-19 20:26:31 +00:00
|
|
|
$this->optimizer = (new $class($this->featureCount))
|
|
|
|
->setLearningRate($this->learningRate)
|
|
|
|
->setMaxIterations($this->maxIterations)
|
|
|
|
->setChangeThreshold(1e-6)
|
|
|
|
->setEarlyStop($this->enableEarlyStop);
|
|
|
|
}
|
2017-02-28 20:45:18 +00:00
|
|
|
|
2017-04-19 20:26:31 +00:00
|
|
|
$this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
|
|
|
|
$this->costValues = $this->optimizer->getCostValues();
|
2017-02-16 22:23:55 +00:00
|
|
|
}
|
|
|
|
|
2017-03-05 08:43:19 +00:00
|
|
|
/**
|
|
|
|
* Checks if the sample should be normalized and if so, returns the
|
|
|
|
* normalized sample
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
protected function checkNormalizedSample(array $sample): array
|
2017-03-05 08:43:19 +00:00
|
|
|
{
|
2018-02-16 06:25:24 +00:00
|
|
|
if ($this->normalizer !== null) {
|
2017-03-05 08:43:19 +00:00
|
|
|
$samples = [$sample];
|
|
|
|
$this->normalizer->transform($samples);
|
|
|
|
$sample = $samples[0];
|
|
|
|
}
|
|
|
|
|
|
|
|
return $sample;
|
|
|
|
}
|
|
|
|
|
2017-02-16 22:23:55 +00:00
|
|
|
/**
|
|
|
|
* Calculates net output of the network as a float value for the given input
|
|
|
|
*
|
2017-11-22 21:16:10 +00:00
|
|
|
* @return int|float
|
2017-02-16 22:23:55 +00:00
|
|
|
*/
|
|
|
|
protected function output(array $sample)
|
|
|
|
{
|
|
|
|
$sum = 0;
|
|
|
|
foreach ($this->weights as $index => $w) {
|
|
|
|
if ($index == 0) {
|
|
|
|
$sum += $w;
|
|
|
|
} else {
|
|
|
|
$sum += $w * $sample[$index - 1];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return $sum;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Returns the class value (either -1 or 1) for the given input
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
protected function outputClass(array $sample): int
|
2017-02-16 22:23:55 +00:00
|
|
|
{
|
|
|
|
return $this->output($sample) > 0 ? 1 : -1;
|
|
|
|
}
|
|
|
|
|
2017-03-05 08:43:19 +00:00
|
|
|
/**
|
|
|
|
* Returns the probability of the sample of belonging to the given label.
|
|
|
|
*
|
|
|
|
* The probability is simply taken as the distance of the sample
|
|
|
|
* to the decision plane.
|
|
|
|
*
|
|
|
|
* @param mixed $label
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
protected function predictProbability(array $sample, $label): float
|
2017-03-05 08:43:19 +00:00
|
|
|
{
|
|
|
|
$predicted = $this->predictSampleBinary($sample);
|
|
|
|
|
2017-09-02 19:41:06 +00:00
|
|
|
if ((string) $predicted == (string) $label) {
|
2017-03-05 08:43:19 +00:00
|
|
|
$sample = $this->checkNormalizedSample($sample);
|
2017-08-17 06:50:37 +00:00
|
|
|
|
2017-11-06 07:56:37 +00:00
|
|
|
return (float) abs($this->output($sample));
|
2017-03-05 08:43:19 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return 0.0;
|
|
|
|
}
|
|
|
|
|
2017-02-16 22:23:55 +00:00
|
|
|
/**
|
|
|
|
* @return mixed
|
|
|
|
*/
|
2017-03-05 08:43:19 +00:00
|
|
|
protected function predictSampleBinary(array $sample)
|
2017-02-16 22:23:55 +00:00
|
|
|
{
|
2017-03-05 08:43:19 +00:00
|
|
|
$sample = $this->checkNormalizedSample($sample);
|
2017-02-21 09:38:18 +00:00
|
|
|
|
2017-02-16 22:23:55 +00:00
|
|
|
$predictedClass = $this->outputClass($sample);
|
|
|
|
|
2017-05-17 07:03:25 +00:00
|
|
|
return $this->labels[$predictedClass];
|
2017-02-16 22:23:55 +00:00
|
|
|
}
|
|
|
|
}
|