From 430c1078cfbbdfe2343d1d057dbdb1b2beb7d5b5 Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Sat, 7 May 2016 23:04:58 +0200 Subject: [PATCH] implement support vector regression --- .../Classification/KNearestNeighbors.php | 4 +- src/Phpml/Classification/NaiveBayes.php | 4 +- src/Phpml/Classification/SVC.php | 3 +- .../Traits => Helper}/Predictable.php | 2 +- .../Traits => Helper}/Trainable.php | 2 +- src/Phpml/Regression/LeastSquares.php | 4 +- src/Phpml/Regression/Regression.php | 4 +- src/Phpml/Regression/SVR.php | 31 ++++++++++++ .../SupportVectorMachine/DataTransformer.php | 10 ++-- .../SupportVectorMachine.php | 10 ++-- tests/Phpml/Regression/SVRTest.php | 50 +++++++++++++++++++ 11 files changed, 108 insertions(+), 16 deletions(-) rename src/Phpml/{Classification/Traits => Helper}/Predictable.php (94%) rename src/Phpml/{Classification/Traits => Helper}/Trainable.php (90%) create mode 100644 src/Phpml/Regression/SVR.php create mode 100644 tests/Phpml/Regression/SVRTest.php diff --git a/src/Phpml/Classification/KNearestNeighbors.php b/src/Phpml/Classification/KNearestNeighbors.php index 93991ae..f1a87cf 100644 --- a/src/Phpml/Classification/KNearestNeighbors.php +++ b/src/Phpml/Classification/KNearestNeighbors.php @@ -4,8 +4,8 @@ declare (strict_types = 1); namespace Phpml\Classification; -use Phpml\Classification\Traits\Predictable; -use Phpml\Classification\Traits\Trainable; +use Phpml\Helper\Predictable; +use Phpml\Helper\Trainable; use Phpml\Math\Distance; use Phpml\Math\Distance\Euclidean; diff --git a/src/Phpml/Classification/NaiveBayes.php b/src/Phpml/Classification/NaiveBayes.php index ae98e1d..9726b40 100644 --- a/src/Phpml/Classification/NaiveBayes.php +++ b/src/Phpml/Classification/NaiveBayes.php @@ -4,8 +4,8 @@ declare (strict_types = 1); namespace Phpml\Classification; -use Phpml\Classification\Traits\Predictable; -use Phpml\Classification\Traits\Trainable; +use Phpml\Helper\Predictable; +use Phpml\Helper\Trainable; class NaiveBayes implements Classifier { diff --git a/src/Phpml/Classification/SVC.php b/src/Phpml/Classification/SVC.php index 8dcb28f..2350d5d 100644 --- a/src/Phpml/Classification/SVC.php +++ b/src/Phpml/Classification/SVC.php @@ -4,6 +4,7 @@ declare (strict_types = 1); namespace Phpml\Classification; +use Phpml\SupportVectorMachine\Kernel; use Phpml\SupportVectorMachine\SupportVectorMachine; use Phpml\SupportVectorMachine\Type; @@ -21,7 +22,7 @@ class SVC extends SupportVectorMachine implements Classifier * @param bool $probabilityEstimates */ public function __construct( - int $kernel, float $cost = 1.0, int $degree = 3, float $gamma = null, float $coef0 = 0.0, + int $kernel = Kernel::LINEAR, float $cost = 1.0, int $degree = 3, float $gamma = null, float $coef0 = 0.0, float $tolerance = 0.001, int $cacheSize = 100, bool $shrinking = true, bool $probabilityEstimates = false ) { diff --git a/src/Phpml/Classification/Traits/Predictable.php b/src/Phpml/Helper/Predictable.php similarity index 94% rename from src/Phpml/Classification/Traits/Predictable.php rename to src/Phpml/Helper/Predictable.php index 804b54a..4bf2a2e 100644 --- a/src/Phpml/Classification/Traits/Predictable.php +++ b/src/Phpml/Helper/Predictable.php @@ -2,7 +2,7 @@ declare (strict_types = 1); -namespace Phpml\Classification\Traits; +namespace Phpml\Helper; trait Predictable { diff --git a/src/Phpml/Classification/Traits/Trainable.php b/src/Phpml/Helper/Trainable.php similarity index 90% rename from src/Phpml/Classification/Traits/Trainable.php rename to src/Phpml/Helper/Trainable.php index 8fa97f2..36b8993 100644 --- a/src/Phpml/Classification/Traits/Trainable.php +++ b/src/Phpml/Helper/Trainable.php @@ -2,7 +2,7 @@ declare (strict_types = 1); -namespace Phpml\Classification\Traits; +namespace Phpml\Helper; trait Trainable { diff --git a/src/Phpml/Regression/LeastSquares.php b/src/Phpml/Regression/LeastSquares.php index cd0251f..83a6a65 100644 --- a/src/Phpml/Regression/LeastSquares.php +++ b/src/Phpml/Regression/LeastSquares.php @@ -4,10 +4,12 @@ declare (strict_types = 1); namespace Phpml\Regression; +use Phpml\Helper\Predictable; use Phpml\Math\Matrix; class LeastSquares implements Regression { + use Predictable; /** * @var array */ @@ -45,7 +47,7 @@ class LeastSquares implements Regression * * @return mixed */ - public function predict($sample) + public function predictSample(array $sample) { $result = $this->intercept; foreach ($this->coefficients as $index => $coefficient) { diff --git a/src/Phpml/Regression/Regression.php b/src/Phpml/Regression/Regression.php index a7837d4..12d2f52 100644 --- a/src/Phpml/Regression/Regression.php +++ b/src/Phpml/Regression/Regression.php @@ -13,9 +13,9 @@ interface Regression public function train(array $samples, array $targets); /** - * @param float $sample + * @param array $samples * * @return mixed */ - public function predict($sample); + public function predict(array $samples); } diff --git a/src/Phpml/Regression/SVR.php b/src/Phpml/Regression/SVR.php new file mode 100644 index 0000000..07b1459 --- /dev/null +++ b/src/Phpml/Regression/SVR.php @@ -0,0 +1,31 @@ + $label) { - $set .= sprintf('%s %s %s', $numericLabels[$label], self::sampleRow($samples[$index]), PHP_EOL); + $set .= sprintf('%s %s %s', ($targets ? $label : $numericLabels[$label]), self::sampleRow($samples[$index]), PHP_EOL); } return $set; diff --git a/src/Phpml/SupportVectorMachine/SupportVectorMachine.php b/src/Phpml/SupportVectorMachine/SupportVectorMachine.php index 7a47db7..ef52d29 100644 --- a/src/Phpml/SupportVectorMachine/SupportVectorMachine.php +++ b/src/Phpml/SupportVectorMachine/SupportVectorMachine.php @@ -131,7 +131,7 @@ class SupportVectorMachine public function train(array $samples, array $labels) { $this->labels = $labels; - $trainingSet = DataTransformer::trainingSet($samples, $labels); + $trainingSet = DataTransformer::trainingSet($samples, $labels, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR])); file_put_contents($trainingSetFileName = $this->varPath.uniqid(), $trainingSet); $modelFileName = $trainingSetFileName.'-model'; @@ -169,13 +169,17 @@ class SupportVectorMachine $output = ''; exec(escapeshellcmd($command), $output); - $rawPredictions = file_get_contents($outputFileName); + $predictions = file_get_contents($outputFileName); unlink($testSetFileName); unlink($modelFileName); unlink($outputFileName); - $predictions = DataTransformer::predictions($rawPredictions, $this->labels); + if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) { + $predictions = DataTransformer::predictions($predictions, $this->labels); + } else { + $predictions = explode(PHP_EOL, trim($predictions)); + } if (!is_array($samples[0])) { return $predictions[0]; diff --git a/tests/Phpml/Regression/SVRTest.php b/tests/Phpml/Regression/SVRTest.php new file mode 100644 index 0000000..d794062 --- /dev/null +++ b/tests/Phpml/Regression/SVRTest.php @@ -0,0 +1,50 @@ +train($samples, $targets); + + $this->assertEquals(4.03, $regression->predict([64]), '', $delta); + + $samples = [[9300], [10565], [15000], [15000], [17764], [57000], [65940], [73676], [77006], [93739], [146088], [153260]]; + $targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025]; + + $regression = new SVR(Kernel::LINEAR); + $regression->train($samples, $targets); + + $this->assertEquals(6236.12, $regression->predict([9300]), '', $delta); + $this->assertEquals(4718.29, $regression->predict([57000]), '', $delta); + $this->assertEquals(4081.69, $regression->predict([77006]), '', $delta); + $this->assertEquals(6236.12, $regression->predict([9300]), '', $delta); + $this->assertEquals(1655.26, $regression->predict([153260]), '', $delta); + } + + public function testPredictMultiFeaturesSamples() + { + $delta = 0.01; + + $samples = [[73676, 1996], [77006, 1998], [10565, 2000], [146088, 1995], [15000, 2001], [65940, 2000], [9300, 2000], [93739, 1996], [153260, 1994], [17764, 2002], [57000, 1998], [15000, 2000]]; + $targets = [2000, 2750, 15500, 960, 4400, 8800, 7100, 2550, 1025, 5900, 4600, 4400]; + + $regression = new SVR(Kernel::LINEAR); + $regression->train($samples, $targets); + + $this->assertEquals(4109.82, $regression->predict([60000, 1996]), '', $delta); + $this->assertEquals(4112.28, $regression->predict([60000, 2000]), '', $delta); + } +}