Fix logistic regression implementation (#169)

* Fix target value of LogisticRegression

* Fix probability calculation in LogisticRegression

* Change the default cost function to log-likelihood

* Remove redundant round function

* Fix for coding standard
This commit is contained in:
Yuji Uchiyama 2017-12-05 20:03:55 +09:00 committed by Arkadiusz Kondas
parent 946fbbc521
commit c4f58f7f6f
2 changed files with 118 additions and 9 deletions

View File

@ -32,7 +32,7 @@ class LogisticRegression extends Adaline
* *
* @var string * @var string
*/ */
protected $costFunction = 'sse'; protected $costFunction = 'log';
/** /**
* Regularization term: only 'L2' is supported * Regularization term: only 'L2' is supported
@ -67,7 +67,7 @@ class LogisticRegression extends Adaline
int $maxIterations = 500, int $maxIterations = 500,
bool $normalizeInputs = true, bool $normalizeInputs = true,
int $trainingType = self::CONJUGATE_GRAD_TRAINING, int $trainingType = self::CONJUGATE_GRAD_TRAINING,
string $cost = 'sse', string $cost = 'log',
string $penalty = 'L2' string $penalty = 'L2'
) { ) {
$trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING); $trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
@ -190,6 +190,8 @@ class LogisticRegression extends Adaline
$hX = 1e-10; $hX = 1e-10;
} }
$y = $y < 0 ? 0 : 1;
$error = -$y * log($hX) - (1 - $y) * log(1 - $hX); $error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
$gradient = $hX - $y; $gradient = $hX - $y;
@ -213,6 +215,8 @@ class LogisticRegression extends Adaline
$this->weights = $weights; $this->weights = $weights;
$hX = $this->output($sample); $hX = $this->output($sample);
$y = $y < 0 ? 0 : 1;
$error = ($y - $hX) ** 2; $error = ($y - $hX) ** 2;
$gradient = -($y - $hX) * $hX * (1 - $hX); $gradient = -($y - $hX) * $hX * (1 - $hX);
@ -243,7 +247,7 @@ class LogisticRegression extends Adaline
{ {
$output = $this->output($sample); $output = $this->output($sample);
if (round($output) > 0.5) { if ($output > 0.5) {
return 1; return 1;
} }
@ -260,14 +264,13 @@ class LogisticRegression extends Adaline
*/ */
protected function predictProbability(array $sample, $label): float protected function predictProbability(array $sample, $label): float
{ {
$predicted = $this->predictSampleBinary($sample);
if ((string) $predicted == (string) $label) {
$sample = $this->checkNormalizedSample($sample); $sample = $this->checkNormalizedSample($sample);
$probability = $this->output($sample);
return (float) abs($this->output($sample) - 0.5); if (array_search($label, $this->labels, true) > 0) {
return $probability;
} }
return 0.0; return 1 - $probability;
} }
} }

View File

@ -0,0 +1,106 @@
<?php
declare(strict_types=1);
namespace tests\Phpml\Classification\Linear;
use Phpml\Classification\Linear\LogisticRegression;
use PHPUnit\Framework\TestCase;
use ReflectionMethod;
use ReflectionProperty;
class LogisticRegressionTest extends TestCase
{
public function testPredictSingleSample(): void
{
// AND problem
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];
$classifier = new LogisticRegression();
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}
public function testPredictMultiClassSample(): void
{
// By use of One-v-Rest, Perceptron can perform multi-class classification
// The samples should be separable by lines perpendicular to the dimensions
$samples = [
[0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D
[5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right
[3, 10], [3, 10], [3, 8], [3, 9], // Third group : cluster at the top-middle
];
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
$classifier = new LogisticRegression();
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.5, 0.5]));
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
$this->assertEquals(2, $classifier->predict([3.0, 9.5]));
}
public function testPredictProbabilitySingleSample(): void
{
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];
$classifier = new LogisticRegression();
$classifier->train($samples, $targets);
$property = new ReflectionProperty($classifier, 'classifiers');
$property->setAccessible(true);
$predictor = $property->getValue($classifier)[0];
$method = new ReflectionMethod($predictor, 'predictProbability');
$method->setAccessible(true);
$zero = $method->invoke($predictor, [0.1, 0.1], 0);
$one = $method->invoke($predictor, [0.1, 0.1], 1);
$this->assertEquals(1, $zero + $one, null, 1e-6);
$this->assertTrue($zero > $one);
$zero = $method->invoke($predictor, [0.9, 0.9], 0);
$one = $method->invoke($predictor, [0.9, 0.9], 1);
$this->assertEquals(1, $zero + $one, null, 1e-6);
$this->assertTrue($zero < $one);
}
public function testPredictProbabilityMultiClassSample(): void
{
$samples = [
[0, 0], [0, 1], [1, 0], [1, 1],
[5, 5], [6, 5], [5, 6], [6, 6],
[3, 10], [3, 10], [3, 8], [3, 9],
];
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
$classifier = new LogisticRegression();
$classifier->train($samples, $targets);
$property = new ReflectionProperty($classifier, 'classifiers');
$property->setAccessible(true);
$predictor = $property->getValue($classifier)[0];
$method = new ReflectionMethod($predictor, 'predictProbability');
$method->setAccessible(true);
$zero = $method->invoke($predictor, [3.0, 9.5], 0);
$not_zero = $method->invoke($predictor, [3.0, 9.5], 'not_0');
$predictor = $property->getValue($classifier)[1];
$method = new ReflectionMethod($predictor, 'predictProbability');
$method->setAccessible(true);
$one = $method->invoke($predictor, [3.0, 9.5], 1);
$not_one = $method->invoke($predictor, [3.0, 9.5], 'not_1');
$predictor = $property->getValue($classifier)[2];
$method = new ReflectionMethod($predictor, 'predictProbability');
$method->setAccessible(true);
$two = $method->invoke($predictor, [3.0, 9.5], 2);
$not_two = $method->invoke($predictor, [3.0, 9.5], 'not_2');
$this->assertEquals(1, $zero + $not_zero, null, 1e-6);
$this->assertEquals(1, $one + $not_one, null, 1e-6);
$this->assertEquals(1, $two + $not_two, null, 1e-6);
$this->assertTrue($zero < $two);
$this->assertTrue($one < $two);
}
}