mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-11 08:10:56 +00:00
DecisionTree and Fuzzy C Means classifiers (#35)
* Fuzzy C-Means implementation * Update FuzzyCMeans * Rename FuzzyCMeans to FuzzyCMeans.php * Update NaiveBayes.php * Small fix applied to improve training performance array_unique is replaced with array_count_values+array_keys which is way faster * Revert "Small fix applied to improve training performance" This reverts commit c20253f16ac3e8c37d33ecaee28a87cc767e3b7f. * Revert "Revert "Small fix applied to improve training performance"" This reverts commit ea10e136c4c11b71609ccdcaf9999067e4be473e. * Revert "Small fix applied to improve training performance" This reverts commit c20253f16ac3e8c37d33ecaee28a87cc767e3b7f. * DecisionTree * FCM Test * FCM Test * DecisionTree Test
This commit is contained in:
parent
95fc139170
commit
87396ebe58
274
src/Phpml/Classification/DecisionTree.php
Normal file
274
src/Phpml/Classification/DecisionTree.php
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Classification;
|
||||||
|
|
||||||
|
use Phpml\Helper\Predictable;
|
||||||
|
use Phpml\Helper\Trainable;
|
||||||
|
use Phpml\Math\Statistic\Mean;
|
||||||
|
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
|
||||||
|
|
||||||
|
class DecisionTree implements Classifier
|
||||||
|
{
|
||||||
|
use Trainable, Predictable;
|
||||||
|
|
||||||
|
const CONTINUOS = 1;
|
||||||
|
const NOMINAL = 2;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
private $samples = array();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
private $columnTypes;
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
private $labels = array();
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
private $featureCount = 0;
|
||||||
|
/**
|
||||||
|
* @var DecisionTreeLeaf
|
||||||
|
*/
|
||||||
|
private $tree = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
private $maxDepth;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
public $actualDepth = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param int $maxDepth
|
||||||
|
*/
|
||||||
|
public function __construct($maxDepth = 10)
|
||||||
|
{
|
||||||
|
$this->maxDepth = $maxDepth;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @param array $samples
|
||||||
|
* @param array $targets
|
||||||
|
*/
|
||||||
|
public function train(array $samples, array $targets)
|
||||||
|
{
|
||||||
|
$this->featureCount = count($samples[0]);
|
||||||
|
$this->columnTypes = $this->getColumnTypes($samples);
|
||||||
|
$this->samples = $samples;
|
||||||
|
$this->targets = $targets;
|
||||||
|
$this->labels = array_keys(array_count_values($targets));
|
||||||
|
$this->tree = $this->getSplitLeaf(range(0, count($samples) - 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected function getColumnTypes(array $samples)
|
||||||
|
{
|
||||||
|
$types = [];
|
||||||
|
for ($i=0; $i<$this->featureCount; $i++) {
|
||||||
|
$values = array_column($samples, $i);
|
||||||
|
$isCategorical = $this->isCategoricalColumn($values);
|
||||||
|
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS;
|
||||||
|
}
|
||||||
|
return $types;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param null|array $records
|
||||||
|
* @return DecisionTreeLeaf
|
||||||
|
*/
|
||||||
|
protected function getSplitLeaf($records, $depth = 0)
|
||||||
|
{
|
||||||
|
$split = $this->getBestSplit($records);
|
||||||
|
$split->level = $depth;
|
||||||
|
if ($this->actualDepth < $depth) {
|
||||||
|
$this->actualDepth = $depth;
|
||||||
|
}
|
||||||
|
$leftRecords = [];
|
||||||
|
$rightRecords= [];
|
||||||
|
$remainingTargets = [];
|
||||||
|
$prevRecord = null;
|
||||||
|
$allSame = true;
|
||||||
|
foreach ($records as $recordNo) {
|
||||||
|
$record = $this->samples[$recordNo];
|
||||||
|
if ($prevRecord && $prevRecord != $record) {
|
||||||
|
$allSame = false;
|
||||||
|
}
|
||||||
|
$prevRecord = $record;
|
||||||
|
if ($split->evaluate($record)) {
|
||||||
|
$leftRecords[] = $recordNo;
|
||||||
|
} else {
|
||||||
|
$rightRecords[]= $recordNo;
|
||||||
|
}
|
||||||
|
$target = $this->targets[$recordNo];
|
||||||
|
if (! in_array($target, $remainingTargets)) {
|
||||||
|
$remainingTargets[] = $target;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (count($remainingTargets) == 1 || $allSame || $depth >= $this->maxDepth) {
|
||||||
|
$split->isTerminal = 1;
|
||||||
|
$classes = array_count_values($remainingTargets);
|
||||||
|
arsort($classes);
|
||||||
|
$split->classValue = key($classes);
|
||||||
|
} else {
|
||||||
|
if ($leftRecords) {
|
||||||
|
$split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
|
||||||
|
}
|
||||||
|
if ($rightRecords) {
|
||||||
|
$split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return $split;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param array $records
|
||||||
|
* @return DecisionTreeLeaf[]
|
||||||
|
*/
|
||||||
|
protected function getBestSplit($records)
|
||||||
|
{
|
||||||
|
$targets = array_intersect_key($this->targets, array_flip($records));
|
||||||
|
$samples = array_intersect_key($this->samples, array_flip($records));
|
||||||
|
$samples = array_combine($records, $this->preprocess($samples));
|
||||||
|
$bestGiniVal = 1;
|
||||||
|
$bestSplit = null;
|
||||||
|
for ($i=0; $i<$this->featureCount; $i++) {
|
||||||
|
$colValues = [];
|
||||||
|
$baseValue = null;
|
||||||
|
foreach ($samples as $index => $row) {
|
||||||
|
$colValues[$index] = $row[$i];
|
||||||
|
if ($baseValue === null) {
|
||||||
|
$baseValue = $row[$i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
$gini = $this->getGiniIndex($baseValue, $colValues, $targets);
|
||||||
|
if ($bestSplit == null || $bestGiniVal > $gini) {
|
||||||
|
$split = new DecisionTreeLeaf();
|
||||||
|
$split->value = $baseValue;
|
||||||
|
$split->giniIndex = $gini;
|
||||||
|
$split->columnIndex = $i;
|
||||||
|
$split->records = $records;
|
||||||
|
$bestSplit = $split;
|
||||||
|
$bestGiniVal = $gini;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return $bestSplit;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param string $baseValue
|
||||||
|
* @param array $colValues
|
||||||
|
* @param array $targets
|
||||||
|
*/
|
||||||
|
public function getGiniIndex($baseValue, $colValues, $targets)
|
||||||
|
{
|
||||||
|
$countMatrix = [];
|
||||||
|
foreach ($this->labels as $label) {
|
||||||
|
$countMatrix[$label] = [0, 0];
|
||||||
|
}
|
||||||
|
foreach ($colValues as $index => $value) {
|
||||||
|
$label = $targets[$index];
|
||||||
|
$rowIndex = $value == $baseValue ? 0 : 1;
|
||||||
|
$countMatrix[$label][$rowIndex]++;
|
||||||
|
}
|
||||||
|
$giniParts = [0, 0];
|
||||||
|
for ($i=0; $i<=1; $i++) {
|
||||||
|
$part = 0;
|
||||||
|
$sum = array_sum(array_column($countMatrix, $i));
|
||||||
|
if ($sum > 0) {
|
||||||
|
foreach ($this->labels as $label) {
|
||||||
|
$part += pow($countMatrix[$label][$i] / floatval($sum), 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
$giniParts[$i] = (1 - $part) * $sum;
|
||||||
|
}
|
||||||
|
return array_sum($giniParts) / count($colValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param array $samples
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
protected function preprocess(array $samples)
|
||||||
|
{
|
||||||
|
// Detect and convert continuous data column values into
|
||||||
|
// discrete values by using the median as a threshold value
|
||||||
|
$columns = array();
|
||||||
|
for ($i=0; $i<$this->featureCount; $i++) {
|
||||||
|
$values = array_column($samples, $i);
|
||||||
|
if ($this->columnTypes[$i] == self::CONTINUOS) {
|
||||||
|
$median = Mean::median($values);
|
||||||
|
foreach ($values as &$value) {
|
||||||
|
if ($value <= $median) {
|
||||||
|
$value = "<= $median";
|
||||||
|
} else {
|
||||||
|
$value = "> $median";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
$columns[] = $values;
|
||||||
|
}
|
||||||
|
// Below method is a strange yet very simple & efficient method
|
||||||
|
// to get the transpose of a 2D array
|
||||||
|
return array_map(null, ...$columns);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param array $columnValues
|
||||||
|
* @return bool
|
||||||
|
*/
|
||||||
|
protected function isCategoricalColumn(array $columnValues)
|
||||||
|
{
|
||||||
|
$count = count($columnValues);
|
||||||
|
// There are two main indicators that *may* show whether a
|
||||||
|
// column is composed of discrete set of values:
|
||||||
|
// 1- Column may contain string values
|
||||||
|
// 2- Number of unique values in the column is only a small fraction of
|
||||||
|
// all values in that column (Lower than or equal to %20 of all values)
|
||||||
|
$numericValues = array_filter($columnValues, 'is_numeric');
|
||||||
|
if (count($numericValues) != $count) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
$distinctValues = array_count_values($columnValues);
|
||||||
|
if (count($distinctValues) <= $count / 5) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return string
|
||||||
|
*/
|
||||||
|
public function getHtml()
|
||||||
|
{
|
||||||
|
return $this->tree->__toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param array $sample
|
||||||
|
* @return mixed
|
||||||
|
*/
|
||||||
|
protected function predictSample(array $sample)
|
||||||
|
{
|
||||||
|
$node = $this->tree;
|
||||||
|
do {
|
||||||
|
if ($node->isTerminal) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if ($node->evaluate($sample)) {
|
||||||
|
$node = $node->leftLeaf;
|
||||||
|
} else {
|
||||||
|
$node = $node->rightLeaf;
|
||||||
|
}
|
||||||
|
} while ($node);
|
||||||
|
return $node->classValue;
|
||||||
|
}
|
||||||
|
}
|
106
src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
Normal file
106
src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
<?php
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Classification\DecisionTree;
|
||||||
|
|
||||||
|
class DecisionTreeLeaf
|
||||||
|
{
|
||||||
|
const OPERATOR_EQ = '=';
|
||||||
|
/**
|
||||||
|
* @var string
|
||||||
|
*/
|
||||||
|
public $value;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
public $columnIndex;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var DecisionTreeLeaf
|
||||||
|
*/
|
||||||
|
public $leftLeaf = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var DecisionTreeLeaf
|
||||||
|
*/
|
||||||
|
public $rightLeaf= null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
public $records = [];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Class value represented by the leaf, this value is non-empty
|
||||||
|
* only for terminal leaves
|
||||||
|
*
|
||||||
|
* @var string
|
||||||
|
*/
|
||||||
|
public $classValue = '';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var bool
|
||||||
|
*/
|
||||||
|
public $isTerminal = false;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var float
|
||||||
|
*/
|
||||||
|
public $giniIndex = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
public $level = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param array $record
|
||||||
|
* @return bool
|
||||||
|
*/
|
||||||
|
public function evaluate($record)
|
||||||
|
{
|
||||||
|
$recordField = $record[$this->columnIndex];
|
||||||
|
if (preg_match("/^([<>=]{1,2})\s*(.*)/", $this->value, $matches)) {
|
||||||
|
$op = $matches[1];
|
||||||
|
$value= floatval($matches[2]);
|
||||||
|
$recordField = strval($recordField);
|
||||||
|
eval("\$result = $recordField $op $value;");
|
||||||
|
return $result;
|
||||||
|
}
|
||||||
|
return $recordField == $this->value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function __toString()
|
||||||
|
{
|
||||||
|
if ($this->isTerminal) {
|
||||||
|
$value = "<b>$this->classValue</b>";
|
||||||
|
} else {
|
||||||
|
$value = $this->value;
|
||||||
|
$col = "col_$this->columnIndex";
|
||||||
|
if (! preg_match("/^[<>=]{1,2}/", $value)) {
|
||||||
|
$value = "=$value";
|
||||||
|
}
|
||||||
|
$value = "<b>$col $value</b><br>Gini: ". number_format($this->giniIndex, 2);
|
||||||
|
}
|
||||||
|
$str = "<table ><tr><td colspan=3 align=center style='border:1px solid;'>
|
||||||
|
$value</td></tr>";
|
||||||
|
if ($this->leftLeaf || $this->rightLeaf) {
|
||||||
|
$str .='<tr>';
|
||||||
|
if ($this->leftLeaf) {
|
||||||
|
$str .="<td valign=top><b>| Yes</b><br>$this->leftLeaf</td>";
|
||||||
|
} else {
|
||||||
|
$str .='<td></td>';
|
||||||
|
}
|
||||||
|
$str .='<td> </td>';
|
||||||
|
if ($this->rightLeaf) {
|
||||||
|
$str .="<td valign=top align=right><b>No |</b><br>$this->rightLeaf</td>";
|
||||||
|
} else {
|
||||||
|
$str .='<td></td>';
|
||||||
|
}
|
||||||
|
$str .= '</tr>';
|
||||||
|
}
|
||||||
|
$str .= '</table>';
|
||||||
|
return $str;
|
||||||
|
}
|
||||||
|
}
|
@ -68,8 +68,8 @@ class NaiveBayes implements Classifier
|
|||||||
$this->sampleCount = count($samples);
|
$this->sampleCount = count($samples);
|
||||||
$this->featureCount = count($samples[0]);
|
$this->featureCount = count($samples[0]);
|
||||||
|
|
||||||
$this->labels = $targets;
|
$labelCounts = array_count_values($targets);
|
||||||
array_unique($this->labels);
|
$this->labels = array_keys($labelCounts);
|
||||||
foreach ($this->labels as $label) {
|
foreach ($this->labels as $label) {
|
||||||
$samples = $this->getSamplesByLabel($label);
|
$samples = $this->getSamplesByLabel($label);
|
||||||
$this->p[$label] = count($samples) / $this->sampleCount;
|
$this->p[$label] = count($samples) / $this->sampleCount;
|
||||||
@ -165,13 +165,6 @@ class NaiveBayes implements Classifier
|
|||||||
*/
|
*/
|
||||||
protected function predictSample(array $sample)
|
protected function predictSample(array $sample)
|
||||||
{
|
{
|
||||||
$isArray = is_array($sample[0]);
|
|
||||||
$samples = $sample;
|
|
||||||
if (!$isArray) {
|
|
||||||
$samples = array($sample);
|
|
||||||
}
|
|
||||||
$samplePredictions = array();
|
|
||||||
foreach ($samples as $sample) {
|
|
||||||
// Use NaiveBayes assumption for each label using:
|
// Use NaiveBayes assumption for each label using:
|
||||||
// P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
|
// P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
|
||||||
// Then compare probability for each class to determine which label is most likely
|
// Then compare probability for each class to determine which label is most likely
|
||||||
@ -186,11 +179,6 @@ class NaiveBayes implements Classifier
|
|||||||
}
|
}
|
||||||
arsort($predictions, SORT_NUMERIC);
|
arsort($predictions, SORT_NUMERIC);
|
||||||
reset($predictions);
|
reset($predictions);
|
||||||
$samplePredictions[] = key($predictions);
|
return key($predictions);
|
||||||
}
|
|
||||||
if (! $isArray) {
|
|
||||||
return $samplePredictions[0];
|
|
||||||
}
|
|
||||||
return $samplePredictions;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
242
src/Phpml/Clustering/FuzzyCMeans.php
Normal file
242
src/Phpml/Clustering/FuzzyCMeans.php
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
<?php
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Clustering;
|
||||||
|
|
||||||
|
use Phpml\Clustering\KMeans\Point;
|
||||||
|
use Phpml\Clustering\KMeans\Cluster;
|
||||||
|
use Phpml\Clustering\KMeans\Space;
|
||||||
|
use Phpml\Math\Distance\Euclidean;
|
||||||
|
|
||||||
|
class FuzzyCMeans implements Clusterer
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
private $clustersNumber;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array|Cluster[]
|
||||||
|
*/
|
||||||
|
private $clusters = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var Space
|
||||||
|
*/
|
||||||
|
private $space;
|
||||||
|
/**
|
||||||
|
* @var array|float[][]
|
||||||
|
*/
|
||||||
|
private $membership;
|
||||||
|
/**
|
||||||
|
* @var float
|
||||||
|
*/
|
||||||
|
private $fuzziness;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var float
|
||||||
|
*/
|
||||||
|
private $epsilon;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
private $maxIterations;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
private $sampleCount;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
private $samples;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param int $clustersNumber
|
||||||
|
*
|
||||||
|
* @throws InvalidArgumentException
|
||||||
|
*/
|
||||||
|
public function __construct(int $clustersNumber, float $fuzziness = 2.0, float $epsilon = 1e-2, int $maxIterations = 100)
|
||||||
|
{
|
||||||
|
if ($clustersNumber <= 0) {
|
||||||
|
throw InvalidArgumentException::invalidClustersNumber();
|
||||||
|
}
|
||||||
|
$this->clustersNumber = $clustersNumber;
|
||||||
|
$this->fuzziness = $fuzziness;
|
||||||
|
$this->epsilon = $epsilon;
|
||||||
|
$this->maxIterations = $maxIterations;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected function initClusters()
|
||||||
|
{
|
||||||
|
// Membership array is a matrix of cluster number by sample counts
|
||||||
|
// We initilize the membership array with random values
|
||||||
|
$dim = $this->space->getDimension();
|
||||||
|
$this->generateRandomMembership($dim, $this->sampleCount);
|
||||||
|
$this->updateClusters();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param int $rows
|
||||||
|
* @param int $cols
|
||||||
|
*/
|
||||||
|
protected function generateRandomMembership(int $rows, int $cols)
|
||||||
|
{
|
||||||
|
$this->membership = [];
|
||||||
|
for ($i=0; $i < $rows; $i++) {
|
||||||
|
$row = [];
|
||||||
|
$total = 0.0;
|
||||||
|
for ($k=0; $k < $cols; $k++) {
|
||||||
|
$val = rand(1, 5) / 10.0;
|
||||||
|
$row[] = $val;
|
||||||
|
$total += $val;
|
||||||
|
}
|
||||||
|
$this->membership[] = array_map(function ($val) use ($total) {
|
||||||
|
return $val / $total;
|
||||||
|
}, $row);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected function updateClusters()
|
||||||
|
{
|
||||||
|
$dim = $this->space->getDimension();
|
||||||
|
if (! $this->clusters) {
|
||||||
|
$this->clusters = [];
|
||||||
|
for ($i=0; $i<$this->clustersNumber; $i++) {
|
||||||
|
$this->clusters[] = new Cluster($this->space, array_fill(0, $dim, 0.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for ($i=0; $i<$this->clustersNumber; $i++) {
|
||||||
|
$cluster = $this->clusters[$i];
|
||||||
|
$center = $cluster->getCoordinates();
|
||||||
|
for ($k=0; $k<$dim; $k++) {
|
||||||
|
$a = $this->getMembershipRowTotal($i, $k, true);
|
||||||
|
$b = $this->getMembershipRowTotal($i, $k, false);
|
||||||
|
$center[$k] = $a / $b;
|
||||||
|
}
|
||||||
|
$cluster->setCoordinates($center);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected function getMembershipRowTotal(int $row, int $col, bool $multiply)
|
||||||
|
{
|
||||||
|
$sum = 0.0;
|
||||||
|
for ($k = 0; $k < $this->sampleCount; $k++) {
|
||||||
|
$val = pow($this->membership[$row][$k], $this->fuzziness);
|
||||||
|
if ($multiply) {
|
||||||
|
$val *= $this->samples[$k][$col];
|
||||||
|
}
|
||||||
|
$sum += $val;
|
||||||
|
}
|
||||||
|
return $sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected function updateMembershipMatrix()
|
||||||
|
{
|
||||||
|
for ($i = 0; $i < $this->clustersNumber; $i++) {
|
||||||
|
for ($k = 0; $k < $this->sampleCount; $k++) {
|
||||||
|
$distCalc = $this->getDistanceCalc($i, $k);
|
||||||
|
$this->membership[$i][$k] = 1.0 / $distCalc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param int $row
|
||||||
|
* @param int $col
|
||||||
|
* @return float
|
||||||
|
*/
|
||||||
|
protected function getDistanceCalc(int $row, int $col)
|
||||||
|
{
|
||||||
|
$sum = 0.0;
|
||||||
|
$distance = new Euclidean();
|
||||||
|
$dist1 = $distance->distance(
|
||||||
|
$this->clusters[$row]->getCoordinates(),
|
||||||
|
$this->samples[$col]);
|
||||||
|
for ($j = 0; $j < $this->clustersNumber; $j++) {
|
||||||
|
$dist2 = $distance->distance(
|
||||||
|
$this->clusters[$j]->getCoordinates(),
|
||||||
|
$this->samples[$col]);
|
||||||
|
$val = pow($dist1 / $dist2, 2.0 / ($this->fuzziness - 1));
|
||||||
|
$sum += $val;
|
||||||
|
}
|
||||||
|
return $sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The objective is to minimize the distance between all data points
|
||||||
|
* and all cluster centers. This method returns the summation of all
|
||||||
|
* these distances
|
||||||
|
*/
|
||||||
|
protected function getObjective()
|
||||||
|
{
|
||||||
|
$sum = 0.0;
|
||||||
|
$distance = new Euclidean();
|
||||||
|
for ($i = 0; $i < $this->clustersNumber; $i++) {
|
||||||
|
$clust = $this->clusters[$i]->getCoordinates();
|
||||||
|
for ($k = 0; $k < $this->sampleCount; $k++) {
|
||||||
|
$point = $this->samples[$k];
|
||||||
|
$sum += $distance->distance($clust, $point);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return $sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
public function getMembershipMatrix()
|
||||||
|
{
|
||||||
|
return $this->membership;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param array|Point[] $samples
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
public function cluster(array $samples)
|
||||||
|
{
|
||||||
|
// Initialize variables, clusters and membership matrix
|
||||||
|
$this->sampleCount = count($samples);
|
||||||
|
$this->samples =& $samples;
|
||||||
|
$this->space = new Space(count($samples[0]));
|
||||||
|
$this->initClusters();
|
||||||
|
|
||||||
|
// Our goal is minimizing the objective value while
|
||||||
|
// executing the clustering steps at a maximum number of iterations
|
||||||
|
$lastObjective = 0.0;
|
||||||
|
$difference = 0.0;
|
||||||
|
$iterations = 0;
|
||||||
|
do {
|
||||||
|
// Update the membership matrix and cluster centers, respectively
|
||||||
|
$this->updateMembershipMatrix();
|
||||||
|
$this->updateClusters();
|
||||||
|
|
||||||
|
// Calculate the new value of the objective function
|
||||||
|
$objectiveVal = $this->getObjective();
|
||||||
|
$difference = abs($lastObjective - $objectiveVal);
|
||||||
|
$lastObjective = $objectiveVal;
|
||||||
|
} while ($difference > $this->epsilon && $iterations++ <= $this->maxIterations);
|
||||||
|
|
||||||
|
// Attach (hard cluster) each data point to the nearest cluster
|
||||||
|
for ($k=0; $k<$this->sampleCount; $k++) {
|
||||||
|
$column = array_column($this->membership, $k);
|
||||||
|
arsort($column);
|
||||||
|
reset($column);
|
||||||
|
$i = key($column);
|
||||||
|
$cluster = $this->clusters[$i];
|
||||||
|
$cluster->attach(new Point($this->samples[$k]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return grouped samples
|
||||||
|
$grouped = [];
|
||||||
|
foreach ($this->clusters as $cluster) {
|
||||||
|
$grouped[] = $cluster->getPoints();
|
||||||
|
}
|
||||||
|
return $grouped;
|
||||||
|
}
|
||||||
|
}
|
60
tests/Phpml/Classification/DecisionTreeTest.php
Normal file
60
tests/Phpml/Classification/DecisionTreeTest.php
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace tests\Classification;
|
||||||
|
|
||||||
|
use Phpml\Classification\DecisionTree;
|
||||||
|
|
||||||
|
class DecisionTreeTest extends \PHPUnit_Framework_TestCase
|
||||||
|
{
|
||||||
|
public $data = [
|
||||||
|
['sunny', 85, 85, 'false', 'Dont_play' ],
|
||||||
|
['sunny', 80, 90, 'true', 'Dont_play' ],
|
||||||
|
['overcast', 83, 78, 'false', 'Play' ],
|
||||||
|
['rain', 70, 96, 'false', 'Play' ],
|
||||||
|
['rain', 68, 80, 'false', 'Play' ],
|
||||||
|
['rain', 65, 70, 'true', 'Dont_play' ],
|
||||||
|
['overcast', 64, 65, 'true', 'Play' ],
|
||||||
|
['sunny', 72, 95, 'false', 'Dont_play' ],
|
||||||
|
['sunny', 69, 70, 'false', 'Play' ],
|
||||||
|
['rain', 75, 80, 'false', 'Play' ],
|
||||||
|
['sunny', 75, 70, 'true', 'Play' ],
|
||||||
|
['overcast', 72, 90, 'true', 'Play' ],
|
||||||
|
['overcast', 81, 75, 'false', 'Play' ],
|
||||||
|
['rain', 71, 80, 'true', 'Dont_play' ]
|
||||||
|
];
|
||||||
|
|
||||||
|
public function getData()
|
||||||
|
{
|
||||||
|
static $data = null, $targets = null;
|
||||||
|
if ($data == null) {
|
||||||
|
$data = $this->data;
|
||||||
|
$targets = array_column($data, 4);
|
||||||
|
array_walk($data, function (&$v) {
|
||||||
|
array_splice($v, 4, 1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return [$data, $targets];
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testPredictSingleSample()
|
||||||
|
{
|
||||||
|
list($data, $targets) = $this->getData();
|
||||||
|
$classifier = new DecisionTree(5);
|
||||||
|
$classifier->train($data, $targets);
|
||||||
|
$this->assertEquals('Dont_play', $classifier->predict(['sunny', 78, 72, 'false']));
|
||||||
|
$this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false']));
|
||||||
|
$this->assertEquals('Dont_play', $classifier->predict(['rain', 60, 60, 'true']));
|
||||||
|
|
||||||
|
return $classifier;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testTreeDepth()
|
||||||
|
{
|
||||||
|
list($data, $targets) = $this->getData();
|
||||||
|
$classifier = new DecisionTree(5);
|
||||||
|
$classifier->train($data, $targets);
|
||||||
|
$this->assertTrue(5 >= $classifier->actualDepth);
|
||||||
|
}
|
||||||
|
}
|
43
tests/Phpml/Clustering/FuzzyCMeansTest.php
Normal file
43
tests/Phpml/Clustering/FuzzyCMeansTest.php
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
<?php
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace tests\Clustering;
|
||||||
|
|
||||||
|
use Phpml\Clustering\FuzzyCMeans;
|
||||||
|
|
||||||
|
class FuzzyCMeansTest extends \PHPUnit_Framework_TestCase
|
||||||
|
{
|
||||||
|
public function testFCMSamplesClustering()
|
||||||
|
{
|
||||||
|
$samples = [[1, 1], [8, 7], [1, 2], [7, 8], [2, 1], [8, 9]];
|
||||||
|
$fcm = new FuzzyCMeans(2);
|
||||||
|
$clusters = $fcm->cluster($samples);
|
||||||
|
$this->assertCount(2, $clusters);
|
||||||
|
foreach ($samples as $index => $sample) {
|
||||||
|
if (in_array($sample, $clusters[0]) || in_array($sample, $clusters[1])) {
|
||||||
|
unset($samples[$index]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
$this->assertCount(0, $samples);
|
||||||
|
return $fcm;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testMembershipMatrix()
|
||||||
|
{
|
||||||
|
$fcm = $this->testFCMSamplesClustering();
|
||||||
|
$clusterCount = 2;
|
||||||
|
$sampleCount = 6;
|
||||||
|
$matrix = $fcm->getMembershipMatrix();
|
||||||
|
$this->assertCount($clusterCount, $matrix);
|
||||||
|
foreach ($matrix as $row) {
|
||||||
|
$this->assertCount($sampleCount, $row);
|
||||||
|
}
|
||||||
|
// Transpose of the matrix
|
||||||
|
array_unshift($matrix, null);
|
||||||
|
$matrix = call_user_func_array('array_map', $matrix);
|
||||||
|
// All column totals should be equal to 1 (100% membership)
|
||||||
|
foreach ($matrix as $col) {
|
||||||
|
$this->assertEquals(1, array_sum($col));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user