mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-01-10 00:37:55 +00:00
implement Least Squares Regression
This commit is contained in:
parent
cbec77d247
commit
80a712e8a8
18
src/Phpml/Math/Statistic/Mean.php
Normal file
18
src/Phpml/Math/Statistic/Mean.php
Normal file
@ -0,0 +1,18 @@
|
||||
<?php
|
||||
|
||||
declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\Math\Statistic;
|
||||
|
||||
class Mean
|
||||
{
|
||||
/**
|
||||
* @param array $a
|
||||
*
|
||||
* @return float
|
||||
*/
|
||||
public static function arithmetic(array $a)
|
||||
{
|
||||
return array_sum($a) / count($a);
|
||||
}
|
||||
}
|
@ -4,6 +4,10 @@ declare (strict_types = 1);
|
||||
|
||||
namespace Phpml\Regression;
|
||||
|
||||
use Phpml\Math\Statistic\Correlation;
|
||||
use Phpml\Math\Statistic\StandardDeviation;
|
||||
use Phpml\Math\Statistic\Mean;
|
||||
|
||||
class LeastSquares implements Regression
|
||||
{
|
||||
/**
|
||||
@ -34,14 +38,35 @@ class LeastSquares implements Regression
|
||||
{
|
||||
$this->features = $features;
|
||||
$this->targets = $targets;
|
||||
|
||||
$this->computeSlope();
|
||||
$this->computeIntercept();
|
||||
}
|
||||
|
||||
/**
|
||||
* @param array $features
|
||||
* @param float $feature
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
public function predict(array $features)
|
||||
public function predict($feature)
|
||||
{
|
||||
return $this->intercept + ($this->slope * $feature);
|
||||
}
|
||||
|
||||
private function computeSlope()
|
||||
{
|
||||
$correlation = Correlation::pearson($this->features, $this->targets);
|
||||
$sdX = StandardDeviation::population($this->features);
|
||||
$sdY = StandardDeviation::population($this->targets);
|
||||
|
||||
$this->slope = $correlation * ($sdY / $sdX);
|
||||
}
|
||||
|
||||
private function computeIntercept()
|
||||
{
|
||||
$meanY = Mean::arithmetic($this->targets);
|
||||
$meanX = Mean::arithmetic($this->features);
|
||||
|
||||
$this->intercept = $meanY - ($this->slope * $meanX);
|
||||
}
|
||||
}
|
||||
|
@ -13,9 +13,9 @@ interface Regression
|
||||
public function train(array $features, array $targets);
|
||||
|
||||
/**
|
||||
* @param array $features
|
||||
* @param float $feature
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
public function predict(array $features);
|
||||
public function predict($feature);
|
||||
}
|
||||
|
37
tests/Phpml/Regression/LeastSquaresTest.php
Normal file
37
tests/Phpml/Regression/LeastSquaresTest.php
Normal file
@ -0,0 +1,37 @@
|
||||
<?php
|
||||
|
||||
declare (strict_types = 1);
|
||||
|
||||
namespace tests\Regression;
|
||||
|
||||
use Phpml\Regression\LeastSquares;
|
||||
|
||||
class LeastSquaresTest extends \PHPUnit_Framework_TestCase
|
||||
{
|
||||
public function testPredictSingleFeature()
|
||||
{
|
||||
$delta = 0.01;
|
||||
|
||||
//https://www.easycalculation.com/analytical/learn-least-square-regression.php
|
||||
$features = [60, 61, 62, 63, 65];
|
||||
$targets = [3.1, 3.6, 3.8, 4, 4.1];
|
||||
|
||||
$regression = new LeastSquares();
|
||||
$regression->train($features, $targets);
|
||||
|
||||
$this->assertEquals(4.06, $regression->predict(64), '', $delta);
|
||||
|
||||
//http://www.stat.wmich.edu/s216/book/node127.html
|
||||
$features = [9300, 10565, 15000, 15000, 17764, 57000, 65940, 73676, 77006, 93739, 146088, 153260];
|
||||
$targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025];
|
||||
|
||||
$regression = new LeastSquares();
|
||||
$regression->train($features, $targets);
|
||||
|
||||
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta);
|
||||
$this->assertEquals(5213.81, $regression->predict(57000), '', $delta);
|
||||
$this->assertEquals(4188.13, $regression->predict(77006), '', $delta);
|
||||
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta);
|
||||
$this->assertEquals(278.66, $regression->predict(153260), '', $delta);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user