Add typehints to DecisionTree

This commit is contained in:
Arkadiusz Kondas 2017-03-05 16:25:01 +01:00
parent 01bb82a2a7
commit c6fbb83573
3 changed files with 49 additions and 39 deletions

View File

@ -4,6 +4,7 @@ declare(strict_types=1);
namespace Phpml\Classification; namespace Phpml\Classification;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Helper\Predictable; use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable; use Phpml\Helper\Trainable;
use Phpml\Math\Statistic\Mean; use Phpml\Math\Statistic\Mean;
@ -13,7 +14,7 @@ class DecisionTree implements Classifier
{ {
use Trainable, Predictable; use Trainable, Predictable;
const CONTINUOS = 1; const CONTINUOUS = 1;
const NOMINAL = 2; const NOMINAL = 2;
/** /**
@ -70,7 +71,7 @@ class DecisionTree implements Classifier
/** /**
* @param int $maxDepth * @param int $maxDepth
*/ */
public function __construct($maxDepth = 10) public function __construct(int $maxDepth = 10)
{ {
$this->maxDepth = $maxDepth; $this->maxDepth = $maxDepth;
} }
@ -85,7 +86,7 @@ class DecisionTree implements Classifier
$this->targets = array_merge($this->targets, $targets); $this->targets = array_merge($this->targets, $targets);
$this->featureCount = count($this->samples[0]); $this->featureCount = count($this->samples[0]);
$this->columnTypes = $this->getColumnTypes($this->samples); $this->columnTypes = self::getColumnTypes($this->samples);
$this->labels = array_keys(array_count_values($this->targets)); $this->labels = array_keys(array_count_values($this->targets));
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1)); $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
@ -105,23 +106,29 @@ class DecisionTree implements Classifier
} }
} }
public static function getColumnTypes(array $samples) /**
* @param array $samples
* @return array
*/
public static function getColumnTypes(array $samples) : array
{ {
$types = []; $types = [];
$featureCount = count($samples[0]); $featureCount = count($samples[0]);
for ($i=0; $i < $featureCount; $i++) { for ($i=0; $i < $featureCount; $i++) {
$values = array_column($samples, $i); $values = array_column($samples, $i);
$isCategorical = self::isCategoricalColumn($values); $isCategorical = self::isCategoricalColumn($values);
$types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS; $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
} }
return $types; return $types;
} }
/** /**
* @param null|array $records * @param array $records
* @param int $depth
* @return DecisionTreeLeaf * @return DecisionTreeLeaf
*/ */
protected function getSplitLeaf($records, $depth = 0) protected function getSplitLeaf(array $records, int $depth = 0) : DecisionTreeLeaf
{ {
$split = $this->getBestSplit($records); $split = $this->getBestSplit($records);
$split->level = $depth; $split->level = $depth;
@ -163,7 +170,7 @@ class DecisionTree implements Classifier
} }
} }
if (count($remainingTargets) == 1 || $allSame || $depth >= $this->maxDepth) { if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {
$split->isTerminal = 1; $split->isTerminal = 1;
arsort($remainingTargets); arsort($remainingTargets);
$split->classValue = key($remainingTargets); $split->classValue = key($remainingTargets);
@ -175,14 +182,15 @@ class DecisionTree implements Classifier
$split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1); $split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1);
} }
} }
return $split; return $split;
} }
/** /**
* @param array $records * @param array $records
* @return DecisionTreeLeaf[] * @return DecisionTreeLeaf
*/ */
protected function getBestSplit($records) protected function getBestSplit(array $records) : DecisionTreeLeaf
{ {
$targets = array_intersect_key($this->targets, array_flip($records)); $targets = array_intersect_key($this->targets, array_flip($records));
$samples = array_intersect_key($this->samples, array_flip($records)); $samples = array_intersect_key($this->samples, array_flip($records));
@ -199,18 +207,18 @@ class DecisionTree implements Classifier
arsort($counts); arsort($counts);
$baseValue = key($counts); $baseValue = key($counts);
$gini = $this->getGiniIndex($baseValue, $colValues, $targets); $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
if ($bestSplit == null || $bestGiniVal > $gini) { if ($bestSplit === null || $bestGiniVal > $gini) {
$split = new DecisionTreeLeaf(); $split = new DecisionTreeLeaf();
$split->value = $baseValue; $split->value = $baseValue;
$split->giniIndex = $gini; $split->giniIndex = $gini;
$split->columnIndex = $i; $split->columnIndex = $i;
$split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS; $split->isContinuous = $this->columnTypes[$i] == self::CONTINUOUS;
$split->records = $records; $split->records = $records;
// If a numeric column is to be selected, then // If a numeric column is to be selected, then
// the original numeric value and the selected operator // the original numeric value and the selected operator
// will also be saved into the leaf for future access // will also be saved into the leaf for future access
if ($this->columnTypes[$i] == self::CONTINUOS) { if ($this->columnTypes[$i] == self::CONTINUOUS) {
$matches = []; $matches = [];
preg_match("/^([<>=]{1,2})\s*(.*)/", strval($split->value), $matches); preg_match("/^([<>=]{1,2})\s*(.*)/", strval($split->value), $matches);
$split->operator = $matches[1]; $split->operator = $matches[1];
@ -221,6 +229,7 @@ class DecisionTree implements Classifier
$bestGiniVal = $gini; $bestGiniVal = $gini;
} }
} }
return $bestSplit; return $bestSplit;
} }
@ -239,10 +248,10 @@ class DecisionTree implements Classifier
* *
* @return array * @return array
*/ */
protected function getSelectedFeatures() protected function getSelectedFeatures() : array
{ {
$allFeatures = range(0, $this->featureCount - 1); $allFeatures = range(0, $this->featureCount - 1);
if ($this->numUsableFeatures == 0 && ! $this->selectedFeatures) { if ($this->numUsableFeatures === 0 && ! $this->selectedFeatures) {
return $allFeatures; return $allFeatures;
} }
@ -262,11 +271,12 @@ class DecisionTree implements Classifier
} }
/** /**
* @param string $baseValue * @param $baseValue
* @param array $colValues * @param array $colValues
* @param array $targets * @param array $targets
* @return float
*/ */
public function getGiniIndex($baseValue, $colValues, $targets) public function getGiniIndex($baseValue, array $colValues, array $targets) : float
{ {
$countMatrix = []; $countMatrix = [];
foreach ($this->labels as $label) { foreach ($this->labels as $label) {
@ -274,7 +284,7 @@ class DecisionTree implements Classifier
} }
foreach ($colValues as $index => $value) { foreach ($colValues as $index => $value) {
$label = $targets[$index]; $label = $targets[$index];
$rowIndex = $value == $baseValue ? 0 : 1; $rowIndex = $value === $baseValue ? 0 : 1;
$countMatrix[$label][$rowIndex]++; $countMatrix[$label][$rowIndex]++;
} }
$giniParts = [0, 0]; $giniParts = [0, 0];
@ -288,6 +298,7 @@ class DecisionTree implements Classifier
} }
$giniParts[$i] = (1 - $part) * $sum; $giniParts[$i] = (1 - $part) * $sum;
} }
return array_sum($giniParts) / count($colValues); return array_sum($giniParts) / count($colValues);
} }
@ -295,14 +306,14 @@ class DecisionTree implements Classifier
* @param array $samples * @param array $samples
* @return array * @return array
*/ */
protected function preprocess(array $samples) protected function preprocess(array $samples) : array
{ {
// Detect and convert continuous data column values into // Detect and convert continuous data column values into
// discrete values by using the median as a threshold value // discrete values by using the median as a threshold value
$columns = []; $columns = [];
for ($i=0; $i<$this->featureCount; $i++) { for ($i=0; $i<$this->featureCount; $i++) {
$values = array_column($samples, $i); $values = array_column($samples, $i);
if ($this->columnTypes[$i] == self::CONTINUOS) { if ($this->columnTypes[$i] == self::CONTINUOUS) {
$median = Mean::median($values); $median = Mean::median($values);
foreach ($values as &$value) { foreach ($values as &$value) {
if ($value <= $median) { if ($value <= $median) {
@ -323,7 +334,7 @@ class DecisionTree implements Classifier
* @param array $columnValues * @param array $columnValues
* @return bool * @return bool
*/ */
protected static function isCategoricalColumn(array $columnValues) protected static function isCategoricalColumn(array $columnValues) : bool
{ {
$count = count($columnValues); $count = count($columnValues);
@ -337,15 +348,13 @@ class DecisionTree implements Classifier
if ($floatValues) { if ($floatValues) {
return false; return false;
} }
if (count($numericValues) != $count) { if (count($numericValues) !== $count) {
return true; return true;
} }
$distinctValues = array_count_values($columnValues); $distinctValues = array_count_values($columnValues);
if (count($distinctValues) <= $count / 5) {
return true; return count($distinctValues) <= $count / 5;
}
return false;
} }
/** /**
@ -357,12 +366,12 @@ class DecisionTree implements Classifier
* *
* @param int $numFeatures * @param int $numFeatures
* @return $this * @return $this
* @throws Exception * @throws InvalidArgumentException
*/ */
public function setNumFeatures(int $numFeatures) public function setNumFeatures(int $numFeatures)
{ {
if ($numFeatures < 0) { if ($numFeatures < 0) {
throw new \Exception("Selected column count should be greater or equal to zero"); throw new InvalidArgumentException('Selected column count should be greater or equal to zero');
} }
$this->numUsableFeatures = $numFeatures; $this->numUsableFeatures = $numFeatures;
@ -386,11 +395,12 @@ class DecisionTree implements Classifier
* *
* @param array $names * @param array $names
* @return $this * @return $this
* @throws InvalidArgumentException
*/ */
public function setColumnNames(array $names) public function setColumnNames(array $names)
{ {
if ($this->featureCount != 0 && count($names) != $this->featureCount) { if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {
throw new \Exception("Length of the given array should be equal to feature count ($this->featureCount)"); throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));
} }
$this->columnNames = $names; $this->columnNames = $names;
@ -411,7 +421,6 @@ class DecisionTree implements Classifier
* each column in the given dataset. The importance values are * each column in the given dataset. The importance values are
* normalized and their total makes 1.<br/> * normalized and their total makes 1.<br/>
* *
* @param array $labels
* @return array * @return array
*/ */
public function getFeatureImportances() public function getFeatureImportances()
@ -447,22 +456,20 @@ class DecisionTree implements Classifier
/** /**
* Collects and returns an array of internal nodes that use the given * Collects and returns an array of internal nodes that use the given
* column as a split criteron * column as a split criterion
* *
* @param int $column * @param int $column
* @param DecisionTreeLeaf * @param DecisionTreeLeaf $node
* @param array $collected
*
* @return array * @return array
*/ */
protected function getSplitNodesByColumn($column, DecisionTreeLeaf $node) protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node) : array
{ {
if (!$node || $node->isTerminal) { if (!$node || $node->isTerminal) {
return []; return [];
} }
$nodes = []; $nodes = [];
if ($node->columnIndex == $column) { if ($node->columnIndex === $column) {
$nodes[] = $node; $nodes[] = $node;
} }

View File

@ -135,7 +135,7 @@ class DecisionStump extends WeightedClassifier
'prob' => [], 'column' => 0, 'prob' => [], 'column' => 0,
'trainingErrorRate' => 1.0]; 'trainingErrorRate' => 1.0];
foreach ($columns as $col) { foreach ($columns as $col) {
if ($this->columnTypes[$col] == DecisionTree::CONTINUOS) { if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
$split = $this->getBestNumericalSplit($col); $split = $this->getBestNumericalSplit($col);
} else { } else {
$split = $this->getBestNominalSplit($col); $split = $this->getBestNominalSplit($col);

View File

@ -6,7 +6,10 @@ namespace Phpml\Classification;
abstract class WeightedClassifier implements Classifier abstract class WeightedClassifier implements Classifier
{ {
protected $weights = null; /**
* @var array
*/
protected $weights;
/** /**
* Sets the array including a weight for each sample * Sets the array including a weight for each sample