From af9ccfe722f26c48307776842003ec2b1659fca0 Mon Sep 17 00:00:00 2001 From: Yuji Uchiyama Date: Sat, 3 Mar 2018 19:19:58 +0900 Subject: [PATCH] Add tests for LogisticRegression (#248) --- .../Linear/LogisticRegressionTest.php | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/tests/Classification/Linear/LogisticRegressionTest.php b/tests/Classification/Linear/LogisticRegressionTest.php index f60d308..ed9b878 100644 --- a/tests/Classification/Linear/LogisticRegressionTest.php +++ b/tests/Classification/Linear/LogisticRegressionTest.php @@ -8,9 +8,49 @@ use Phpml\Classification\Linear\LogisticRegression; use PHPUnit\Framework\TestCase; use ReflectionMethod; use ReflectionProperty; +use Throwable; class LogisticRegressionTest extends TestCase { + public function testConstructorThrowWhenInvalidTrainingType(): void + { + $this->expectException(Throwable::class); + + $classifier = new LogisticRegression( + 500, + true, + -1, + 'log', + 'L2' + ); + } + + public function testConstructorThrowWhenInvalidCost(): void + { + $this->expectException(Throwable::class); + + $classifier = new LogisticRegression( + 500, + true, + LogisticRegression::CONJUGATE_GRAD_TRAINING, + 'invalid', + 'L2' + ); + } + + public function testConstructorThrowWhenInvalidPenalty(): void + { + $this->expectException(Throwable::class); + + $classifier = new LogisticRegression( + 500, + true, + LogisticRegression::CONJUGATE_GRAD_TRAINING, + 'log', + 'invalid' + ); + } + public function testPredictSingleSample(): void { // AND problem @@ -22,6 +62,76 @@ class LogisticRegressionTest extends TestCase $this->assertEquals(1, $classifier->predict([0.9, 0.9])); } + public function testPredictSingleSampleWithBatchTraining(): void + { + $samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]]; + $targets = [0, 0, 0, 1, 0, 1]; + + // $maxIterations is set to 10000 as batch training needs more + // iteration to converge than CG method in general. + $classifier = new LogisticRegression( + 10000, + true, + LogisticRegression::BATCH_TRAINING, + 'log', + 'L2' + ); + $classifier->train($samples, $targets); + $this->assertEquals(0, $classifier->predict([0.1, 0.1])); + $this->assertEquals(1, $classifier->predict([0.9, 0.9])); + } + + public function testPredictSingleSampleWithOnlineTraining(): void + { + $samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]]; + $targets = [0, 0, 0, 1, 0, 1]; + + // $penalty is set to empty (no penalty) because L2 penalty seems to + // prevent convergence in online training for this dataset. + $classifier = new LogisticRegression( + 10000, + true, + LogisticRegression::ONLINE_TRAINING, + 'log', + '' + ); + $classifier->train($samples, $targets); + $this->assertEquals(0, $classifier->predict([0.1, 0.1])); + $this->assertEquals(1, $classifier->predict([0.9, 0.9])); + } + + public function testPredictSingleSampleWithSSECost(): void + { + $samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]]; + $targets = [0, 0, 0, 1, 0, 1]; + $classifier = new LogisticRegression( + 500, + true, + LogisticRegression::CONJUGATE_GRAD_TRAINING, + 'sse', + 'L2' + ); + $classifier->train($samples, $targets); + $this->assertEquals(0, $classifier->predict([0.1, 0.1])); + $this->assertEquals(1, $classifier->predict([0.9, 0.9])); + } + + public function testPredictSingleSampleWithoutPenalty(): void + { + $samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]]; + $targets = [0, 0, 0, 1, 0, 1]; + $classifier = new LogisticRegression( + 500, + true, + LogisticRegression::CONJUGATE_GRAD_TRAINING, + 'log', + '' + ); + $classifier->train($samples, $targets); + $this->assertEquals(0, $classifier->predict([0.1, 0.1])); + $this->assertEquals(1, $classifier->predict([0.9, 0.9])); + } + public function testPredictMultiClassSample(): void { // By use of One-v-Rest, Perceptron can perform multi-class classification