mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2025-01-10 00:37:55 +00:00
ls reg with error :(
This commit is contained in:
parent
3e4dc3ddf8
commit
9d74174a68
@ -10,6 +10,11 @@ use Phpml\Math\Statistic\Mean;
|
||||
|
||||
class LeastSquares implements Regression
|
||||
{
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $samples;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
@ -21,52 +26,86 @@ class LeastSquares implements Regression
|
||||
private $targets;
|
||||
|
||||
/**
|
||||
* @var float
|
||||
* @var array
|
||||
*/
|
||||
private $slope;
|
||||
private $slopes;
|
||||
|
||||
/**
|
||||
* @var
|
||||
* @var float
|
||||
*/
|
||||
private $intercept;
|
||||
|
||||
/**
|
||||
* @param array $features
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
public function train(array $features, array $targets)
|
||||
public function train(array $samples, array $targets)
|
||||
{
|
||||
$this->features = $features;
|
||||
$this->samples = $samples;
|
||||
$this->targets = $targets;
|
||||
$this->features = [];
|
||||
|
||||
$this->computeSlope();
|
||||
$this->computeSlopes();
|
||||
$this->computeIntercept();
|
||||
}
|
||||
|
||||
/**
|
||||
* @param float $feature
|
||||
* @param float $sample
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
public function predict($feature)
|
||||
public function predict($sample)
|
||||
{
|
||||
return $this->intercept + ($this->slope * $feature);
|
||||
$result = $this->intercept;
|
||||
foreach ($this->slopes as $index => $slope) {
|
||||
$result += ($slope * $sample[$index]);
|
||||
}
|
||||
|
||||
private function computeSlope()
|
||||
return $result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return array
|
||||
*/
|
||||
public function getSlopes()
|
||||
{
|
||||
$correlation = Correlation::pearson($this->features, $this->targets);
|
||||
$sdX = StandardDeviation::population($this->features);
|
||||
return $this->slopes;
|
||||
}
|
||||
|
||||
private function computeSlopes()
|
||||
{
|
||||
$features = count($this->samples[0]);
|
||||
$sdY = StandardDeviation::population($this->targets);
|
||||
|
||||
$this->slope = $correlation * ($sdY / $sdX);
|
||||
for($i=0; $i<$features; $i++) {
|
||||
$correlation = Correlation::pearson($this->getFeatures($i), $this->targets);
|
||||
$sdXi = StandardDeviation::population($this->getFeatures($i));
|
||||
$this->slopes[] = $correlation * ($sdY / $sdXi);
|
||||
}
|
||||
}
|
||||
|
||||
private function computeIntercept()
|
||||
{
|
||||
$meanY = Mean::arithmetic($this->targets);
|
||||
$meanX = Mean::arithmetic($this->features);
|
||||
$this->intercept = Mean::arithmetic($this->targets);
|
||||
foreach ($this->slopes as $index => $slope) {
|
||||
$this->intercept -= $slope * Mean::arithmetic($this->getFeatures($index));
|
||||
}
|
||||
}
|
||||
|
||||
$this->intercept = $meanY - ($this->slope * $meanX);
|
||||
/**
|
||||
* @param $index
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
private function getFeatures($index)
|
||||
{
|
||||
if(!isset($this->features[$index])) {
|
||||
$this->features[$index] = [];
|
||||
foreach ($this->samples as $sample) {
|
||||
$this->features[$index][] = $sample[$index];
|
||||
}
|
||||
}
|
||||
|
||||
return $this->features[$index];
|
||||
}
|
||||
}
|
||||
|
@ -7,15 +7,15 @@ namespace Phpml\Regression;
|
||||
interface Regression
|
||||
{
|
||||
/**
|
||||
* @param array $features
|
||||
* @param array $samples
|
||||
* @param array $targets
|
||||
*/
|
||||
public function train(array $features, array $targets);
|
||||
public function train(array $samples, array $targets);
|
||||
|
||||
/**
|
||||
* @param float $feature
|
||||
* @param float $sample
|
||||
*
|
||||
* @return mixed
|
||||
*/
|
||||
public function predict($feature);
|
||||
public function predict($sample);
|
||||
}
|
||||
|
@ -21,6 +21,11 @@ class CorrelationTest extends \PHPUnit_Framework_TestCase
|
||||
$x = [43, 21, 25, 42, 57, 59];
|
||||
$y = [99, 65, 79, 75, 87, 82];
|
||||
$this->assertEquals(0.549, Correlation::pearson($x, $y), '', $delta);
|
||||
|
||||
$delta = 0.001;
|
||||
$x = [60, 61, 62, 63, 65];
|
||||
$y = [3.1, 3.6, 3.8, 4, 4.1];
|
||||
$this->assertEquals(0.911, Correlation::pearson($x, $y), '', $delta);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -8,30 +8,45 @@ use Phpml\Regression\LeastSquares;
|
||||
|
||||
class LeastSquaresTest extends \PHPUnit_Framework_TestCase
|
||||
{
|
||||
public function testPredictSingleFeature()
|
||||
public function testPredictSingleFeatureSamples()
|
||||
{
|
||||
$delta = 0.01;
|
||||
|
||||
//https://www.easycalculation.com/analytical/learn-least-square-regression.php
|
||||
$features = [60, 61, 62, 63, 65];
|
||||
$samples = [[60], [61], [62], [63], [65]];
|
||||
$targets = [3.1, 3.6, 3.8, 4, 4.1];
|
||||
|
||||
$regression = new LeastSquares();
|
||||
$regression->train($features, $targets);
|
||||
$regression->train($samples, $targets);
|
||||
|
||||
$this->assertEquals(4.06, $regression->predict(64), '', $delta);
|
||||
$this->assertEquals(4.06, $regression->predict([64]), '', $delta);
|
||||
|
||||
//http://www.stat.wmich.edu/s216/book/node127.html
|
||||
$features = [9300, 10565, 15000, 15000, 17764, 57000, 65940, 73676, 77006, 93739, 146088, 153260];
|
||||
$samples = [[9300], [10565], [15000], [15000], [17764], [57000], [65940], [73676], [77006], [93739], [146088], [153260]];
|
||||
$targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025];
|
||||
|
||||
$regression = new LeastSquares();
|
||||
$regression->train($features, $targets);
|
||||
$regression->train($samples, $targets);
|
||||
|
||||
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta);
|
||||
$this->assertEquals(5213.81, $regression->predict(57000), '', $delta);
|
||||
$this->assertEquals(4188.13, $regression->predict(77006), '', $delta);
|
||||
$this->assertEquals(7659.35, $regression->predict(9300), '', $delta);
|
||||
$this->assertEquals(278.66, $regression->predict(153260), '', $delta);
|
||||
$this->assertEquals(7659.35, $regression->predict([9300]), '', $delta);
|
||||
$this->assertEquals(5213.81, $regression->predict([57000]), '', $delta);
|
||||
$this->assertEquals(4188.13, $regression->predict([77006]), '', $delta);
|
||||
$this->assertEquals(7659.35, $regression->predict([9300]), '', $delta);
|
||||
$this->assertEquals(278.66, $regression->predict([153260]), '', $delta);
|
||||
}
|
||||
|
||||
public function testPredictMultiFeaturesSamples()
|
||||
{
|
||||
$delta = 0.01;
|
||||
|
||||
//http://www.stat.wmich.edu/s216/book/node129.html
|
||||
$samples = [[73676, 1996],[77006,1998],[ 10565, 2000],[146088, 1995],[ 15000, 2001],[ 65940, 2000],[ 9300, 2000],[ 93739, 1996],[153260, 1994],[ 17764, 2002],[ 57000, 1998],[ 15000, 2000]];
|
||||
$targets = [2000, 2750, 15500, 960, 4400, 8800, 7100, 2550, 1025, 5900, 4600, 4400];
|
||||
|
||||
$regression = new LeastSquares();
|
||||
$regression->train($samples, $targets);
|
||||
|
||||
$this->assertEquals(3807, $regression->predict([60000, 1996]), '', $delta);
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user