php-ml/tests/SupportVectorMachine/SupportVectorMachineTest.php

207 lines
5.9 KiB
PHP
Raw Normal View History

2016-05-05 23:29:11 +02:00
<?php
2016-11-20 22:53:17 +01:00
declare(strict_types=1);
2016-05-05 23:29:11 +02:00
namespace Phpml\Tests\SupportVectorMachine;
2016-05-05 23:29:11 +02:00
use Phpml\Exception\InvalidArgumentException;
use Phpml\Exception\InvalidOperationException;
2018-01-26 22:07:22 +01:00
use Phpml\Exception\LibsvmCommandException;
2016-05-05 23:29:11 +02:00
use Phpml\SupportVectorMachine\Kernel;
use Phpml\SupportVectorMachine\SupportVectorMachine;
use Phpml\SupportVectorMachine\Type;
2017-02-03 12:58:25 +01:00
use PHPUnit\Framework\TestCase;
2016-05-05 23:29:11 +02:00
2017-02-03 12:58:25 +01:00
class SupportVectorMachineTest extends TestCase
2016-05-05 23:29:11 +02:00
{
public function testTrainCSVCModelWithLinearKernel(): void
2016-05-05 23:29:11 +02:00
{
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$model =
'svm_type c_svc
kernel_type linear
nr_class 2
total_sv 2
rho 0
label 0 1
nr_sv 1 1
SV
2016-05-06 22:33:04 +02:00
0.25 1:2 2:4
-0.25 1:4 2:2
2016-05-05 23:29:11 +02:00
';
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::LINEAR, 100.0);
$svm->train($samples, $labels);
2018-10-28 07:44:52 +01:00
self::assertEquals($model, $svm->getModel());
2016-05-05 23:29:11 +02:00
}
2016-05-06 22:55:41 +02:00
public function testTrainCSVCModelWithProbabilityEstimate(): void
{
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$svm = new SupportVectorMachine(
Type::C_SVC,
Kernel::LINEAR,
100.0,
0.5,
3,
null,
0.0,
0.1,
0.01,
100,
true,
true
);
$svm->train($samples, $labels);
self::assertStringContainsString(PHP_EOL.'probA ', $svm->getModel());
self::assertStringContainsString(PHP_EOL.'probB ', $svm->getModel());
}
public function testPredictSampleWithLinearKernel(): void
2016-05-06 22:55:41 +02:00
{
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::LINEAR, 100.0);
$svm->train($samples, $labels);
$predictions = $svm->predict([
[3, 2],
[2, 3],
[4, -5],
]);
2018-10-28 07:44:52 +01:00
self::assertEquals('b', $predictions[0]);
self::assertEquals('a', $predictions[1]);
self::assertEquals('b', $predictions[2]);
2016-05-06 22:55:41 +02:00
}
2016-05-07 14:08:09 +02:00
public function testPredictSampleFromMultipleClassWithRbfKernel(): void
2016-05-07 14:08:09 +02:00
{
$samples = [
[1, 3], [1, 4], [1, 4],
[3, 1], [4, 1], [4, 2],
[-3, -1], [-4, -1], [-4, -2],
];
$labels = [
'a', 'a', 'a',
'b', 'b', 'b',
'c', 'c', 'c',
];
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF, 100.0);
$svm->train($samples, $labels);
$predictions = $svm->predict([
[1, 5],
[4, 3],
[-4, -3],
]);
2018-10-28 07:44:52 +01:00
self::assertEquals('a', $predictions[0]);
self::assertEquals('b', $predictions[1]);
self::assertEquals('c', $predictions[2]);
2016-05-07 14:08:09 +02:00
}
public function testPredictProbability(): void
{
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$svm = new SupportVectorMachine(
Type::C_SVC,
Kernel::LINEAR,
100.0,
0.5,
3,
null,
0.0,
0.1,
0.01,
100,
true,
true
);
$svm->train($samples, $labels);
$predictions = $svm->predictProbability([
[3, 2],
[2, 3],
[4, -5],
]);
2018-10-28 07:44:52 +01:00
self::assertTrue($predictions[0]['a'] < $predictions[0]['b']);
self::assertTrue($predictions[1]['a'] > $predictions[1]['b']);
self::assertTrue($predictions[2]['a'] < $predictions[2]['b']);
// Should be true because the latter is farther from the decision boundary
2018-10-28 07:44:52 +01:00
self::assertTrue($predictions[0]['b'] < $predictions[2]['b']);
}
public function testThrowExceptionWhenVarPathIsNotWritable(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('is not writable');
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF);
$svm->setVarPath('var-path');
}
public function testThrowExceptionWhenBinPathDoesNotExist(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('does not exist');
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF);
$svm->setBinPath('bin-path');
}
public function testThrowExceptionWhenFileIsNotFoundInBinPath(): void
{
$this->expectException(InvalidArgumentException::class);
$this->expectExceptionMessage('not found');
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF);
$svm->setBinPath('var');
}
2018-01-26 22:07:22 +01:00
public function testThrowExceptionWhenLibsvmFailsDuringTrain(): void
{
$this->expectException(LibsvmCommandException::class);
$this->expectExceptionMessage('ERROR: unknown svm type');
$svm = new SupportVectorMachine(99, Kernel::RBF);
$svm->train([], []);
}
public function testThrowExceptionWhenLibsvmFailsDuringPredict(): void
{
$this->expectException(LibsvmCommandException::class);
$this->expectExceptionMessage('can\'t open model file');
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF);
$svm->predict([1]);
}
public function testThrowExceptionWhenPredictProbabilityCalledWithoutProperModel(): void
{
$this->expectException(InvalidOperationException::class);
$this->expectExceptionMessage('Model does not support probabiliy estimates');
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::LINEAR, 100.0);
$svm->train($samples, $labels);
$svm->predictProbability([
[3, 2],
[2, 3],
[4, -5],
]);
}
2016-05-05 23:29:11 +02:00
}