diff --git a/src/Phpml/Classification/DecisionTree.php b/src/Phpml/Classification/DecisionTree.php index 1a04802..6a860eb 100644 --- a/src/Phpml/Classification/DecisionTree.php +++ b/src/Phpml/Classification/DecisionTree.php @@ -56,6 +56,17 @@ class DecisionTree implements Classifier */ private $numUsableFeatures = 0; + /** + * @var array + */ + private $featureImportances = null; + + /** + * + * @var array + */ + private $columnNames = null; + /** * @param int $maxDepth */ @@ -76,6 +87,21 @@ class DecisionTree implements Classifier $this->columnTypes = $this->getColumnTypes($this->samples); $this->labels = array_keys(array_count_values($this->targets)); $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1)); + + // Each time the tree is trained, feature importances are reset so that + // we will have to compute it again depending on the new data + $this->featureImportances = null; + + // If column names are given or computed before, then there is no + // need to init it and accidentally remove the previous given names + if ($this->columnNames === null) { + $this->columnNames = range(0, $this->featureCount - 1); + } elseif (count($this->columnNames) > $this->featureCount) { + $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount); + } elseif (count($this->columnNames) < $this->featureCount) { + $this->columnNames = array_merge($this->columnNames, + range(count($this->columnNames), $this->featureCount - 1)); + } } protected function getColumnTypes(array $samples) @@ -164,6 +190,7 @@ class DecisionTree implements Classifier $split->value = $baseValue; $split->giniIndex = $gini; $split->columnIndex = $i; + $split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS; $split->records = $records; $bestSplit = $split; $bestGiniVal = $gini; @@ -292,6 +319,25 @@ class DecisionTree implements Classifier } $this->numUsableFeatures = $numFeatures; + + return $this; + } + + /** + * A string array to represent columns. Useful when HTML output or + * column importances are desired to be inspected. + * + * @param array $names + * @return $this + */ + public function setColumnNames(array $names) + { + if ($this->featureCount != 0 && count($names) != $this->featureCount) { + throw new \Exception("Length of the given array should be equal to feature count ($this->featureCount)"); + } + + $this->columnNames = $names; + return $this; } @@ -300,7 +346,80 @@ class DecisionTree implements Classifier */ public function getHtml() { - return $this->tree->__toString(); + return $this->tree->getHTML($this->columnNames); + } + + /** + * This will return an array including an importance value for + * each column in the given dataset. The importance values are + * normalized and their total makes 1.
+ * + * @param array $labels + * @return array + */ + public function getFeatureImportances() + { + if ($this->featureImportances !== null) { + return $this->featureImportances; + } + + $sampleCount = count($this->samples); + $this->featureImportances = []; + foreach ($this->columnNames as $column => $columnName) { + $nodes = $this->getSplitNodesByColumn($column, $this->tree); + + $importance = 0; + foreach ($nodes as $node) { + $importance += $node->getNodeImpurityDecrease($sampleCount); + } + + $this->featureImportances[$columnName] = $importance; + } + + // Normalize & sort the importances + $total = array_sum($this->featureImportances); + if ($total > 0) { + foreach ($this->featureImportances as &$importance) { + $importance /= $total; + } + arsort($this->featureImportances); + } + + return $this->featureImportances; + } + + /** + * Collects and returns an array of internal nodes that use the given + * column as a split criteron + * + * @param int $column + * @param DecisionTreeLeaf + * @param array $collected + * + * @return array + */ + protected function getSplitNodesByColumn($column, DecisionTreeLeaf $node) + { + if (!$node || $node->isTerminal) { + return []; + } + + $nodes = []; + if ($node->columnIndex == $column) { + $nodes[] = $node; + } + + $lNodes = []; + $rNodes = []; + if ($node->leftLeaf) { + $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf); + } + if ($node->rightLeaf) { + $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf); + } + $nodes = array_merge($nodes, $lNodes, $rNodes); + + return $nodes; } /** diff --git a/src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php b/src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php index 1993864..e30fc10 100644 --- a/src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php +++ b/src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php @@ -6,7 +6,6 @@ namespace Phpml\Classification\DecisionTree; class DecisionTreeLeaf { - const OPERATOR_EQ = '='; /** * @var string */ @@ -45,6 +44,11 @@ class DecisionTreeLeaf */ public $isTerminal = false; + /** + * @var bool + */ + public $isContinuous = false; + /** * @var float */ @@ -62,7 +66,7 @@ class DecisionTreeLeaf public function evaluate($record) { $recordField = $record[$this->columnIndex]; - if (is_string($this->value) && preg_match("/^([<>=]{1,2})\s*(.*)/", $this->value, $matches)) { + if ($this->isContinuous && preg_match("/^([<>=]{1,2})\s*(.*)/", strval($this->value), $matches)) { $op = $matches[1]; $value= floatval($matches[2]); $recordField = strval($recordField); @@ -72,13 +76,51 @@ class DecisionTreeLeaf return $recordField == $this->value; } - public function __toString() + /** + * Returns Mean Decrease Impurity (MDI) in the node. + * For terminal nodes, this value is equal to 0 + * + * @return float + */ + public function getNodeImpurityDecrease(int $parentRecordCount) + { + if ($this->isTerminal) { + return 0.0; + } + + $nodeSampleCount = (float)count($this->records); + $iT = $this->giniIndex; + + if ($this->leftLeaf) { + $pL = count($this->leftLeaf->records)/$nodeSampleCount; + $iT -= $pL * $this->leftLeaf->giniIndex; + } + + if ($this->rightLeaf) { + $pR = count($this->rightLeaf->records)/$nodeSampleCount; + $iT -= $pR * $this->rightLeaf->giniIndex; + } + + return $iT * $nodeSampleCount / $parentRecordCount; + } + + /** + * Returns HTML representation of the node including children nodes + * + * @param $columnNames + * @return string + */ + public function getHTML($columnNames = null) { if ($this->isTerminal) { $value = "$this->classValue"; } else { $value = $this->value; - $col = "col_$this->columnIndex"; + if ($columnNames !== null) { + $col = $columnNames[$this->columnIndex]; + } else { + $col = "col_$this->columnIndex"; + } if (! preg_match("/^[<>=]{1,2}/", $value)) { $value = "=$value"; } @@ -89,13 +131,13 @@ class DecisionTreeLeaf if ($this->leftLeaf || $this->rightLeaf) { $str .=''; if ($this->leftLeaf) { - $str .="| Yes
$this->leftLeaf"; + $str .="| Yes
" . $this->leftLeaf->getHTML($columnNames) . ""; } else { $str .=''; } $str .=' '; if ($this->rightLeaf) { - $str .="No |
$this->rightLeaf"; + $str .="No |
" . $this->rightLeaf->getHTML($columnNames) . ""; } else { $str .=''; } @@ -104,4 +146,14 @@ class DecisionTreeLeaf $str .= ''; return $str; } + + /** + * HTML representation of the tree without column names + * + * @return string + */ + public function __toString() + { + return $this->getHTML(); + } } diff --git a/src/Phpml/Classification/Ensemble/Bagging.php b/src/Phpml/Classification/Ensemble/Bagging.php index 817869e..d579b24 100644 --- a/src/Phpml/Classification/Ensemble/Bagging.php +++ b/src/Phpml/Classification/Ensemble/Bagging.php @@ -53,7 +53,7 @@ class Bagging implements Classifier /** * @var float */ - protected $subsetRatio = 0.5; + protected $subsetRatio = 0.7; /** * @var array @@ -120,7 +120,7 @@ class Bagging implements Classifier $this->featureCount = count($samples[0]); $this->numSamples = count($this->samples); - // Init classifiers and train them with random sub-samples + // Init classifiers and train them with bootstrap samples $this->classifiers = $this->initClassifiers(); $index = 0; foreach ($this->classifiers as $classifier) { @@ -134,16 +134,14 @@ class Bagging implements Classifier * @param int $index * @return array */ - protected function getRandomSubset($index) + protected function getRandomSubset(int $index) { - $subsetLength = (int)ceil(sqrt($this->numSamples)); - $denom = $this->subsetRatio / 2; - $subsetLength = $this->numSamples / (1 / $denom); - $index = $index * $subsetLength % $this->numSamples; $samples = []; $targets = []; - for ($i=0; $i<$subsetLength * 2; $i++) { - $rand = rand($index, $this->numSamples - 1); + srand($index); + $bootstrapSize = $this->subsetRatio * $this->numSamples; + for ($i=0; $i < $bootstrapSize; $i++) { + $rand = rand(0, $this->numSamples - 1); $samples[] = $this->samples[$rand]; $targets[] = $this->targets[$rand]; } diff --git a/src/Phpml/Classification/Ensemble/RandomForest.php b/src/Phpml/Classification/Ensemble/RandomForest.php index 37df7ae..025badf 100644 --- a/src/Phpml/Classification/Ensemble/RandomForest.php +++ b/src/Phpml/Classification/Ensemble/RandomForest.php @@ -16,6 +16,18 @@ class RandomForest extends Bagging */ protected $featureSubsetRatio = 'log'; + /** + * @var array + */ + protected $columnNames = null; + + /** + * Initializes RandomForest with the given number of trees. More trees + * may increase the prediction performance while it will also substantially + * increase the processing time and the required memory + * + * @param type $numClassifier + */ public function __construct($numClassifier = 50) { parent::__construct($numClassifier); @@ -24,14 +36,13 @@ class RandomForest extends Bagging } /** - * This method is used to determine how much of the original columns (features) + * This method is used to determine how many of the original columns (features) * will be used to construct subsets to train base classifiers.
* * Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0
* - * If there are many features that diminishes classification performance, then - * small values should be preferred, otherwise, with low number of features, - * default value (0.7) will result in satisfactory performance. + * Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1 + * features to be taken into consideration while selecting subspace of features * * @param mixed $ratio string or float should be given * @return $this @@ -65,6 +76,55 @@ class RandomForest extends Bagging return parent::setClassifer($classifier, $classifierOptions); } + /** + * This will return an array including an importance value for + * each column in the given dataset. Importance values for a column + * is the average importance of that column in all trees in the forest + * + * @return array + */ + public function getFeatureImportances() + { + // Traverse each tree and sum importance of the columns + $sum = []; + foreach ($this->classifiers as $tree) { + /* @var $tree DecisionTree */ + $importances = $tree->getFeatureImportances(); + + foreach ($importances as $column => $importance) { + if (array_key_exists($column, $sum)) { + $sum[$column] += $importance; + } else { + $sum[$column] = $importance; + } + } + } + + // Normalize & sort the importance values + $total = array_sum($sum); + foreach ($sum as &$importance) { + $importance /= $total; + } + + arsort($sum); + + return $sum; + } + + /** + * A string array to represent the columns is given. They are useful + * when trying to print some information about the trees such as feature importances + * + * @param array $names + * @return $this + */ + public function setColumnNames(array $names) + { + $this->columnNames = $names; + + return $this; + } + /** * @param DecisionTree $classifier * @param int $index @@ -84,6 +144,12 @@ class RandomForest extends Bagging $featureCount = $this->featureCount; } - return $classifier->setNumFeatures($featureCount); + if ($this->columnNames === null) { + $this->columnNames = range(0, $this->featureCount - 1); + } + + return $classifier + ->setColumnNames($this->columnNames) + ->setNumFeatures($featureCount); } } diff --git a/src/Phpml/Dataset/CsvDataset.php b/src/Phpml/Dataset/CsvDataset.php index dd722d4..483b1af 100644 --- a/src/Phpml/Dataset/CsvDataset.php +++ b/src/Phpml/Dataset/CsvDataset.php @@ -8,6 +8,11 @@ use Phpml\Exception\FileException; class CsvDataset extends ArrayDataset { + /** + * @var array + */ + protected $columnNames; + /** * @param string $filepath * @param int $features @@ -26,7 +31,10 @@ class CsvDataset extends ArrayDataset } if ($headingRow) { - fgets($handle); + $data = fgetcsv($handle, 1000, ','); + $this->columnNames = array_slice($data, 0, $features); + } else { + $this->columnNames = range(0, $features - 1); } while (($data = fgetcsv($handle, 1000, ',')) !== false) { @@ -35,4 +43,12 @@ class CsvDataset extends ArrayDataset } fclose($handle); } + + /** + * @return array + */ + public function getColumnNames() + { + return $this->columnNames; + } }