implement support vector regression

This commit is contained in:
Arkadiusz Kondas 2016-05-07 23:04:58 +02:00
parent c409658483
commit 430c1078cf
11 changed files with 108 additions and 16 deletions

View File

@ -4,8 +4,8 @@ declare (strict_types = 1);
namespace Phpml\Classification; namespace Phpml\Classification;
use Phpml\Classification\Traits\Predictable; use Phpml\Helper\Predictable;
use Phpml\Classification\Traits\Trainable; use Phpml\Helper\Trainable;
use Phpml\Math\Distance; use Phpml\Math\Distance;
use Phpml\Math\Distance\Euclidean; use Phpml\Math\Distance\Euclidean;

View File

@ -4,8 +4,8 @@ declare (strict_types = 1);
namespace Phpml\Classification; namespace Phpml\Classification;
use Phpml\Classification\Traits\Predictable; use Phpml\Helper\Predictable;
use Phpml\Classification\Traits\Trainable; use Phpml\Helper\Trainable;
class NaiveBayes implements Classifier class NaiveBayes implements Classifier
{ {

View File

@ -4,6 +4,7 @@ declare (strict_types = 1);
namespace Phpml\Classification; namespace Phpml\Classification;
use Phpml\SupportVectorMachine\Kernel;
use Phpml\SupportVectorMachine\SupportVectorMachine; use Phpml\SupportVectorMachine\SupportVectorMachine;
use Phpml\SupportVectorMachine\Type; use Phpml\SupportVectorMachine\Type;
@ -21,7 +22,7 @@ class SVC extends SupportVectorMachine implements Classifier
* @param bool $probabilityEstimates * @param bool $probabilityEstimates
*/ */
public function __construct( public function __construct(
int $kernel, float $cost = 1.0, int $degree = 3, float $gamma = null, float $coef0 = 0.0, int $kernel = Kernel::LINEAR, float $cost = 1.0, int $degree = 3, float $gamma = null, float $coef0 = 0.0,
float $tolerance = 0.001, int $cacheSize = 100, bool $shrinking = true, float $tolerance = 0.001, int $cacheSize = 100, bool $shrinking = true,
bool $probabilityEstimates = false bool $probabilityEstimates = false
) { ) {

View File

@ -2,7 +2,7 @@
declare (strict_types = 1); declare (strict_types = 1);
namespace Phpml\Classification\Traits; namespace Phpml\Helper;
trait Predictable trait Predictable
{ {

View File

@ -2,7 +2,7 @@
declare (strict_types = 1); declare (strict_types = 1);
namespace Phpml\Classification\Traits; namespace Phpml\Helper;
trait Trainable trait Trainable
{ {

View File

@ -4,10 +4,12 @@ declare (strict_types = 1);
namespace Phpml\Regression; namespace Phpml\Regression;
use Phpml\Helper\Predictable;
use Phpml\Math\Matrix; use Phpml\Math\Matrix;
class LeastSquares implements Regression class LeastSquares implements Regression
{ {
use Predictable;
/** /**
* @var array * @var array
*/ */
@ -45,7 +47,7 @@ class LeastSquares implements Regression
* *
* @return mixed * @return mixed
*/ */
public function predict($sample) public function predictSample(array $sample)
{ {
$result = $this->intercept; $result = $this->intercept;
foreach ($this->coefficients as $index => $coefficient) { foreach ($this->coefficients as $index => $coefficient) {

View File

@ -13,9 +13,9 @@ interface Regression
public function train(array $samples, array $targets); public function train(array $samples, array $targets);
/** /**
* @param float $sample * @param array $samples
* *
* @return mixed * @return mixed
*/ */
public function predict($sample); public function predict(array $samples);
} }

View File

@ -0,0 +1,31 @@
<?php
declare (strict_types = 1);
namespace Phpml\Regression;
use Phpml\SupportVectorMachine\Kernel;
use Phpml\SupportVectorMachine\SupportVectorMachine;
use Phpml\SupportVectorMachine\Type;
class SVR extends SupportVectorMachine implements Regression
{
/**
* @param int $kernel
* @param int $degree
* @param float $epsilon
* @param float $cost
* @param float|null $gamma
* @param float $coef0
* @param float $tolerance
* @param int $cacheSize
* @param bool $shrinking
*/
public function __construct(
int $kernel = Kernel::RBF, int $degree = 3, float $epsilon = 0.1, float $cost = 1.0,
float $gamma = null, float $coef0 = 0.0, float $tolerance = 0.001,
int $cacheSize = 100, bool $shrinking = true
) {
parent::__construct(Type::EPSILON_SVR, $kernel, $cost, 0.5, $degree, $gamma, $coef0, $epsilon, $tolerance, $cacheSize, $shrinking, false);
}
}

View File

@ -9,15 +9,19 @@ class DataTransformer
/** /**
* @param array $samples * @param array $samples
* @param array $labels * @param array $labels
* @param bool $targets
* *
* @return string * @return string
*/ */
public static function trainingSet(array $samples, array $labels): string public static function trainingSet(array $samples, array $labels, bool $targets = false): string
{ {
$set = ''; $set = '';
if (!$targets) {
$numericLabels = self::numericLabels($labels); $numericLabels = self::numericLabels($labels);
}
foreach ($labels as $index => $label) { foreach ($labels as $index => $label) {
$set .= sprintf('%s %s %s', $numericLabels[$label], self::sampleRow($samples[$index]), PHP_EOL); $set .= sprintf('%s %s %s', ($targets ? $label : $numericLabels[$label]), self::sampleRow($samples[$index]), PHP_EOL);
} }
return $set; return $set;

View File

@ -131,7 +131,7 @@ class SupportVectorMachine
public function train(array $samples, array $labels) public function train(array $samples, array $labels)
{ {
$this->labels = $labels; $this->labels = $labels;
$trainingSet = DataTransformer::trainingSet($samples, $labels); $trainingSet = DataTransformer::trainingSet($samples, $labels, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR]));
file_put_contents($trainingSetFileName = $this->varPath.uniqid(), $trainingSet); file_put_contents($trainingSetFileName = $this->varPath.uniqid(), $trainingSet);
$modelFileName = $trainingSetFileName.'-model'; $modelFileName = $trainingSetFileName.'-model';
@ -169,13 +169,17 @@ class SupportVectorMachine
$output = ''; $output = '';
exec(escapeshellcmd($command), $output); exec(escapeshellcmd($command), $output);
$rawPredictions = file_get_contents($outputFileName); $predictions = file_get_contents($outputFileName);
unlink($testSetFileName); unlink($testSetFileName);
unlink($modelFileName); unlink($modelFileName);
unlink($outputFileName); unlink($outputFileName);
$predictions = DataTransformer::predictions($rawPredictions, $this->labels); if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
$predictions = DataTransformer::predictions($predictions, $this->labels);
} else {
$predictions = explode(PHP_EOL, trim($predictions));
}
if (!is_array($samples[0])) { if (!is_array($samples[0])) {
return $predictions[0]; return $predictions[0];

View File

@ -0,0 +1,50 @@
<?php
declare (strict_types = 1);
namespace tests\Regression;
use Phpml\Regression\SVR;
use Phpml\SupportVectorMachine\Kernel;
class SVRTest extends \PHPUnit_Framework_TestCase
{
public function testPredictSingleFeatureSamples()
{
$delta = 0.01;
$samples = [[60], [61], [62], [63], [65]];
$targets = [3.1, 3.6, 3.8, 4, 4.1];
$regression = new SVR(Kernel::LINEAR);
$regression->train($samples, $targets);
$this->assertEquals(4.03, $regression->predict([64]), '', $delta);
$samples = [[9300], [10565], [15000], [15000], [17764], [57000], [65940], [73676], [77006], [93739], [146088], [153260]];
$targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025];
$regression = new SVR(Kernel::LINEAR);
$regression->train($samples, $targets);
$this->assertEquals(6236.12, $regression->predict([9300]), '', $delta);
$this->assertEquals(4718.29, $regression->predict([57000]), '', $delta);
$this->assertEquals(4081.69, $regression->predict([77006]), '', $delta);
$this->assertEquals(6236.12, $regression->predict([9300]), '', $delta);
$this->assertEquals(1655.26, $regression->predict([153260]), '', $delta);
}
public function testPredictMultiFeaturesSamples()
{
$delta = 0.01;
$samples = [[73676, 1996], [77006, 1998], [10565, 2000], [146088, 1995], [15000, 2001], [65940, 2000], [9300, 2000], [93739, 1996], [153260, 1994], [17764, 2002], [57000, 1998], [15000, 2000]];
$targets = [2000, 2750, 15500, 960, 4400, 8800, 7100, 2550, 1025, 5900, 4600, 4400];
$regression = new SVR(Kernel::LINEAR);
$regression->train($samples, $targets);
$this->assertEquals(4109.82, $regression->predict([60000, 1996]), '', $delta);
$this->assertEquals(4112.28, $regression->predict([60000, 2000]), '', $delta);
}
}