implement Least Squares Regression

This commit is contained in:
Arkadiusz Kondas 2016-04-27 23:51:14 +02:00
parent cbec77d247
commit 80a712e8a8
4 changed files with 84 additions and 4 deletions

View File

@ -0,0 +1,18 @@
<?php
declare (strict_types = 1);
namespace Phpml\Math\Statistic;
class Mean
{
/**
* @param array $a
*
* @return float
*/
public static function arithmetic(array $a)
{
return array_sum($a) / count($a);
}
}

View File

@ -4,6 +4,10 @@ declare (strict_types = 1);
namespace Phpml\Regression;
use Phpml\Math\Statistic\Correlation;
use Phpml\Math\Statistic\StandardDeviation;
use Phpml\Math\Statistic\Mean;
class LeastSquares implements Regression
{
/**
@ -34,14 +38,35 @@ class LeastSquares implements Regression
{
$this->features = $features;
$this->targets = $targets;
$this->computeSlope();
$this->computeIntercept();
}
/**
* @param array $features
* @param float $feature
*
* @return mixed
*/
public function predict(array $features)
public function predict($feature)
{
return $this->intercept + ($this->slope * $feature);
}
private function computeSlope()
{
$correlation = Correlation::pearson($this->features, $this->targets);
$sdX = StandardDeviation::population($this->features);
$sdY = StandardDeviation::population($this->targets);
$this->slope = $correlation * ($sdY / $sdX);
}
private function computeIntercept()
{
$meanY = Mean::arithmetic($this->targets);
$meanX = Mean::arithmetic($this->features);
$this->intercept = $meanY - ($this->slope * $meanX);
}
}

View File

@ -13,9 +13,9 @@ interface Regression
public function train(array $features, array $targets);
/**
* @param array $features
* @param float $feature
*
* @return mixed
*/
public function predict(array $features);
public function predict($feature);
}

View File

@ -0,0 +1,37 @@
<?php
declare (strict_types = 1);
namespace tests\Regression;
use Phpml\Regression\LeastSquares;
class LeastSquaresTest extends \PHPUnit_Framework_TestCase
{
public function testPredictSingleFeature()
{
$delta = 0.01;
//https://www.easycalculation.com/analytical/learn-least-square-regression.php
$features = [60, 61, 62, 63, 65];
$targets = [3.1, 3.6, 3.8, 4, 4.1];
$regression = new LeastSquares();
$regression->train($features, $targets);
$this->assertEquals(4.06, $regression->predict(64), '', $delta);
//http://www.stat.wmich.edu/s216/book/node127.html
$features = [9300, 10565, 15000, 15000, 17764, 57000, 65940, 73676, 77006, 93739, 146088, 153260];
$targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025];
$regression = new LeastSquares();
$regression->train($features, $targets);
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta);
$this->assertEquals(5213.81, $regression->predict(57000), '', $delta);
$this->assertEquals(4188.13, $regression->predict(77006), '', $delta);
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta);
$this->assertEquals(278.66, $regression->predict(153260), '', $delta);
}
}