mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-15 17:57:11 +00:00
implement Least Squares Regression
This commit is contained in:
parent
cbec77d247
commit
80a712e8a8
18
src/Phpml/Math/Statistic/Mean.php
Normal file
18
src/Phpml/Math/Statistic/Mean.php
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare (strict_types = 1);
|
||||||
|
|
||||||
|
namespace Phpml\Math\Statistic;
|
||||||
|
|
||||||
|
class Mean
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @param array $a
|
||||||
|
*
|
||||||
|
* @return float
|
||||||
|
*/
|
||||||
|
public static function arithmetic(array $a)
|
||||||
|
{
|
||||||
|
return array_sum($a) / count($a);
|
||||||
|
}
|
||||||
|
}
|
@ -4,6 +4,10 @@ declare (strict_types = 1);
|
|||||||
|
|
||||||
namespace Phpml\Regression;
|
namespace Phpml\Regression;
|
||||||
|
|
||||||
|
use Phpml\Math\Statistic\Correlation;
|
||||||
|
use Phpml\Math\Statistic\StandardDeviation;
|
||||||
|
use Phpml\Math\Statistic\Mean;
|
||||||
|
|
||||||
class LeastSquares implements Regression
|
class LeastSquares implements Regression
|
||||||
{
|
{
|
||||||
/**
|
/**
|
||||||
@ -34,14 +38,35 @@ class LeastSquares implements Regression
|
|||||||
{
|
{
|
||||||
$this->features = $features;
|
$this->features = $features;
|
||||||
$this->targets = $targets;
|
$this->targets = $targets;
|
||||||
|
|
||||||
|
$this->computeSlope();
|
||||||
|
$this->computeIntercept();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param array $features
|
* @param float $feature
|
||||||
*
|
*
|
||||||
* @return mixed
|
* @return mixed
|
||||||
*/
|
*/
|
||||||
public function predict(array $features)
|
public function predict($feature)
|
||||||
{
|
{
|
||||||
|
return $this->intercept + ($this->slope * $feature);
|
||||||
|
}
|
||||||
|
|
||||||
|
private function computeSlope()
|
||||||
|
{
|
||||||
|
$correlation = Correlation::pearson($this->features, $this->targets);
|
||||||
|
$sdX = StandardDeviation::population($this->features);
|
||||||
|
$sdY = StandardDeviation::population($this->targets);
|
||||||
|
|
||||||
|
$this->slope = $correlation * ($sdY / $sdX);
|
||||||
|
}
|
||||||
|
|
||||||
|
private function computeIntercept()
|
||||||
|
{
|
||||||
|
$meanY = Mean::arithmetic($this->targets);
|
||||||
|
$meanX = Mean::arithmetic($this->features);
|
||||||
|
|
||||||
|
$this->intercept = $meanY - ($this->slope * $meanX);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,9 +13,9 @@ interface Regression
|
|||||||
public function train(array $features, array $targets);
|
public function train(array $features, array $targets);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param array $features
|
* @param float $feature
|
||||||
*
|
*
|
||||||
* @return mixed
|
* @return mixed
|
||||||
*/
|
*/
|
||||||
public function predict(array $features);
|
public function predict($feature);
|
||||||
}
|
}
|
||||||
|
37
tests/Phpml/Regression/LeastSquaresTest.php
Normal file
37
tests/Phpml/Regression/LeastSquaresTest.php
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare (strict_types = 1);
|
||||||
|
|
||||||
|
namespace tests\Regression;
|
||||||
|
|
||||||
|
use Phpml\Regression\LeastSquares;
|
||||||
|
|
||||||
|
class LeastSquaresTest extends \PHPUnit_Framework_TestCase
|
||||||
|
{
|
||||||
|
public function testPredictSingleFeature()
|
||||||
|
{
|
||||||
|
$delta = 0.01;
|
||||||
|
|
||||||
|
//https://www.easycalculation.com/analytical/learn-least-square-regression.php
|
||||||
|
$features = [60, 61, 62, 63, 65];
|
||||||
|
$targets = [3.1, 3.6, 3.8, 4, 4.1];
|
||||||
|
|
||||||
|
$regression = new LeastSquares();
|
||||||
|
$regression->train($features, $targets);
|
||||||
|
|
||||||
|
$this->assertEquals(4.06, $regression->predict(64), '', $delta);
|
||||||
|
|
||||||
|
//http://www.stat.wmich.edu/s216/book/node127.html
|
||||||
|
$features = [9300, 10565, 15000, 15000, 17764, 57000, 65940, 73676, 77006, 93739, 146088, 153260];
|
||||||
|
$targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025];
|
||||||
|
|
||||||
|
$regression = new LeastSquares();
|
||||||
|
$regression->train($features, $targets);
|
||||||
|
|
||||||
|
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta);
|
||||||
|
$this->assertEquals(5213.81, $regression->predict(57000), '', $delta);
|
||||||
|
$this->assertEquals(4188.13, $regression->predict(77006), '', $delta);
|
||||||
|
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta);
|
||||||
|
$this->assertEquals(278.66, $regression->predict(153260), '', $delta);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user