mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-22 13:05:10 +00:00
RandomForest::getFeatureImportances() method (#47)
* RandomForest::getFeatureImportances() method * CsvDataset update for column names
This commit is contained in:
parent
240a22788f
commit
a33d5fe9c8
@ -56,6 +56,17 @@ class DecisionTree implements Classifier
|
|||||||
*/
|
*/
|
||||||
private $numUsableFeatures = 0;
|
private $numUsableFeatures = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
private $featureImportances = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
private $columnNames = null;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param int $maxDepth
|
* @param int $maxDepth
|
||||||
*/
|
*/
|
||||||
@ -76,6 +87,21 @@ class DecisionTree implements Classifier
|
|||||||
$this->columnTypes = $this->getColumnTypes($this->samples);
|
$this->columnTypes = $this->getColumnTypes($this->samples);
|
||||||
$this->labels = array_keys(array_count_values($this->targets));
|
$this->labels = array_keys(array_count_values($this->targets));
|
||||||
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
|
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
|
||||||
|
|
||||||
|
// Each time the tree is trained, feature importances are reset so that
|
||||||
|
// we will have to compute it again depending on the new data
|
||||||
|
$this->featureImportances = null;
|
||||||
|
|
||||||
|
// If column names are given or computed before, then there is no
|
||||||
|
// need to init it and accidentally remove the previous given names
|
||||||
|
if ($this->columnNames === null) {
|
||||||
|
$this->columnNames = range(0, $this->featureCount - 1);
|
||||||
|
} elseif (count($this->columnNames) > $this->featureCount) {
|
||||||
|
$this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
|
||||||
|
} elseif (count($this->columnNames) < $this->featureCount) {
|
||||||
|
$this->columnNames = array_merge($this->columnNames,
|
||||||
|
range(count($this->columnNames), $this->featureCount - 1));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected function getColumnTypes(array $samples)
|
protected function getColumnTypes(array $samples)
|
||||||
@ -164,6 +190,7 @@ class DecisionTree implements Classifier
|
|||||||
$split->value = $baseValue;
|
$split->value = $baseValue;
|
||||||
$split->giniIndex = $gini;
|
$split->giniIndex = $gini;
|
||||||
$split->columnIndex = $i;
|
$split->columnIndex = $i;
|
||||||
|
$split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS;
|
||||||
$split->records = $records;
|
$split->records = $records;
|
||||||
$bestSplit = $split;
|
$bestSplit = $split;
|
||||||
$bestGiniVal = $gini;
|
$bestGiniVal = $gini;
|
||||||
@ -292,6 +319,25 @@ class DecisionTree implements Classifier
|
|||||||
}
|
}
|
||||||
|
|
||||||
$this->numUsableFeatures = $numFeatures;
|
$this->numUsableFeatures = $numFeatures;
|
||||||
|
|
||||||
|
return $this;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A string array to represent columns. Useful when HTML output or
|
||||||
|
* column importances are desired to be inspected.
|
||||||
|
*
|
||||||
|
* @param array $names
|
||||||
|
* @return $this
|
||||||
|
*/
|
||||||
|
public function setColumnNames(array $names)
|
||||||
|
{
|
||||||
|
if ($this->featureCount != 0 && count($names) != $this->featureCount) {
|
||||||
|
throw new \Exception("Length of the given array should be equal to feature count ($this->featureCount)");
|
||||||
|
}
|
||||||
|
|
||||||
|
$this->columnNames = $names;
|
||||||
|
|
||||||
return $this;
|
return $this;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -300,7 +346,80 @@ class DecisionTree implements Classifier
|
|||||||
*/
|
*/
|
||||||
public function getHtml()
|
public function getHtml()
|
||||||
{
|
{
|
||||||
return $this->tree->__toString();
|
return $this->tree->getHTML($this->columnNames);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This will return an array including an importance value for
|
||||||
|
* each column in the given dataset. The importance values are
|
||||||
|
* normalized and their total makes 1.<br/>
|
||||||
|
*
|
||||||
|
* @param array $labels
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
public function getFeatureImportances()
|
||||||
|
{
|
||||||
|
if ($this->featureImportances !== null) {
|
||||||
|
return $this->featureImportances;
|
||||||
|
}
|
||||||
|
|
||||||
|
$sampleCount = count($this->samples);
|
||||||
|
$this->featureImportances = [];
|
||||||
|
foreach ($this->columnNames as $column => $columnName) {
|
||||||
|
$nodes = $this->getSplitNodesByColumn($column, $this->tree);
|
||||||
|
|
||||||
|
$importance = 0;
|
||||||
|
foreach ($nodes as $node) {
|
||||||
|
$importance += $node->getNodeImpurityDecrease($sampleCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
$this->featureImportances[$columnName] = $importance;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize & sort the importances
|
||||||
|
$total = array_sum($this->featureImportances);
|
||||||
|
if ($total > 0) {
|
||||||
|
foreach ($this->featureImportances as &$importance) {
|
||||||
|
$importance /= $total;
|
||||||
|
}
|
||||||
|
arsort($this->featureImportances);
|
||||||
|
}
|
||||||
|
|
||||||
|
return $this->featureImportances;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Collects and returns an array of internal nodes that use the given
|
||||||
|
* column as a split criteron
|
||||||
|
*
|
||||||
|
* @param int $column
|
||||||
|
* @param DecisionTreeLeaf
|
||||||
|
* @param array $collected
|
||||||
|
*
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
protected function getSplitNodesByColumn($column, DecisionTreeLeaf $node)
|
||||||
|
{
|
||||||
|
if (!$node || $node->isTerminal) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
$nodes = [];
|
||||||
|
if ($node->columnIndex == $column) {
|
||||||
|
$nodes[] = $node;
|
||||||
|
}
|
||||||
|
|
||||||
|
$lNodes = [];
|
||||||
|
$rNodes = [];
|
||||||
|
if ($node->leftLeaf) {
|
||||||
|
$lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
|
||||||
|
}
|
||||||
|
if ($node->rightLeaf) {
|
||||||
|
$rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
|
||||||
|
}
|
||||||
|
$nodes = array_merge($nodes, $lNodes, $rNodes);
|
||||||
|
|
||||||
|
return $nodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -6,7 +6,6 @@ namespace Phpml\Classification\DecisionTree;
|
|||||||
|
|
||||||
class DecisionTreeLeaf
|
class DecisionTreeLeaf
|
||||||
{
|
{
|
||||||
const OPERATOR_EQ = '=';
|
|
||||||
/**
|
/**
|
||||||
* @var string
|
* @var string
|
||||||
*/
|
*/
|
||||||
@ -45,6 +44,11 @@ class DecisionTreeLeaf
|
|||||||
*/
|
*/
|
||||||
public $isTerminal = false;
|
public $isTerminal = false;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var bool
|
||||||
|
*/
|
||||||
|
public $isContinuous = false;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @var float
|
* @var float
|
||||||
*/
|
*/
|
||||||
@ -62,7 +66,7 @@ class DecisionTreeLeaf
|
|||||||
public function evaluate($record)
|
public function evaluate($record)
|
||||||
{
|
{
|
||||||
$recordField = $record[$this->columnIndex];
|
$recordField = $record[$this->columnIndex];
|
||||||
if (is_string($this->value) && preg_match("/^([<>=]{1,2})\s*(.*)/", $this->value, $matches)) {
|
if ($this->isContinuous && preg_match("/^([<>=]{1,2})\s*(.*)/", strval($this->value), $matches)) {
|
||||||
$op = $matches[1];
|
$op = $matches[1];
|
||||||
$value= floatval($matches[2]);
|
$value= floatval($matches[2]);
|
||||||
$recordField = strval($recordField);
|
$recordField = strval($recordField);
|
||||||
@ -72,13 +76,51 @@ class DecisionTreeLeaf
|
|||||||
return $recordField == $this->value;
|
return $recordField == $this->value;
|
||||||
}
|
}
|
||||||
|
|
||||||
public function __toString()
|
/**
|
||||||
|
* Returns Mean Decrease Impurity (MDI) in the node.
|
||||||
|
* For terminal nodes, this value is equal to 0
|
||||||
|
*
|
||||||
|
* @return float
|
||||||
|
*/
|
||||||
|
public function getNodeImpurityDecrease(int $parentRecordCount)
|
||||||
|
{
|
||||||
|
if ($this->isTerminal) {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
$nodeSampleCount = (float)count($this->records);
|
||||||
|
$iT = $this->giniIndex;
|
||||||
|
|
||||||
|
if ($this->leftLeaf) {
|
||||||
|
$pL = count($this->leftLeaf->records)/$nodeSampleCount;
|
||||||
|
$iT -= $pL * $this->leftLeaf->giniIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ($this->rightLeaf) {
|
||||||
|
$pR = count($this->rightLeaf->records)/$nodeSampleCount;
|
||||||
|
$iT -= $pR * $this->rightLeaf->giniIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
return $iT * $nodeSampleCount / $parentRecordCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns HTML representation of the node including children nodes
|
||||||
|
*
|
||||||
|
* @param $columnNames
|
||||||
|
* @return string
|
||||||
|
*/
|
||||||
|
public function getHTML($columnNames = null)
|
||||||
{
|
{
|
||||||
if ($this->isTerminal) {
|
if ($this->isTerminal) {
|
||||||
$value = "<b>$this->classValue</b>";
|
$value = "<b>$this->classValue</b>";
|
||||||
} else {
|
} else {
|
||||||
$value = $this->value;
|
$value = $this->value;
|
||||||
$col = "col_$this->columnIndex";
|
if ($columnNames !== null) {
|
||||||
|
$col = $columnNames[$this->columnIndex];
|
||||||
|
} else {
|
||||||
|
$col = "col_$this->columnIndex";
|
||||||
|
}
|
||||||
if (! preg_match("/^[<>=]{1,2}/", $value)) {
|
if (! preg_match("/^[<>=]{1,2}/", $value)) {
|
||||||
$value = "=$value";
|
$value = "=$value";
|
||||||
}
|
}
|
||||||
@ -89,13 +131,13 @@ class DecisionTreeLeaf
|
|||||||
if ($this->leftLeaf || $this->rightLeaf) {
|
if ($this->leftLeaf || $this->rightLeaf) {
|
||||||
$str .='<tr>';
|
$str .='<tr>';
|
||||||
if ($this->leftLeaf) {
|
if ($this->leftLeaf) {
|
||||||
$str .="<td valign=top><b>| Yes</b><br>$this->leftLeaf</td>";
|
$str .="<td valign=top><b>| Yes</b><br>" . $this->leftLeaf->getHTML($columnNames) . "</td>";
|
||||||
} else {
|
} else {
|
||||||
$str .='<td></td>';
|
$str .='<td></td>';
|
||||||
}
|
}
|
||||||
$str .='<td> </td>';
|
$str .='<td> </td>';
|
||||||
if ($this->rightLeaf) {
|
if ($this->rightLeaf) {
|
||||||
$str .="<td valign=top align=right><b>No |</b><br>$this->rightLeaf</td>";
|
$str .="<td valign=top align=right><b>No |</b><br>" . $this->rightLeaf->getHTML($columnNames) . "</td>";
|
||||||
} else {
|
} else {
|
||||||
$str .='<td></td>';
|
$str .='<td></td>';
|
||||||
}
|
}
|
||||||
@ -104,4 +146,14 @@ class DecisionTreeLeaf
|
|||||||
$str .= '</table>';
|
$str .= '</table>';
|
||||||
return $str;
|
return $str;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* HTML representation of the tree without column names
|
||||||
|
*
|
||||||
|
* @return string
|
||||||
|
*/
|
||||||
|
public function __toString()
|
||||||
|
{
|
||||||
|
return $this->getHTML();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -53,7 +53,7 @@ class Bagging implements Classifier
|
|||||||
/**
|
/**
|
||||||
* @var float
|
* @var float
|
||||||
*/
|
*/
|
||||||
protected $subsetRatio = 0.5;
|
protected $subsetRatio = 0.7;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @var array
|
* @var array
|
||||||
@ -120,7 +120,7 @@ class Bagging implements Classifier
|
|||||||
$this->featureCount = count($samples[0]);
|
$this->featureCount = count($samples[0]);
|
||||||
$this->numSamples = count($this->samples);
|
$this->numSamples = count($this->samples);
|
||||||
|
|
||||||
// Init classifiers and train them with random sub-samples
|
// Init classifiers and train them with bootstrap samples
|
||||||
$this->classifiers = $this->initClassifiers();
|
$this->classifiers = $this->initClassifiers();
|
||||||
$index = 0;
|
$index = 0;
|
||||||
foreach ($this->classifiers as $classifier) {
|
foreach ($this->classifiers as $classifier) {
|
||||||
@ -134,16 +134,14 @@ class Bagging implements Classifier
|
|||||||
* @param int $index
|
* @param int $index
|
||||||
* @return array
|
* @return array
|
||||||
*/
|
*/
|
||||||
protected function getRandomSubset($index)
|
protected function getRandomSubset(int $index)
|
||||||
{
|
{
|
||||||
$subsetLength = (int)ceil(sqrt($this->numSamples));
|
|
||||||
$denom = $this->subsetRatio / 2;
|
|
||||||
$subsetLength = $this->numSamples / (1 / $denom);
|
|
||||||
$index = $index * $subsetLength % $this->numSamples;
|
|
||||||
$samples = [];
|
$samples = [];
|
||||||
$targets = [];
|
$targets = [];
|
||||||
for ($i=0; $i<$subsetLength * 2; $i++) {
|
srand($index);
|
||||||
$rand = rand($index, $this->numSamples - 1);
|
$bootstrapSize = $this->subsetRatio * $this->numSamples;
|
||||||
|
for ($i=0; $i < $bootstrapSize; $i++) {
|
||||||
|
$rand = rand(0, $this->numSamples - 1);
|
||||||
$samples[] = $this->samples[$rand];
|
$samples[] = $this->samples[$rand];
|
||||||
$targets[] = $this->targets[$rand];
|
$targets[] = $this->targets[$rand];
|
||||||
}
|
}
|
||||||
|
@ -16,6 +16,18 @@ class RandomForest extends Bagging
|
|||||||
*/
|
*/
|
||||||
protected $featureSubsetRatio = 'log';
|
protected $featureSubsetRatio = 'log';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
protected $columnNames = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initializes RandomForest with the given number of trees. More trees
|
||||||
|
* may increase the prediction performance while it will also substantially
|
||||||
|
* increase the processing time and the required memory
|
||||||
|
*
|
||||||
|
* @param type $numClassifier
|
||||||
|
*/
|
||||||
public function __construct($numClassifier = 50)
|
public function __construct($numClassifier = 50)
|
||||||
{
|
{
|
||||||
parent::__construct($numClassifier);
|
parent::__construct($numClassifier);
|
||||||
@ -24,14 +36,13 @@ class RandomForest extends Bagging
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method is used to determine how much of the original columns (features)
|
* This method is used to determine how many of the original columns (features)
|
||||||
* will be used to construct subsets to train base classifiers.<br>
|
* will be used to construct subsets to train base classifiers.<br>
|
||||||
*
|
*
|
||||||
* Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
|
* Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
|
||||||
*
|
*
|
||||||
* If there are many features that diminishes classification performance, then
|
* Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
|
||||||
* small values should be preferred, otherwise, with low number of features,
|
* features to be taken into consideration while selecting subspace of features
|
||||||
* default value (0.7) will result in satisfactory performance.
|
|
||||||
*
|
*
|
||||||
* @param mixed $ratio string or float should be given
|
* @param mixed $ratio string or float should be given
|
||||||
* @return $this
|
* @return $this
|
||||||
@ -65,6 +76,55 @@ class RandomForest extends Bagging
|
|||||||
return parent::setClassifer($classifier, $classifierOptions);
|
return parent::setClassifer($classifier, $classifierOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This will return an array including an importance value for
|
||||||
|
* each column in the given dataset. Importance values for a column
|
||||||
|
* is the average importance of that column in all trees in the forest
|
||||||
|
*
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
public function getFeatureImportances()
|
||||||
|
{
|
||||||
|
// Traverse each tree and sum importance of the columns
|
||||||
|
$sum = [];
|
||||||
|
foreach ($this->classifiers as $tree) {
|
||||||
|
/* @var $tree DecisionTree */
|
||||||
|
$importances = $tree->getFeatureImportances();
|
||||||
|
|
||||||
|
foreach ($importances as $column => $importance) {
|
||||||
|
if (array_key_exists($column, $sum)) {
|
||||||
|
$sum[$column] += $importance;
|
||||||
|
} else {
|
||||||
|
$sum[$column] = $importance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize & sort the importance values
|
||||||
|
$total = array_sum($sum);
|
||||||
|
foreach ($sum as &$importance) {
|
||||||
|
$importance /= $total;
|
||||||
|
}
|
||||||
|
|
||||||
|
arsort($sum);
|
||||||
|
|
||||||
|
return $sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A string array to represent the columns is given. They are useful
|
||||||
|
* when trying to print some information about the trees such as feature importances
|
||||||
|
*
|
||||||
|
* @param array $names
|
||||||
|
* @return $this
|
||||||
|
*/
|
||||||
|
public function setColumnNames(array $names)
|
||||||
|
{
|
||||||
|
$this->columnNames = $names;
|
||||||
|
|
||||||
|
return $this;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param DecisionTree $classifier
|
* @param DecisionTree $classifier
|
||||||
* @param int $index
|
* @param int $index
|
||||||
@ -84,6 +144,12 @@ class RandomForest extends Bagging
|
|||||||
$featureCount = $this->featureCount;
|
$featureCount = $this->featureCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
return $classifier->setNumFeatures($featureCount);
|
if ($this->columnNames === null) {
|
||||||
|
$this->columnNames = range(0, $this->featureCount - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return $classifier
|
||||||
|
->setColumnNames($this->columnNames)
|
||||||
|
->setNumFeatures($featureCount);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,11 @@ use Phpml\Exception\FileException;
|
|||||||
|
|
||||||
class CsvDataset extends ArrayDataset
|
class CsvDataset extends ArrayDataset
|
||||||
{
|
{
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
protected $columnNames;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param string $filepath
|
* @param string $filepath
|
||||||
* @param int $features
|
* @param int $features
|
||||||
@ -26,7 +31,10 @@ class CsvDataset extends ArrayDataset
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ($headingRow) {
|
if ($headingRow) {
|
||||||
fgets($handle);
|
$data = fgetcsv($handle, 1000, ',');
|
||||||
|
$this->columnNames = array_slice($data, 0, $features);
|
||||||
|
} else {
|
||||||
|
$this->columnNames = range(0, $features - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
while (($data = fgetcsv($handle, 1000, ',')) !== false) {
|
while (($data = fgetcsv($handle, 1000, ',')) !== false) {
|
||||||
@ -35,4 +43,12 @@ class CsvDataset extends ArrayDataset
|
|||||||
}
|
}
|
||||||
fclose($handle);
|
fclose($handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
public function getColumnNames()
|
||||||
|
{
|
||||||
|
return $this->columnNames;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user