mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-06-02 16:40:50 +00:00
Support probability estimation in SVC (#218)
* Add test for svm model with probability estimation * Extract buildPredictCommand method * Fix test to use PHP_EOL * Add predictProbability method (not completed) * Add test for DataTransformer::predictions * Fix SVM to use PHP_EOL * Support probability estimation in SVM * Add documentation * Add InvalidOperationException class * Throw InvalidOperationException before executing libsvm if probability estimation is not supported
This commit is contained in:
parent
ed775fb232
commit
ec091b5ea3
|
@ -47,3 +47,42 @@ $classifier->predict([3, 2]);
|
||||||
$classifier->predict([[3, 2], [1, 5]]);
|
$classifier->predict([[3, 2], [1, 5]]);
|
||||||
// return ['b', 'a']
|
// return ['b', 'a']
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Probability estimation
|
||||||
|
|
||||||
|
To predict probabilities you must build a classifier with `$probabilityEstimates` set to true. Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
use Phpml\Classification\SVC;
|
||||||
|
use Phpml\SupportVectorMachine\Kernel;
|
||||||
|
|
||||||
|
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
|
||||||
|
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
|
||||||
|
|
||||||
|
$classifier = new SVC(
|
||||||
|
Kernel::LINEAR, // $kernel
|
||||||
|
1.0, // $cost
|
||||||
|
3, // $degree
|
||||||
|
null, // $gamma
|
||||||
|
0.0, // $coef0
|
||||||
|
0.001, // $tolerance
|
||||||
|
100, // $cacheSize
|
||||||
|
true, // $shrinking
|
||||||
|
true // $probabilityEstimates, set to true
|
||||||
|
);
|
||||||
|
|
||||||
|
$classifier->train($samples, $labels);
|
||||||
|
```
|
||||||
|
|
||||||
|
Then use `predictProbability` method instead of `predict`:
|
||||||
|
|
||||||
|
```
|
||||||
|
$classifier->predictProbability([3, 2]);
|
||||||
|
// return ['a' => 0.349833, 'b' => 0.650167]
|
||||||
|
|
||||||
|
$classifier->predictProbability([[3, 2], [1, 5]]);
|
||||||
|
// return [
|
||||||
|
// ['a' => 0.349833, 'b' => 0.650167],
|
||||||
|
// ['a' => 0.922664, 'b' => 0.0773364],
|
||||||
|
// ]
|
||||||
|
```
|
||||||
|
|
11
src/Phpml/Exception/InvalidOperationException.php
Normal file
11
src/Phpml/Exception/InvalidOperationException.php
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Exception;
|
||||||
|
|
||||||
|
use Exception;
|
||||||
|
|
||||||
|
class InvalidOperationException extends Exception
|
||||||
|
{
|
||||||
|
}
|
|
@ -49,6 +49,37 @@ class DataTransformer
|
||||||
return $results;
|
return $results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static function probabilities(string $rawPredictions, array $labels): array
|
||||||
|
{
|
||||||
|
$numericLabels = self::numericLabels($labels);
|
||||||
|
|
||||||
|
$predictions = explode(PHP_EOL, trim($rawPredictions));
|
||||||
|
|
||||||
|
$header = array_shift($predictions);
|
||||||
|
$headerColumns = explode(' ', $header);
|
||||||
|
array_shift($headerColumns);
|
||||||
|
|
||||||
|
$columnLabels = [];
|
||||||
|
foreach ($headerColumns as $numericLabel) {
|
||||||
|
$columnLabels[] = array_search($numericLabel, $numericLabels);
|
||||||
|
}
|
||||||
|
|
||||||
|
$results = [];
|
||||||
|
foreach ($predictions as $rawResult) {
|
||||||
|
$probabilities = explode(' ', $rawResult);
|
||||||
|
array_shift($probabilities);
|
||||||
|
|
||||||
|
$result = [];
|
||||||
|
foreach ($probabilities as $i => $prob) {
|
||||||
|
$result[$columnLabels[$i]] = (float) $prob;
|
||||||
|
}
|
||||||
|
|
||||||
|
$results[] = $result;
|
||||||
|
}
|
||||||
|
|
||||||
|
return $results;
|
||||||
|
}
|
||||||
|
|
||||||
public static function numericLabels(array $labels): array
|
public static function numericLabels(array $labels): array
|
||||||
{
|
{
|
||||||
$numericLabels = [];
|
$numericLabels = [];
|
||||||
|
|
|
@ -5,6 +5,7 @@ declare(strict_types=1);
|
||||||
namespace Phpml\SupportVectorMachine;
|
namespace Phpml\SupportVectorMachine;
|
||||||
|
|
||||||
use Phpml\Exception\InvalidArgumentException;
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
|
use Phpml\Exception\InvalidOperationException;
|
||||||
use Phpml\Exception\LibsvmCommandException;
|
use Phpml\Exception\LibsvmCommandException;
|
||||||
use Phpml\Helper\Trainable;
|
use Phpml\Helper\Trainable;
|
||||||
|
|
||||||
|
@ -178,13 +179,61 @@ class SupportVectorMachine
|
||||||
* @throws LibsvmCommandException
|
* @throws LibsvmCommandException
|
||||||
*/
|
*/
|
||||||
public function predict(array $samples)
|
public function predict(array $samples)
|
||||||
|
{
|
||||||
|
$predictions = $this->runSvmPredict($samples, false);
|
||||||
|
|
||||||
|
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
|
||||||
|
$predictions = DataTransformer::predictions($predictions, $this->targets);
|
||||||
|
} else {
|
||||||
|
$predictions = explode(PHP_EOL, trim($predictions));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_array($samples[0])) {
|
||||||
|
return $predictions[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
return $predictions;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return array|string
|
||||||
|
*
|
||||||
|
* @throws LibsvmCommandException
|
||||||
|
*/
|
||||||
|
public function predictProbability(array $samples)
|
||||||
|
{
|
||||||
|
if (!$this->probabilityEstimates) {
|
||||||
|
throw new InvalidOperationException('Model does not support probabiliy estimates');
|
||||||
|
}
|
||||||
|
|
||||||
|
$predictions = $this->runSvmPredict($samples, true);
|
||||||
|
|
||||||
|
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
|
||||||
|
$predictions = DataTransformer::probabilities($predictions, $this->targets);
|
||||||
|
} else {
|
||||||
|
$predictions = explode(PHP_EOL, trim($predictions));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_array($samples[0])) {
|
||||||
|
return $predictions[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
return $predictions;
|
||||||
|
}
|
||||||
|
|
||||||
|
private function runSvmPredict(array $samples, bool $probabilityEstimates): string
|
||||||
{
|
{
|
||||||
$testSet = DataTransformer::testSet($samples);
|
$testSet = DataTransformer::testSet($samples);
|
||||||
file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
|
file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
|
||||||
file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
|
file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
|
||||||
$outputFileName = $testSetFileName.'-output';
|
$outputFileName = $testSetFileName.'-output';
|
||||||
|
|
||||||
$command = sprintf('%ssvm-predict%s %s %s %s', $this->binPath, $this->getOSExtension(), $testSetFileName, $modelFileName, $outputFileName);
|
$command = $this->buildPredictCommand(
|
||||||
|
$testSetFileName,
|
||||||
|
$modelFileName,
|
||||||
|
$outputFileName,
|
||||||
|
$probabilityEstimates
|
||||||
|
);
|
||||||
$output = [];
|
$output = [];
|
||||||
exec(escapeshellcmd($command).' 2>&1', $output, $return);
|
exec(escapeshellcmd($command).' 2>&1', $output, $return);
|
||||||
|
|
||||||
|
@ -198,16 +247,6 @@ class SupportVectorMachine
|
||||||
throw LibsvmCommandException::failedToRun($command, array_pop($output));
|
throw LibsvmCommandException::failedToRun($command, array_pop($output));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
|
|
||||||
$predictions = DataTransformer::predictions($predictions, $this->targets);
|
|
||||||
} else {
|
|
||||||
$predictions = explode(PHP_EOL, trim($predictions));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!is_array($samples[0])) {
|
|
||||||
return $predictions[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
return $predictions;
|
return $predictions;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,6 +285,23 @@ class SupportVectorMachine
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private function buildPredictCommand(
|
||||||
|
string $testSetFileName,
|
||||||
|
string $modelFileName,
|
||||||
|
string $outputFileName,
|
||||||
|
bool $probabilityEstimates
|
||||||
|
): string {
|
||||||
|
return sprintf(
|
||||||
|
'%ssvm-predict%s -b %d %s %s %s',
|
||||||
|
$this->binPath,
|
||||||
|
$this->getOSExtension(),
|
||||||
|
$probabilityEstimates ? 1 : 0,
|
||||||
|
escapeshellarg($testSetFileName),
|
||||||
|
escapeshellarg($modelFileName),
|
||||||
|
escapeshellarg($outputFileName)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
private function ensureDirectorySeparator(string &$path): void
|
private function ensureDirectorySeparator(string &$path): void
|
||||||
{
|
{
|
||||||
if (substr($path, -1) !== DIRECTORY_SEPARATOR) {
|
if (substr($path, -1) !== DIRECTORY_SEPARATOR) {
|
||||||
|
|
|
@ -37,4 +37,45 @@ class DataTransformerTest extends TestCase
|
||||||
|
|
||||||
$this->assertEquals($testSet, DataTransformer::testSet($samples));
|
$this->assertEquals($testSet, DataTransformer::testSet($samples));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function testPredictions(): void
|
||||||
|
{
|
||||||
|
$labels = ['a', 'a', 'b', 'b'];
|
||||||
|
$rawPredictions = implode(PHP_EOL, [0, 1, 0, 0]);
|
||||||
|
|
||||||
|
$predictions = ['a', 'b', 'a', 'a'];
|
||||||
|
|
||||||
|
$this->assertEquals($predictions, DataTransformer::predictions($rawPredictions, $labels));
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testProbabilities(): void
|
||||||
|
{
|
||||||
|
$labels = ['a', 'b', 'c'];
|
||||||
|
$rawPredictions = implode(PHP_EOL, [
|
||||||
|
'labels 0 1 2',
|
||||||
|
'1 0.1 0.7 0.2',
|
||||||
|
'2 0.2 0.3 0.5',
|
||||||
|
'0 0.6 0.1 0.3',
|
||||||
|
]);
|
||||||
|
|
||||||
|
$probabilities = [
|
||||||
|
[
|
||||||
|
'a' => 0.1,
|
||||||
|
'b' => 0.7,
|
||||||
|
'c' => 0.2,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'a' => 0.2,
|
||||||
|
'b' => 0.3,
|
||||||
|
'c' => 0.5,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'a' => 0.6,
|
||||||
|
'b' => 0.1,
|
||||||
|
'c' => 0.3,
|
||||||
|
],
|
||||||
|
];
|
||||||
|
|
||||||
|
$this->assertEquals($probabilities, DataTransformer::probabilities($rawPredictions, $labels));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ declare(strict_types=1);
|
||||||
namespace Phpml\Tests\SupportVectorMachine;
|
namespace Phpml\Tests\SupportVectorMachine;
|
||||||
|
|
||||||
use Phpml\Exception\InvalidArgumentException;
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
|
use Phpml\Exception\InvalidOperationException;
|
||||||
use Phpml\Exception\LibsvmCommandException;
|
use Phpml\Exception\LibsvmCommandException;
|
||||||
use Phpml\SupportVectorMachine\Kernel;
|
use Phpml\SupportVectorMachine\Kernel;
|
||||||
use Phpml\SupportVectorMachine\SupportVectorMachine;
|
use Phpml\SupportVectorMachine\SupportVectorMachine;
|
||||||
|
@ -37,6 +38,31 @@ SV
|
||||||
$this->assertEquals($model, $svm->getModel());
|
$this->assertEquals($model, $svm->getModel());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function testTrainCSVCModelWithProbabilityEstimate(): void
|
||||||
|
{
|
||||||
|
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
|
||||||
|
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
|
||||||
|
|
||||||
|
$svm = new SupportVectorMachine(
|
||||||
|
Type::C_SVC,
|
||||||
|
Kernel::LINEAR,
|
||||||
|
100.0,
|
||||||
|
0.5,
|
||||||
|
3,
|
||||||
|
null,
|
||||||
|
0.0,
|
||||||
|
0.1,
|
||||||
|
0.01,
|
||||||
|
100,
|
||||||
|
true,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
$svm->train($samples, $labels);
|
||||||
|
|
||||||
|
$this->assertContains(PHP_EOL.'probA ', $svm->getModel());
|
||||||
|
$this->assertContains(PHP_EOL.'probB ', $svm->getModel());
|
||||||
|
}
|
||||||
|
|
||||||
public function testPredictSampleWithLinearKernel(): void
|
public function testPredictSampleWithLinearKernel(): void
|
||||||
{
|
{
|
||||||
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
|
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
|
||||||
|
@ -83,6 +109,41 @@ SV
|
||||||
$this->assertEquals('c', $predictions[2]);
|
$this->assertEquals('c', $predictions[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function testPredictProbability(): void
|
||||||
|
{
|
||||||
|
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
|
||||||
|
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
|
||||||
|
|
||||||
|
$svm = new SupportVectorMachine(
|
||||||
|
Type::C_SVC,
|
||||||
|
Kernel::LINEAR,
|
||||||
|
100.0,
|
||||||
|
0.5,
|
||||||
|
3,
|
||||||
|
null,
|
||||||
|
0.0,
|
||||||
|
0.1,
|
||||||
|
0.01,
|
||||||
|
100,
|
||||||
|
true,
|
||||||
|
true
|
||||||
|
);
|
||||||
|
$svm->train($samples, $labels);
|
||||||
|
|
||||||
|
$predictions = $svm->predictProbability([
|
||||||
|
[3, 2],
|
||||||
|
[2, 3],
|
||||||
|
[4, -5],
|
||||||
|
]);
|
||||||
|
|
||||||
|
$this->assertTrue($predictions[0]['a'] < $predictions[0]['b']);
|
||||||
|
$this->assertTrue($predictions[1]['a'] > $predictions[1]['b']);
|
||||||
|
$this->assertTrue($predictions[2]['a'] < $predictions[2]['b']);
|
||||||
|
|
||||||
|
// Should be true because the latter is farther from the decision boundary
|
||||||
|
$this->assertTrue($predictions[0]['b'] < $predictions[2]['b']);
|
||||||
|
}
|
||||||
|
|
||||||
public function testThrowExceptionWhenVarPathIsNotWritable(): void
|
public function testThrowExceptionWhenVarPathIsNotWritable(): void
|
||||||
{
|
{
|
||||||
$this->expectException(InvalidArgumentException::class);
|
$this->expectException(InvalidArgumentException::class);
|
||||||
|
@ -124,4 +185,22 @@ SV
|
||||||
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF);
|
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF);
|
||||||
$svm->predict([1]);
|
$svm->predict([1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function testThrowExceptionWhenPredictProbabilityCalledWithoutProperModel(): void
|
||||||
|
{
|
||||||
|
$this->expectException(InvalidOperationException::class);
|
||||||
|
$this->expectExceptionMessage('Model does not support probabiliy estimates');
|
||||||
|
|
||||||
|
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
|
||||||
|
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
|
||||||
|
|
||||||
|
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::LINEAR, 100.0);
|
||||||
|
$svm->train($samples, $labels);
|
||||||
|
|
||||||
|
$predictions = $svm->predictProbability([
|
||||||
|
[3, 2],
|
||||||
|
[2, 3],
|
||||||
|
[4, -5],
|
||||||
|
]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user