From 9d74174a68f7a272963e2c680d440f4cd3dce9d4 Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Fri, 29 Apr 2016 23:03:08 +0200 Subject: [PATCH] ls reg with error :( --- src/Phpml/Regression/LeastSquares.php | 73 ++++++++++++++----- src/Phpml/Regression/Regression.php | 8 +- .../Phpml/Math/Statistic/CorrelationTest.php | 5 ++ tests/Phpml/Regression/LeastSquaresTest.php | 37 +++++++--- 4 files changed, 91 insertions(+), 32 deletions(-) diff --git a/src/Phpml/Regression/LeastSquares.php b/src/Phpml/Regression/LeastSquares.php index 793ae95..622c187 100644 --- a/src/Phpml/Regression/LeastSquares.php +++ b/src/Phpml/Regression/LeastSquares.php @@ -10,6 +10,11 @@ use Phpml\Math\Statistic\Mean; class LeastSquares implements Regression { + /** + * @var array + */ + private $samples; + /** * @var array */ @@ -21,52 +26,86 @@ class LeastSquares implements Regression private $targets; /** - * @var float + * @var array */ - private $slope; + private $slopes; /** - * @var + * @var float */ private $intercept; /** - * @param array $features + * @param array $samples * @param array $targets */ - public function train(array $features, array $targets) + public function train(array $samples, array $targets) { - $this->features = $features; + $this->samples = $samples; $this->targets = $targets; + $this->features = []; - $this->computeSlope(); + $this->computeSlopes(); $this->computeIntercept(); } /** - * @param float $feature + * @param float $sample * * @return mixed */ - public function predict($feature) + public function predict($sample) { - return $this->intercept + ($this->slope * $feature); + $result = $this->intercept; + foreach ($this->slopes as $index => $slope) { + $result += ($slope * $sample[$index]); + } + + return $result; } - private function computeSlope() + /** + * @return array + */ + public function getSlopes() { - $correlation = Correlation::pearson($this->features, $this->targets); - $sdX = StandardDeviation::population($this->features); + return $this->slopes; + } + + private function computeSlopes() + { + $features = count($this->samples[0]); $sdY = StandardDeviation::population($this->targets); - $this->slope = $correlation * ($sdY / $sdX); + for($i=0; $i<$features; $i++) { + $correlation = Correlation::pearson($this->getFeatures($i), $this->targets); + $sdXi = StandardDeviation::population($this->getFeatures($i)); + $this->slopes[] = $correlation * ($sdY / $sdXi); + } } private function computeIntercept() { - $meanY = Mean::arithmetic($this->targets); - $meanX = Mean::arithmetic($this->features); + $this->intercept = Mean::arithmetic($this->targets); + foreach ($this->slopes as $index => $slope) { + $this->intercept -= $slope * Mean::arithmetic($this->getFeatures($index)); + } + } - $this->intercept = $meanY - ($this->slope * $meanX); + /** + * @param $index + * + * @return array + */ + private function getFeatures($index) + { + if(!isset($this->features[$index])) { + $this->features[$index] = []; + foreach ($this->samples as $sample) { + $this->features[$index][] = $sample[$index]; + } + } + + return $this->features[$index]; } } diff --git a/src/Phpml/Regression/Regression.php b/src/Phpml/Regression/Regression.php index f1f5c8a..a7837d4 100644 --- a/src/Phpml/Regression/Regression.php +++ b/src/Phpml/Regression/Regression.php @@ -7,15 +7,15 @@ namespace Phpml\Regression; interface Regression { /** - * @param array $features + * @param array $samples * @param array $targets */ - public function train(array $features, array $targets); + public function train(array $samples, array $targets); /** - * @param float $feature + * @param float $sample * * @return mixed */ - public function predict($feature); + public function predict($sample); } diff --git a/tests/Phpml/Math/Statistic/CorrelationTest.php b/tests/Phpml/Math/Statistic/CorrelationTest.php index 492d38c..948dc16 100644 --- a/tests/Phpml/Math/Statistic/CorrelationTest.php +++ b/tests/Phpml/Math/Statistic/CorrelationTest.php @@ -21,6 +21,11 @@ class CorrelationTest extends \PHPUnit_Framework_TestCase $x = [43, 21, 25, 42, 57, 59]; $y = [99, 65, 79, 75, 87, 82]; $this->assertEquals(0.549, Correlation::pearson($x, $y), '', $delta); + + $delta = 0.001; + $x = [60, 61, 62, 63, 65]; + $y = [3.1, 3.6, 3.8, 4, 4.1]; + $this->assertEquals(0.911, Correlation::pearson($x, $y), '', $delta); } /** diff --git a/tests/Phpml/Regression/LeastSquaresTest.php b/tests/Phpml/Regression/LeastSquaresTest.php index eed7537..d5975d8 100644 --- a/tests/Phpml/Regression/LeastSquaresTest.php +++ b/tests/Phpml/Regression/LeastSquaresTest.php @@ -8,30 +8,45 @@ use Phpml\Regression\LeastSquares; class LeastSquaresTest extends \PHPUnit_Framework_TestCase { - public function testPredictSingleFeature() + public function testPredictSingleFeatureSamples() { $delta = 0.01; //https://www.easycalculation.com/analytical/learn-least-square-regression.php - $features = [60, 61, 62, 63, 65]; + $samples = [[60], [61], [62], [63], [65]]; $targets = [3.1, 3.6, 3.8, 4, 4.1]; $regression = new LeastSquares(); - $regression->train($features, $targets); + $regression->train($samples, $targets); - $this->assertEquals(4.06, $regression->predict(64), '', $delta); + $this->assertEquals(4.06, $regression->predict([64]), '', $delta); //http://www.stat.wmich.edu/s216/book/node127.html - $features = [9300, 10565, 15000, 15000, 17764, 57000, 65940, 73676, 77006, 93739, 146088, 153260]; + $samples = [[9300], [10565], [15000], [15000], [17764], [57000], [65940], [73676], [77006], [93739], [146088], [153260]]; $targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025]; $regression = new LeastSquares(); - $regression->train($features, $targets); + $regression->train($samples, $targets); - $this->assertEquals(7659.35, $regression->predict(9300), '', $delta); - $this->assertEquals(5213.81, $regression->predict(57000), '', $delta); - $this->assertEquals(4188.13, $regression->predict(77006), '', $delta); - $this->assertEquals(7659.35, $regression->predict(9300), '', $delta); - $this->assertEquals(278.66, $regression->predict(153260), '', $delta); + $this->assertEquals(7659.35, $regression->predict([9300]), '', $delta); + $this->assertEquals(5213.81, $regression->predict([57000]), '', $delta); + $this->assertEquals(4188.13, $regression->predict([77006]), '', $delta); + $this->assertEquals(7659.35, $regression->predict([9300]), '', $delta); + $this->assertEquals(278.66, $regression->predict([153260]), '', $delta); } + + public function testPredictMultiFeaturesSamples() + { + $delta = 0.01; + + //http://www.stat.wmich.edu/s216/book/node129.html + $samples = [[73676, 1996],[77006,1998],[ 10565, 2000],[146088, 1995],[ 15000, 2001],[ 65940, 2000],[ 9300, 2000],[ 93739, 1996],[153260, 1994],[ 17764, 2002],[ 57000, 1998],[ 15000, 2000]]; + $targets = [2000, 2750, 15500, 960, 4400, 8800, 7100, 2550, 1025, 5900, 4600, 4400]; + + $regression = new LeastSquares(); + $regression->train($samples, $targets); + + $this->assertEquals(3807, $regression->predict([60000, 1996]), '', $delta); + } + }