From 7b5b6418f42f743aa747629fc95487877853312b Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Fri, 6 May 2016 22:55:41 +0200 Subject: [PATCH] libsvm predict program implementation --- .../SupportVectorMachine/DataTransformer.php | 17 +++++++++++ .../SupportVectorMachine.php | 29 ++++++++++++++++++- .../DataTransformerTest.php | 1 - .../SupportVectorMachineTest.php | 19 ++++++++++++ 4 files changed, 64 insertions(+), 2 deletions(-) diff --git a/src/Phpml/SupportVectorMachine/DataTransformer.php b/src/Phpml/SupportVectorMachine/DataTransformer.php index 05599f1..1ce4bee 100644 --- a/src/Phpml/SupportVectorMachine/DataTransformer.php +++ b/src/Phpml/SupportVectorMachine/DataTransformer.php @@ -38,6 +38,23 @@ class DataTransformer return $set; } + /** + * @param string $resultString + * @param array $labels + * + * @return array + */ + public static function results(string $resultString, array $labels): array + { + $numericLabels = self::numericLabels($labels); + $results = []; + foreach (explode(PHP_EOL, $resultString) as $result) { + $results[] = array_search($result, $numericLabels); + } + + return $results; + } + /** * @param array $labels * diff --git a/src/Phpml/SupportVectorMachine/SupportVectorMachine.php b/src/Phpml/SupportVectorMachine/SupportVectorMachine.php index 325fcf3..f14d534 100644 --- a/src/Phpml/SupportVectorMachine/SupportVectorMachine.php +++ b/src/Phpml/SupportVectorMachine/SupportVectorMachine.php @@ -36,6 +36,11 @@ class SupportVectorMachine */ private $model; + /** + * @var array + */ + private $labels; + /** * @param int $type * @param int $kernel @@ -59,6 +64,7 @@ class SupportVectorMachine */ public function train(array $samples, array $labels) { + $this->labels = $labels; $trainingSet = DataTransformer::trainingSet($samples, $labels); file_put_contents($trainingSetFileName = $this->varPath.uniqid(), $trainingSet); $modelFileName = $trainingSetFileName.'-model'; @@ -81,8 +87,29 @@ class SupportVectorMachine return $this->model; } + /** + * @param array $samples + * + * @return array + */ public function predict(array $samples) { - $testSet = DataTransformer::testSet(); + $testSet = DataTransformer::testSet($samples); + file_put_contents($testSetFileName = $this->varPath.uniqid(), $testSet); + $modelFileName = $testSetFileName.'-model'; + file_put_contents($modelFileName, $this->model); + $outputFileName = $testSetFileName.'-output'; + + $command = sprintf('%ssvm-predict %s %s %s', $this->binPath, $testSetFileName, $modelFileName, $outputFileName); + $output = ''; + exec(escapeshellcmd($command), $output); + + $predictions = file_get_contents($outputFileName); + + unlink($testSetFileName); + unlink($modelFileName); + unlink($outputFileName); + + return DataTransformer::results($predictions, $this->labels); } } diff --git a/tests/Phpml/SupportVectorMachine/DataTransformerTest.php b/tests/Phpml/SupportVectorMachine/DataTransformerTest.php index ff2a7c7..c07948a 100644 --- a/tests/Phpml/SupportVectorMachine/DataTransformerTest.php +++ b/tests/Phpml/SupportVectorMachine/DataTransformerTest.php @@ -36,5 +36,4 @@ class DataTransformerTest extends \PHPUnit_Framework_TestCase $this->assertEquals($testSet, DataTransformer::testSet($samples)); } - } diff --git a/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php b/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php index e4a8857..330f7f0 100644 --- a/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php +++ b/tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php @@ -33,4 +33,23 @@ SV $this->assertEquals($model, $svm->getModel()); } + + public function testPredictCSVCModelWithLinearKernel() + { + $samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; + $labels = ['a', 'a', 'a', 'b', 'b', 'b']; + + $svm = new SupportVectorMachine(Type::C_SVC, Kernel::LINEAR, 100.0); + $svm->train($samples, $labels); + + $predictions = $svm->predict([ + [3, 2], + [2, 3], + [4, -5], + ]); + + $this->assertEquals('b', $predictions[0]); + $this->assertEquals('a', $predictions[1]); + $this->assertEquals('b', $predictions[2]); + } }