Add tests for LogisticRegression (#248)

This commit is contained in:
Yuji Uchiyama 2018-03-03 19:19:58 +09:00 committed by Arkadiusz Kondas
parent 9c195559df
commit af9ccfe722
1 changed files with 110 additions and 0 deletions

View File

@ -8,9 +8,49 @@ use Phpml\Classification\Linear\LogisticRegression;
use PHPUnit\Framework\TestCase;
use ReflectionMethod;
use ReflectionProperty;
use Throwable;
class LogisticRegressionTest extends TestCase
{
public function testConstructorThrowWhenInvalidTrainingType(): void
{
$this->expectException(Throwable::class);
$classifier = new LogisticRegression(
500,
true,
-1,
'log',
'L2'
);
}
public function testConstructorThrowWhenInvalidCost(): void
{
$this->expectException(Throwable::class);
$classifier = new LogisticRegression(
500,
true,
LogisticRegression::CONJUGATE_GRAD_TRAINING,
'invalid',
'L2'
);
}
public function testConstructorThrowWhenInvalidPenalty(): void
{
$this->expectException(Throwable::class);
$classifier = new LogisticRegression(
500,
true,
LogisticRegression::CONJUGATE_GRAD_TRAINING,
'log',
'invalid'
);
}
public function testPredictSingleSample(): void
{
// AND problem
@ -22,6 +62,76 @@ class LogisticRegressionTest extends TestCase
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}
public function testPredictSingleSampleWithBatchTraining(): void
{
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];
// $maxIterations is set to 10000 as batch training needs more
// iteration to converge than CG method in general.
$classifier = new LogisticRegression(
10000,
true,
LogisticRegression::BATCH_TRAINING,
'log',
'L2'
);
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}
public function testPredictSingleSampleWithOnlineTraining(): void
{
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];
// $penalty is set to empty (no penalty) because L2 penalty seems to
// prevent convergence in online training for this dataset.
$classifier = new LogisticRegression(
10000,
true,
LogisticRegression::ONLINE_TRAINING,
'log',
''
);
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}
public function testPredictSingleSampleWithSSECost(): void
{
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];
$classifier = new LogisticRegression(
500,
true,
LogisticRegression::CONJUGATE_GRAD_TRAINING,
'sse',
'L2'
);
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}
public function testPredictSingleSampleWithoutPenalty(): void
{
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];
$classifier = new LogisticRegression(
500,
true,
LogisticRegression::CONJUGATE_GRAD_TRAINING,
'log',
''
);
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}
public function testPredictMultiClassSample(): void
{
// By use of One-v-Rest, Perceptron can perform multi-class classification