diff --git a/src/Phpml/Classification/DecisionTree.php b/src/Phpml/Classification/DecisionTree.php
index 1a04802..6a860eb 100644
--- a/src/Phpml/Classification/DecisionTree.php
+++ b/src/Phpml/Classification/DecisionTree.php
@@ -56,6 +56,17 @@ class DecisionTree implements Classifier
*/
private $numUsableFeatures = 0;
+ /**
+ * @var array
+ */
+ private $featureImportances = null;
+
+ /**
+ *
+ * @var array
+ */
+ private $columnNames = null;
+
/**
* @param int $maxDepth
*/
@@ -76,6 +87,21 @@ class DecisionTree implements Classifier
$this->columnTypes = $this->getColumnTypes($this->samples);
$this->labels = array_keys(array_count_values($this->targets));
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
+
+ // Each time the tree is trained, feature importances are reset so that
+ // we will have to compute it again depending on the new data
+ $this->featureImportances = null;
+
+ // If column names are given or computed before, then there is no
+ // need to init it and accidentally remove the previous given names
+ if ($this->columnNames === null) {
+ $this->columnNames = range(0, $this->featureCount - 1);
+ } elseif (count($this->columnNames) > $this->featureCount) {
+ $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
+ } elseif (count($this->columnNames) < $this->featureCount) {
+ $this->columnNames = array_merge($this->columnNames,
+ range(count($this->columnNames), $this->featureCount - 1));
+ }
}
protected function getColumnTypes(array $samples)
@@ -164,6 +190,7 @@ class DecisionTree implements Classifier
$split->value = $baseValue;
$split->giniIndex = $gini;
$split->columnIndex = $i;
+ $split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS;
$split->records = $records;
$bestSplit = $split;
$bestGiniVal = $gini;
@@ -292,6 +319,25 @@ class DecisionTree implements Classifier
}
$this->numUsableFeatures = $numFeatures;
+
+ return $this;
+ }
+
+ /**
+ * A string array to represent columns. Useful when HTML output or
+ * column importances are desired to be inspected.
+ *
+ * @param array $names
+ * @return $this
+ */
+ public function setColumnNames(array $names)
+ {
+ if ($this->featureCount != 0 && count($names) != $this->featureCount) {
+ throw new \Exception("Length of the given array should be equal to feature count ($this->featureCount)");
+ }
+
+ $this->columnNames = $names;
+
return $this;
}
@@ -300,7 +346,80 @@ class DecisionTree implements Classifier
*/
public function getHtml()
{
- return $this->tree->__toString();
+ return $this->tree->getHTML($this->columnNames);
+ }
+
+ /**
+ * This will return an array including an importance value for
+ * each column in the given dataset. The importance values are
+ * normalized and their total makes 1.
+ *
+ * @param array $labels
+ * @return array
+ */
+ public function getFeatureImportances()
+ {
+ if ($this->featureImportances !== null) {
+ return $this->featureImportances;
+ }
+
+ $sampleCount = count($this->samples);
+ $this->featureImportances = [];
+ foreach ($this->columnNames as $column => $columnName) {
+ $nodes = $this->getSplitNodesByColumn($column, $this->tree);
+
+ $importance = 0;
+ foreach ($nodes as $node) {
+ $importance += $node->getNodeImpurityDecrease($sampleCount);
+ }
+
+ $this->featureImportances[$columnName] = $importance;
+ }
+
+ // Normalize & sort the importances
+ $total = array_sum($this->featureImportances);
+ if ($total > 0) {
+ foreach ($this->featureImportances as &$importance) {
+ $importance /= $total;
+ }
+ arsort($this->featureImportances);
+ }
+
+ return $this->featureImportances;
+ }
+
+ /**
+ * Collects and returns an array of internal nodes that use the given
+ * column as a split criteron
+ *
+ * @param int $column
+ * @param DecisionTreeLeaf
+ * @param array $collected
+ *
+ * @return array
+ */
+ protected function getSplitNodesByColumn($column, DecisionTreeLeaf $node)
+ {
+ if (!$node || $node->isTerminal) {
+ return [];
+ }
+
+ $nodes = [];
+ if ($node->columnIndex == $column) {
+ $nodes[] = $node;
+ }
+
+ $lNodes = [];
+ $rNodes = [];
+ if ($node->leftLeaf) {
+ $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
+ }
+ if ($node->rightLeaf) {
+ $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
+ }
+ $nodes = array_merge($nodes, $lNodes, $rNodes);
+
+ return $nodes;
}
/**
diff --git a/src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php b/src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
index 1993864..e30fc10 100644
--- a/src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
+++ b/src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
@@ -6,7 +6,6 @@ namespace Phpml\Classification\DecisionTree;
class DecisionTreeLeaf
{
- const OPERATOR_EQ = '=';
/**
* @var string
*/
@@ -45,6 +44,11 @@ class DecisionTreeLeaf
*/
public $isTerminal = false;
+ /**
+ * @var bool
+ */
+ public $isContinuous = false;
+
/**
* @var float
*/
@@ -62,7 +66,7 @@ class DecisionTreeLeaf
public function evaluate($record)
{
$recordField = $record[$this->columnIndex];
- if (is_string($this->value) && preg_match("/^([<>=]{1,2})\s*(.*)/", $this->value, $matches)) {
+ if ($this->isContinuous && preg_match("/^([<>=]{1,2})\s*(.*)/", strval($this->value), $matches)) {
$op = $matches[1];
$value= floatval($matches[2]);
$recordField = strval($recordField);
@@ -72,13 +76,51 @@ class DecisionTreeLeaf
return $recordField == $this->value;
}
- public function __toString()
+ /**
+ * Returns Mean Decrease Impurity (MDI) in the node.
+ * For terminal nodes, this value is equal to 0
+ *
+ * @return float
+ */
+ public function getNodeImpurityDecrease(int $parentRecordCount)
+ {
+ if ($this->isTerminal) {
+ return 0.0;
+ }
+
+ $nodeSampleCount = (float)count($this->records);
+ $iT = $this->giniIndex;
+
+ if ($this->leftLeaf) {
+ $pL = count($this->leftLeaf->records)/$nodeSampleCount;
+ $iT -= $pL * $this->leftLeaf->giniIndex;
+ }
+
+ if ($this->rightLeaf) {
+ $pR = count($this->rightLeaf->records)/$nodeSampleCount;
+ $iT -= $pR * $this->rightLeaf->giniIndex;
+ }
+
+ return $iT * $nodeSampleCount / $parentRecordCount;
+ }
+
+ /**
+ * Returns HTML representation of the node including children nodes
+ *
+ * @param $columnNames
+ * @return string
+ */
+ public function getHTML($columnNames = null)
{
if ($this->isTerminal) {
$value = "$this->classValue";
} else {
$value = $this->value;
- $col = "col_$this->columnIndex";
+ if ($columnNames !== null) {
+ $col = $columnNames[$this->columnIndex];
+ } else {
+ $col = "col_$this->columnIndex";
+ }
if (! preg_match("/^[<>=]{1,2}/", $value)) {
$value = "=$value";
}
@@ -89,13 +131,13 @@ class DecisionTreeLeaf
if ($this->leftLeaf || $this->rightLeaf) {
$str .='