Support probability estimation in SVC (#218)

* Add test for svm model with probability estimation

* Extract buildPredictCommand method

* Fix test to use PHP_EOL

* Add predictProbability method (not completed)

* Add test for DataTransformer::predictions

* Fix SVM to use PHP_EOL

* Support probability estimation in SVM

* Add documentation

* Add InvalidOperationException class

* Throw InvalidOperationException before executing libsvm if probability estimation is not supported
This commit is contained in:
Yuji Uchiyama 2018-02-07 04:39:25 +09:00 committed by Arkadiusz Kondas
parent ed775fb232
commit ec091b5ea3
6 changed files with 268 additions and 11 deletions

View File

@ -47,3 +47,42 @@ $classifier->predict([3, 2]);
$classifier->predict([[3, 2], [1, 5]]);
// return ['b', 'a']
```
### Probability estimation
To predict probabilities you must build a classifier with `$probabilityEstimates` set to true. Example:
```
use Phpml\Classification\SVC;
use Phpml\SupportVectorMachine\Kernel;
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$classifier = new SVC(
Kernel::LINEAR, // $kernel
1.0, // $cost
3, // $degree
null, // $gamma
0.0, // $coef0
0.001, // $tolerance
100, // $cacheSize
true, // $shrinking
true // $probabilityEstimates, set to true
);
$classifier->train($samples, $labels);
```
Then use `predictProbability` method instead of `predict`:
```
$classifier->predictProbability([3, 2]);
// return ['a' => 0.349833, 'b' => 0.650167]
$classifier->predictProbability([[3, 2], [1, 5]]);
// return [
// ['a' => 0.349833, 'b' => 0.650167],
// ['a' => 0.922664, 'b' => 0.0773364],
// ]
```

View File

@ -0,0 +1,11 @@
<?php
declare(strict_types=1);
namespace Phpml\Exception;
use Exception;
class InvalidOperationException extends Exception
{
}

View File

@ -49,6 +49,37 @@ class DataTransformer
return $results;
}
public static function probabilities(string $rawPredictions, array $labels): array
{
$numericLabels = self::numericLabels($labels);
$predictions = explode(PHP_EOL, trim($rawPredictions));
$header = array_shift($predictions);
$headerColumns = explode(' ', $header);
array_shift($headerColumns);
$columnLabels = [];
foreach ($headerColumns as $numericLabel) {
$columnLabels[] = array_search($numericLabel, $numericLabels);
}
$results = [];
foreach ($predictions as $rawResult) {
$probabilities = explode(' ', $rawResult);
array_shift($probabilities);
$result = [];
foreach ($probabilities as $i => $prob) {
$result[$columnLabels[$i]] = (float) $prob;
}
$results[] = $result;
}
return $results;
}
public static function numericLabels(array $labels): array
{
$numericLabels = [];

View File

@ -5,6 +5,7 @@ declare(strict_types=1);
namespace Phpml\SupportVectorMachine;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Exception\InvalidOperationException;
use Phpml\Exception\LibsvmCommandException;
use Phpml\Helper\Trainable;
@ -178,13 +179,61 @@ class SupportVectorMachine
* @throws LibsvmCommandException
*/
public function predict(array $samples)
{
$predictions = $this->runSvmPredict($samples, false);
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
$predictions = DataTransformer::predictions($predictions, $this->targets);
} else {
$predictions = explode(PHP_EOL, trim($predictions));
}
if (!is_array($samples[0])) {
return $predictions[0];
}
return $predictions;
}
/**
* @return array|string
*
* @throws LibsvmCommandException
*/
public function predictProbability(array $samples)
{
if (!$this->probabilityEstimates) {
throw new InvalidOperationException('Model does not support probabiliy estimates');
}
$predictions = $this->runSvmPredict($samples, true);
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
$predictions = DataTransformer::probabilities($predictions, $this->targets);
} else {
$predictions = explode(PHP_EOL, trim($predictions));
}
if (!is_array($samples[0])) {
return $predictions[0];
}
return $predictions;
}
private function runSvmPredict(array $samples, bool $probabilityEstimates): string
{
$testSet = DataTransformer::testSet($samples);
file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
$outputFileName = $testSetFileName.'-output';
$command = sprintf('%ssvm-predict%s %s %s %s', $this->binPath, $this->getOSExtension(), $testSetFileName, $modelFileName, $outputFileName);
$command = $this->buildPredictCommand(
$testSetFileName,
$modelFileName,
$outputFileName,
$probabilityEstimates
);
$output = [];
exec(escapeshellcmd($command).' 2>&1', $output, $return);
@ -198,16 +247,6 @@ class SupportVectorMachine
throw LibsvmCommandException::failedToRun($command, array_pop($output));
}
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
$predictions = DataTransformer::predictions($predictions, $this->targets);
} else {
$predictions = explode(PHP_EOL, trim($predictions));
}
if (!is_array($samples[0])) {
return $predictions[0];
}
return $predictions;
}
@ -246,6 +285,23 @@ class SupportVectorMachine
);
}
private function buildPredictCommand(
string $testSetFileName,
string $modelFileName,
string $outputFileName,
bool $probabilityEstimates
): string {
return sprintf(
'%ssvm-predict%s -b %d %s %s %s',
$this->binPath,
$this->getOSExtension(),
$probabilityEstimates ? 1 : 0,
escapeshellarg($testSetFileName),
escapeshellarg($modelFileName),
escapeshellarg($outputFileName)
);
}
private function ensureDirectorySeparator(string &$path): void
{
if (substr($path, -1) !== DIRECTORY_SEPARATOR) {

View File

@ -37,4 +37,45 @@ class DataTransformerTest extends TestCase
$this->assertEquals($testSet, DataTransformer::testSet($samples));
}
public function testPredictions(): void
{
$labels = ['a', 'a', 'b', 'b'];
$rawPredictions = implode(PHP_EOL, [0, 1, 0, 0]);
$predictions = ['a', 'b', 'a', 'a'];
$this->assertEquals($predictions, DataTransformer::predictions($rawPredictions, $labels));
}
public function testProbabilities(): void
{
$labels = ['a', 'b', 'c'];
$rawPredictions = implode(PHP_EOL, [
'labels 0 1 2',
'1 0.1 0.7 0.2',
'2 0.2 0.3 0.5',
'0 0.6 0.1 0.3',
]);
$probabilities = [
[
'a' => 0.1,
'b' => 0.7,
'c' => 0.2,
],
[
'a' => 0.2,
'b' => 0.3,
'c' => 0.5,
],
[
'a' => 0.6,
'b' => 0.1,
'c' => 0.3,
],
];
$this->assertEquals($probabilities, DataTransformer::probabilities($rawPredictions, $labels));
}
}

View File

@ -5,6 +5,7 @@ declare(strict_types=1);
namespace Phpml\Tests\SupportVectorMachine;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Exception\InvalidOperationException;
use Phpml\Exception\LibsvmCommandException;
use Phpml\SupportVectorMachine\Kernel;
use Phpml\SupportVectorMachine\SupportVectorMachine;
@ -37,6 +38,31 @@ SV
$this->assertEquals($model, $svm->getModel());
}
public function testTrainCSVCModelWithProbabilityEstimate(): void
{
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$svm = new SupportVectorMachine(
Type::C_SVC,
Kernel::LINEAR,
100.0,
0.5,
3,
null,
0.0,
0.1,
0.01,
100,
true,
true
);
$svm->train($samples, $labels);
$this->assertContains(PHP_EOL.'probA ', $svm->getModel());
$this->assertContains(PHP_EOL.'probB ', $svm->getModel());
}
public function testPredictSampleWithLinearKernel(): void
{
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
@ -83,6 +109,41 @@ SV
$this->assertEquals('c', $predictions[2]);
}
public function testPredictProbability(): void
{
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$svm = new SupportVectorMachine(
Type::C_SVC,
Kernel::LINEAR,
100.0,
0.5,
3,
null,
0.0,
0.1,
0.01,
100,
true,
true
);
$svm->train($samples, $labels);
$predictions = $svm->predictProbability([
[3, 2],
[2, 3],
[4, -5],
]);
$this->assertTrue($predictions[0]['a'] < $predictions[0]['b']);
$this->assertTrue($predictions[1]['a'] > $predictions[1]['b']);
$this->assertTrue($predictions[2]['a'] < $predictions[2]['b']);
// Should be true because the latter is farther from the decision boundary
$this->assertTrue($predictions[0]['b'] < $predictions[2]['b']);
}
public function testThrowExceptionWhenVarPathIsNotWritable(): void
{
$this->expectException(InvalidArgumentException::class);
@ -124,4 +185,22 @@ SV
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF);
$svm->predict([1]);
}
public function testThrowExceptionWhenPredictProbabilityCalledWithoutProperModel(): void
{
$this->expectException(InvalidOperationException::class);
$this->expectExceptionMessage('Model does not support probabiliy estimates');
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::LINEAR, 100.0);
$svm->train($samples, $labels);
$predictions = $svm->predictProbability([
[3, 2],
[2, 3],
[4, -5],
]);
}
}