From ff79de7e14cc4d009e488fddfaa8e3add9193ed8 Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Sat, 30 Apr 2016 13:54:01 +0200 Subject: [PATCH] better arguments format for regression --- src/Phpml/Math/Matrix.php | 25 ++++++++++++---- src/Phpml/Regression/LeastSquares.php | 32 +++++++++++++++++++-- tests/Phpml/Regression/LeastSquaresTest.php | 12 ++++---- 3 files changed, 55 insertions(+), 14 deletions(-) diff --git a/src/Phpml/Math/Matrix.php b/src/Phpml/Math/Matrix.php index a071fd9..aeab3bc 100644 --- a/src/Phpml/Math/Matrix.php +++ b/src/Phpml/Math/Matrix.php @@ -51,6 +51,21 @@ class Matrix $this->matrix = $matrix; } + /** + * @param array $array + * + * @return Matrix + */ + public static function fromFlatArray(array $array) + { + $matrix = []; + foreach ($array as $value) { + $matrix[] = [$value]; + } + + return new self($matrix); + } + /** * @return array */ @@ -115,16 +130,14 @@ class Matrix if ($this->rows == 1 && $this->columns == 1) { $determinant = $this->matrix[0][0]; } elseif ($this->rows == 2 && $this->columns == 2) { - $determinant = $this->matrix[0][0] * $this->matrix[1][1] - + $determinant = + $this->matrix[0][0] * $this->matrix[1][1] - $this->matrix[0][1] * $this->matrix[1][0]; } else { for ($j = 0; $j < $this->columns; ++$j) { $subMatrix = $this->crossOut(0, $j); - if (fmod($j, 2) == 0) { - $determinant += $this->matrix[0][$j] * $subMatrix->getDeterminant(); - } else { - $determinant -= $this->matrix[0][$j] * $subMatrix->getDeterminant(); - } + $minor = $this->matrix[0][$j] * $subMatrix->getDeterminant(); + $determinant += fmod($j, 2) == 0 ? $minor : -$minor; } } diff --git a/src/Phpml/Regression/LeastSquares.php b/src/Phpml/Regression/LeastSquares.php index af755c5..cd0251f 100644 --- a/src/Phpml/Regression/LeastSquares.php +++ b/src/Phpml/Regression/LeastSquares.php @@ -76,8 +76,8 @@ class LeastSquares implements Regression */ private function computeCoefficients() { - $samplesMatrix = new Matrix($this->samples); - $targetsMatrix = new Matrix($this->targets); + $samplesMatrix = $this->getSamplesMatrix(); + $targetsMatrix = $this->getTargetsMatrix(); $ts = $samplesMatrix->transpose()->multiply($samplesMatrix)->inverse(); $tf = $samplesMatrix->transpose()->multiply($targetsMatrix); @@ -85,4 +85,32 @@ class LeastSquares implements Regression $this->coefficients = $ts->multiply($tf)->getColumnValues(0); $this->intercept = array_shift($this->coefficients); } + + /** + * Add one dimension for intercept calculation. + * + * @return Matrix + */ + private function getSamplesMatrix() + { + $samples = []; + foreach ($this->samples as $sample) { + array_unshift($sample, 1); + $samples[] = $sample; + } + + return new Matrix($samples); + } + + /** + * @return Matrix + */ + private function getTargetsMatrix() + { + if (is_array($this->targets[0])) { + return new Matrix($this->targets); + } + + return Matrix::fromFlatArray($this->targets); + } } diff --git a/tests/Phpml/Regression/LeastSquaresTest.php b/tests/Phpml/Regression/LeastSquaresTest.php index 7544417..a9b4882 100644 --- a/tests/Phpml/Regression/LeastSquaresTest.php +++ b/tests/Phpml/Regression/LeastSquaresTest.php @@ -13,8 +13,8 @@ class LeastSquaresTest extends \PHPUnit_Framework_TestCase $delta = 0.01; //https://www.easycalculation.com/analytical/learn-least-square-regression.php - $samples = [[1, 60], [1, 61], [1, 62], [1, 63], [1, 65]]; - $targets = [[3.1], [3.6], [3.8], [4], [4.1]]; + $samples = [[60], [61], [62], [63], [65]]; + $targets = [3.1, 3.6, 3.8, 4, 4.1]; $regression = new LeastSquares(); $regression->train($samples, $targets); @@ -22,8 +22,8 @@ class LeastSquaresTest extends \PHPUnit_Framework_TestCase $this->assertEquals(4.06, $regression->predict([64]), '', $delta); //http://www.stat.wmich.edu/s216/book/node127.html - $samples = [[1, 9300], [1, 10565], [1, 15000], [1, 15000], [1, 17764], [1, 57000], [1, 65940], [1, 73676], [1, 77006], [1, 93739], [1, 146088], [1, 153260]]; - $targets = [[7100], [15500], [4400], [4400], [5900], [4600], [8800], [2000], [2750], [2550], [960], [1025]]; + $samples = [[9300], [10565], [15000], [15000], [17764], [57000], [65940], [73676], [77006], [93739], [146088], [153260]]; + $targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025]; $regression = new LeastSquares(); $regression->train($samples, $targets); @@ -40,8 +40,8 @@ class LeastSquaresTest extends \PHPUnit_Framework_TestCase $delta = 0.01; //http://www.stat.wmich.edu/s216/book/node129.html - $samples = [[1, 73676, 1996], [1, 77006, 1998], [1, 10565, 2000], [1, 146088, 1995], [1, 15000, 2001], [1, 65940, 2000], [1, 9300, 2000], [1, 93739, 1996], [1, 153260, 1994], [1, 17764, 2002], [1, 57000, 1998], [1, 15000, 2000]]; - $targets = [[2000], [2750], [15500], [960], [4400], [8800], [7100], [2550], [1025], [5900], [4600], [4400]]; + $samples = [[73676, 1996], [77006, 1998], [10565, 2000], [146088, 1995], [15000, 2001], [65940, 2000], [9300, 2000], [93739, 1996], [153260, 1994], [17764, 2002], [57000, 1998], [15000, 2000]]; + $targets = [2000, 2750, 15500, 960, 4400, 8800, 7100, 2550, 1025, 5900, 4600, 4400]; $regression = new LeastSquares(); $regression->train($samples, $targets);