libsvm predict program implementation

This commit is contained in:
Arkadiusz Kondas 2016-05-06 22:55:41 +02:00
parent dfb7b6b108
commit 7b5b6418f4
4 changed files with 64 additions and 2 deletions

View File

@ -38,6 +38,23 @@ class DataTransformer
return $set;
}
/**
* @param string $resultString
* @param array $labels
*
* @return array
*/
public static function results(string $resultString, array $labels): array
{
$numericLabels = self::numericLabels($labels);
$results = [];
foreach (explode(PHP_EOL, $resultString) as $result) {
$results[] = array_search($result, $numericLabels);
}
return $results;
}
/**
* @param array $labels
*

View File

@ -36,6 +36,11 @@ class SupportVectorMachine
*/
private $model;
/**
* @var array
*/
private $labels;
/**
* @param int $type
* @param int $kernel
@ -59,6 +64,7 @@ class SupportVectorMachine
*/
public function train(array $samples, array $labels)
{
$this->labels = $labels;
$trainingSet = DataTransformer::trainingSet($samples, $labels);
file_put_contents($trainingSetFileName = $this->varPath.uniqid(), $trainingSet);
$modelFileName = $trainingSetFileName.'-model';
@ -81,8 +87,29 @@ class SupportVectorMachine
return $this->model;
}
/**
* @param array $samples
*
* @return array
*/
public function predict(array $samples)
{
$testSet = DataTransformer::testSet();
$testSet = DataTransformer::testSet($samples);
file_put_contents($testSetFileName = $this->varPath.uniqid(), $testSet);
$modelFileName = $testSetFileName.'-model';
file_put_contents($modelFileName, $this->model);
$outputFileName = $testSetFileName.'-output';
$command = sprintf('%ssvm-predict %s %s %s', $this->binPath, $testSetFileName, $modelFileName, $outputFileName);
$output = '';
exec(escapeshellcmd($command), $output);
$predictions = file_get_contents($outputFileName);
unlink($testSetFileName);
unlink($modelFileName);
unlink($outputFileName);
return DataTransformer::results($predictions, $this->labels);
}
}

View File

@ -36,5 +36,4 @@ class DataTransformerTest extends \PHPUnit_Framework_TestCase
$this->assertEquals($testSet, DataTransformer::testSet($samples));
}
}

View File

@ -33,4 +33,23 @@ SV
$this->assertEquals($model, $svm->getModel());
}
public function testPredictCSVCModelWithLinearKernel()
{
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::LINEAR, 100.0);
$svm->train($samples, $labels);
$predictions = $svm->predict([
[3, 2],
[2, 3],
[4, -5],
]);
$this->assertEquals('b', $predictions[0]);
$this->assertEquals('a', $predictions[1]);
$this->assertEquals('b', $predictions[2]);
}
}