diff --git a/src/Phpml/Classification/DecisionTree.php b/src/Phpml/Classification/DecisionTree.php index 0a70d2f..6e890c9 100644 --- a/src/Phpml/Classification/DecisionTree.php +++ b/src/Phpml/Classification/DecisionTree.php @@ -4,6 +4,7 @@ declare(strict_types=1); namespace Phpml\Classification; +use Phpml\Exception\InvalidArgumentException; use Phpml\Helper\Predictable; use Phpml\Helper\Trainable; use Phpml\Math\Statistic\Mean; @@ -13,7 +14,7 @@ class DecisionTree implements Classifier { use Trainable, Predictable; - const CONTINUOS = 1; + const CONTINUOUS = 1; const NOMINAL = 2; /** @@ -70,7 +71,7 @@ class DecisionTree implements Classifier /** * @param int $maxDepth */ - public function __construct($maxDepth = 10) + public function __construct(int $maxDepth = 10) { $this->maxDepth = $maxDepth; } @@ -85,7 +86,7 @@ class DecisionTree implements Classifier $this->targets = array_merge($this->targets, $targets); $this->featureCount = count($this->samples[0]); - $this->columnTypes = $this->getColumnTypes($this->samples); + $this->columnTypes = self::getColumnTypes($this->samples); $this->labels = array_keys(array_count_values($this->targets)); $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1)); @@ -105,23 +106,29 @@ class DecisionTree implements Classifier } } - public static function getColumnTypes(array $samples) + /** + * @param array $samples + * @return array + */ + public static function getColumnTypes(array $samples) : array { $types = []; $featureCount = count($samples[0]); for ($i=0; $i < $featureCount; $i++) { $values = array_column($samples, $i); $isCategorical = self::isCategoricalColumn($values); - $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS; + $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS; } + return $types; } /** - * @param null|array $records + * @param array $records + * @param int $depth * @return DecisionTreeLeaf */ - protected function getSplitLeaf($records, $depth = 0) + protected function getSplitLeaf(array $records, int $depth = 0) : DecisionTreeLeaf { $split = $this->getBestSplit($records); $split->level = $depth; @@ -163,7 +170,7 @@ class DecisionTree implements Classifier } } - if (count($remainingTargets) == 1 || $allSame || $depth >= $this->maxDepth) { + if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) { $split->isTerminal = 1; arsort($remainingTargets); $split->classValue = key($remainingTargets); @@ -175,14 +182,15 @@ class DecisionTree implements Classifier $split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1); } } + return $split; } /** * @param array $records - * @return DecisionTreeLeaf[] + * @return DecisionTreeLeaf */ - protected function getBestSplit($records) + protected function getBestSplit(array $records) : DecisionTreeLeaf { $targets = array_intersect_key($this->targets, array_flip($records)); $samples = array_intersect_key($this->samples, array_flip($records)); @@ -199,18 +207,18 @@ class DecisionTree implements Classifier arsort($counts); $baseValue = key($counts); $gini = $this->getGiniIndex($baseValue, $colValues, $targets); - if ($bestSplit == null || $bestGiniVal > $gini) { + if ($bestSplit === null || $bestGiniVal > $gini) { $split = new DecisionTreeLeaf(); $split->value = $baseValue; $split->giniIndex = $gini; $split->columnIndex = $i; - $split->isContinuous = $this->columnTypes[$i] == self::CONTINUOS; + $split->isContinuous = $this->columnTypes[$i] == self::CONTINUOUS; $split->records = $records; // If a numeric column is to be selected, then // the original numeric value and the selected operator // will also be saved into the leaf for future access - if ($this->columnTypes[$i] == self::CONTINUOS) { + if ($this->columnTypes[$i] == self::CONTINUOUS) { $matches = []; preg_match("/^([<>=]{1,2})\s*(.*)/", strval($split->value), $matches); $split->operator = $matches[1]; @@ -221,6 +229,7 @@ class DecisionTree implements Classifier $bestGiniVal = $gini; } } + return $bestSplit; } @@ -239,10 +248,10 @@ class DecisionTree implements Classifier * * @return array */ - protected function getSelectedFeatures() + protected function getSelectedFeatures() : array { $allFeatures = range(0, $this->featureCount - 1); - if ($this->numUsableFeatures == 0 && ! $this->selectedFeatures) { + if ($this->numUsableFeatures === 0 && ! $this->selectedFeatures) { return $allFeatures; } @@ -262,11 +271,12 @@ class DecisionTree implements Classifier } /** - * @param string $baseValue + * @param $baseValue * @param array $colValues * @param array $targets + * @return float */ - public function getGiniIndex($baseValue, $colValues, $targets) + public function getGiniIndex($baseValue, array $colValues, array $targets) : float { $countMatrix = []; foreach ($this->labels as $label) { @@ -274,7 +284,7 @@ class DecisionTree implements Classifier } foreach ($colValues as $index => $value) { $label = $targets[$index]; - $rowIndex = $value == $baseValue ? 0 : 1; + $rowIndex = $value === $baseValue ? 0 : 1; $countMatrix[$label][$rowIndex]++; } $giniParts = [0, 0]; @@ -288,6 +298,7 @@ class DecisionTree implements Classifier } $giniParts[$i] = (1 - $part) * $sum; } + return array_sum($giniParts) / count($colValues); } @@ -295,14 +306,14 @@ class DecisionTree implements Classifier * @param array $samples * @return array */ - protected function preprocess(array $samples) + protected function preprocess(array $samples) : array { // Detect and convert continuous data column values into // discrete values by using the median as a threshold value $columns = []; for ($i=0; $i<$this->featureCount; $i++) { $values = array_column($samples, $i); - if ($this->columnTypes[$i] == self::CONTINUOS) { + if ($this->columnTypes[$i] == self::CONTINUOUS) { $median = Mean::median($values); foreach ($values as &$value) { if ($value <= $median) { @@ -323,7 +334,7 @@ class DecisionTree implements Classifier * @param array $columnValues * @return bool */ - protected static function isCategoricalColumn(array $columnValues) + protected static function isCategoricalColumn(array $columnValues) : bool { $count = count($columnValues); @@ -337,15 +348,13 @@ class DecisionTree implements Classifier if ($floatValues) { return false; } - if (count($numericValues) != $count) { + if (count($numericValues) !== $count) { return true; } $distinctValues = array_count_values($columnValues); - if (count($distinctValues) <= $count / 5) { - return true; - } - return false; + + return count($distinctValues) <= $count / 5; } /** @@ -357,12 +366,12 @@ class DecisionTree implements Classifier * * @param int $numFeatures * @return $this - * @throws Exception + * @throws InvalidArgumentException */ public function setNumFeatures(int $numFeatures) { if ($numFeatures < 0) { - throw new \Exception("Selected column count should be greater or equal to zero"); + throw new InvalidArgumentException('Selected column count should be greater or equal to zero'); } $this->numUsableFeatures = $numFeatures; @@ -386,11 +395,12 @@ class DecisionTree implements Classifier * * @param array $names * @return $this + * @throws InvalidArgumentException */ public function setColumnNames(array $names) { - if ($this->featureCount != 0 && count($names) != $this->featureCount) { - throw new \Exception("Length of the given array should be equal to feature count ($this->featureCount)"); + if ($this->featureCount !== 0 && count($names) !== $this->featureCount) { + throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount)); } $this->columnNames = $names; @@ -411,7 +421,6 @@ class DecisionTree implements Classifier * each column in the given dataset. The importance values are * normalized and their total makes 1.
* - * @param array $labels * @return array */ public function getFeatureImportances() @@ -447,22 +456,20 @@ class DecisionTree implements Classifier /** * Collects and returns an array of internal nodes that use the given - * column as a split criteron + * column as a split criterion * * @param int $column - * @param DecisionTreeLeaf - * @param array $collected - * + * @param DecisionTreeLeaf $node * @return array */ - protected function getSplitNodesByColumn($column, DecisionTreeLeaf $node) + protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node) : array { if (!$node || $node->isTerminal) { return []; } $nodes = []; - if ($node->columnIndex == $column) { + if ($node->columnIndex === $column) { $nodes[] = $node; } diff --git a/src/Phpml/Classification/Linear/DecisionStump.php b/src/Phpml/Classification/Linear/DecisionStump.php index de86fe9..8287bbc 100644 --- a/src/Phpml/Classification/Linear/DecisionStump.php +++ b/src/Phpml/Classification/Linear/DecisionStump.php @@ -135,7 +135,7 @@ class DecisionStump extends WeightedClassifier 'prob' => [], 'column' => 0, 'trainingErrorRate' => 1.0]; foreach ($columns as $col) { - if ($this->columnTypes[$col] == DecisionTree::CONTINUOS) { + if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) { $split = $this->getBestNumericalSplit($col); } else { $split = $this->getBestNominalSplit($col); diff --git a/src/Phpml/Classification/WeightedClassifier.php b/src/Phpml/Classification/WeightedClassifier.php index c0ec045..4af3de4 100644 --- a/src/Phpml/Classification/WeightedClassifier.php +++ b/src/Phpml/Classification/WeightedClassifier.php @@ -6,7 +6,10 @@ namespace Phpml\Classification; abstract class WeightedClassifier implements Classifier { - protected $weights = null; + /** + * @var array + */ + protected $weights; /** * Sets the array including a weight for each sample