mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-09-22 12:19:02 +00:00
Add typehints to DecisionTree
This commit is contained in:
parent
01bb82a2a7
commit
c6fbb83573
@ -4,6 +4,7 @@ declare(strict_types=1);
|
|||||||
|
|
||||||
namespace Phpml\Classification;
|
namespace Phpml\Classification;
|
||||||
|
|
||||||
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
use Phpml\Helper\Predictable;
|
use Phpml\Helper\Predictable;
|
||||||
use Phpml\Helper\Trainable;
|
use Phpml\Helper\Trainable;
|
||||||
use Phpml\Math\Statistic\Mean;
|
use Phpml\Math\Statistic\Mean;
|
||||||
@ -13,7 +14,7 @@ class DecisionTree implements Classifier
|
|||||||
{
|
{
|
||||||
use Trainable, Predictable;
|
use Trainable, Predictable;
|
||||||
|
|
||||||
const CONTINUOS = 1;
|
const CONTINUOUS = 1;
|
||||||
const NOMINAL = 2;
|
const NOMINAL = 2;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -70,7 +71,7 @@ class DecisionTree implements Classifier
|
|||||||
/**
|
/**
|
||||||
* @param int $maxDepth
|
* @param int $maxDepth
|
||||||
*/
|
*/
|
||||||
public function __construct($maxDepth = 10)
|
public function __construct(int $maxDepth = 10)
|
||||||
{
|
{
|
||||||
$this->maxDepth = $maxDepth;
|
$this->maxDepth = $maxDepth;
|
||||||
}
|
}
|
||||||
@ -85,7 +86,7 @@ class DecisionTree implements Classifier
|
|||||||
$this->targets = array_merge($this->targets, $targets);
|
$this->targets = array_merge($this->targets, $targets);
|
||||||
|
|
||||||
$this->featureCount = count($this->samples[0]);
|
$this->featureCount = count($this->samples[0]);
|
||||||
$this->columnTypes = $this->getColumnTypes($this->samples);
|
$this->columnTypes = self::getColumnTypes($this->samples);
|
||||||
$this->labels = array_keys(array_count_values($this->targets));
|
$this->labels = array_keys(array_count_values($this->targets));
|
||||||
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
|
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
|
||||||
|
|
||||||
@ -105,23 +106,29 @@ class DecisionTree implements Classifier
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static function getColumnTypes(array $samples)
|
/**
|
||||||
|
* @param array $samples
|
||||||
|
* @return array
|
||||||
|
*/
|
||||||
|
public static function getColumnTypes(array $samples) : array
|
||||||
{
|
{
|
||||||
$types = [];
|
$types = [];
|
||||||
$featureCount = count($samples[0]);
|
$featureCount = count($samples[0]);
|
||||||
for ($i=0; $i < $featureCount; $i++) {
|
for ($i=0; $i < $featureCount; $i++) {
|
||||||
$values = array_column($samples, $i);
|
$values = array_column($samples, $i);
|
||||||
$isCategorical = self::isCategoricalColumn($values);
|
$isCategorical = self::isCategoricalColumn($values);
|
||||||
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS;
|
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
|
||||||
}
|
}
|
||||||
|
|
||||||
return $types;
|
return $types;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param null|array $records
|
* @param array $records
|
||||||
|
* @param int $depth
|
||||||
* @return DecisionTreeLeaf
|
* @return DecisionTreeLeaf
|
||||||
*/
|
*/
|
||||||
protected function getSplitLeaf($records, $depth = 0)
|
protected function getSplitLeaf(array $records, int $depth = 0) : DecisionTreeLeaf
|
||||||
{
|
{
|
||||||
$split = $this->getBestSplit($records);
|
$split = $this->getBestSplit($records);
|
||||||
$split->level = $depth;
|
$split->level = $depth;
|
||||||
@ -163,7 +170,7 @@ class DecisionTree implements Classifier
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (count($remainingTargets) == 1 || $allSame || $depth >= $this->maxDepth) {
|
if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {
|
||||||
$split->isTerminal = 1;
|
$split->isTerminal = 1;
|
||||||
arsort($remainingTargets);
|
arsort($remainingTargets);
|
||||||
$split->classValue = key($remainingTargets);
|
$split->classValue = key($remainingTargets);
|
||||||
@ -175,14 +182,15 @@ class DecisionTree implements Classifier
|
|||||||
$split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1);
|
$split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return $split;
|
return $split;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param array $records
|
* @param array $records
|
||||||
* @return DecisionTreeLeaf[]
|
* @return DecisionTreeLeaf
|
||||||
*/
|
*/
|
||||||
protected function getBestSplit($records)
|
protected function getBestSplit(array $records) : DecisionTreeLeaf
|
||||||
{
|
{
|
||||||
$targets = array_intersect_key($this->targets, array_flip($records));
|
$targets = array_intersect_key($this->targets, array_flip($records));
|
||||||
$samples = array_intersect_key($this->samples, array_flip($records));
|
$samples = array_intersect_key($this->samples, array_flip($records));
|
||||||
@ -199,18 +207,18 @@ class DecisionTree implements Classifier
|
|||||||
arsort($counts);
|
arsort($counts);
|
||||||
$baseValue = key($counts);
|
$baseValue = key($counts);
|
||||||
$gini = $this->getGiniIndex($baseValue, $colValues, $targets);
|
$gini = $this->getGiniIndex($baseValue, $colValues, $targets);
|
||||||
if ($bestSplit == null || $bestGiniVal > $gini) {
|
if ($bestSplit === null || $bestGiniVal > $gini) {
|
||||||
$split = new DecisionTreeLeaf();
|
$split = new DecisionTreeLeaf();
|
||||||
$split->value = $baseValue;
|
$split->value = $baseValue;
|
||||||
$split->giniIndex = $gini;
|
$split->giniIndex = $gini;
|
||||||
$split->columnIndex = $i;
|
$split->columnIndex = $i;
|
||||||
$split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS;
|
$split->isContinuous = $this->columnTypes[$i] == self::CONTINUOUS;
|
||||||
$split->records = $records;
|
$split->records = $records;
|
||||||
|
|
||||||
// If a numeric column is to be selected, then
|
// If a numeric column is to be selected, then
|
||||||
// the original numeric value and the selected operator
|
// the original numeric value and the selected operator
|
||||||
// will also be saved into the leaf for future access
|
// will also be saved into the leaf for future access
|
||||||
if ($this->columnTypes[$i] == self::CONTINUOS) {
|
if ($this->columnTypes[$i] == self::CONTINUOUS) {
|
||||||
$matches = [];
|
$matches = [];
|
||||||
preg_match("/^([<>=]{1,2})\s*(.*)/", strval($split->value), $matches);
|
preg_match("/^([<>=]{1,2})\s*(.*)/", strval($split->value), $matches);
|
||||||
$split->operator = $matches[1];
|
$split->operator = $matches[1];
|
||||||
@ -221,6 +229,7 @@ class DecisionTree implements Classifier
|
|||||||
$bestGiniVal = $gini;
|
$bestGiniVal = $gini;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return $bestSplit;
|
return $bestSplit;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -239,10 +248,10 @@ class DecisionTree implements Classifier
|
|||||||
*
|
*
|
||||||
* @return array
|
* @return array
|
||||||
*/
|
*/
|
||||||
protected function getSelectedFeatures()
|
protected function getSelectedFeatures() : array
|
||||||
{
|
{
|
||||||
$allFeatures = range(0, $this->featureCount - 1);
|
$allFeatures = range(0, $this->featureCount - 1);
|
||||||
if ($this->numUsableFeatures == 0 && ! $this->selectedFeatures) {
|
if ($this->numUsableFeatures === 0 && ! $this->selectedFeatures) {
|
||||||
return $allFeatures;
|
return $allFeatures;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -262,11 +271,12 @@ class DecisionTree implements Classifier
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param string $baseValue
|
* @param $baseValue
|
||||||
* @param array $colValues
|
* @param array $colValues
|
||||||
* @param array $targets
|
* @param array $targets
|
||||||
|
* @return float
|
||||||
*/
|
*/
|
||||||
public function getGiniIndex($baseValue, $colValues, $targets)
|
public function getGiniIndex($baseValue, array $colValues, array $targets) : float
|
||||||
{
|
{
|
||||||
$countMatrix = [];
|
$countMatrix = [];
|
||||||
foreach ($this->labels as $label) {
|
foreach ($this->labels as $label) {
|
||||||
@ -274,7 +284,7 @@ class DecisionTree implements Classifier
|
|||||||
}
|
}
|
||||||
foreach ($colValues as $index => $value) {
|
foreach ($colValues as $index => $value) {
|
||||||
$label = $targets[$index];
|
$label = $targets[$index];
|
||||||
$rowIndex = $value == $baseValue ? 0 : 1;
|
$rowIndex = $value === $baseValue ? 0 : 1;
|
||||||
$countMatrix[$label][$rowIndex]++;
|
$countMatrix[$label][$rowIndex]++;
|
||||||
}
|
}
|
||||||
$giniParts = [0, 0];
|
$giniParts = [0, 0];
|
||||||
@ -288,6 +298,7 @@ class DecisionTree implements Classifier
|
|||||||
}
|
}
|
||||||
$giniParts[$i] = (1 - $part) * $sum;
|
$giniParts[$i] = (1 - $part) * $sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
return array_sum($giniParts) / count($colValues);
|
return array_sum($giniParts) / count($colValues);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -295,14 +306,14 @@ class DecisionTree implements Classifier
|
|||||||
* @param array $samples
|
* @param array $samples
|
||||||
* @return array
|
* @return array
|
||||||
*/
|
*/
|
||||||
protected function preprocess(array $samples)
|
protected function preprocess(array $samples) : array
|
||||||
{
|
{
|
||||||
// Detect and convert continuous data column values into
|
// Detect and convert continuous data column values into
|
||||||
// discrete values by using the median as a threshold value
|
// discrete values by using the median as a threshold value
|
||||||
$columns = [];
|
$columns = [];
|
||||||
for ($i=0; $i<$this->featureCount; $i++) {
|
for ($i=0; $i<$this->featureCount; $i++) {
|
||||||
$values = array_column($samples, $i);
|
$values = array_column($samples, $i);
|
||||||
if ($this->columnTypes[$i] == self::CONTINUOS) {
|
if ($this->columnTypes[$i] == self::CONTINUOUS) {
|
||||||
$median = Mean::median($values);
|
$median = Mean::median($values);
|
||||||
foreach ($values as &$value) {
|
foreach ($values as &$value) {
|
||||||
if ($value <= $median) {
|
if ($value <= $median) {
|
||||||
@ -323,7 +334,7 @@ class DecisionTree implements Classifier
|
|||||||
* @param array $columnValues
|
* @param array $columnValues
|
||||||
* @return bool
|
* @return bool
|
||||||
*/
|
*/
|
||||||
protected static function isCategoricalColumn(array $columnValues)
|
protected static function isCategoricalColumn(array $columnValues) : bool
|
||||||
{
|
{
|
||||||
$count = count($columnValues);
|
$count = count($columnValues);
|
||||||
|
|
||||||
@ -337,15 +348,13 @@ class DecisionTree implements Classifier
|
|||||||
if ($floatValues) {
|
if ($floatValues) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (count($numericValues) != $count) {
|
if (count($numericValues) !== $count) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
$distinctValues = array_count_values($columnValues);
|
$distinctValues = array_count_values($columnValues);
|
||||||
if (count($distinctValues) <= $count / 5) {
|
|
||||||
return true;
|
return count($distinctValues) <= $count / 5;
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -357,12 +366,12 @@ class DecisionTree implements Classifier
|
|||||||
*
|
*
|
||||||
* @param int $numFeatures
|
* @param int $numFeatures
|
||||||
* @return $this
|
* @return $this
|
||||||
* @throws Exception
|
* @throws InvalidArgumentException
|
||||||
*/
|
*/
|
||||||
public function setNumFeatures(int $numFeatures)
|
public function setNumFeatures(int $numFeatures)
|
||||||
{
|
{
|
||||||
if ($numFeatures < 0) {
|
if ($numFeatures < 0) {
|
||||||
throw new \Exception("Selected column count should be greater or equal to zero");
|
throw new InvalidArgumentException('Selected column count should be greater or equal to zero');
|
||||||
}
|
}
|
||||||
|
|
||||||
$this->numUsableFeatures = $numFeatures;
|
$this->numUsableFeatures = $numFeatures;
|
||||||
@ -386,11 +395,12 @@ class DecisionTree implements Classifier
|
|||||||
*
|
*
|
||||||
* @param array $names
|
* @param array $names
|
||||||
* @return $this
|
* @return $this
|
||||||
|
* @throws InvalidArgumentException
|
||||||
*/
|
*/
|
||||||
public function setColumnNames(array $names)
|
public function setColumnNames(array $names)
|
||||||
{
|
{
|
||||||
if ($this->featureCount != 0 && count($names) != $this->featureCount) {
|
if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {
|
||||||
throw new \Exception("Length of the given array should be equal to feature count ($this->featureCount)");
|
throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));
|
||||||
}
|
}
|
||||||
|
|
||||||
$this->columnNames = $names;
|
$this->columnNames = $names;
|
||||||
@ -411,7 +421,6 @@ class DecisionTree implements Classifier
|
|||||||
* each column in the given dataset. The importance values are
|
* each column in the given dataset. The importance values are
|
||||||
* normalized and their total makes 1.<br/>
|
* normalized and their total makes 1.<br/>
|
||||||
*
|
*
|
||||||
* @param array $labels
|
|
||||||
* @return array
|
* @return array
|
||||||
*/
|
*/
|
||||||
public function getFeatureImportances()
|
public function getFeatureImportances()
|
||||||
@ -447,22 +456,20 @@ class DecisionTree implements Classifier
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Collects and returns an array of internal nodes that use the given
|
* Collects and returns an array of internal nodes that use the given
|
||||||
* column as a split criteron
|
* column as a split criterion
|
||||||
*
|
*
|
||||||
* @param int $column
|
* @param int $column
|
||||||
* @param DecisionTreeLeaf
|
* @param DecisionTreeLeaf $node
|
||||||
* @param array $collected
|
|
||||||
*
|
|
||||||
* @return array
|
* @return array
|
||||||
*/
|
*/
|
||||||
protected function getSplitNodesByColumn($column, DecisionTreeLeaf $node)
|
protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node) : array
|
||||||
{
|
{
|
||||||
if (!$node || $node->isTerminal) {
|
if (!$node || $node->isTerminal) {
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
$nodes = [];
|
$nodes = [];
|
||||||
if ($node->columnIndex == $column) {
|
if ($node->columnIndex === $column) {
|
||||||
$nodes[] = $node;
|
$nodes[] = $node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -135,7 +135,7 @@ class DecisionStump extends WeightedClassifier
|
|||||||
'prob' => [], 'column' => 0,
|
'prob' => [], 'column' => 0,
|
||||||
'trainingErrorRate' => 1.0];
|
'trainingErrorRate' => 1.0];
|
||||||
foreach ($columns as $col) {
|
foreach ($columns as $col) {
|
||||||
if ($this->columnTypes[$col] == DecisionTree::CONTINUOS) {
|
if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
|
||||||
$split = $this->getBestNumericalSplit($col);
|
$split = $this->getBestNumericalSplit($col);
|
||||||
} else {
|
} else {
|
||||||
$split = $this->getBestNominalSplit($col);
|
$split = $this->getBestNominalSplit($col);
|
||||||
|
@ -6,7 +6,10 @@ namespace Phpml\Classification;
|
|||||||
|
|
||||||
abstract class WeightedClassifier implements Classifier
|
abstract class WeightedClassifier implements Classifier
|
||||||
{
|
{
|
||||||
protected $weights = null;
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
protected $weights;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sets the array including a weight for each sample
|
* Sets the array including a weight for each sample
|
||||||
|
Loading…
Reference in New Issue
Block a user