mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-01-23 23:28:24 +00:00
Support probability estimation in SVC (#218)
* Add test for svm model with probability estimation * Extract buildPredictCommand method * Fix test to use PHP_EOL * Add predictProbability method (not completed) * Add test for DataTransformer::predictions * Fix SVM to use PHP_EOL * Support probability estimation in SVM * Add documentation * Add InvalidOperationException class * Throw InvalidOperationException before executing libsvm if probability estimation is not supported
This commit is contained in:
parent
ed775fb232
commit
ec091b5ea3
@ -47,3 +47,42 @@ $classifier->predict([3, 2]);
|
||||
$classifier->predict([[3, 2], [1, 5]]);
|
||||
// return ['b', 'a']
|
||||
```
|
||||
|
||||
### Probability estimation
|
||||
|
||||
To predict probabilities you must build a classifier with `$probabilityEstimates` set to true. Example:
|
||||
|
||||
```
|
||||
use Phpml\Classification\SVC;
|
||||
use Phpml\SupportVectorMachine\Kernel;
|
||||
|
||||
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
|
||||
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
|
||||
|
||||
$classifier = new SVC(
|
||||
Kernel::LINEAR, // $kernel
|
||||
1.0, // $cost
|
||||
3, // $degree
|
||||
null, // $gamma
|
||||
0.0, // $coef0
|
||||
0.001, // $tolerance
|
||||
100, // $cacheSize
|
||||
true, // $shrinking
|
||||
true // $probabilityEstimates, set to true
|
||||
);
|
||||
|
||||
$classifier->train($samples, $labels);
|
||||
```
|
||||
|
||||
Then use `predictProbability` method instead of `predict`:
|
||||
|
||||
```
|
||||
$classifier->predictProbability([3, 2]);
|
||||
// return ['a' => 0.349833, 'b' => 0.650167]
|
||||
|
||||
$classifier->predictProbability([[3, 2], [1, 5]]);
|
||||
// return [
|
||||
// ['a' => 0.349833, 'b' => 0.650167],
|
||||
// ['a' => 0.922664, 'b' => 0.0773364],
|
||||
// ]
|
||||
```
|
||||
|
11
src/Phpml/Exception/InvalidOperationException.php
Normal file
11
src/Phpml/Exception/InvalidOperationException.php
Normal file
@ -0,0 +1,11 @@
|
||||
<?php
|
||||
|
||||
declare(strict_types=1);
|
||||
|
||||
namespace Phpml\Exception;
|
||||
|
||||
use Exception;
|
||||
|
||||
class InvalidOperationException extends Exception
|
||||
{
|
||||
}
|
@ -49,6 +49,37 @@ class DataTransformer
|
||||
return $results;
|
||||
}
|
||||
|
||||
public static function probabilities(string $rawPredictions, array $labels): array
|
||||
{
|
||||
$numericLabels = self::numericLabels($labels);
|
||||
|
||||
$predictions = explode(PHP_EOL, trim($rawPredictions));
|
||||
|
||||
$header = array_shift($predictions);
|
||||
$headerColumns = explode(' ', $header);
|
||||
array_shift($headerColumns);
|
||||
|
||||
$columnLabels = [];
|
||||
foreach ($headerColumns as $numericLabel) {
|
||||
$columnLabels[] = array_search($numericLabel, $numericLabels);
|
||||
}
|
||||
|
||||
$results = [];
|
||||
foreach ($predictions as $rawResult) {
|
||||
$probabilities = explode(' ', $rawResult);
|
||||
array_shift($probabilities);
|
||||
|
||||
$result = [];
|
||||
foreach ($probabilities as $i => $prob) {
|
||||
$result[$columnLabels[$i]] = (float) $prob;
|
||||
}
|
||||
|
||||
$results[] = $result;
|
||||
}
|
||||
|
||||
return $results;
|
||||
}
|
||||
|
||||
public static function numericLabels(array $labels): array
|
||||
{
|
||||
$numericLabels = [];
|
||||
|
@ -5,6 +5,7 @@ declare(strict_types=1);
|
||||
namespace Phpml\SupportVectorMachine;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
use Phpml\Exception\InvalidOperationException;
|
||||
use Phpml\Exception\LibsvmCommandException;
|
||||
use Phpml\Helper\Trainable;
|
||||
|
||||
@ -178,13 +179,61 @@ class SupportVectorMachine
|
||||
* @throws LibsvmCommandException
|
||||
*/
|
||||
public function predict(array $samples)
|
||||
{
|
||||
$predictions = $this->runSvmPredict($samples, false);
|
||||
|
||||
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
|
||||
$predictions = DataTransformer::predictions($predictions, $this->targets);
|
||||
} else {
|
||||
$predictions = explode(PHP_EOL, trim($predictions));
|
||||
}
|
||||
|
||||
if (!is_array($samples[0])) {
|
||||
return $predictions[0];
|
||||
}
|
||||
|
||||
return $predictions;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return array|string
|
||||
*
|
||||
* @throws LibsvmCommandException
|
||||
*/
|
||||
public function predictProbability(array $samples)
|
||||
{
|
||||
if (!$this->probabilityEstimates) {
|
||||
throw new InvalidOperationException('Model does not support probabiliy estimates');
|
||||
}
|
||||
|
||||
$predictions = $this->runSvmPredict($samples, true);
|
||||
|
||||
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
|
||||
$predictions = DataTransformer::probabilities($predictions, $this->targets);
|
||||
} else {
|
||||
$predictions = explode(PHP_EOL, trim($predictions));
|
||||
}
|
||||
|
||||
if (!is_array($samples[0])) {
|
||||
return $predictions[0];
|
||||
}
|
||||
|
||||
return $predictions;
|
||||
}
|
||||
|
||||
private function runSvmPredict(array $samples, bool $probabilityEstimates): string
|
||||
{
|
||||
$testSet = DataTransformer::testSet($samples);
|
||||
file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
|
||||
file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
|
||||
$outputFileName = $testSetFileName.'-output';
|
||||
|
||||
$command = sprintf('%ssvm-predict%s %s %s %s', $this->binPath, $this->getOSExtension(), $testSetFileName, $modelFileName, $outputFileName);
|
||||
$command = $this->buildPredictCommand(
|
||||
$testSetFileName,
|
||||
$modelFileName,
|
||||
$outputFileName,
|
||||
$probabilityEstimates
|
||||
);
|
||||
$output = [];
|
||||
exec(escapeshellcmd($command).' 2>&1', $output, $return);
|
||||
|
||||
@ -198,16 +247,6 @@ class SupportVectorMachine
|
||||
throw LibsvmCommandException::failedToRun($command, array_pop($output));
|
||||
}
|
||||
|
||||
if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
|
||||
$predictions = DataTransformer::predictions($predictions, $this->targets);
|
||||
} else {
|
||||
$predictions = explode(PHP_EOL, trim($predictions));
|
||||
}
|
||||
|
||||
if (!is_array($samples[0])) {
|
||||
return $predictions[0];
|
||||
}
|
||||
|
||||
return $predictions;
|
||||
}
|
||||
|
||||
@ -246,6 +285,23 @@ class SupportVectorMachine
|
||||
);
|
||||
}
|
||||
|
||||
private function buildPredictCommand(
|
||||
string $testSetFileName,
|
||||
string $modelFileName,
|
||||
string $outputFileName,
|
||||
bool $probabilityEstimates
|
||||
): string {
|
||||
return sprintf(
|
||||
'%ssvm-predict%s -b %d %s %s %s',
|
||||
$this->binPath,
|
||||
$this->getOSExtension(),
|
||||
$probabilityEstimates ? 1 : 0,
|
||||
escapeshellarg($testSetFileName),
|
||||
escapeshellarg($modelFileName),
|
||||
escapeshellarg($outputFileName)
|
||||
);
|
||||
}
|
||||
|
||||
private function ensureDirectorySeparator(string &$path): void
|
||||
{
|
||||
if (substr($path, -1) !== DIRECTORY_SEPARATOR) {
|
||||
|
@ -37,4 +37,45 @@ class DataTransformerTest extends TestCase
|
||||
|
||||
$this->assertEquals($testSet, DataTransformer::testSet($samples));
|
||||
}
|
||||
|
||||
public function testPredictions(): void
|
||||
{
|
||||
$labels = ['a', 'a', 'b', 'b'];
|
||||
$rawPredictions = implode(PHP_EOL, [0, 1, 0, 0]);
|
||||
|
||||
$predictions = ['a', 'b', 'a', 'a'];
|
||||
|
||||
$this->assertEquals($predictions, DataTransformer::predictions($rawPredictions, $labels));
|
||||
}
|
||||
|
||||
public function testProbabilities(): void
|
||||
{
|
||||
$labels = ['a', 'b', 'c'];
|
||||
$rawPredictions = implode(PHP_EOL, [
|
||||
'labels 0 1 2',
|
||||
'1 0.1 0.7 0.2',
|
||||
'2 0.2 0.3 0.5',
|
||||
'0 0.6 0.1 0.3',
|
||||
]);
|
||||
|
||||
$probabilities = [
|
||||
[
|
||||
'a' => 0.1,
|
||||
'b' => 0.7,
|
||||
'c' => 0.2,
|
||||
],
|
||||
[
|
||||
'a' => 0.2,
|
||||
'b' => 0.3,
|
||||
'c' => 0.5,
|
||||
],
|
||||
[
|
||||
'a' => 0.6,
|
||||
'b' => 0.1,
|
||||
'c' => 0.3,
|
||||
],
|
||||
];
|
||||
|
||||
$this->assertEquals($probabilities, DataTransformer::probabilities($rawPredictions, $labels));
|
||||
}
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ declare(strict_types=1);
|
||||
namespace Phpml\Tests\SupportVectorMachine;
|
||||
|
||||
use Phpml\Exception\InvalidArgumentException;
|
||||
use Phpml\Exception\InvalidOperationException;
|
||||
use Phpml\Exception\LibsvmCommandException;
|
||||
use Phpml\SupportVectorMachine\Kernel;
|
||||
use Phpml\SupportVectorMachine\SupportVectorMachine;
|
||||
@ -37,6 +38,31 @@ SV
|
||||
$this->assertEquals($model, $svm->getModel());
|
||||
}
|
||||
|
||||
public function testTrainCSVCModelWithProbabilityEstimate(): void
|
||||
{
|
||||
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
|
||||
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
|
||||
|
||||
$svm = new SupportVectorMachine(
|
||||
Type::C_SVC,
|
||||
Kernel::LINEAR,
|
||||
100.0,
|
||||
0.5,
|
||||
3,
|
||||
null,
|
||||
0.0,
|
||||
0.1,
|
||||
0.01,
|
||||
100,
|
||||
true,
|
||||
true
|
||||
);
|
||||
$svm->train($samples, $labels);
|
||||
|
||||
$this->assertContains(PHP_EOL.'probA ', $svm->getModel());
|
||||
$this->assertContains(PHP_EOL.'probB ', $svm->getModel());
|
||||
}
|
||||
|
||||
public function testPredictSampleWithLinearKernel(): void
|
||||
{
|
||||
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
|
||||
@ -83,6 +109,41 @@ SV
|
||||
$this->assertEquals('c', $predictions[2]);
|
||||
}
|
||||
|
||||
public function testPredictProbability(): void
|
||||
{
|
||||
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
|
||||
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
|
||||
|
||||
$svm = new SupportVectorMachine(
|
||||
Type::C_SVC,
|
||||
Kernel::LINEAR,
|
||||
100.0,
|
||||
0.5,
|
||||
3,
|
||||
null,
|
||||
0.0,
|
||||
0.1,
|
||||
0.01,
|
||||
100,
|
||||
true,
|
||||
true
|
||||
);
|
||||
$svm->train($samples, $labels);
|
||||
|
||||
$predictions = $svm->predictProbability([
|
||||
[3, 2],
|
||||
[2, 3],
|
||||
[4, -5],
|
||||
]);
|
||||
|
||||
$this->assertTrue($predictions[0]['a'] < $predictions[0]['b']);
|
||||
$this->assertTrue($predictions[1]['a'] > $predictions[1]['b']);
|
||||
$this->assertTrue($predictions[2]['a'] < $predictions[2]['b']);
|
||||
|
||||
// Should be true because the latter is farther from the decision boundary
|
||||
$this->assertTrue($predictions[0]['b'] < $predictions[2]['b']);
|
||||
}
|
||||
|
||||
public function testThrowExceptionWhenVarPathIsNotWritable(): void
|
||||
{
|
||||
$this->expectException(InvalidArgumentException::class);
|
||||
@ -124,4 +185,22 @@ SV
|
||||
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF);
|
||||
$svm->predict([1]);
|
||||
}
|
||||
|
||||
public function testThrowExceptionWhenPredictProbabilityCalledWithoutProperModel(): void
|
||||
{
|
||||
$this->expectException(InvalidOperationException::class);
|
||||
$this->expectExceptionMessage('Model does not support probabiliy estimates');
|
||||
|
||||
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
|
||||
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
|
||||
|
||||
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::LINEAR, 100.0);
|
||||
$svm->train($samples, $labels);
|
||||
|
||||
$predictions = $svm->predictProbability([
|
||||
[3, 2],
|
||||
[2, 3],
|
||||
[4, -5],
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user