2017-01-31 19:27:15 +00:00
|
|
|
<?php
|
|
|
|
|
|
|
|
declare(strict_types=1);
|
|
|
|
|
|
|
|
namespace Phpml\Classification;
|
|
|
|
|
|
|
|
use Phpml\Helper\Predictable;
|
|
|
|
use Phpml\Helper\Trainable;
|
|
|
|
use Phpml\Math\Statistic\Mean;
|
|
|
|
use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
|
|
|
|
|
|
|
|
class DecisionTree implements Classifier
|
|
|
|
{
|
|
|
|
use Trainable, Predictable;
|
|
|
|
|
|
|
|
const CONTINUOS = 1;
|
|
|
|
const NOMINAL = 2;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
2017-01-31 19:33:08 +00:00
|
|
|
private $samples = [];
|
2017-01-31 19:27:15 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $columnTypes;
|
2017-01-31 19:33:08 +00:00
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
2017-01-31 19:33:08 +00:00
|
|
|
private $labels = [];
|
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
/**
|
|
|
|
* @var int
|
|
|
|
*/
|
|
|
|
private $featureCount = 0;
|
2017-01-31 19:33:08 +00:00
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
/**
|
|
|
|
* @var DecisionTreeLeaf
|
|
|
|
*/
|
|
|
|
private $tree = null;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var int
|
|
|
|
*/
|
|
|
|
private $maxDepth;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var int
|
|
|
|
*/
|
|
|
|
public $actualDepth = 0;
|
|
|
|
|
2017-02-07 11:37:56 +00:00
|
|
|
/**
|
|
|
|
* @var int
|
|
|
|
*/
|
|
|
|
private $numUsableFeatures = 0;
|
|
|
|
|
2017-02-13 20:23:18 +00:00
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $featureImportances = null;
|
|
|
|
|
|
|
|
/**
|
|
|
|
*
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
private $columnNames = null;
|
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
/**
|
|
|
|
* @param int $maxDepth
|
|
|
|
*/
|
|
|
|
public function __construct($maxDepth = 10)
|
|
|
|
{
|
|
|
|
$this->maxDepth = $maxDepth;
|
|
|
|
}
|
|
|
|
/**
|
|
|
|
* @param array $samples
|
|
|
|
* @param array $targets
|
|
|
|
*/
|
|
|
|
public function train(array $samples, array $targets)
|
|
|
|
{
|
2017-02-01 18:06:38 +00:00
|
|
|
$this->samples = array_merge($this->samples, $samples);
|
|
|
|
$this->targets = array_merge($this->targets, $targets);
|
|
|
|
|
|
|
|
$this->featureCount = count($this->samples[0]);
|
|
|
|
$this->columnTypes = $this->getColumnTypes($this->samples);
|
|
|
|
$this->labels = array_keys(array_count_values($this->targets));
|
|
|
|
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
|
2017-02-13 20:23:18 +00:00
|
|
|
|
|
|
|
// Each time the tree is trained, feature importances are reset so that
|
|
|
|
// we will have to compute it again depending on the new data
|
|
|
|
$this->featureImportances = null;
|
|
|
|
|
|
|
|
// If column names are given or computed before, then there is no
|
|
|
|
// need to init it and accidentally remove the previous given names
|
|
|
|
if ($this->columnNames === null) {
|
|
|
|
$this->columnNames = range(0, $this->featureCount - 1);
|
|
|
|
} elseif (count($this->columnNames) > $this->featureCount) {
|
|
|
|
$this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
|
|
|
|
} elseif (count($this->columnNames) < $this->featureCount) {
|
|
|
|
$this->columnNames = array_merge($this->columnNames,
|
|
|
|
range(count($this->columnNames), $this->featureCount - 1));
|
|
|
|
}
|
2017-01-31 19:27:15 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
protected function getColumnTypes(array $samples)
|
|
|
|
{
|
|
|
|
$types = [];
|
|
|
|
for ($i=0; $i<$this->featureCount; $i++) {
|
|
|
|
$values = array_column($samples, $i);
|
|
|
|
$isCategorical = $this->isCategoricalColumn($values);
|
|
|
|
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS;
|
|
|
|
}
|
|
|
|
return $types;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @param null|array $records
|
|
|
|
* @return DecisionTreeLeaf
|
|
|
|
*/
|
|
|
|
protected function getSplitLeaf($records, $depth = 0)
|
|
|
|
{
|
|
|
|
$split = $this->getBestSplit($records);
|
|
|
|
$split->level = $depth;
|
|
|
|
if ($this->actualDepth < $depth) {
|
|
|
|
$this->actualDepth = $depth;
|
|
|
|
}
|
|
|
|
$leftRecords = [];
|
|
|
|
$rightRecords= [];
|
|
|
|
$remainingTargets = [];
|
|
|
|
$prevRecord = null;
|
|
|
|
$allSame = true;
|
|
|
|
foreach ($records as $recordNo) {
|
|
|
|
$record = $this->samples[$recordNo];
|
|
|
|
if ($prevRecord && $prevRecord != $record) {
|
|
|
|
$allSame = false;
|
|
|
|
}
|
|
|
|
$prevRecord = $record;
|
|
|
|
if ($split->evaluate($record)) {
|
|
|
|
$leftRecords[] = $recordNo;
|
|
|
|
} else {
|
|
|
|
$rightRecords[]= $recordNo;
|
|
|
|
}
|
|
|
|
$target = $this->targets[$recordNo];
|
|
|
|
if (! in_array($target, $remainingTargets)) {
|
|
|
|
$remainingTargets[] = $target;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (count($remainingTargets) == 1 || $allSame || $depth >= $this->maxDepth) {
|
|
|
|
$split->isTerminal = 1;
|
|
|
|
$classes = array_count_values($remainingTargets);
|
|
|
|
arsort($classes);
|
|
|
|
$split->classValue = key($classes);
|
|
|
|
} else {
|
|
|
|
if ($leftRecords) {
|
|
|
|
$split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
|
|
|
|
}
|
|
|
|
if ($rightRecords) {
|
|
|
|
$split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return $split;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @param array $records
|
|
|
|
* @return DecisionTreeLeaf[]
|
|
|
|
*/
|
|
|
|
protected function getBestSplit($records)
|
|
|
|
{
|
|
|
|
$targets = array_intersect_key($this->targets, array_flip($records));
|
|
|
|
$samples = array_intersect_key($this->samples, array_flip($records));
|
|
|
|
$samples = array_combine($records, $this->preprocess($samples));
|
|
|
|
$bestGiniVal = 1;
|
|
|
|
$bestSplit = null;
|
2017-02-07 11:37:56 +00:00
|
|
|
$features = $this->getSelectedFeatures();
|
|
|
|
foreach ($features as $i) {
|
2017-01-31 19:27:15 +00:00
|
|
|
$colValues = [];
|
|
|
|
foreach ($samples as $index => $row) {
|
|
|
|
$colValues[$index] = $row[$i];
|
|
|
|
}
|
2017-02-07 11:37:56 +00:00
|
|
|
$counts = array_count_values($colValues);
|
|
|
|
arsort($counts);
|
|
|
|
$baseValue = key($counts);
|
2017-01-31 19:27:15 +00:00
|
|
|
$gini = $this->getGiniIndex($baseValue, $colValues, $targets);
|
|
|
|
if ($bestSplit == null || $bestGiniVal > $gini) {
|
|
|
|
$split = new DecisionTreeLeaf();
|
|
|
|
$split->value = $baseValue;
|
|
|
|
$split->giniIndex = $gini;
|
|
|
|
$split->columnIndex = $i;
|
2017-02-13 20:23:18 +00:00
|
|
|
$split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS;
|
2017-01-31 19:27:15 +00:00
|
|
|
$split->records = $records;
|
|
|
|
$bestSplit = $split;
|
|
|
|
$bestGiniVal = $gini;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return $bestSplit;
|
|
|
|
}
|
|
|
|
|
2017-02-07 11:37:56 +00:00
|
|
|
/**
|
|
|
|
* @return array
|
|
|
|
*/
|
|
|
|
protected function getSelectedFeatures()
|
|
|
|
{
|
|
|
|
$allFeatures = range(0, $this->featureCount - 1);
|
|
|
|
if ($this->numUsableFeatures == 0) {
|
|
|
|
return $allFeatures;
|
|
|
|
}
|
|
|
|
|
|
|
|
$numFeatures = $this->numUsableFeatures;
|
|
|
|
if ($numFeatures > $this->featureCount) {
|
|
|
|
$numFeatures = $this->featureCount;
|
|
|
|
}
|
|
|
|
shuffle($allFeatures);
|
|
|
|
$selectedFeatures = array_slice($allFeatures, 0, $numFeatures, false);
|
|
|
|
sort($selectedFeatures);
|
|
|
|
|
|
|
|
return $selectedFeatures;
|
|
|
|
}
|
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
/**
|
|
|
|
* @param string $baseValue
|
|
|
|
* @param array $colValues
|
|
|
|
* @param array $targets
|
|
|
|
*/
|
|
|
|
public function getGiniIndex($baseValue, $colValues, $targets)
|
|
|
|
{
|
|
|
|
$countMatrix = [];
|
|
|
|
foreach ($this->labels as $label) {
|
|
|
|
$countMatrix[$label] = [0, 0];
|
|
|
|
}
|
|
|
|
foreach ($colValues as $index => $value) {
|
|
|
|
$label = $targets[$index];
|
|
|
|
$rowIndex = $value == $baseValue ? 0 : 1;
|
|
|
|
$countMatrix[$label][$rowIndex]++;
|
|
|
|
}
|
|
|
|
$giniParts = [0, 0];
|
|
|
|
for ($i=0; $i<=1; $i++) {
|
|
|
|
$part = 0;
|
|
|
|
$sum = array_sum(array_column($countMatrix, $i));
|
|
|
|
if ($sum > 0) {
|
|
|
|
foreach ($this->labels as $label) {
|
|
|
|
$part += pow($countMatrix[$label][$i] / floatval($sum), 2);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
$giniParts[$i] = (1 - $part) * $sum;
|
|
|
|
}
|
|
|
|
return array_sum($giniParts) / count($colValues);
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @param array $samples
|
|
|
|
* @return array
|
|
|
|
*/
|
|
|
|
protected function preprocess(array $samples)
|
|
|
|
{
|
|
|
|
// Detect and convert continuous data column values into
|
|
|
|
// discrete values by using the median as a threshold value
|
2017-01-31 19:33:08 +00:00
|
|
|
$columns = [];
|
2017-01-31 19:27:15 +00:00
|
|
|
for ($i=0; $i<$this->featureCount; $i++) {
|
|
|
|
$values = array_column($samples, $i);
|
|
|
|
if ($this->columnTypes[$i] == self::CONTINUOS) {
|
|
|
|
$median = Mean::median($values);
|
|
|
|
foreach ($values as &$value) {
|
|
|
|
if ($value <= $median) {
|
|
|
|
$value = "<= $median";
|
|
|
|
} else {
|
|
|
|
$value = "> $median";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
$columns[] = $values;
|
|
|
|
}
|
|
|
|
// Below method is a strange yet very simple & efficient method
|
|
|
|
// to get the transpose of a 2D array
|
|
|
|
return array_map(null, ...$columns);
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @param array $columnValues
|
|
|
|
* @return bool
|
|
|
|
*/
|
|
|
|
protected function isCategoricalColumn(array $columnValues)
|
|
|
|
{
|
|
|
|
$count = count($columnValues);
|
|
|
|
// There are two main indicators that *may* show whether a
|
|
|
|
// column is composed of discrete set of values:
|
|
|
|
// 1- Column may contain string values
|
|
|
|
// 2- Number of unique values in the column is only a small fraction of
|
|
|
|
// all values in that column (Lower than or equal to %20 of all values)
|
|
|
|
$numericValues = array_filter($columnValues, 'is_numeric');
|
|
|
|
if (count($numericValues) != $count) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
$distinctValues = array_count_values($columnValues);
|
|
|
|
if (count($distinctValues) <= $count / 5) {
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2017-02-07 11:37:56 +00:00
|
|
|
/**
|
|
|
|
* This method is used to set number of columns to be used
|
|
|
|
* when deciding a split at an internal node of the tree. <br>
|
|
|
|
* If the value is given 0, then all features are used (default behaviour),
|
|
|
|
* otherwise the given value will be used as a maximum for number of columns
|
|
|
|
* randomly selected for each split operation.
|
|
|
|
*
|
|
|
|
* @param int $numFeatures
|
|
|
|
* @return $this
|
|
|
|
* @throws Exception
|
|
|
|
*/
|
|
|
|
public function setNumFeatures(int $numFeatures)
|
|
|
|
{
|
|
|
|
if ($numFeatures < 0) {
|
|
|
|
throw new \Exception("Selected column count should be greater or equal to zero");
|
|
|
|
}
|
|
|
|
|
|
|
|
$this->numUsableFeatures = $numFeatures;
|
2017-02-13 20:23:18 +00:00
|
|
|
|
|
|
|
return $this;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* A string array to represent columns. Useful when HTML output or
|
|
|
|
* column importances are desired to be inspected.
|
|
|
|
*
|
|
|
|
* @param array $names
|
|
|
|
* @return $this
|
|
|
|
*/
|
|
|
|
public function setColumnNames(array $names)
|
|
|
|
{
|
|
|
|
if ($this->featureCount != 0 && count($names) != $this->featureCount) {
|
|
|
|
throw new \Exception("Length of the given array should be equal to feature count ($this->featureCount)");
|
|
|
|
}
|
|
|
|
|
|
|
|
$this->columnNames = $names;
|
|
|
|
|
2017-02-07 11:37:56 +00:00
|
|
|
return $this;
|
|
|
|
}
|
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
/**
|
|
|
|
* @return string
|
|
|
|
*/
|
|
|
|
public function getHtml()
|
|
|
|
{
|
2017-02-13 20:23:18 +00:00
|
|
|
return $this->tree->getHTML($this->columnNames);
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* This will return an array including an importance value for
|
|
|
|
* each column in the given dataset. The importance values are
|
|
|
|
* normalized and their total makes 1.<br/>
|
|
|
|
*
|
|
|
|
* @param array $labels
|
|
|
|
* @return array
|
|
|
|
*/
|
|
|
|
public function getFeatureImportances()
|
|
|
|
{
|
|
|
|
if ($this->featureImportances !== null) {
|
|
|
|
return $this->featureImportances;
|
|
|
|
}
|
|
|
|
|
|
|
|
$sampleCount = count($this->samples);
|
|
|
|
$this->featureImportances = [];
|
|
|
|
foreach ($this->columnNames as $column => $columnName) {
|
|
|
|
$nodes = $this->getSplitNodesByColumn($column, $this->tree);
|
|
|
|
|
|
|
|
$importance = 0;
|
|
|
|
foreach ($nodes as $node) {
|
|
|
|
$importance += $node->getNodeImpurityDecrease($sampleCount);
|
|
|
|
}
|
|
|
|
|
|
|
|
$this->featureImportances[$columnName] = $importance;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Normalize & sort the importances
|
|
|
|
$total = array_sum($this->featureImportances);
|
|
|
|
if ($total > 0) {
|
|
|
|
foreach ($this->featureImportances as &$importance) {
|
|
|
|
$importance /= $total;
|
|
|
|
}
|
|
|
|
arsort($this->featureImportances);
|
|
|
|
}
|
|
|
|
|
|
|
|
return $this->featureImportances;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Collects and returns an array of internal nodes that use the given
|
|
|
|
* column as a split criteron
|
|
|
|
*
|
|
|
|
* @param int $column
|
|
|
|
* @param DecisionTreeLeaf
|
|
|
|
* @param array $collected
|
|
|
|
*
|
|
|
|
* @return array
|
|
|
|
*/
|
|
|
|
protected function getSplitNodesByColumn($column, DecisionTreeLeaf $node)
|
|
|
|
{
|
|
|
|
if (!$node || $node->isTerminal) {
|
|
|
|
return [];
|
|
|
|
}
|
|
|
|
|
|
|
|
$nodes = [];
|
|
|
|
if ($node->columnIndex == $column) {
|
|
|
|
$nodes[] = $node;
|
|
|
|
}
|
|
|
|
|
|
|
|
$lNodes = [];
|
|
|
|
$rNodes = [];
|
|
|
|
if ($node->leftLeaf) {
|
|
|
|
$lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
|
|
|
|
}
|
|
|
|
if ($node->rightLeaf) {
|
|
|
|
$rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
|
|
|
|
}
|
|
|
|
$nodes = array_merge($nodes, $lNodes, $rNodes);
|
|
|
|
|
|
|
|
return $nodes;
|
2017-01-31 19:27:15 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @param array $sample
|
|
|
|
* @return mixed
|
|
|
|
*/
|
|
|
|
protected function predictSample(array $sample)
|
|
|
|
{
|
|
|
|
$node = $this->tree;
|
|
|
|
do {
|
|
|
|
if ($node->isTerminal) {
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
if ($node->evaluate($sample)) {
|
|
|
|
$node = $node->leftLeaf;
|
|
|
|
} else {
|
|
|
|
$node = $node->rightLeaf;
|
|
|
|
}
|
|
|
|
} while ($node);
|
2017-02-07 11:37:56 +00:00
|
|
|
|
|
|
|
return $node ? $node->classValue : $this->labels[0];
|
2017-01-31 19:27:15 +00:00
|
|
|
}
|
|
|
|
}
|