From ec091b5ea3d3bfa8f3168c476cb0c71725b3d344 Mon Sep 17 00:00:00 2001 From: Yuji Uchiyama Date: Wed, 7 Feb 2018 04:39:25 +0900 Subject: [PATCH] Support probability estimation in SVC (#218) * Add test for svm model with probability estimation * Extract buildPredictCommand method * Fix test to use PHP_EOL * Add predictProbability method (not completed) * Add test for DataTransformer::predictions * Fix SVM to use PHP_EOL * Support probability estimation in SVM * Add documentation * Add InvalidOperationException class * Throw InvalidOperationException before executing libsvm if probability estimation is not supported --- docs/machine-learning/classification/svc.md | 39 +++++++++ .../Exception/InvalidOperationException.php | 11 +++ .../SupportVectorMachine/DataTransformer.php | 31 ++++++++ .../SupportVectorMachine.php | 78 +++++++++++++++--- .../DataTransformerTest.php | 41 ++++++++++ .../SupportVectorMachineTest.php | 79 +++++++++++++++++++ 6 files changed, 268 insertions(+), 11 deletions(-) create mode 100644 src/Phpml/Exception/InvalidOperationException.php diff --git a/docs/machine-learning/classification/svc.md b/docs/machine-learning/classification/svc.md index 62da509..da0511c 100644 --- a/docs/machine-learning/classification/svc.md +++ b/docs/machine-learning/classification/svc.md @@ -47,3 +47,42 @@ $classifier->predict([3, 2]); $classifier->predict([[3, 2], [1, 5]]); // return ['b', 'a'] ``` + +### Probability estimation + +To predict probabilities you must build a classifier with `$probabilityEstimates` set to true. Example: + +``` +use Phpml\Classification\SVC; +use Phpml\SupportVectorMachine\Kernel; + +$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; +$labels = ['a', 'a', 'a', 'b', 'b', 'b']; + +$classifier = new SVC( + Kernel::LINEAR, // $kernel + 1.0, // $cost + 3, // $degree + null, // $gamma + 0.0, // $coef0 + 0.001, // $tolerance + 100, // $cacheSize + true, // $shrinking + true // $probabilityEstimates, set to true +); + +$classifier->train($samples, $labels); +``` + +Then use `predictProbability` method instead of `predict`: + +``` +$classifier->predictProbability([3, 2]); +// return ['a' => 0.349833, 'b' => 0.650167] + +$classifier->predictProbability([[3, 2], [1, 5]]); +// return [ +// ['a' => 0.349833, 'b' => 0.650167], +// ['a' => 0.922664, 'b' => 0.0773364], +// ] +``` diff --git a/src/Phpml/Exception/InvalidOperationException.php b/src/Phpml/Exception/InvalidOperationException.php new file mode 100644 index 0000000..0eba973 --- /dev/null +++ b/src/Phpml/Exception/InvalidOperationException.php @@ -0,0 +1,11 @@ + $prob) { + $result[$columnLabels[$i]] = (float) $prob; + } + + $results[] = $result; + } + + return $results; + } + public static function numericLabels(array $labels): array { $numericLabels = []; diff --git a/src/Phpml/SupportVectorMachine/SupportVectorMachine.php b/src/Phpml/SupportVectorMachine/SupportVectorMachine.php index ce7a7ba..ddd843f 100644 --- a/src/Phpml/SupportVectorMachine/SupportVectorMachine.php +++ b/src/Phpml/SupportVectorMachine/SupportVectorMachine.php @@ -5,6 +5,7 @@ declare(strict_types=1); namespace Phpml\SupportVectorMachine; use Phpml\Exception\InvalidArgumentException; +use Phpml\Exception\InvalidOperationException; use Phpml\Exception\LibsvmCommandException; use Phpml\Helper\Trainable; @@ -178,13 +179,61 @@ class SupportVectorMachine * @throws LibsvmCommandException */ public function predict(array $samples) + { + $predictions = $this->runSvmPredict($samples, false); + + if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) { + $predictions = DataTransformer::predictions($predictions, $this->targets); + } else { + $predictions = explode(PHP_EOL, trim($predictions)); + } + + if (!is_array($samples[0])) { + return $predictions[0]; + } + + return $predictions; + } + + /** + * @return array|string + * + * @throws LibsvmCommandException + */ + public function predictProbability(array $samples) + { + if (!$this->probabilityEstimates) { + throw new InvalidOperationException('Model does not support probabiliy estimates'); + } + + $predictions = $this->runSvmPredict($samples, true); + + if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) { + $predictions = DataTransformer::probabilities($predictions, $this->targets); + } else { + $predictions = explode(PHP_EOL, trim($predictions)); + } + + if (!is_array($samples[0])) { + return $predictions[0]; + } + + return $predictions; + } + + private function runSvmPredict(array $samples, bool $probabilityEstimates): string { $testSet = DataTransformer::testSet($samples); file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet); file_put_contents($modelFileName = $testSetFileName.'-model', $this->model); $outputFileName = $testSetFileName.'-output'; - $command = sprintf('%ssvm-predict%s %s %s %s', $this->binPath, $this->getOSExtension(), $testSetFileName, $modelFileName, $outputFileName); + $command = $this->buildPredictCommand( + $testSetFileName, + $modelFileName, + $outputFileName, + $probabilityEstimates + ); $output = []; exec(escapeshellcmd($command).' 2>&1', $output, $return); @@ -198,16 +247,6 @@ class SupportVectorMachine throw LibsvmCommandException::failedToRun($command, array_pop($output)); } - if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) { - $predictions = DataTransformer::predictions($predictions, $this->targets); - } else { - $predictions = explode(PHP_EOL, trim($predictions)); - } - - if (!is_array($samples[0])) { - return $predictions[0]; - } - return $predictions; } @@ -246,6 +285,23 @@ class SupportVectorMachine ); } + private function buildPredictCommand( + string $testSetFileName, + string $modelFileName, + string $outputFileName, + bool $probabilityEstimates + ): string { + return sprintf( + '%ssvm-predict%s -b %d %s %s %s', + $this->binPath, + $this->getOSExtension(), + $probabilityEstimates ? 1 : 0, + escapeshellarg($testSetFileName), + escapeshellarg($modelFileName), + escapeshellarg($outputFileName) + ); + } + private function ensureDirectorySeparator(string &$path): void { if (substr($path, -1) !== DIRECTORY_SEPARATOR) { diff --git a/tests/Phpml/SupportVectorMachine/DataTransformerTest.php b/tests/Phpml/SupportVectorMachine/DataTransformerTest.php index 1db1fdf..79dcb49 100644 --- a/tests/Phpml/SupportVectorMachine/DataTransformerTest.php +++ b/tests/Phpml/SupportVectorMachine/DataTransformerTest.php @@ -37,4 +37,45 @@ class DataTransformerTest extends TestCase $this->assertEquals($testSet, DataTransformer::testSet($samples)); } + + public function testPredictions(): void + { + $labels = ['a', 'a', 'b', 'b']; + $rawPredictions = implode(PHP_EOL, [0, 1, 0, 0]); + + $predictions = ['a', 'b', 'a', 'a']; + + $this->assertEquals($predictions, DataTransformer::predictions($rawPredictions, $labels)); + } + + public function testProbabilities(): void + { + $labels = ['a', 'b', 'c']; + $rawPredictions = implode(PHP_EOL, [ + 'labels 0 1 2', + '1 0.1 0.7 0.2', + '2 0.2 0.3 0.5', + '0 0.6 0.1 0.3', + ]); + + $probabilities = [ + [ + 'a' => 0.1, + 'b' => 0.7, + 'c' => 0.2, + ], + [ + 'a' => 0.2, + 'b' => 0.3, + 'c' => 0.5, + ], + [ + 'a' => 0.6, + 'b' => 0.1, + 'c' => 0.3, + ], + ]; + + $this->assertEquals($probabilities, DataTransformer::probabilities($rawPredictions, $labels)); + } } diff --git a/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php b/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php index 466c962..899fa40 100644 --- a/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php +++ b/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php @@ -5,6 +5,7 @@ declare(strict_types=1); namespace Phpml\Tests\SupportVectorMachine; use Phpml\Exception\InvalidArgumentException; +use Phpml\Exception\InvalidOperationException; use Phpml\Exception\LibsvmCommandException; use Phpml\SupportVectorMachine\Kernel; use Phpml\SupportVectorMachine\SupportVectorMachine; @@ -37,6 +38,31 @@ SV $this->assertEquals($model, $svm->getModel()); } + public function testTrainCSVCModelWithProbabilityEstimate(): void + { + $samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; + $labels = ['a', 'a', 'a', 'b', 'b', 'b']; + + $svm = new SupportVectorMachine( + Type::C_SVC, + Kernel::LINEAR, + 100.0, + 0.5, + 3, + null, + 0.0, + 0.1, + 0.01, + 100, + true, + true + ); + $svm->train($samples, $labels); + + $this->assertContains(PHP_EOL.'probA ', $svm->getModel()); + $this->assertContains(PHP_EOL.'probB ', $svm->getModel()); + } + public function testPredictSampleWithLinearKernel(): void { $samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; @@ -83,6 +109,41 @@ SV $this->assertEquals('c', $predictions[2]); } + public function testPredictProbability(): void + { + $samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; + $labels = ['a', 'a', 'a', 'b', 'b', 'b']; + + $svm = new SupportVectorMachine( + Type::C_SVC, + Kernel::LINEAR, + 100.0, + 0.5, + 3, + null, + 0.0, + 0.1, + 0.01, + 100, + true, + true + ); + $svm->train($samples, $labels); + + $predictions = $svm->predictProbability([ + [3, 2], + [2, 3], + [4, -5], + ]); + + $this->assertTrue($predictions[0]['a'] < $predictions[0]['b']); + $this->assertTrue($predictions[1]['a'] > $predictions[1]['b']); + $this->assertTrue($predictions[2]['a'] < $predictions[2]['b']); + + // Should be true because the latter is farther from the decision boundary + $this->assertTrue($predictions[0]['b'] < $predictions[2]['b']); + } + public function testThrowExceptionWhenVarPathIsNotWritable(): void { $this->expectException(InvalidArgumentException::class); @@ -124,4 +185,22 @@ SV $svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF); $svm->predict([1]); } + + public function testThrowExceptionWhenPredictProbabilityCalledWithoutProperModel(): void + { + $this->expectException(InvalidOperationException::class); + $this->expectExceptionMessage('Model does not support probabiliy estimates'); + + $samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; + $labels = ['a', 'a', 'a', 'b', 'b', 'b']; + + $svm = new SupportVectorMachine(Type::C_SVC, Kernel::LINEAR, 100.0); + $svm->train($samples, $labels); + + $predictions = $svm->predictProbability([ + [3, 2], + [2, 3], + [4, -5], + ]); + } }