2017-04-25 06:58:02 +00:00
|
|
|
<?php
|
|
|
|
|
|
|
|
declare(strict_types=1);
|
|
|
|
|
|
|
|
namespace Phpml\DimensionReduction;
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
use Exception;
|
2017-04-25 06:58:02 +00:00
|
|
|
use Phpml\Math\Matrix;
|
|
|
|
|
|
|
|
class LDA extends EigenTransformerBase
|
|
|
|
{
|
|
|
|
/**
|
|
|
|
* @var bool
|
|
|
|
*/
|
|
|
|
public $fit = false;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
public $labels = [];
|
2017-04-25 06:58:02 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
public $means = [];
|
2017-04-25 06:58:02 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
public $counts = [];
|
2017-04-25 06:58:02 +00:00
|
|
|
|
|
|
|
/**
|
2017-05-17 07:03:25 +00:00
|
|
|
* @var float[]
|
2017-04-25 06:58:02 +00:00
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
public $overallMean = [];
|
2017-04-25 06:58:02 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* Linear Discriminant Analysis (LDA) is used to reduce the dimensionality
|
|
|
|
* of the data. Unlike Principal Component Analysis (PCA), it is a supervised
|
|
|
|
* technique that requires the class labels in order to fit the data to a
|
|
|
|
* lower dimensional space. <br><br>
|
|
|
|
* The algorithm can be initialized by speciyfing
|
|
|
|
* either with the totalVariance(a value between 0.1 and 0.99)
|
|
|
|
* or numFeatures (number of features in the dataset) to be preserved.
|
|
|
|
*
|
|
|
|
* @param float|null $totalVariance Total explained variance to be preserved
|
2017-08-17 06:50:37 +00:00
|
|
|
* @param int|null $numFeatures Number of features to be preserved
|
2017-04-25 06:58:02 +00:00
|
|
|
*
|
|
|
|
* @throws \Exception
|
|
|
|
*/
|
2017-11-14 20:21:23 +00:00
|
|
|
public function __construct(?float $totalVariance = null, ?int $numFeatures = null)
|
2017-04-25 06:58:02 +00:00
|
|
|
{
|
|
|
|
if ($totalVariance !== null && ($totalVariance < 0.1 || $totalVariance > 0.99)) {
|
2017-11-22 21:16:10 +00:00
|
|
|
throw new Exception('Total variance can be a value between 0.1 and 0.99');
|
2017-04-25 06:58:02 +00:00
|
|
|
}
|
2017-11-22 21:16:10 +00:00
|
|
|
|
2017-04-25 06:58:02 +00:00
|
|
|
if ($numFeatures !== null && $numFeatures <= 0) {
|
2017-11-22 21:16:10 +00:00
|
|
|
throw new Exception('Number of features to be preserved should be greater than 0');
|
2017-04-25 06:58:02 +00:00
|
|
|
}
|
2017-11-22 21:16:10 +00:00
|
|
|
|
2017-04-25 06:58:02 +00:00
|
|
|
if ($totalVariance !== null && $numFeatures !== null) {
|
2017-11-22 21:16:10 +00:00
|
|
|
throw new Exception('Either totalVariance or numFeatures should be specified in order to run the algorithm');
|
2017-04-25 06:58:02 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if ($numFeatures !== null) {
|
|
|
|
$this->numFeatures = $numFeatures;
|
|
|
|
}
|
2017-11-22 21:16:10 +00:00
|
|
|
|
2017-04-25 06:58:02 +00:00
|
|
|
if ($totalVariance !== null) {
|
|
|
|
$this->totalVariance = $totalVariance;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Trains the algorithm to transform the given data to a lower dimensional space.
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
public function fit(array $data, array $classes): array
|
2017-04-25 06:58:02 +00:00
|
|
|
{
|
|
|
|
$this->labels = $this->getLabels($classes);
|
2017-08-17 06:50:37 +00:00
|
|
|
$this->means = $this->calculateMeans($data, $classes);
|
2017-04-25 06:58:02 +00:00
|
|
|
|
|
|
|
$sW = $this->calculateClassVar($data, $classes);
|
|
|
|
$sB = $this->calculateClassCov();
|
|
|
|
|
|
|
|
$S = $sW->inverse()->multiply($sB);
|
|
|
|
$this->eigenDecomposition($S->toArray());
|
|
|
|
|
|
|
|
$this->fit = true;
|
|
|
|
|
|
|
|
return $this->reduce($data);
|
|
|
|
}
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
/**
|
|
|
|
* Transforms the given sample to a lower dimensional vector by using
|
|
|
|
* the eigenVectors obtained in the last run of <code>fit</code>.
|
|
|
|
*
|
|
|
|
* @throws \Exception
|
|
|
|
*/
|
|
|
|
public function transform(array $sample): array
|
|
|
|
{
|
|
|
|
if (!$this->fit) {
|
|
|
|
throw new Exception('LDA has not been fitted with respect to original dataset, please run LDA::fit() first');
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!is_array($sample[0])) {
|
|
|
|
$sample = [$sample];
|
|
|
|
}
|
|
|
|
|
|
|
|
return $this->reduce($sample);
|
|
|
|
}
|
|
|
|
|
2017-04-25 06:58:02 +00:00
|
|
|
/**
|
|
|
|
* Returns unique labels in the dataset
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
protected function getLabels(array $classes): array
|
2017-04-25 06:58:02 +00:00
|
|
|
{
|
|
|
|
$counts = array_count_values($classes);
|
|
|
|
|
|
|
|
return array_keys($counts);
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Calculates mean of each column for each class and returns
|
|
|
|
* n by m matrix where n is number of labels and m is number of columns
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
protected function calculateMeans(array $data, array $classes): array
|
2017-04-25 06:58:02 +00:00
|
|
|
{
|
|
|
|
$means = [];
|
2017-08-17 06:50:37 +00:00
|
|
|
$counts = [];
|
2017-04-25 06:58:02 +00:00
|
|
|
$overallMean = array_fill(0, count($data[0]), 0.0);
|
|
|
|
|
|
|
|
foreach ($data as $index => $row) {
|
|
|
|
$label = array_search($classes[$index], $this->labels);
|
|
|
|
|
|
|
|
foreach ($row as $col => $val) {
|
2017-05-17 07:03:25 +00:00
|
|
|
if (!isset($means[$label][$col])) {
|
2017-04-25 06:58:02 +00:00
|
|
|
$means[$label][$col] = 0.0;
|
|
|
|
}
|
2017-11-22 21:16:10 +00:00
|
|
|
|
2017-04-25 06:58:02 +00:00
|
|
|
$means[$label][$col] += $val;
|
|
|
|
$overallMean[$col] += $val;
|
|
|
|
}
|
|
|
|
|
2017-05-17 07:03:25 +00:00
|
|
|
if (!isset($counts[$label])) {
|
2017-04-25 06:58:02 +00:00
|
|
|
$counts[$label] = 0;
|
|
|
|
}
|
2017-05-17 07:03:25 +00:00
|
|
|
|
|
|
|
++$counts[$label];
|
2017-04-25 06:58:02 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
foreach ($means as $index => $row) {
|
|
|
|
foreach ($row as $col => $sum) {
|
|
|
|
$means[$index][$col] = $sum / $counts[$index];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Calculate overall mean of the dataset for each column
|
|
|
|
$numElements = array_sum($counts);
|
|
|
|
$map = function ($el) use ($numElements) {
|
|
|
|
return $el / $numElements;
|
|
|
|
};
|
|
|
|
$this->overallMean = array_map($map, $overallMean);
|
|
|
|
$this->counts = $counts;
|
|
|
|
|
|
|
|
return $means;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Returns in-class scatter matrix for each class, which
|
|
|
|
* is a n by m matrix where n is number of classes and
|
|
|
|
* m is number of columns
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
protected function calculateClassVar(array $data, array $classes): Matrix
|
2017-04-25 06:58:02 +00:00
|
|
|
{
|
|
|
|
// s is an n (number of classes) by m (number of column) matrix
|
|
|
|
$s = array_fill(0, count($data[0]), array_fill(0, count($data[0]), 0));
|
|
|
|
$sW = new Matrix($s, false);
|
|
|
|
|
|
|
|
foreach ($data as $index => $row) {
|
|
|
|
$label = array_search($classes[$index], $this->labels);
|
|
|
|
$means = $this->means[$label];
|
|
|
|
|
|
|
|
$row = $this->calculateVar($row, $means);
|
|
|
|
|
|
|
|
$sW = $sW->add($row);
|
|
|
|
}
|
|
|
|
|
|
|
|
return $sW;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Returns between-class scatter matrix for each class, which
|
|
|
|
* is an n by m matrix where n is number of classes and
|
|
|
|
* m is number of columns
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
protected function calculateClassCov(): Matrix
|
2017-04-25 06:58:02 +00:00
|
|
|
{
|
|
|
|
// s is an n (number of classes) by m (number of column) matrix
|
|
|
|
$s = array_fill(0, count($this->overallMean), array_fill(0, count($this->overallMean), 0));
|
|
|
|
$sB = new Matrix($s, false);
|
|
|
|
|
|
|
|
foreach ($this->means as $index => $classMeans) {
|
|
|
|
$row = $this->calculateVar($classMeans, $this->overallMean);
|
|
|
|
$N = $this->counts[$index];
|
|
|
|
$sB = $sB->add($row->multiplyByScalar($N));
|
|
|
|
}
|
|
|
|
|
|
|
|
return $sB;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Returns the result of the calculation (x - m)T.(x - m)
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
protected function calculateVar(array $row, array $means): Matrix
|
2017-04-25 06:58:02 +00:00
|
|
|
{
|
|
|
|
$x = new Matrix($row, false);
|
|
|
|
$m = new Matrix($means, false);
|
|
|
|
$diff = $x->subtract($m);
|
|
|
|
|
|
|
|
return $diff->transpose()->multiply($diff);
|
|
|
|
}
|
|
|
|
}
|