mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-22 04:55:10 +00:00
RandomForest::getFeatureImportances() method (#47)
* RandomForest::getFeatureImportances() method * CsvDataset update for column names
This commit is contained in:
parent
240a22788f
commit
a33d5fe9c8
@ -56,6 +56,17 @@ class DecisionTree implements Classifier
|
||||
*/
|
||||
private $numUsableFeatures = 0;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
private $featureImportances = null;
|
||||
|
||||
/**
|
||||
*
|
||||
* @var array
|
||||
*/
|
||||
private $columnNames = null;
|
||||
|
||||
/**
|
||||
* @param int $maxDepth
|
||||
*/
|
||||
@ -76,6 +87,21 @@ class DecisionTree implements Classifier
|
||||
$this->columnTypes = $this->getColumnTypes($this->samples);
|
||||
$this->labels = array_keys(array_count_values($this->targets));
|
||||
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
|
||||
|
||||
// Each time the tree is trained, feature importances are reset so that
|
||||
// we will have to compute it again depending on the new data
|
||||
$this->featureImportances = null;
|
||||
|
||||
// If column names are given or computed before, then there is no
|
||||
// need to init it and accidentally remove the previous given names
|
||||
if ($this->columnNames === null) {
|
||||
$this->columnNames = range(0, $this->featureCount - 1);
|
||||
} elseif (count($this->columnNames) > $this->featureCount) {
|
||||
$this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
|
||||
} elseif (count($this->columnNames) < $this->featureCount) {
|
||||
$this->columnNames = array_merge($this->columnNames,
|
||||
range(count($this->columnNames), $this->featureCount - 1));
|
||||
}
|
||||
}
|
||||
|
||||
protected function getColumnTypes(array $samples)
|
||||
@ -164,6 +190,7 @@ class DecisionTree implements Classifier
|
||||
$split->value = $baseValue;
|
||||
$split->giniIndex = $gini;
|
||||
$split->columnIndex = $i;
|
||||
$split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS;
|
||||
$split->records = $records;
|
||||
$bestSplit = $split;
|
||||
$bestGiniVal = $gini;
|
||||
@ -292,6 +319,25 @@ class DecisionTree implements Classifier
|
||||
}
|
||||
|
||||
$this->numUsableFeatures = $numFeatures;
|
||||
|
||||
return $this;
|
||||
}
|
||||
|
||||
/**
|
||||
* A string array to represent columns. Useful when HTML output or
|
||||
* column importances are desired to be inspected.
|
||||
*
|
||||
* @param array $names
|
||||
* @return $this
|
||||
*/
|
||||
public function setColumnNames(array $names)
|
||||
{
|
||||
if ($this->featureCount != 0 && count($names) != $this->featureCount) {
|
||||
throw new \Exception("Length of the given array should be equal to feature count ($this->featureCount)");
|
||||
}
|
||||
|
||||
$this->columnNames = $names;
|
||||
|
||||
return $this;
|
||||
}
|
||||
|
||||
@ -300,7 +346,80 @@ class DecisionTree implements Classifier
|
||||
*/
|
||||
public function getHtml()
|
||||
{
|
||||
return $this->tree->__toString();
|
||||
return $this->tree->getHTML($this->columnNames);
|
||||
}
|
||||
|
||||
/**
|
||||
* This will return an array including an importance value for
|
||||
* each column in the given dataset. The importance values are
|
||||
* normalized and their total makes 1.<br/>
|
||||
*
|
||||
* @param array $labels
|
||||
* @return array
|
||||
*/
|
||||
public function getFeatureImportances()
|
||||
{
|
||||
if ($this->featureImportances !== null) {
|
||||
return $this->featureImportances;
|
||||
}
|
||||
|
||||
$sampleCount = count($this->samples);
|
||||
$this->featureImportances = [];
|
||||
foreach ($this->columnNames as $column => $columnName) {
|
||||
$nodes = $this->getSplitNodesByColumn($column, $this->tree);
|
||||
|
||||
$importance = 0;
|
||||
foreach ($nodes as $node) {
|
||||
$importance += $node->getNodeImpurityDecrease($sampleCount);
|
||||
}
|
||||
|
||||
$this->featureImportances[$columnName] = $importance;
|
||||
}
|
||||
|
||||
// Normalize & sort the importances
|
||||
$total = array_sum($this->featureImportances);
|
||||
if ($total > 0) {
|
||||
foreach ($this->featureImportances as &$importance) {
|
||||
$importance /= $total;
|
||||
}
|
||||
arsort($this->featureImportances);
|
||||
}
|
||||
|
||||
return $this->featureImportances;
|
||||
}
|
||||
|
||||
/**
|
||||
* Collects and returns an array of internal nodes that use the given
|
||||
* column as a split criteron
|
||||
*
|
||||
* @param int $column
|
||||
* @param DecisionTreeLeaf
|
||||
* @param array $collected
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
protected function getSplitNodesByColumn($column, DecisionTreeLeaf $node)
|
||||
{
|
||||
if (!$node || $node->isTerminal) {
|
||||
return [];
|
||||
}
|
||||
|
||||
$nodes = [];
|
||||
if ($node->columnIndex == $column) {
|
||||
$nodes[] = $node;
|
||||
}
|
||||
|
||||
$lNodes = [];
|
||||
$rNodes = [];
|
||||
if ($node->leftLeaf) {
|
||||
$lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
|
||||
}
|
||||
if ($node->rightLeaf) {
|
||||
$rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
|
||||
}
|
||||
$nodes = array_merge($nodes, $lNodes, $rNodes);
|
||||
|
||||
return $nodes;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -6,7 +6,6 @@ namespace Phpml\Classification\DecisionTree;
|
||||
|
||||
class DecisionTreeLeaf
|
||||
{
|
||||
const OPERATOR_EQ = '=';
|
||||
/**
|
||||
* @var string
|
||||
*/
|
||||
@ -45,6 +44,11 @@ class DecisionTreeLeaf
|
||||
*/
|
||||
public $isTerminal = false;
|
||||
|
||||
/**
|
||||
* @var bool
|
||||
*/
|
||||
public $isContinuous = false;
|
||||
|
||||
/**
|
||||
* @var float
|
||||
*/
|
||||
@ -62,7 +66,7 @@ class DecisionTreeLeaf
|
||||
public function evaluate($record)
|
||||
{
|
||||
$recordField = $record[$this->columnIndex];
|
||||
if (is_string($this->value) && preg_match("/^([<>=]{1,2})\s*(.*)/", $this->value, $matches)) {
|
||||
if ($this->isContinuous && preg_match("/^([<>=]{1,2})\s*(.*)/", strval($this->value), $matches)) {
|
||||
$op = $matches[1];
|
||||
$value= floatval($matches[2]);
|
||||
$recordField = strval($recordField);
|
||||
@ -72,13 +76,51 @@ class DecisionTreeLeaf
|
||||
return $recordField == $this->value;
|
||||
}
|
||||
|
||||
public function __toString()
|
||||
/**
|
||||
* Returns Mean Decrease Impurity (MDI) in the node.
|
||||
* For terminal nodes, this value is equal to 0
|
||||
*
|
||||
* @return float
|
||||
*/
|
||||
public function getNodeImpurityDecrease(int $parentRecordCount)
|
||||
{
|
||||
if ($this->isTerminal) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
$nodeSampleCount = (float)count($this->records);
|
||||
$iT = $this->giniIndex;
|
||||
|
||||
if ($this->leftLeaf) {
|
||||
$pL = count($this->leftLeaf->records)/$nodeSampleCount;
|
||||
$iT -= $pL * $this->leftLeaf->giniIndex;
|
||||
}
|
||||
|
||||
if ($this->rightLeaf) {
|
||||
$pR = count($this->rightLeaf->records)/$nodeSampleCount;
|
||||
$iT -= $pR * $this->rightLeaf->giniIndex;
|
||||
}
|
||||
|
||||
return $iT * $nodeSampleCount / $parentRecordCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns HTML representation of the node including children nodes
|
||||
*
|
||||
* @param $columnNames
|
||||
* @return string
|
||||
*/
|
||||
public function getHTML($columnNames = null)
|
||||
{
|
||||
if ($this->isTerminal) {
|
||||
$value = "<b>$this->classValue</b>";
|
||||
} else {
|
||||
$value = $this->value;
|
||||
if ($columnNames !== null) {
|
||||
$col = $columnNames[$this->columnIndex];
|
||||
} else {
|
||||
$col = "col_$this->columnIndex";
|
||||
}
|
||||
if (! preg_match("/^[<>=]{1,2}/", $value)) {
|
||||
$value = "=$value";
|
||||
}
|
||||
@ -89,13 +131,13 @@ class DecisionTreeLeaf
|
||||
if ($this->leftLeaf || $this->rightLeaf) {
|
||||
$str .='<tr>';
|
||||
if ($this->leftLeaf) {
|
||||
$str .="<td valign=top><b>| Yes</b><br>$this->leftLeaf</td>";
|
||||
$str .="<td valign=top><b>| Yes</b><br>" . $this->leftLeaf->getHTML($columnNames) . "</td>";
|
||||
} else {
|
||||
$str .='<td></td>';
|
||||
}
|
||||
$str .='<td> </td>';
|
||||
if ($this->rightLeaf) {
|
||||
$str .="<td valign=top align=right><b>No |</b><br>$this->rightLeaf</td>";
|
||||
$str .="<td valign=top align=right><b>No |</b><br>" . $this->rightLeaf->getHTML($columnNames) . "</td>";
|
||||
} else {
|
||||
$str .='<td></td>';
|
||||
}
|
||||
@ -104,4 +146,14 @@ class DecisionTreeLeaf
|
||||
$str .= '</table>';
|
||||
return $str;
|
||||
}
|
||||
|
||||
/**
|
||||
* HTML representation of the tree without column names
|
||||
*
|
||||
* @return string
|
||||
*/
|
||||
public function __toString()
|
||||
{
|
||||
return $this->getHTML();
|
||||
}
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ class Bagging implements Classifier
|
||||
/**
|
||||
* @var float
|
||||
*/
|
||||
protected $subsetRatio = 0.5;
|
||||
protected $subsetRatio = 0.7;
|
||||
|
||||
/**
|
||||
* @var array
|
||||
@ -120,7 +120,7 @@ class Bagging implements Classifier
|
||||
$this->featureCount = count($samples[0]);
|
||||
$this->numSamples = count($this->samples);
|
||||
|
||||
// Init classifiers and train them with random sub-samples
|
||||
// Init classifiers and train them with bootstrap samples
|
||||
$this->classifiers = $this->initClassifiers();
|
||||
$index = 0;
|
||||
foreach ($this->classifiers as $classifier) {
|
||||
@ -134,16 +134,14 @@ class Bagging implements Classifier
|
||||
* @param int $index
|
||||
* @return array
|
||||
*/
|
||||
protected function getRandomSubset($index)
|
||||
protected function getRandomSubset(int $index)
|
||||
{
|
||||
$subsetLength = (int)ceil(sqrt($this->numSamples));
|
||||
$denom = $this->subsetRatio / 2;
|
||||
$subsetLength = $this->numSamples / (1 / $denom);
|
||||
$index = $index * $subsetLength % $this->numSamples;
|
||||
$samples = [];
|
||||
$targets = [];
|
||||
for ($i=0; $i<$subsetLength * 2; $i++) {
|
||||
$rand = rand($index, $this->numSamples - 1);
|
||||
srand($index);
|
||||
$bootstrapSize = $this->subsetRatio * $this->numSamples;
|
||||
for ($i=0; $i < $bootstrapSize; $i++) {
|
||||
$rand = rand(0, $this->numSamples - 1);
|
||||
$samples[] = $this->samples[$rand];
|
||||
$targets[] = $this->targets[$rand];
|
||||
}
|
||||
|
@ -16,6 +16,18 @@ class RandomForest extends Bagging
|
||||
*/
|
||||
protected $featureSubsetRatio = 'log';
|
||||
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $columnNames = null;
|
||||
|
||||
/**
|
||||
* Initializes RandomForest with the given number of trees. More trees
|
||||
* may increase the prediction performance while it will also substantially
|
||||
* increase the processing time and the required memory
|
||||
*
|
||||
* @param type $numClassifier
|
||||
*/
|
||||
public function __construct($numClassifier = 50)
|
||||
{
|
||||
parent::__construct($numClassifier);
|
||||
@ -24,14 +36,13 @@ class RandomForest extends Bagging
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is used to determine how much of the original columns (features)
|
||||
* This method is used to determine how many of the original columns (features)
|
||||
* will be used to construct subsets to train base classifiers.<br>
|
||||
*
|
||||
* Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
|
||||
*
|
||||
* If there are many features that diminishes classification performance, then
|
||||
* small values should be preferred, otherwise, with low number of features,
|
||||
* default value (0.7) will result in satisfactory performance.
|
||||
* Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
|
||||
* features to be taken into consideration while selecting subspace of features
|
||||
*
|
||||
* @param mixed $ratio string or float should be given
|
||||
* @return $this
|
||||
@ -65,6 +76,55 @@ class RandomForest extends Bagging
|
||||
return parent::setClassifer($classifier, $classifierOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* This will return an array including an importance value for
|
||||
* each column in the given dataset. Importance values for a column
|
||||
* is the average importance of that column in all trees in the forest
|
||||
*
|
||||
* @return array
|
||||
*/
|
||||
public function getFeatureImportances()
|
||||
{
|
||||
// Traverse each tree and sum importance of the columns
|
||||
$sum = [];
|
||||
foreach ($this->classifiers as $tree) {
|
||||
/* @var $tree DecisionTree */
|
||||
$importances = $tree->getFeatureImportances();
|
||||
|
||||
foreach ($importances as $column => $importance) {
|
||||
if (array_key_exists($column, $sum)) {
|
||||
$sum[$column] += $importance;
|
||||
} else {
|
||||
$sum[$column] = $importance;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize & sort the importance values
|
||||
$total = array_sum($sum);
|
||||
foreach ($sum as &$importance) {
|
||||
$importance /= $total;
|
||||
}
|
||||
|
||||
arsort($sum);
|
||||
|
||||
return $sum;
|
||||
}
|
||||
|
||||
/**
|
||||
* A string array to represent the columns is given. They are useful
|
||||
* when trying to print some information about the trees such as feature importances
|
||||
*
|
||||
* @param array $names
|
||||
* @return $this
|
||||
*/
|
||||
public function setColumnNames(array $names)
|
||||
{
|
||||
$this->columnNames = $names;
|
||||
|
||||
return $this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param DecisionTree $classifier
|
||||
* @param int $index
|
||||
@ -84,6 +144,12 @@ class RandomForest extends Bagging
|
||||
$featureCount = $this->featureCount;
|
||||
}
|
||||
|
||||
return $classifier->setNumFeatures($featureCount);
|
||||
if ($this->columnNames === null) {
|
||||
$this->columnNames = range(0, $this->featureCount - 1);
|
||||
}
|
||||
|
||||
return $classifier
|
||||
->setColumnNames($this->columnNames)
|
||||
->setNumFeatures($featureCount);
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,11 @@ use Phpml\Exception\FileException;
|
||||
|
||||
class CsvDataset extends ArrayDataset
|
||||
{
|
||||
/**
|
||||
* @var array
|
||||
*/
|
||||
protected $columnNames;
|
||||
|
||||
/**
|
||||
* @param string $filepath
|
||||
* @param int $features
|
||||
@ -26,7 +31,10 @@ class CsvDataset extends ArrayDataset
|
||||
}
|
||||
|
||||
if ($headingRow) {
|
||||
fgets($handle);
|
||||
$data = fgetcsv($handle, 1000, ',');
|
||||
$this->columnNames = array_slice($data, 0, $features);
|
||||
} else {
|
||||
$this->columnNames = range(0, $features - 1);
|
||||
}
|
||||
|
||||
while (($data = fgetcsv($handle, 1000, ',')) !== false) {
|
||||
@ -35,4 +43,12 @@ class CsvDataset extends ArrayDataset
|
||||
}
|
||||
fclose($handle);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return array
|
||||
*/
|
||||
public function getColumnNames()
|
||||
{
|
||||
return $this->columnNames;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user