ls reg with error :(

This commit is contained in:
Arkadiusz Kondas 2016-04-29 23:03:08 +02:00
parent 3e4dc3ddf8
commit 9d74174a68
4 changed files with 91 additions and 32 deletions

View File

@ -10,6 +10,11 @@ use Phpml\Math\Statistic\Mean;
class LeastSquares implements Regression
{
/**
* @var array
*/
private $samples;
/**
* @var array
*/
@ -21,52 +26,86 @@ class LeastSquares implements Regression
private $targets;
/**
* @var float
* @var array
*/
private $slope;
private $slopes;
/**
* @var
* @var float
*/
private $intercept;
/**
* @param array $features
* @param array $samples
* @param array $targets
*/
public function train(array $features, array $targets)
public function train(array $samples, array $targets)
{
$this->features = $features;
$this->samples = $samples;
$this->targets = $targets;
$this->features = [];
$this->computeSlope();
$this->computeSlopes();
$this->computeIntercept();
}
/**
* @param float $feature
* @param float $sample
*
* @return mixed
*/
public function predict($feature)
public function predict($sample)
{
return $this->intercept + ($this->slope * $feature);
$result = $this->intercept;
foreach ($this->slopes as $index => $slope) {
$result += ($slope * $sample[$index]);
}
return $result;
}
private function computeSlope()
/**
* @return array
*/
public function getSlopes()
{
$correlation = Correlation::pearson($this->features, $this->targets);
$sdX = StandardDeviation::population($this->features);
return $this->slopes;
}
private function computeSlopes()
{
$features = count($this->samples[0]);
$sdY = StandardDeviation::population($this->targets);
$this->slope = $correlation * ($sdY / $sdX);
for($i=0; $i<$features; $i++) {
$correlation = Correlation::pearson($this->getFeatures($i), $this->targets);
$sdXi = StandardDeviation::population($this->getFeatures($i));
$this->slopes[] = $correlation * ($sdY / $sdXi);
}
}
private function computeIntercept()
{
$meanY = Mean::arithmetic($this->targets);
$meanX = Mean::arithmetic($this->features);
$this->intercept = Mean::arithmetic($this->targets);
foreach ($this->slopes as $index => $slope) {
$this->intercept -= $slope * Mean::arithmetic($this->getFeatures($index));
}
}
$this->intercept = $meanY - ($this->slope * $meanX);
/**
* @param $index
*
* @return array
*/
private function getFeatures($index)
{
if(!isset($this->features[$index])) {
$this->features[$index] = [];
foreach ($this->samples as $sample) {
$this->features[$index][] = $sample[$index];
}
}
return $this->features[$index];
}
}

View File

@ -7,15 +7,15 @@ namespace Phpml\Regression;
interface Regression
{
/**
* @param array $features
* @param array $samples
* @param array $targets
*/
public function train(array $features, array $targets);
public function train(array $samples, array $targets);
/**
* @param float $feature
* @param float $sample
*
* @return mixed
*/
public function predict($feature);
public function predict($sample);
}

View File

@ -21,6 +21,11 @@ class CorrelationTest extends \PHPUnit_Framework_TestCase
$x = [43, 21, 25, 42, 57, 59];
$y = [99, 65, 79, 75, 87, 82];
$this->assertEquals(0.549, Correlation::pearson($x, $y), '', $delta);
$delta = 0.001;
$x = [60, 61, 62, 63, 65];
$y = [3.1, 3.6, 3.8, 4, 4.1];
$this->assertEquals(0.911, Correlation::pearson($x, $y), '', $delta);
}
/**

View File

@ -8,30 +8,45 @@ use Phpml\Regression\LeastSquares;
class LeastSquaresTest extends \PHPUnit_Framework_TestCase
{
public function testPredictSingleFeature()
public function testPredictSingleFeatureSamples()
{
$delta = 0.01;
//https://www.easycalculation.com/analytical/learn-least-square-regression.php
$features = [60, 61, 62, 63, 65];
$samples = [[60], [61], [62], [63], [65]];
$targets = [3.1, 3.6, 3.8, 4, 4.1];
$regression = new LeastSquares();
$regression->train($features, $targets);
$regression->train($samples, $targets);
$this->assertEquals(4.06, $regression->predict(64), '', $delta);
$this->assertEquals(4.06, $regression->predict([64]), '', $delta);
//http://www.stat.wmich.edu/s216/book/node127.html
$features = [9300, 10565, 15000, 15000, 17764, 57000, 65940, 73676, 77006, 93739, 146088, 153260];
$samples = [[9300], [10565], [15000], [15000], [17764], [57000], [65940], [73676], [77006], [93739], [146088], [153260]];
$targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025];
$regression = new LeastSquares();
$regression->train($features, $targets);
$regression->train($samples, $targets);
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta);
$this->assertEquals(5213.81, $regression->predict(57000), '', $delta);
$this->assertEquals(4188.13, $regression->predict(77006), '', $delta);
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta);
$this->assertEquals(278.66, $regression->predict(153260), '', $delta);
$this->assertEquals(7659.35, $regression->predict([9300]), '', $delta);
$this->assertEquals(5213.81, $regression->predict([57000]), '', $delta);
$this->assertEquals(4188.13, $regression->predict([77006]), '', $delta);
$this->assertEquals(7659.35, $regression->predict([9300]), '', $delta);
$this->assertEquals(278.66, $regression->predict([153260]), '', $delta);
}
public function testPredictMultiFeaturesSamples()
{
$delta = 0.01;
//http://www.stat.wmich.edu/s216/book/node129.html
$samples = [[73676, 1996],[77006,1998],[ 10565, 2000],[146088, 1995],[ 15000, 2001],[ 65940, 2000],[ 9300, 2000],[ 93739, 1996],[153260, 1994],[ 17764, 2002],[ 57000, 1998],[ 15000, 2000]];
$targets = [2000, 2750, 15500, 960, 4400, 8800, 7100, 2550, 1025, 5900, 4600, 4400];
$regression = new LeastSquares();
$regression->train($samples, $targets);
$this->assertEquals(3807, $regression->predict([60000, 1996]), '', $delta);
}
}