From c4f58f7f6f35450bf276ab72e19550713b7f2ac2 Mon Sep 17 00:00:00 2001 From: Yuji Uchiyama Date: Tue, 5 Dec 2017 20:03:55 +0900 Subject: [PATCH] Fix logistic regression implementation (#169) * Fix target value of LogisticRegression * Fix probability calculation in LogisticRegression * Change the default cost function to log-likelihood * Remove redundant round function * Fix for coding standard --- .../Linear/LogisticRegression.php | 21 ++-- .../Linear/LogisticRegressionTest.php | 106 ++++++++++++++++++ 2 files changed, 118 insertions(+), 9 deletions(-) create mode 100644 tests/Phpml/Classification/Linear/LogisticRegressionTest.php diff --git a/src/Phpml/Classification/Linear/LogisticRegression.php b/src/Phpml/Classification/Linear/LogisticRegression.php index 6b8cdd5..3818161 100644 --- a/src/Phpml/Classification/Linear/LogisticRegression.php +++ b/src/Phpml/Classification/Linear/LogisticRegression.php @@ -32,7 +32,7 @@ class LogisticRegression extends Adaline * * @var string */ - protected $costFunction = 'sse'; + protected $costFunction = 'log'; /** * Regularization term: only 'L2' is supported @@ -67,7 +67,7 @@ class LogisticRegression extends Adaline int $maxIterations = 500, bool $normalizeInputs = true, int $trainingType = self::CONJUGATE_GRAD_TRAINING, - string $cost = 'sse', + string $cost = 'log', string $penalty = 'L2' ) { $trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING); @@ -190,6 +190,8 @@ class LogisticRegression extends Adaline $hX = 1e-10; } + $y = $y < 0 ? 0 : 1; + $error = -$y * log($hX) - (1 - $y) * log(1 - $hX); $gradient = $hX - $y; @@ -213,6 +215,8 @@ class LogisticRegression extends Adaline $this->weights = $weights; $hX = $this->output($sample); + $y = $y < 0 ? 0 : 1; + $error = ($y - $hX) ** 2; $gradient = -($y - $hX) * $hX * (1 - $hX); @@ -243,7 +247,7 @@ class LogisticRegression extends Adaline { $output = $this->output($sample); - if (round($output) > 0.5) { + if ($output > 0.5) { return 1; } @@ -260,14 +264,13 @@ class LogisticRegression extends Adaline */ protected function predictProbability(array $sample, $label): float { - $predicted = $this->predictSampleBinary($sample); + $sample = $this->checkNormalizedSample($sample); + $probability = $this->output($sample); - if ((string) $predicted == (string) $label) { - $sample = $this->checkNormalizedSample($sample); - - return (float) abs($this->output($sample) - 0.5); + if (array_search($label, $this->labels, true) > 0) { + return $probability; } - return 0.0; + return 1 - $probability; } } diff --git a/tests/Phpml/Classification/Linear/LogisticRegressionTest.php b/tests/Phpml/Classification/Linear/LogisticRegressionTest.php new file mode 100644 index 0000000..85fc159 --- /dev/null +++ b/tests/Phpml/Classification/Linear/LogisticRegressionTest.php @@ -0,0 +1,106 @@ +train($samples, $targets); + $this->assertEquals(0, $classifier->predict([0.1, 0.1])); + $this->assertEquals(1, $classifier->predict([0.9, 0.9])); + } + + public function testPredictMultiClassSample(): void + { + // By use of One-v-Rest, Perceptron can perform multi-class classification + // The samples should be separable by lines perpendicular to the dimensions + $samples = [ + [0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D + [5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right + [3, 10], [3, 10], [3, 8], [3, 9], // Third group : cluster at the top-middle + ]; + $targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]; + + $classifier = new LogisticRegression(); + $classifier->train($samples, $targets); + $this->assertEquals(0, $classifier->predict([0.5, 0.5])); + $this->assertEquals(1, $classifier->predict([6.0, 5.0])); + $this->assertEquals(2, $classifier->predict([3.0, 9.5])); + } + + public function testPredictProbabilitySingleSample(): void + { + $samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]]; + $targets = [0, 0, 0, 1, 0, 1]; + $classifier = new LogisticRegression(); + $classifier->train($samples, $targets); + + $property = new ReflectionProperty($classifier, 'classifiers'); + $property->setAccessible(true); + $predictor = $property->getValue($classifier)[0]; + $method = new ReflectionMethod($predictor, 'predictProbability'); + $method->setAccessible(true); + + $zero = $method->invoke($predictor, [0.1, 0.1], 0); + $one = $method->invoke($predictor, [0.1, 0.1], 1); + $this->assertEquals(1, $zero + $one, null, 1e-6); + $this->assertTrue($zero > $one); + + $zero = $method->invoke($predictor, [0.9, 0.9], 0); + $one = $method->invoke($predictor, [0.9, 0.9], 1); + $this->assertEquals(1, $zero + $one, null, 1e-6); + $this->assertTrue($zero < $one); + } + + public function testPredictProbabilityMultiClassSample(): void + { + $samples = [ + [0, 0], [0, 1], [1, 0], [1, 1], + [5, 5], [6, 5], [5, 6], [6, 6], + [3, 10], [3, 10], [3, 8], [3, 9], + ]; + $targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]; + + $classifier = new LogisticRegression(); + $classifier->train($samples, $targets); + + $property = new ReflectionProperty($classifier, 'classifiers'); + $property->setAccessible(true); + + $predictor = $property->getValue($classifier)[0]; + $method = new ReflectionMethod($predictor, 'predictProbability'); + $method->setAccessible(true); + $zero = $method->invoke($predictor, [3.0, 9.5], 0); + $not_zero = $method->invoke($predictor, [3.0, 9.5], 'not_0'); + + $predictor = $property->getValue($classifier)[1]; + $method = new ReflectionMethod($predictor, 'predictProbability'); + $method->setAccessible(true); + $one = $method->invoke($predictor, [3.0, 9.5], 1); + $not_one = $method->invoke($predictor, [3.0, 9.5], 'not_1'); + + $predictor = $property->getValue($classifier)[2]; + $method = new ReflectionMethod($predictor, 'predictProbability'); + $method->setAccessible(true); + $two = $method->invoke($predictor, [3.0, 9.5], 2); + $not_two = $method->invoke($predictor, [3.0, 9.5], 'not_2'); + + $this->assertEquals(1, $zero + $not_zero, null, 1e-6); + $this->assertEquals(1, $one + $not_one, null, 1e-6); + $this->assertEquals(1, $two + $not_two, null, 1e-6); + $this->assertTrue($zero < $two); + $this->assertTrue($one < $two); + } +}