RandomForest::getFeatureImportances() method (#47)

* RandomForest::getFeatureImportances() method

* CsvDataset update for column names
This commit is contained in:
Mustafa Karabulut 2017-02-13 23:23:18 +03:00 committed by Arkadiusz Kondas
parent 240a22788f
commit a33d5fe9c8
5 changed files with 273 additions and 22 deletions

View File

@ -56,6 +56,17 @@ class DecisionTree implements Classifier
*/
private $numUsableFeatures = 0;
/**
* @var array
*/
private $featureImportances = null;
/**
*
* @var array
*/
private $columnNames = null;
/**
* @param int $maxDepth
*/
@ -76,6 +87,21 @@ class DecisionTree implements Classifier
$this->columnTypes = $this->getColumnTypes($this->samples);
$this->labels = array_keys(array_count_values($this->targets));
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
// Each time the tree is trained, feature importances are reset so that
// we will have to compute it again depending on the new data
$this->featureImportances = null;
// If column names are given or computed before, then there is no
// need to init it and accidentally remove the previous given names
if ($this->columnNames === null) {
$this->columnNames = range(0, $this->featureCount - 1);
} elseif (count($this->columnNames) > $this->featureCount) {
$this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
} elseif (count($this->columnNames) < $this->featureCount) {
$this->columnNames = array_merge($this->columnNames,
range(count($this->columnNames), $this->featureCount - 1));
}
}
protected function getColumnTypes(array $samples)
@ -164,6 +190,7 @@ class DecisionTree implements Classifier
$split->value = $baseValue;
$split->giniIndex = $gini;
$split->columnIndex = $i;
$split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS;
$split->records = $records;
$bestSplit = $split;
$bestGiniVal = $gini;
@ -292,6 +319,25 @@ class DecisionTree implements Classifier
}
$this->numUsableFeatures = $numFeatures;
return $this;
}
/**
* A string array to represent columns. Useful when HTML output or
* column importances are desired to be inspected.
*
* @param array $names
* @return $this
*/
public function setColumnNames(array $names)
{
if ($this->featureCount != 0 && count($names) != $this->featureCount) {
throw new \Exception("Length of the given array should be equal to feature count ($this->featureCount)");
}
$this->columnNames = $names;
return $this;
}
@ -300,7 +346,80 @@ class DecisionTree implements Classifier
*/
public function getHtml()
{
return $this->tree->__toString();
return $this->tree->getHTML($this->columnNames);
}
/**
* This will return an array including an importance value for
* each column in the given dataset. The importance values are
* normalized and their total makes 1.<br/>
*
* @param array $labels
* @return array
*/
public function getFeatureImportances()
{
if ($this->featureImportances !== null) {
return $this->featureImportances;
}
$sampleCount = count($this->samples);
$this->featureImportances = [];
foreach ($this->columnNames as $column => $columnName) {
$nodes = $this->getSplitNodesByColumn($column, $this->tree);
$importance = 0;
foreach ($nodes as $node) {
$importance += $node->getNodeImpurityDecrease($sampleCount);
}
$this->featureImportances[$columnName] = $importance;
}
// Normalize & sort the importances
$total = array_sum($this->featureImportances);
if ($total > 0) {
foreach ($this->featureImportances as &$importance) {
$importance /= $total;
}
arsort($this->featureImportances);
}
return $this->featureImportances;
}
/**
* Collects and returns an array of internal nodes that use the given
* column as a split criteron
*
* @param int $column
* @param DecisionTreeLeaf
* @param array $collected
*
* @return array
*/
protected function getSplitNodesByColumn($column, DecisionTreeLeaf $node)
{
if (!$node || $node->isTerminal) {
return [];
}
$nodes = [];
if ($node->columnIndex == $column) {
$nodes[] = $node;
}
$lNodes = [];
$rNodes = [];
if ($node->leftLeaf) {
$lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
}
if ($node->rightLeaf) {
$rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
}
$nodes = array_merge($nodes, $lNodes, $rNodes);
return $nodes;
}
/**

View File

@ -6,7 +6,6 @@ namespace Phpml\Classification\DecisionTree;
class DecisionTreeLeaf
{
const OPERATOR_EQ = '=';
/**
* @var string
*/
@ -45,6 +44,11 @@ class DecisionTreeLeaf
*/
public $isTerminal = false;
/**
* @var bool
*/
public $isContinuous = false;
/**
* @var float
*/
@ -62,7 +66,7 @@ class DecisionTreeLeaf
public function evaluate($record)
{
$recordField = $record[$this->columnIndex];
if (is_string($this->value) && preg_match("/^([<>=]{1,2})\s*(.*)/", $this->value, $matches)) {
if ($this->isContinuous && preg_match("/^([<>=]{1,2})\s*(.*)/", strval($this->value), $matches)) {
$op = $matches[1];
$value= floatval($matches[2]);
$recordField = strval($recordField);
@ -72,13 +76,51 @@ class DecisionTreeLeaf
return $recordField == $this->value;
}
public function __toString()
/**
* Returns Mean Decrease Impurity (MDI) in the node.
* For terminal nodes, this value is equal to 0
*
* @return float
*/
public function getNodeImpurityDecrease(int $parentRecordCount)
{
if ($this->isTerminal) {
return 0.0;
}
$nodeSampleCount = (float)count($this->records);
$iT = $this->giniIndex;
if ($this->leftLeaf) {
$pL = count($this->leftLeaf->records)/$nodeSampleCount;
$iT -= $pL * $this->leftLeaf->giniIndex;
}
if ($this->rightLeaf) {
$pR = count($this->rightLeaf->records)/$nodeSampleCount;
$iT -= $pR * $this->rightLeaf->giniIndex;
}
return $iT * $nodeSampleCount / $parentRecordCount;
}
/**
* Returns HTML representation of the node including children nodes
*
* @param $columnNames
* @return string
*/
public function getHTML($columnNames = null)
{
if ($this->isTerminal) {
$value = "<b>$this->classValue</b>";
} else {
$value = $this->value;
$col = "col_$this->columnIndex";
if ($columnNames !== null) {
$col = $columnNames[$this->columnIndex];
} else {
$col = "col_$this->columnIndex";
}
if (! preg_match("/^[<>=]{1,2}/", $value)) {
$value = "=$value";
}
@ -89,13 +131,13 @@ class DecisionTreeLeaf
if ($this->leftLeaf || $this->rightLeaf) {
$str .='<tr>';
if ($this->leftLeaf) {
$str .="<td valign=top><b>| Yes</b><br>$this->leftLeaf</td>";
$str .="<td valign=top><b>| Yes</b><br>" . $this->leftLeaf->getHTML($columnNames) . "</td>";
} else {
$str .='<td></td>';
}
$str .='<td>&nbsp;</td>';
if ($this->rightLeaf) {
$str .="<td valign=top align=right><b>No |</b><br>$this->rightLeaf</td>";
$str .="<td valign=top align=right><b>No |</b><br>" . $this->rightLeaf->getHTML($columnNames) . "</td>";
} else {
$str .='<td></td>';
}
@ -104,4 +146,14 @@ class DecisionTreeLeaf
$str .= '</table>';
return $str;
}
/**
* HTML representation of the tree without column names
*
* @return string
*/
public function __toString()
{
return $this->getHTML();
}
}

View File

@ -53,7 +53,7 @@ class Bagging implements Classifier
/**
* @var float
*/
protected $subsetRatio = 0.5;
protected $subsetRatio = 0.7;
/**
* @var array
@ -120,7 +120,7 @@ class Bagging implements Classifier
$this->featureCount = count($samples[0]);
$this->numSamples = count($this->samples);
// Init classifiers and train them with random sub-samples
// Init classifiers and train them with bootstrap samples
$this->classifiers = $this->initClassifiers();
$index = 0;
foreach ($this->classifiers as $classifier) {
@ -134,16 +134,14 @@ class Bagging implements Classifier
* @param int $index
* @return array
*/
protected function getRandomSubset($index)
protected function getRandomSubset(int $index)
{
$subsetLength = (int)ceil(sqrt($this->numSamples));
$denom = $this->subsetRatio / 2;
$subsetLength = $this->numSamples / (1 / $denom);
$index = $index * $subsetLength % $this->numSamples;
$samples = [];
$targets = [];
for ($i=0; $i<$subsetLength * 2; $i++) {
$rand = rand($index, $this->numSamples - 1);
srand($index);
$bootstrapSize = $this->subsetRatio * $this->numSamples;
for ($i=0; $i < $bootstrapSize; $i++) {
$rand = rand(0, $this->numSamples - 1);
$samples[] = $this->samples[$rand];
$targets[] = $this->targets[$rand];
}

View File

@ -16,6 +16,18 @@ class RandomForest extends Bagging
*/
protected $featureSubsetRatio = 'log';
/**
* @var array
*/
protected $columnNames = null;
/**
* Initializes RandomForest with the given number of trees. More trees
* may increase the prediction performance while it will also substantially
* increase the processing time and the required memory
*
* @param type $numClassifier
*/
public function __construct($numClassifier = 50)
{
parent::__construct($numClassifier);
@ -24,14 +36,13 @@ class RandomForest extends Bagging
}
/**
* This method is used to determine how much of the original columns (features)
* This method is used to determine how many of the original columns (features)
* will be used to construct subsets to train base classifiers.<br>
*
* Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
*
* If there are many features that diminishes classification performance, then
* small values should be preferred, otherwise, with low number of features,
* default value (0.7) will result in satisfactory performance.
* Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
* features to be taken into consideration while selecting subspace of features
*
* @param mixed $ratio string or float should be given
* @return $this
@ -65,6 +76,55 @@ class RandomForest extends Bagging
return parent::setClassifer($classifier, $classifierOptions);
}
/**
* This will return an array including an importance value for
* each column in the given dataset. Importance values for a column
* is the average importance of that column in all trees in the forest
*
* @return array
*/
public function getFeatureImportances()
{
// Traverse each tree and sum importance of the columns
$sum = [];
foreach ($this->classifiers as $tree) {
/* @var $tree DecisionTree */
$importances = $tree->getFeatureImportances();
foreach ($importances as $column => $importance) {
if (array_key_exists($column, $sum)) {
$sum[$column] += $importance;
} else {
$sum[$column] = $importance;
}
}
}
// Normalize & sort the importance values
$total = array_sum($sum);
foreach ($sum as &$importance) {
$importance /= $total;
}
arsort($sum);
return $sum;
}
/**
* A string array to represent the columns is given. They are useful
* when trying to print some information about the trees such as feature importances
*
* @param array $names
* @return $this
*/
public function setColumnNames(array $names)
{
$this->columnNames = $names;
return $this;
}
/**
* @param DecisionTree $classifier
* @param int $index
@ -84,6 +144,12 @@ class RandomForest extends Bagging
$featureCount = $this->featureCount;
}
return $classifier->setNumFeatures($featureCount);
if ($this->columnNames === null) {
$this->columnNames = range(0, $this->featureCount - 1);
}
return $classifier
->setColumnNames($this->columnNames)
->setNumFeatures($featureCount);
}
}

View File

@ -8,6 +8,11 @@ use Phpml\Exception\FileException;
class CsvDataset extends ArrayDataset
{
/**
* @var array
*/
protected $columnNames;
/**
* @param string $filepath
* @param int $features
@ -26,7 +31,10 @@ class CsvDataset extends ArrayDataset
}
if ($headingRow) {
fgets($handle);
$data = fgetcsv($handle, 1000, ',');
$this->columnNames = array_slice($data, 0, $features);
} else {
$this->columnNames = range(0, $features - 1);
}
while (($data = fgetcsv($handle, 1000, ',')) !== false) {
@ -35,4 +43,12 @@ class CsvDataset extends ArrayDataset
}
fclose($handle);
}
/**
* @return array
*/
public function getColumnNames()
{
return $this->columnNames;
}
}