mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-01-09 00:20:53 +00:00
Fix logistic regression implementation (#169)
* Fix target value of LogisticRegression * Fix probability calculation in LogisticRegression * Change the default cost function to log-likelihood * Remove redundant round function * Fix for coding standard
This commit is contained in:
parent
946fbbc521
commit
c4f58f7f6f
@ -32,7 +32,7 @@ class LogisticRegression extends Adaline
|
||||
*
|
||||
* @var string
|
||||
*/
|
||||
protected $costFunction = 'sse';
|
||||
protected $costFunction = 'log';
|
||||
|
||||
/**
|
||||
* Regularization term: only 'L2' is supported
|
||||
@ -67,7 +67,7 @@ class LogisticRegression extends Adaline
|
||||
int $maxIterations = 500,
|
||||
bool $normalizeInputs = true,
|
||||
int $trainingType = self::CONJUGATE_GRAD_TRAINING,
|
||||
string $cost = 'sse',
|
||||
string $cost = 'log',
|
||||
string $penalty = 'L2'
|
||||
) {
|
||||
$trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
|
||||
@ -190,6 +190,8 @@ class LogisticRegression extends Adaline
|
||||
$hX = 1e-10;
|
||||
}
|
||||
|
||||
$y = $y < 0 ? 0 : 1;
|
||||
|
||||
$error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
|
||||
$gradient = $hX - $y;
|
||||
|
||||
@ -213,6 +215,8 @@ class LogisticRegression extends Adaline
|
||||
$this->weights = $weights;
|
||||
$hX = $this->output($sample);
|
||||
|
||||
$y = $y < 0 ? 0 : 1;
|
||||
|
||||
$error = ($y - $hX) ** 2;
|
||||
$gradient = -($y - $hX) * $hX * (1 - $hX);
|
||||
|
||||
@ -243,7 +247,7 @@ class LogisticRegression extends Adaline
|
||||
{
|
||||
$output = $this->output($sample);
|
||||
|
||||
if (round($output) > 0.5) {
|
||||
if ($output > 0.5) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@ -260,14 +264,13 @@ class LogisticRegression extends Adaline
|
||||
*/
|
||||
protected function predictProbability(array $sample, $label): float
|
||||
{
|
||||
$predicted = $this->predictSampleBinary($sample);
|
||||
$sample = $this->checkNormalizedSample($sample);
|
||||
$probability = $this->output($sample);
|
||||
|
||||
if ((string) $predicted == (string) $label) {
|
||||
$sample = $this->checkNormalizedSample($sample);
|
||||
|
||||
return (float) abs($this->output($sample) - 0.5);
|
||||
if (array_search($label, $this->labels, true) > 0) {
|
||||
return $probability;
|
||||
}
|
||||
|
||||
return 0.0;
|
||||
return 1 - $probability;
|
||||
}
|
||||
}
|
||||
|
106
tests/Phpml/Classification/Linear/LogisticRegressionTest.php
Normal file
106
tests/Phpml/Classification/Linear/LogisticRegressionTest.php
Normal file
@ -0,0 +1,106 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace tests\Phpml\Classification\Linear;
|
||||
|
||||
use Phpml\Classification\Linear\LogisticRegression;
|
||||
use PHPUnit\Framework\TestCase;
|
||||
use ReflectionMethod;
|
||||
use ReflectionProperty;
|
||||
|
||||
class LogisticRegressionTest extends TestCase
|
||||
{
|
||||
public function testPredictSingleSample(): void
|
||||
{
|
||||
// AND problem
|
||||
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
||||
$targets = [0, 0, 0, 1, 0, 1];
|
||||
$classifier = new LogisticRegression();
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
||||
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
||||
}
|
||||
|
||||
public function testPredictMultiClassSample(): void
|
||||
{
|
||||
// By use of One-v-Rest, Perceptron can perform multi-class classification
|
||||
// The samples should be separable by lines perpendicular to the dimensions
|
||||
$samples = [
|
||||
[0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D
|
||||
[5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right
|
||||
[3, 10], [3, 10], [3, 8], [3, 9], // Third group : cluster at the top-middle
|
||||
];
|
||||
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
|
||||
|
||||
$classifier = new LogisticRegression();
|
||||
$classifier->train($samples, $targets);
|
||||
$this->assertEquals(0, $classifier->predict([0.5, 0.5]));
|
||||
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
|
||||
$this->assertEquals(2, $classifier->predict([3.0, 9.5]));
|
||||
}
|
||||
|
||||
public function testPredictProbabilitySingleSample(): void
|
||||
{
|
||||
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
||||
$targets = [0, 0, 0, 1, 0, 1];
|
||||
$classifier = new LogisticRegression();
|
||||
$classifier->train($samples, $targets);
|
||||
|
||||
$property = new ReflectionProperty($classifier, 'classifiers');
|
||||
$property->setAccessible(true);
|
||||
$predictor = $property->getValue($classifier)[0];
|
||||
$method = new ReflectionMethod($predictor, 'predictProbability');
|
||||
$method->setAccessible(true);
|
||||
|
||||
$zero = $method->invoke($predictor, [0.1, 0.1], 0);
|
||||
$one = $method->invoke($predictor, [0.1, 0.1], 1);
|
||||
$this->assertEquals(1, $zero + $one, null, 1e-6);
|
||||
$this->assertTrue($zero > $one);
|
||||
|
||||
$zero = $method->invoke($predictor, [0.9, 0.9], 0);
|
||||
$one = $method->invoke($predictor, [0.9, 0.9], 1);
|
||||
$this->assertEquals(1, $zero + $one, null, 1e-6);
|
||||
$this->assertTrue($zero < $one);
|
||||
}
|
||||
|
||||
public function testPredictProbabilityMultiClassSample(): void
|
||||
{
|
||||
$samples = [
|
||||
[0, 0], [0, 1], [1, 0], [1, 1],
|
||||
[5, 5], [6, 5], [5, 6], [6, 6],
|
||||
[3, 10], [3, 10], [3, 8], [3, 9],
|
||||
];
|
||||
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
|
||||
|
||||
$classifier = new LogisticRegression();
|
||||
$classifier->train($samples, $targets);
|
||||
|
||||
$property = new ReflectionProperty($classifier, 'classifiers');
|
||||
$property->setAccessible(true);
|
||||
|
||||
$predictor = $property->getValue($classifier)[0];
|
||||
$method = new ReflectionMethod($predictor, 'predictProbability');
|
||||
$method->setAccessible(true);
|
||||
$zero = $method->invoke($predictor, [3.0, 9.5], 0);
|
||||
$not_zero = $method->invoke($predictor, [3.0, 9.5], 'not_0');
|
||||
|
||||
$predictor = $property->getValue($classifier)[1];
|
||||
$method = new ReflectionMethod($predictor, 'predictProbability');
|
||||
$method->setAccessible(true);
|
||||
$one = $method->invoke($predictor, [3.0, 9.5], 1);
|
||||
$not_one = $method->invoke($predictor, [3.0, 9.5], 'not_1');
|
||||
|
||||
$predictor = $property->getValue($classifier)[2];
|
||||
$method = new ReflectionMethod($predictor, 'predictProbability');
|
||||
$method->setAccessible(true);
|
||||
$two = $method->invoke($predictor, [3.0, 9.5], 2);
|
||||
$not_two = $method->invoke($predictor, [3.0, 9.5], 'not_2');
|
||||
|
||||
$this->assertEquals(1, $zero + $not_zero, null, 1e-6);
|
||||
$this->assertEquals(1, $one + $not_one, null, 1e-6);
|
||||
$this->assertEquals(1, $two + $not_two, null, 1e-6);
|
||||
$this->assertTrue($zero < $two);
|
||||
$this->assertTrue($one < $two);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user