mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-06-12 13:22:21 +00:00
Add tests for LogisticRegression (#248)
This commit is contained in:
parent
9c195559df
commit
af9ccfe722
|
@ -8,9 +8,49 @@ use Phpml\Classification\Linear\LogisticRegression;
|
||||||
use PHPUnit\Framework\TestCase;
|
use PHPUnit\Framework\TestCase;
|
||||||
use ReflectionMethod;
|
use ReflectionMethod;
|
||||||
use ReflectionProperty;
|
use ReflectionProperty;
|
||||||
|
use Throwable;
|
||||||
|
|
||||||
class LogisticRegressionTest extends TestCase
|
class LogisticRegressionTest extends TestCase
|
||||||
{
|
{
|
||||||
|
public function testConstructorThrowWhenInvalidTrainingType(): void
|
||||||
|
{
|
||||||
|
$this->expectException(Throwable::class);
|
||||||
|
|
||||||
|
$classifier = new LogisticRegression(
|
||||||
|
500,
|
||||||
|
true,
|
||||||
|
-1,
|
||||||
|
'log',
|
||||||
|
'L2'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testConstructorThrowWhenInvalidCost(): void
|
||||||
|
{
|
||||||
|
$this->expectException(Throwable::class);
|
||||||
|
|
||||||
|
$classifier = new LogisticRegression(
|
||||||
|
500,
|
||||||
|
true,
|
||||||
|
LogisticRegression::CONJUGATE_GRAD_TRAINING,
|
||||||
|
'invalid',
|
||||||
|
'L2'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testConstructorThrowWhenInvalidPenalty(): void
|
||||||
|
{
|
||||||
|
$this->expectException(Throwable::class);
|
||||||
|
|
||||||
|
$classifier = new LogisticRegression(
|
||||||
|
500,
|
||||||
|
true,
|
||||||
|
LogisticRegression::CONJUGATE_GRAD_TRAINING,
|
||||||
|
'log',
|
||||||
|
'invalid'
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
public function testPredictSingleSample(): void
|
public function testPredictSingleSample(): void
|
||||||
{
|
{
|
||||||
// AND problem
|
// AND problem
|
||||||
|
@ -22,6 +62,76 @@ class LogisticRegressionTest extends TestCase
|
||||||
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function testPredictSingleSampleWithBatchTraining(): void
|
||||||
|
{
|
||||||
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
||||||
|
$targets = [0, 0, 0, 1, 0, 1];
|
||||||
|
|
||||||
|
// $maxIterations is set to 10000 as batch training needs more
|
||||||
|
// iteration to converge than CG method in general.
|
||||||
|
$classifier = new LogisticRegression(
|
||||||
|
10000,
|
||||||
|
true,
|
||||||
|
LogisticRegression::BATCH_TRAINING,
|
||||||
|
'log',
|
||||||
|
'L2'
|
||||||
|
);
|
||||||
|
$classifier->train($samples, $targets);
|
||||||
|
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
||||||
|
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testPredictSingleSampleWithOnlineTraining(): void
|
||||||
|
{
|
||||||
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
||||||
|
$targets = [0, 0, 0, 1, 0, 1];
|
||||||
|
|
||||||
|
// $penalty is set to empty (no penalty) because L2 penalty seems to
|
||||||
|
// prevent convergence in online training for this dataset.
|
||||||
|
$classifier = new LogisticRegression(
|
||||||
|
10000,
|
||||||
|
true,
|
||||||
|
LogisticRegression::ONLINE_TRAINING,
|
||||||
|
'log',
|
||||||
|
''
|
||||||
|
);
|
||||||
|
$classifier->train($samples, $targets);
|
||||||
|
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
||||||
|
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testPredictSingleSampleWithSSECost(): void
|
||||||
|
{
|
||||||
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
||||||
|
$targets = [0, 0, 0, 1, 0, 1];
|
||||||
|
$classifier = new LogisticRegression(
|
||||||
|
500,
|
||||||
|
true,
|
||||||
|
LogisticRegression::CONJUGATE_GRAD_TRAINING,
|
||||||
|
'sse',
|
||||||
|
'L2'
|
||||||
|
);
|
||||||
|
$classifier->train($samples, $targets);
|
||||||
|
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
||||||
|
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testPredictSingleSampleWithoutPenalty(): void
|
||||||
|
{
|
||||||
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
|
||||||
|
$targets = [0, 0, 0, 1, 0, 1];
|
||||||
|
$classifier = new LogisticRegression(
|
||||||
|
500,
|
||||||
|
true,
|
||||||
|
LogisticRegression::CONJUGATE_GRAD_TRAINING,
|
||||||
|
'log',
|
||||||
|
''
|
||||||
|
);
|
||||||
|
$classifier->train($samples, $targets);
|
||||||
|
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
|
||||||
|
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
|
||||||
|
}
|
||||||
|
|
||||||
public function testPredictMultiClassSample(): void
|
public function testPredictMultiClassSample(): void
|
||||||
{
|
{
|
||||||
// By use of One-v-Rest, Perceptron can perform multi-class classification
|
// By use of One-v-Rest, Perceptron can perform multi-class classification
|
||||||
|
|
Loading…
Reference in New Issue
Block a user