diff --git a/src/Phpml/Math/Statistic/Mean.php b/src/Phpml/Math/Statistic/Mean.php new file mode 100644 index 0000000..2716b78 --- /dev/null +++ b/src/Phpml/Math/Statistic/Mean.php @@ -0,0 +1,18 @@ +features = $features; $this->targets = $targets; + + $this->computeSlope(); + $this->computeIntercept(); } /** - * @param array $features + * @param float $feature * * @return mixed */ - public function predict(array $features) + public function predict($feature) { + return $this->intercept + ($this->slope * $feature); + } + + private function computeSlope() + { + $correlation = Correlation::pearson($this->features, $this->targets); + $sdX = StandardDeviation::population($this->features); + $sdY = StandardDeviation::population($this->targets); + + $this->slope = $correlation * ($sdY / $sdX); + } + + private function computeIntercept() + { + $meanY = Mean::arithmetic($this->targets); + $meanX = Mean::arithmetic($this->features); + + $this->intercept = $meanY - ($this->slope * $meanX); } } diff --git a/src/Phpml/Regression/Regression.php b/src/Phpml/Regression/Regression.php index 34e0b6d..f1f5c8a 100644 --- a/src/Phpml/Regression/Regression.php +++ b/src/Phpml/Regression/Regression.php @@ -13,9 +13,9 @@ interface Regression public function train(array $features, array $targets); /** - * @param array $features + * @param float $feature * * @return mixed */ - public function predict(array $features); + public function predict($feature); } diff --git a/tests/Phpml/Regression/LeastSquaresTest.php b/tests/Phpml/Regression/LeastSquaresTest.php new file mode 100644 index 0000000..eed7537 --- /dev/null +++ b/tests/Phpml/Regression/LeastSquaresTest.php @@ -0,0 +1,37 @@ +train($features, $targets); + + $this->assertEquals(4.06, $regression->predict(64), '', $delta); + + //http://www.stat.wmich.edu/s216/book/node127.html + $features = [9300, 10565, 15000, 15000, 17764, 57000, 65940, 73676, 77006, 93739, 146088, 153260]; + $targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025]; + + $regression = new LeastSquares(); + $regression->train($features, $targets); + + $this->assertEquals(7659.35, $regression->predict(9300), '', $delta); + $this->assertEquals(5213.81, $regression->predict(57000), '', $delta); + $this->assertEquals(4188.13, $regression->predict(77006), '', $delta); + $this->assertEquals(7659.35, $regression->predict(9300), '', $delta); + $this->assertEquals(278.66, $regression->predict(153260), '', $delta); + } +}