mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-01-24 15:48:24 +00:00
Fix logistic regression implementation (#169)
* Fix target value of LogisticRegression * Fix probability calculation in LogisticRegression * Change the default cost function to log-likelihood * Remove redundant round function * Fix for coding standard
This commit is contained in:
parent
946fbbc521
commit
c4f58f7f6f
@ -32,7 +32,7 @@ class LogisticRegression extends Adaline
|
|||||||
*
|
*
|
||||||
* @var string
|
* @var string
|
||||||
*/
|
*/
|
||||||
protected $costFunction = 'sse';
|
protected $costFunction = 'log';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Regularization term: only 'L2' is supported
|
* Regularization term: only 'L2' is supported
|
||||||
@ -67,7 +67,7 @@ class LogisticRegression extends Adaline
|
|||||||
int $maxIterations = 500,
|
int $maxIterations = 500,
|
||||||
bool $normalizeInputs = true,
|
bool $normalizeInputs = true,
|
||||||
int $trainingType = self::CONJUGATE_GRAD_TRAINING,
|
int $trainingType = self::CONJUGATE_GRAD_TRAINING,
|
||||||
string $cost = 'sse',
|
string $cost = 'log',
|
||||||
string $penalty = 'L2'
|
string $penalty = 'L2'
|
||||||
) {
|
) {
|
||||||
$trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
|
$trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
|
||||||
@ -190,6 +190,8 @@ class LogisticRegression extends Adaline
|
|||||||
$hX = 1e-10;
|
$hX = 1e-10;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
$y = $y < 0 ? 0 : 1;
|
||||||
|
|
||||||
$error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
|
$error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
|
||||||
$gradient = $hX - $y;
|
$gradient = $hX - $y;
|
||||||
|
|
||||||
@ -213,6 +215,8 @@ class LogisticRegression extends Adaline
|
|||||||
$this->weights = $weights;
|
$this->weights = $weights;
|
||||||
$hX = $this->output($sample);
|
$hX = $this->output($sample);
|
||||||
|
|
||||||
|
$y = $y < 0 ? 0 : 1;
|
||||||
|
|
||||||
$error = ($y - $hX) ** 2;
|
$error = ($y - $hX) ** 2;
|
||||||
$gradient = -($y - $hX) * $hX * (1 - $hX);
|
$gradient = -($y - $hX) * $hX * (1 - $hX);
|
||||||
|
|
||||||
@ -243,7 +247,7 @@ class LogisticRegression extends Adaline
|
|||||||
{
|
{
|
||||||
$output = $this->output($sample);
|
$output = $this->output($sample);
|
||||||
|
|
||||||
if (round($output) > 0.5) {
|
if ($output > 0.5) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,14 +264,13 @@ class LogisticRegression extends Adaline
|
|||||||
*/
|
*/
|
||||||
protected function predictProbability(array $sample, $label): float
|
protected function predictProbability(array $sample, $label): float
|
||||||
{
|
{
|
||||||
$predicted = $this->predictSampleBinary($sample);
|
|
||||||
|
|
||||||
if ((string) $predicted == (string) $label) {
|
|
||||||
$sample = $this->checkNormalizedSample($sample);
|
$sample = $this->checkNormalizedSample($sample);
|
||||||
|
$probability = $this->output($sample);
|
||||||
|
|
||||||
return (float) abs($this->output($sample) - 0.5);
|
if (array_search($label, $this->labels, true) > 0) {
|
||||||
|
return $probability;
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0.0;
|
return 1 - $probability;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
106
tests/Phpml/Classification/Linear/LogisticRegressionTest.php
Normal file
106
tests/Phpml/Classification/Linear/LogisticRegressionTest.php
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace tests\Phpml\Classification\Linear;
|
||||||
|
|
||||||
|
use Phpml\Classification\Linear\LogisticRegression;
|
||||||
|
use PHPUnit\Framework\TestCase;
|
||||||
|
use ReflectionMethod;
|
||||||
|
use ReflectionProperty;
|
||||||
|
|
||||||
|
class LogisticRegressionTest extends TestCase
|
||||||
|
{
|
||||||
|
public function testPredictSingleSample(): void
|
||||||
|
{
|
||||||
|
// AND problem
|
||||||
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
||||||
|
$targets = [0, 0, 0, 1, 0, 1];
|
||||||
|
$classifier = new LogisticRegression();
|
||||||
|
$classifier->train($samples, $targets);
|
||||||
|
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
||||||
|
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testPredictMultiClassSample(): void
|
||||||
|
{
|
||||||
|
// By use of One-v-Rest, Perceptron can perform multi-class classification
|
||||||
|
// The samples should be separable by lines perpendicular to the dimensions
|
||||||
|
$samples = [
|
||||||
|
[0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D
|
||||||
|
[5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right
|
||||||
|
[3, 10], [3, 10], [3, 8], [3, 9], // Third group : cluster at the top-middle
|
||||||
|
];
|
||||||
|
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
|
||||||
|
|
||||||
|
$classifier = new LogisticRegression();
|
||||||
|
$classifier->train($samples, $targets);
|
||||||
|
$this->assertEquals(0, $classifier->predict([0.5, 0.5]));
|
||||||
|
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
|
||||||
|
$this->assertEquals(2, $classifier->predict([3.0, 9.5]));
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testPredictProbabilitySingleSample(): void
|
||||||
|
{
|
||||||
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
||||||
|
$targets = [0, 0, 0, 1, 0, 1];
|
||||||
|
$classifier = new LogisticRegression();
|
||||||
|
$classifier->train($samples, $targets);
|
||||||
|
|
||||||
|
$property = new ReflectionProperty($classifier, 'classifiers');
|
||||||
|
$property->setAccessible(true);
|
||||||
|
$predictor = $property->getValue($classifier)[0];
|
||||||
|
$method = new ReflectionMethod($predictor, 'predictProbability');
|
||||||
|
$method->setAccessible(true);
|
||||||
|
|
||||||
|
$zero = $method->invoke($predictor, [0.1, 0.1], 0);
|
||||||
|
$one = $method->invoke($predictor, [0.1, 0.1], 1);
|
||||||
|
$this->assertEquals(1, $zero + $one, null, 1e-6);
|
||||||
|
$this->assertTrue($zero > $one);
|
||||||
|
|
||||||
|
$zero = $method->invoke($predictor, [0.9, 0.9], 0);
|
||||||
|
$one = $method->invoke($predictor, [0.9, 0.9], 1);
|
||||||
|
$this->assertEquals(1, $zero + $one, null, 1e-6);
|
||||||
|
$this->assertTrue($zero < $one);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testPredictProbabilityMultiClassSample(): void
|
||||||
|
{
|
||||||
|
$samples = [
|
||||||
|
[0, 0], [0, 1], [1, 0], [1, 1],
|
||||||
|
[5, 5], [6, 5], [5, 6], [6, 6],
|
||||||
|
[3, 10], [3, 10], [3, 8], [3, 9],
|
||||||
|
];
|
||||||
|
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
|
||||||
|
|
||||||
|
$classifier = new LogisticRegression();
|
||||||
|
$classifier->train($samples, $targets);
|
||||||
|
|
||||||
|
$property = new ReflectionProperty($classifier, 'classifiers');
|
||||||
|
$property->setAccessible(true);
|
||||||
|
|
||||||
|
$predictor = $property->getValue($classifier)[0];
|
||||||
|
$method = new ReflectionMethod($predictor, 'predictProbability');
|
||||||
|
$method->setAccessible(true);
|
||||||
|
$zero = $method->invoke($predictor, [3.0, 9.5], 0);
|
||||||
|
$not_zero = $method->invoke($predictor, [3.0, 9.5], 'not_0');
|
||||||
|
|
||||||
|
$predictor = $property->getValue($classifier)[1];
|
||||||
|
$method = new ReflectionMethod($predictor, 'predictProbability');
|
||||||
|
$method->setAccessible(true);
|
||||||
|
$one = $method->invoke($predictor, [3.0, 9.5], 1);
|
||||||
|
$not_one = $method->invoke($predictor, [3.0, 9.5], 'not_1');
|
||||||
|
|
||||||
|
$predictor = $property->getValue($classifier)[2];
|
||||||
|
$method = new ReflectionMethod($predictor, 'predictProbability');
|
||||||
|
$method->setAccessible(true);
|
||||||
|
$two = $method->invoke($predictor, [3.0, 9.5], 2);
|
||||||
|
$not_two = $method->invoke($predictor, [3.0, 9.5], 'not_2');
|
||||||
|
|
||||||
|
$this->assertEquals(1, $zero + $not_zero, null, 1e-6);
|
||||||
|
$this->assertEquals(1, $one + $not_one, null, 1e-6);
|
||||||
|
$this->assertEquals(1, $two + $not_two, null, 1e-6);
|
||||||
|
$this->assertTrue($zero < $two);
|
||||||
|
$this->assertTrue($one < $two);
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user