2017-12-05 20:03:55 +09:00
|
|
|
<?php
|
|
|
|
|
|
|
|
declare(strict_types=1);
|
|
|
|
|
2018-01-06 21:25:47 +01:00
|
|
|
namespace Phpml\Tests\Classification\Linear;
|
2017-12-05 20:03:55 +09:00
|
|
|
|
|
|
|
use Phpml\Classification\Linear\LogisticRegression;
|
2018-03-07 07:26:36 +09:00
|
|
|
use Phpml\Exception\InvalidArgumentException;
|
2017-12-05 20:03:55 +09:00
|
|
|
use PHPUnit\Framework\TestCase;
|
|
|
|
use ReflectionMethod;
|
|
|
|
use ReflectionProperty;
|
|
|
|
|
|
|
|
class LogisticRegressionTest extends TestCase
|
|
|
|
{
|
2018-03-03 19:19:58 +09:00
|
|
|
public function testConstructorThrowWhenInvalidTrainingType(): void
|
|
|
|
{
|
2018-03-07 07:26:36 +09:00
|
|
|
$this->expectException(InvalidArgumentException::class);
|
2018-10-16 21:42:06 +02:00
|
|
|
$this->expectExceptionMessage('Logistic regression can only be trained with '.
|
|
|
|
'batch (gradient descent), online (stochastic gradient descent) '.
|
|
|
|
'or conjugate batch (conjugate gradients) algorithms');
|
2018-03-03 19:19:58 +09:00
|
|
|
|
2018-10-16 21:42:06 +02:00
|
|
|
new LogisticRegression(
|
2018-03-03 19:19:58 +09:00
|
|
|
500,
|
|
|
|
true,
|
|
|
|
-1,
|
|
|
|
'log',
|
|
|
|
'L2'
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
public function testConstructorThrowWhenInvalidCost(): void
|
|
|
|
{
|
2018-03-07 07:26:36 +09:00
|
|
|
$this->expectException(InvalidArgumentException::class);
|
2018-10-16 21:42:06 +02:00
|
|
|
$this->expectExceptionMessage("Logistic regression cost function can be one of the following: \n".
|
|
|
|
"'log' for log-likelihood and 'sse' for sum of squared errors");
|
2018-03-03 19:19:58 +09:00
|
|
|
|
2018-10-16 21:42:06 +02:00
|
|
|
new LogisticRegression(
|
2018-03-03 19:19:58 +09:00
|
|
|
500,
|
|
|
|
true,
|
|
|
|
LogisticRegression::CONJUGATE_GRAD_TRAINING,
|
|
|
|
'invalid',
|
|
|
|
'L2'
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
|
|
|
public function testConstructorThrowWhenInvalidPenalty(): void
|
|
|
|
{
|
2018-03-07 07:26:36 +09:00
|
|
|
$this->expectException(InvalidArgumentException::class);
|
2018-10-16 21:42:06 +02:00
|
|
|
$this->expectExceptionMessage('Logistic regression supports only \'L2\' regularization');
|
2018-03-03 19:19:58 +09:00
|
|
|
|
2018-10-16 21:42:06 +02:00
|
|
|
new LogisticRegression(
|
2018-03-03 19:19:58 +09:00
|
|
|
500,
|
|
|
|
true,
|
|
|
|
LogisticRegression::CONJUGATE_GRAD_TRAINING,
|
|
|
|
'log',
|
|
|
|
'invalid'
|
|
|
|
);
|
|
|
|
}
|
|
|
|
|
2017-12-05 20:03:55 +09:00
|
|
|
public function testPredictSingleSample(): void
|
|
|
|
{
|
|
|
|
// AND problem
|
|
|
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
|
|
|
$targets = [0, 0, 0, 1, 0, 1];
|
|
|
|
$classifier = new LogisticRegression();
|
|
|
|
$classifier->train($samples, $targets);
|
|
|
|
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
|
|
|
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
|
|
|
}
|
|
|
|
|
2018-03-03 19:19:58 +09:00
|
|
|
public function testPredictSingleSampleWithBatchTraining(): void
|
|
|
|
{
|
|
|
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
|
|
|
$targets = [0, 0, 0, 1, 0, 1];
|
|
|
|
|
|
|
|
// $maxIterations is set to 10000 as batch training needs more
|
|
|
|
// iteration to converge than CG method in general.
|
|
|
|
$classifier = new LogisticRegression(
|
|
|
|
10000,
|
|
|
|
true,
|
|
|
|
LogisticRegression::BATCH_TRAINING,
|
|
|
|
'log',
|
|
|
|
'L2'
|
|
|
|
);
|
|
|
|
$classifier->train($samples, $targets);
|
|
|
|
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
|
|
|
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
|
|
|
}
|
|
|
|
|
|
|
|
public function testPredictSingleSampleWithOnlineTraining(): void
|
|
|
|
{
|
|
|
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
|
|
|
$targets = [0, 0, 0, 1, 0, 1];
|
|
|
|
|
|
|
|
// $penalty is set to empty (no penalty) because L2 penalty seems to
|
|
|
|
// prevent convergence in online training for this dataset.
|
|
|
|
$classifier = new LogisticRegression(
|
|
|
|
10000,
|
|
|
|
true,
|
|
|
|
LogisticRegression::ONLINE_TRAINING,
|
|
|
|
'log',
|
|
|
|
''
|
|
|
|
);
|
|
|
|
$classifier->train($samples, $targets);
|
|
|
|
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
|
|
|
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
|
|
|
}
|
|
|
|
|
|
|
|
public function testPredictSingleSampleWithSSECost(): void
|
|
|
|
{
|
|
|
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
|
|
|
$targets = [0, 0, 0, 1, 0, 1];
|
|
|
|
$classifier = new LogisticRegression(
|
|
|
|
500,
|
|
|
|
true,
|
|
|
|
LogisticRegression::CONJUGATE_GRAD_TRAINING,
|
|
|
|
'sse',
|
|
|
|
'L2'
|
|
|
|
);
|
|
|
|
$classifier->train($samples, $targets);
|
|
|
|
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
|
|
|
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
|
|
|
}
|
|
|
|
|
|
|
|
public function testPredictSingleSampleWithoutPenalty(): void
|
|
|
|
{
|
|
|
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
|
|
|
$targets = [0, 0, 0, 1, 0, 1];
|
|
|
|
$classifier = new LogisticRegression(
|
|
|
|
500,
|
|
|
|
true,
|
|
|
|
LogisticRegression::CONJUGATE_GRAD_TRAINING,
|
|
|
|
'log',
|
|
|
|
''
|
|
|
|
);
|
|
|
|
$classifier->train($samples, $targets);
|
|
|
|
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
|
|
|
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
|
|
|
}
|
|
|
|
|
2017-12-05 20:03:55 +09:00
|
|
|
public function testPredictMultiClassSample(): void
|
|
|
|
{
|
|
|
|
// By use of One-v-Rest, Perceptron can perform multi-class classification
|
|
|
|
// The samples should be separable by lines perpendicular to the dimensions
|
|
|
|
$samples = [
|
|
|
|
[0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D
|
|
|
|
[5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right
|
|
|
|
[3, 10], [3, 10], [3, 8], [3, 9], // Third group : cluster at the top-middle
|
|
|
|
];
|
|
|
|
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
|
|
|
|
|
|
|
|
$classifier = new LogisticRegression();
|
|
|
|
$classifier->train($samples, $targets);
|
|
|
|
$this->assertEquals(0, $classifier->predict([0.5, 0.5]));
|
|
|
|
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
|
|
|
|
$this->assertEquals(2, $classifier->predict([3.0, 9.5]));
|
|
|
|
}
|
|
|
|
|
|
|
|
public function testPredictProbabilitySingleSample(): void
|
|
|
|
{
|
|
|
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
|
|
|
$targets = [0, 0, 0, 1, 0, 1];
|
|
|
|
$classifier = new LogisticRegression();
|
|
|
|
$classifier->train($samples, $targets);
|
|
|
|
|
|
|
|
$property = new ReflectionProperty($classifier, 'classifiers');
|
|
|
|
$property->setAccessible(true);
|
|
|
|
$predictor = $property->getValue($classifier)[0];
|
|
|
|
$method = new ReflectionMethod($predictor, 'predictProbability');
|
|
|
|
$method->setAccessible(true);
|
|
|
|
|
|
|
|
$zero = $method->invoke($predictor, [0.1, 0.1], 0);
|
|
|
|
$one = $method->invoke($predictor, [0.1, 0.1], 1);
|
2018-01-06 21:25:47 +01:00
|
|
|
$this->assertEquals(1, $zero + $one, '', 1e-6);
|
2017-12-05 20:03:55 +09:00
|
|
|
$this->assertTrue($zero > $one);
|
|
|
|
|
|
|
|
$zero = $method->invoke($predictor, [0.9, 0.9], 0);
|
|
|
|
$one = $method->invoke($predictor, [0.9, 0.9], 1);
|
2018-01-06 21:25:47 +01:00
|
|
|
$this->assertEquals(1, $zero + $one, '', 1e-6);
|
2017-12-05 20:03:55 +09:00
|
|
|
$this->assertTrue($zero < $one);
|
|
|
|
}
|
|
|
|
|
|
|
|
public function testPredictProbabilityMultiClassSample(): void
|
|
|
|
{
|
|
|
|
$samples = [
|
|
|
|
[0, 0], [0, 1], [1, 0], [1, 1],
|
|
|
|
[5, 5], [6, 5], [5, 6], [6, 6],
|
|
|
|
[3, 10], [3, 10], [3, 8], [3, 9],
|
|
|
|
];
|
|
|
|
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];
|
|
|
|
|
|
|
|
$classifier = new LogisticRegression();
|
|
|
|
$classifier->train($samples, $targets);
|
|
|
|
|
|
|
|
$property = new ReflectionProperty($classifier, 'classifiers');
|
|
|
|
$property->setAccessible(true);
|
|
|
|
|
|
|
|
$predictor = $property->getValue($classifier)[0];
|
|
|
|
$method = new ReflectionMethod($predictor, 'predictProbability');
|
|
|
|
$method->setAccessible(true);
|
|
|
|
$zero = $method->invoke($predictor, [3.0, 9.5], 0);
|
|
|
|
$not_zero = $method->invoke($predictor, [3.0, 9.5], 'not_0');
|
|
|
|
|
|
|
|
$predictor = $property->getValue($classifier)[1];
|
|
|
|
$method = new ReflectionMethod($predictor, 'predictProbability');
|
|
|
|
$method->setAccessible(true);
|
|
|
|
$one = $method->invoke($predictor, [3.0, 9.5], 1);
|
|
|
|
$not_one = $method->invoke($predictor, [3.0, 9.5], 'not_1');
|
|
|
|
|
|
|
|
$predictor = $property->getValue($classifier)[2];
|
|
|
|
$method = new ReflectionMethod($predictor, 'predictProbability');
|
|
|
|
$method->setAccessible(true);
|
|
|
|
$two = $method->invoke($predictor, [3.0, 9.5], 2);
|
|
|
|
$not_two = $method->invoke($predictor, [3.0, 9.5], 'not_2');
|
|
|
|
|
2018-01-06 21:25:47 +01:00
|
|
|
$this->assertEquals(1, $zero + $not_zero, '', 1e-6);
|
|
|
|
$this->assertEquals(1, $one + $not_one, '', 1e-6);
|
|
|
|
$this->assertEquals(1, $two + $not_two, '', 1e-6);
|
2017-12-05 20:03:55 +09:00
|
|
|
$this->assertTrue($zero < $two);
|
|
|
|
$this->assertTrue($one < $two);
|
|
|
|
}
|
|
|
|
}
|