columns = range(0, $features - 1); $this->maxFeatures = $this->maxFeatures ?? (int) round(sqrt($features)); $this->grow($samples, $targets); $this->columns = []; } public function predict(array $samples) { if ($this->bare()) { throw new InvalidOperationException('Regressor must be trained first'); } $predictions = []; foreach ($samples as $sample) { $node = $this->search($sample); $predictions[] = $node instanceof AverageNode ? $node->outcome() : null; } return $predictions; } protected function split(array $samples, array $targets): DecisionNode { $bestVariance = INF; $bestColumn = $bestValue = null; $bestGroups = []; shuffle($this->columns); foreach (array_slice($this->columns, 0, $this->maxFeatures) as $column) { $values = array_unique(array_column($samples, $column)); foreach ($values as $value) { $groups = $this->partition($column, $value, $samples, $targets); $variance = $this->splitImpurity($groups); if ($variance < $bestVariance) { $bestColumn = $column; $bestValue = $value; $bestGroups = $groups; $bestVariance = $variance; } if ($variance <= $this->tolerance) { break 2; } } } return new DecisionNode($bestColumn, $bestValue, $bestGroups, $bestVariance); } protected function terminate(array $targets): BinaryNode { return new AverageNode(Mean::arithmetic($targets), Variance::population($targets), count($targets)); } protected function splitImpurity(array $groups): float { $samplesCount = (int) array_sum(array_map(static function (array $group) { return count($group[0]); }, $groups)); $impurity = 0.; foreach ($groups as $group) { $k = count($group[1]); if ($k < 2) { continue 1; } $variance = Variance::population($group[1]); $impurity += ($k / $samplesCount) * $variance; } return $impurity; } /** * @param int|float $value */ private function partition(int $column, $value, array $samples, array $targets): array { $leftSamples = $leftTargets = $rightSamples = $rightTargets = []; foreach ($samples as $index => $sample) { if ($sample[$column] < $value) { $leftSamples[] = $sample; $leftTargets[] = $targets[$index]; } else { $rightSamples[] = $sample; $rightTargets[] = $targets[$index]; } } return [ [$leftSamples, $leftTargets], [$rightSamples, $rightTargets], ]; } }