ls reg with error :(

This commit is contained in:
Arkadiusz Kondas 2016-04-29 23:03:08 +02:00
parent 3e4dc3ddf8
commit 9d74174a68
4 changed files with 91 additions and 32 deletions

View File

@ -10,6 +10,11 @@ use Phpml\Math\Statistic\Mean;
class LeastSquares implements Regression class LeastSquares implements Regression
{ {
/**
* @var array
*/
private $samples;
/** /**
* @var array * @var array
*/ */
@ -21,52 +26,86 @@ class LeastSquares implements Regression
private $targets; private $targets;
/** /**
* @var float * @var array
*/ */
private $slope; private $slopes;
/** /**
* @var * @var float
*/ */
private $intercept; private $intercept;
/** /**
* @param array $features * @param array $samples
* @param array $targets * @param array $targets
*/ */
public function train(array $features, array $targets) public function train(array $samples, array $targets)
{ {
$this->features = $features; $this->samples = $samples;
$this->targets = $targets; $this->targets = $targets;
$this->features = [];
$this->computeSlope(); $this->computeSlopes();
$this->computeIntercept(); $this->computeIntercept();
} }
/** /**
* @param float $feature * @param float $sample
* *
* @return mixed * @return mixed
*/ */
public function predict($feature) public function predict($sample)
{ {
return $this->intercept + ($this->slope * $feature); $result = $this->intercept;
foreach ($this->slopes as $index => $slope) {
$result += ($slope * $sample[$index]);
} }
private function computeSlope() return $result;
}
/**
* @return array
*/
public function getSlopes()
{ {
$correlation = Correlation::pearson($this->features, $this->targets); return $this->slopes;
$sdX = StandardDeviation::population($this->features); }
private function computeSlopes()
{
$features = count($this->samples[0]);
$sdY = StandardDeviation::population($this->targets); $sdY = StandardDeviation::population($this->targets);
$this->slope = $correlation * ($sdY / $sdX); for($i=0; $i<$features; $i++) {
$correlation = Correlation::pearson($this->getFeatures($i), $this->targets);
$sdXi = StandardDeviation::population($this->getFeatures($i));
$this->slopes[] = $correlation * ($sdY / $sdXi);
}
} }
private function computeIntercept() private function computeIntercept()
{ {
$meanY = Mean::arithmetic($this->targets); $this->intercept = Mean::arithmetic($this->targets);
$meanX = Mean::arithmetic($this->features); foreach ($this->slopes as $index => $slope) {
$this->intercept -= $slope * Mean::arithmetic($this->getFeatures($index));
}
}
$this->intercept = $meanY - ($this->slope * $meanX); /**
* @param $index
*
* @return array
*/
private function getFeatures($index)
{
if(!isset($this->features[$index])) {
$this->features[$index] = [];
foreach ($this->samples as $sample) {
$this->features[$index][] = $sample[$index];
}
}
return $this->features[$index];
} }
} }

View File

@ -7,15 +7,15 @@ namespace Phpml\Regression;
interface Regression interface Regression
{ {
/** /**
* @param array $features * @param array $samples
* @param array $targets * @param array $targets
*/ */
public function train(array $features, array $targets); public function train(array $samples, array $targets);
/** /**
* @param float $feature * @param float $sample
* *
* @return mixed * @return mixed
*/ */
public function predict($feature); public function predict($sample);
} }

View File

@ -21,6 +21,11 @@ class CorrelationTest extends \PHPUnit_Framework_TestCase
$x = [43, 21, 25, 42, 57, 59]; $x = [43, 21, 25, 42, 57, 59];
$y = [99, 65, 79, 75, 87, 82]; $y = [99, 65, 79, 75, 87, 82];
$this->assertEquals(0.549, Correlation::pearson($x, $y), '', $delta); $this->assertEquals(0.549, Correlation::pearson($x, $y), '', $delta);
$delta = 0.001;
$x = [60, 61, 62, 63, 65];
$y = [3.1, 3.6, 3.8, 4, 4.1];
$this->assertEquals(0.911, Correlation::pearson($x, $y), '', $delta);
} }
/** /**

View File

@ -8,30 +8,45 @@ use Phpml\Regression\LeastSquares;
class LeastSquaresTest extends \PHPUnit_Framework_TestCase class LeastSquaresTest extends \PHPUnit_Framework_TestCase
{ {
public function testPredictSingleFeature() public function testPredictSingleFeatureSamples()
{ {
$delta = 0.01; $delta = 0.01;
//https://www.easycalculation.com/analytical/learn-least-square-regression.php //https://www.easycalculation.com/analytical/learn-least-square-regression.php
$features = [60, 61, 62, 63, 65]; $samples = [[60], [61], [62], [63], [65]];
$targets = [3.1, 3.6, 3.8, 4, 4.1]; $targets = [3.1, 3.6, 3.8, 4, 4.1];
$regression = new LeastSquares(); $regression = new LeastSquares();
$regression->train($features, $targets); $regression->train($samples, $targets);
$this->assertEquals(4.06, $regression->predict(64), '', $delta); $this->assertEquals(4.06, $regression->predict([64]), '', $delta);
//http://www.stat.wmich.edu/s216/book/node127.html //http://www.stat.wmich.edu/s216/book/node127.html
$features = [9300, 10565, 15000, 15000, 17764, 57000, 65940, 73676, 77006, 93739, 146088, 153260]; $samples = [[9300], [10565], [15000], [15000], [17764], [57000], [65940], [73676], [77006], [93739], [146088], [153260]];
$targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025]; $targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025];
$regression = new LeastSquares(); $regression = new LeastSquares();
$regression->train($features, $targets); $regression->train($samples, $targets);
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta); $this->assertEquals(7659.35, $regression->predict([9300]), '', $delta);
$this->assertEquals(5213.81, $regression->predict(57000), '', $delta); $this->assertEquals(5213.81, $regression->predict([57000]), '', $delta);
$this->assertEquals(4188.13, $regression->predict(77006), '', $delta); $this->assertEquals(4188.13, $regression->predict([77006]), '', $delta);
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta); $this->assertEquals(7659.35, $regression->predict([9300]), '', $delta);
$this->assertEquals(278.66, $regression->predict(153260), '', $delta); $this->assertEquals(278.66, $regression->predict([153260]), '', $delta);
} }
public function testPredictMultiFeaturesSamples()
{
$delta = 0.01;
//http://www.stat.wmich.edu/s216/book/node129.html
$samples = [[73676, 1996],[77006,1998],[ 10565, 2000],[146088, 1995],[ 15000, 2001],[ 65940, 2000],[ 9300, 2000],[ 93739, 1996],[153260, 1994],[ 17764, 2002],[ 57000, 1998],[ 15000, 2000]];
$targets = [2000, 2750, 15500, 960, 4400, 8800, 7100, 2550, 1025, 5900, 4600, 4400];
$regression = new LeastSquares();
$regression->train($samples, $targets);
$this->assertEquals(3807, $regression->predict([60000, 1996]), '', $delta);
}
} }