2017-02-07 11:37:56 +00:00
|
|
|
<?php
|
|
|
|
|
|
|
|
declare(strict_types=1);
|
|
|
|
|
|
|
|
namespace Phpml\Classification\Ensemble;
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
use Phpml\Classification\Classifier;
|
2017-02-07 11:37:56 +00:00
|
|
|
use Phpml\Classification\DecisionTree;
|
2018-03-04 16:02:36 +00:00
|
|
|
use Phpml\Exception\InvalidArgumentException;
|
2017-02-07 11:37:56 +00:00
|
|
|
|
|
|
|
class RandomForest extends Bagging
|
|
|
|
{
|
|
|
|
/**
|
|
|
|
* @var float|string
|
|
|
|
*/
|
|
|
|
protected $featureSubsetRatio = 'log';
|
|
|
|
|
2017-02-13 20:23:18 +00:00
|
|
|
/**
|
2018-10-16 19:42:06 +00:00
|
|
|
* @var array|null
|
2017-02-13 20:23:18 +00:00
|
|
|
*/
|
2018-10-16 19:42:06 +00:00
|
|
|
protected $columnNames;
|
2017-02-13 20:23:18 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* Initializes RandomForest with the given number of trees. More trees
|
|
|
|
* may increase the prediction performance while it will also substantially
|
|
|
|
* increase the processing time and the required memory
|
|
|
|
*/
|
2017-05-17 07:03:25 +00:00
|
|
|
public function __construct(int $numClassifier = 50)
|
2017-02-07 11:37:56 +00:00
|
|
|
{
|
|
|
|
parent::__construct($numClassifier);
|
|
|
|
|
|
|
|
$this->setSubsetRatio(1.0);
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
2017-02-13 20:23:18 +00:00
|
|
|
* This method is used to determine how many of the original columns (features)
|
2017-02-07 11:37:56 +00:00
|
|
|
* will be used to construct subsets to train base classifiers.<br>
|
|
|
|
*
|
|
|
|
* Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
|
|
|
|
*
|
2017-02-13 20:23:18 +00:00
|
|
|
* Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
|
|
|
|
* features to be taken into consideration while selecting subspace of features
|
2017-02-07 11:37:56 +00:00
|
|
|
*
|
2018-03-04 16:02:36 +00:00
|
|
|
* @param string|float $ratio
|
2017-02-07 11:37:56 +00:00
|
|
|
*/
|
2018-03-04 16:02:36 +00:00
|
|
|
public function setFeatureSubsetRatio($ratio): self
|
2017-02-07 11:37:56 +00:00
|
|
|
{
|
2018-03-04 16:02:36 +00:00
|
|
|
if (!is_string($ratio) && !is_float($ratio)) {
|
|
|
|
throw new InvalidArgumentException('Feature subset ratio must be a string or a float');
|
|
|
|
}
|
|
|
|
|
2017-02-07 11:37:56 +00:00
|
|
|
if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
|
2018-03-04 16:02:36 +00:00
|
|
|
throw new InvalidArgumentException('When a float is given, feature subset ratio should be between 0.1 and 1.0');
|
2017-02-07 11:37:56 +00:00
|
|
|
}
|
2017-05-17 07:03:25 +00:00
|
|
|
|
2018-10-16 19:42:06 +00:00
|
|
|
if (is_string($ratio) && $ratio !== 'sqrt' && $ratio !== 'log') {
|
2018-03-04 16:02:36 +00:00
|
|
|
throw new InvalidArgumentException("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
|
2017-02-07 11:37:56 +00:00
|
|
|
}
|
2017-05-17 07:03:25 +00:00
|
|
|
|
2017-02-07 11:37:56 +00:00
|
|
|
$this->featureSubsetRatio = $ratio;
|
2017-08-17 06:50:37 +00:00
|
|
|
|
2017-02-07 11:37:56 +00:00
|
|
|
return $this;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* RandomForest algorithm is usable *only* with DecisionTree
|
|
|
|
*
|
|
|
|
* @return $this
|
|
|
|
*/
|
|
|
|
public function setClassifer(string $classifier, array $classifierOptions = [])
|
|
|
|
{
|
2018-10-16 19:42:06 +00:00
|
|
|
if ($classifier !== DecisionTree::class) {
|
2018-03-04 16:02:36 +00:00
|
|
|
throw new InvalidArgumentException('RandomForest can only use DecisionTree as base classifier');
|
2017-02-07 11:37:56 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
return parent::setClassifer($classifier, $classifierOptions);
|
|
|
|
}
|
|
|
|
|
2017-02-13 20:23:18 +00:00
|
|
|
/**
|
|
|
|
* This will return an array including an importance value for
|
|
|
|
* each column in the given dataset. Importance values for a column
|
|
|
|
* is the average importance of that column in all trees in the forest
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
public function getFeatureImportances(): array
|
2017-02-13 20:23:18 +00:00
|
|
|
{
|
|
|
|
// Traverse each tree and sum importance of the columns
|
|
|
|
$sum = [];
|
|
|
|
foreach ($this->classifiers as $tree) {
|
2018-12-12 20:56:44 +00:00
|
|
|
/** @var DecisionTree $tree */
|
2017-02-13 20:23:18 +00:00
|
|
|
$importances = $tree->getFeatureImportances();
|
|
|
|
|
|
|
|
foreach ($importances as $column => $importance) {
|
|
|
|
if (array_key_exists($column, $sum)) {
|
|
|
|
$sum[$column] += $importance;
|
|
|
|
} else {
|
|
|
|
$sum[$column] = $importance;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Normalize & sort the importance values
|
|
|
|
$total = array_sum($sum);
|
2018-10-16 19:42:06 +00:00
|
|
|
array_walk($sum, function (&$importance) use ($total): void {
|
2017-02-13 20:23:18 +00:00
|
|
|
$importance /= $total;
|
2018-10-16 19:42:06 +00:00
|
|
|
});
|
2017-02-13 20:23:18 +00:00
|
|
|
arsort($sum);
|
|
|
|
|
|
|
|
return $sum;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* A string array to represent the columns is given. They are useful
|
|
|
|
* when trying to print some information about the trees such as feature importances
|
|
|
|
*
|
|
|
|
* @return $this
|
|
|
|
*/
|
|
|
|
public function setColumnNames(array $names)
|
|
|
|
{
|
|
|
|
$this->columnNames = $names;
|
|
|
|
|
|
|
|
return $this;
|
|
|
|
}
|
|
|
|
|
2017-02-07 11:37:56 +00:00
|
|
|
/**
|
|
|
|
* @param DecisionTree $classifier
|
2017-05-17 07:03:25 +00:00
|
|
|
*
|
2017-02-07 11:37:56 +00:00
|
|
|
* @return DecisionTree
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
protected function initSingleClassifier(Classifier $classifier): Classifier
|
2017-02-07 11:37:56 +00:00
|
|
|
{
|
|
|
|
if (is_float($this->featureSubsetRatio)) {
|
2017-08-17 06:50:37 +00:00
|
|
|
$featureCount = (int) ($this->featureSubsetRatio * $this->featureCount);
|
2018-10-16 19:42:06 +00:00
|
|
|
} elseif ($this->featureSubsetRatio === 'sqrt') {
|
2017-08-17 06:50:37 +00:00
|
|
|
$featureCount = (int) sqrt($this->featureCount) + 1;
|
2017-02-07 11:37:56 +00:00
|
|
|
} else {
|
2017-08-17 06:50:37 +00:00
|
|
|
$featureCount = (int) log($this->featureCount, 2) + 1;
|
2017-02-07 11:37:56 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
if ($featureCount >= $this->featureCount) {
|
|
|
|
$featureCount = $this->featureCount;
|
|
|
|
}
|
|
|
|
|
2017-02-13 20:23:18 +00:00
|
|
|
if ($this->columnNames === null) {
|
|
|
|
$this->columnNames = range(0, $this->featureCount - 1);
|
|
|
|
}
|
|
|
|
|
|
|
|
return $classifier
|
2018-01-06 20:25:47 +00:00
|
|
|
->setColumnNames($this->columnNames)
|
|
|
|
->setNumFeatures($featureCount);
|
2017-02-07 11:37:56 +00:00
|
|
|
}
|
|
|
|
}
|