diff --git a/src/Regression/DecisionTreeRegressor.php b/src/Regression/DecisionTreeRegressor.php new file mode 100644 index 0000000..6260a03 --- /dev/null +++ b/src/Regression/DecisionTreeRegressor.php @@ -0,0 +1,144 @@ +columns = range(0, $features - 1); + $this->maxFeatures = $this->maxFeatures ?? (int) round(sqrt($features)); + + $this->grow($samples, $targets); + + $this->columns = []; + } + + public function predict(array $samples) + { + if ($this->bare()) { + throw new InvalidOperationException('Regressor must be trained first'); + } + + $predictions = []; + + foreach ($samples as $sample) { + $node = $this->search($sample); + + $predictions[] = $node instanceof AverageNode + ? $node->outcome() + : null; + } + + return $predictions; + } + + protected function split(array $samples, array $targets): DecisionNode + { + $bestVariance = INF; + $bestColumn = $bestValue = null; + $bestGroups = []; + + shuffle($this->columns); + + foreach (array_slice($this->columns, 0, $this->maxFeatures) as $column) { + $values = array_unique(array_column($samples, $column)); + + foreach ($values as $value) { + $groups = $this->partition($column, $value, $samples, $targets); + + $variance = $this->splitImpurity($groups); + + if ($variance < $bestVariance) { + $bestColumn = $column; + $bestValue = $value; + $bestGroups = $groups; + $bestVariance = $variance; + } + + if ($variance <= $this->tolerance) { + break 2; + } + } + } + + return new DecisionNode($bestColumn, $bestValue, $bestGroups, $bestVariance); + } + + protected function terminate(array $targets): BinaryNode + { + return new AverageNode(Mean::arithmetic($targets), Variance::population($targets), count($targets)); + } + + protected function splitImpurity(array $groups): float + { + $samplesCount = (int) array_sum(array_map(static function (array $group) { + return count($group[0]); + }, $groups)); + + $impurity = 0.; + + foreach ($groups as $group) { + $k = count($group[1]); + + if ($k < 2) { + continue 1; + } + + $variance = Variance::population($group[1]); + + $impurity += ($k / $samplesCount) * $variance; + } + + return $impurity; + } + + /** + * @param int|float $value + */ + private function partition(int $column, $value, array $samples, array $targets): array + { + $leftSamples = $leftTargets = $rightSamples = $rightTargets = []; + foreach ($samples as $index => $sample) { + if ($sample[$column] < $value) { + $leftSamples[] = $sample; + $leftTargets[] = $targets[$index]; + } else { + $rightSamples[] = $sample; + $rightTargets[] = $targets[$index]; + } + } + + return [ + [$leftSamples, $leftTargets], + [$rightSamples, $rightTargets], + ]; + } +} diff --git a/src/Tree/CART.php b/src/Tree/CART.php new file mode 100644 index 0000000..5ed1504 --- /dev/null +++ b/src/Tree/CART.php @@ -0,0 +1,176 @@ +maxDepth = $maxDepth; + $this->maxLeafSize = $maxLeafSize; + $this->minPurityIncrease = $minPurityIncrease; + } + + public function root(): ?DecisionNode + { + return $this->root; + } + + public function height(): int + { + return $this->root !== null ? $this->root->height() : 0; + } + + public function balance(): int + { + return $this->root !== null ? $this->root->balance() : 0; + } + + public function bare(): bool + { + return $this->root === null; + } + + public function grow(array $samples, array $targets): void + { + $this->featureCount = count($samples[0]); + $depth = 1; + $this->root = $this->split($samples, $targets); + $stack = [[$this->root, $depth]]; + + while ($stack) { + [$current, $depth] = array_pop($stack) ?? []; + + [$left, $right] = $current->groups(); + + $current->cleanup(); + + $depth++; + + if ($left === [] || $right === []) { + $node = $this->terminate(array_merge($left[1], $right[1])); + + $current->attachLeft($node); + $current->attachRight($node); + + continue 1; + } + + if ($depth >= $this->maxDepth) { + $current->attachLeft($this->terminate($left[1])); + $current->attachRight($this->terminate($right[1])); + + continue 1; + } + + if (count($left[1]) > $this->maxLeafSize) { + $node = $this->split($left[0], $left[1]); + + if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) { + $current->attachLeft($node); + + $stack[] = [$node, $depth]; + } else { + $current->attachLeft($this->terminate($left[1])); + } + } else { + $current->attachLeft($this->terminate($left[1])); + } + + if (count($right[1]) > $this->maxLeafSize) { + $node = $this->split($right[0], $right[1]); + + if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) { + $current->attachRight($node); + + $stack[] = [$node, $depth]; + } else { + $current->attachRight($this->terminate($right[1])); + } + } else { + $current->attachRight($this->terminate($right[1])); + } + } + } + + public function search(array $sample): ?BinaryNode + { + $current = $this->root; + + while ($current) { + if ($current instanceof DecisionNode) { + $value = $current->value(); + + if (is_string($value)) { + if ($sample[$current->column()] === $value) { + $current = $current->left(); + } else { + $current = $current->right(); + } + } else { + if ($sample[$current->column()] < $value) { + $current = $current->left(); + } else { + $current = $current->right(); + } + } + + continue 1; + } + + if ($current instanceof LeafNode) { + break 1; + } + } + + return $current; + } + + abstract protected function split(array $samples, array $targets): DecisionNode; + + abstract protected function terminate(array $targets): BinaryNode; +} diff --git a/src/Tree/Node.php b/src/Tree/Node.php new file mode 100644 index 0000000..3176b62 --- /dev/null +++ b/src/Tree/Node.php @@ -0,0 +1,9 @@ +outcome = $outcome; + $this->impurity = $impurity; + $this->samplesCount = $samplesCount; + } + + public function outcome(): float + { + return $this->outcome; + } + + public function impurity(): float + { + return $this->impurity; + } + + public function samplesCount(): int + { + return $this->samplesCount; + } +} diff --git a/src/Tree/Node/BinaryNode.php b/src/Tree/Node/BinaryNode.php new file mode 100644 index 0000000..c6797b5 --- /dev/null +++ b/src/Tree/Node/BinaryNode.php @@ -0,0 +1,83 @@ +parent; + } + + public function left(): ?self + { + return $this->left; + } + + public function right(): ?self + { + return $this->right; + } + + public function height(): int + { + return 1 + max($this->left !== null ? $this->left->height() : 0, $this->right !== null ? $this->right->height() : 0); + } + + public function balance(): int + { + return ($this->right !== null ? $this->right->height() : 0) - ($this->left !== null ? $this->left->height() : 0); + } + + public function setParent(?self $node = null): void + { + $this->parent = $node; + } + + public function attachLeft(self $node): void + { + $node->setParent($this); + $this->left = $node; + } + + public function detachLeft(): void + { + if ($this->left !== null) { + $this->left->setParent(); + $this->left = null; + } + } + + public function attachRight(self $node): void + { + $node->setParent($this); + $this->right = $node; + } + + public function detachRight(): void + { + if ($this->right !== null) { + $this->right->setParent(); + $this->right = null; + } + } +} diff --git a/src/Tree/Node/DecisionNode.php b/src/Tree/Node/DecisionNode.php new file mode 100644 index 0000000..f621fed --- /dev/null +++ b/src/Tree/Node/DecisionNode.php @@ -0,0 +1,107 @@ +column = $column; + $this->value = $value; + $this->groups = $groups; + $this->impurity = $impurity; + $this->samplesCount = (int) array_sum(array_map(function (array $group) { + return count($group[0]); + }, $groups)); + } + + public function column(): int + { + return $this->column; + } + + /** + * @return mixed + */ + public function value() + { + return $this->value; + } + + public function groups(): array + { + return $this->groups; + } + + public function impurity(): float + { + return $this->impurity; + } + + public function samplesCount(): int + { + return $this->samplesCount; + } + + public function purityIncrease(): float + { + $impurity = $this->impurity; + + if ($this->left() instanceof PurityNode) { + $impurity -= $this->left()->impurity() + * ($this->left()->samplesCount() / $this->samplesCount); + } + + if ($this->right() instanceof PurityNode) { + $impurity -= $this->right()->impurity() + * ($this->right()->samplesCount() / $this->samplesCount); + } + + return $impurity; + } + + public function cleanup(): void + { + $this->groups = [[], []]; + } +} diff --git a/src/Tree/Node/LeafNode.php b/src/Tree/Node/LeafNode.php new file mode 100644 index 0000000..ebb848e --- /dev/null +++ b/src/Tree/Node/LeafNode.php @@ -0,0 +1,9 @@ +train($samples, $targets); + + self::assertEqualsWithDelta([4.05], $regression->predict([[64]]), $delta); + + $samples = [[9300], [10565], [15000], [15000], [17764], [57000], [65940], [73676], [77006], [93739], [146088], [153260]]; + $targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025]; + + $regression = new DecisionTreeRegressor(); + $regression->train($samples, $targets); + + self::assertEqualsWithDelta([11300.0], $regression->predict([[9300]]), $delta); + self::assertEqualsWithDelta([5250.0], $regression->predict([[57000]]), $delta); + self::assertEqualsWithDelta([2433.33], $regression->predict([[77006]]), $delta); + self::assertEqualsWithDelta([11300.0], $regression->predict([[9300]]), $delta); + self::assertEqualsWithDelta([992.5], $regression->predict([[153260]]), $delta); + } + + public function testPreventPredictWhenNotTrained(): void + { + $regression = new DecisionTreeRegressor(); + + $this->expectException(InvalidOperationException::class); + + $regression->predict([[1]]); + } + + public function testSaveAndRestore(): void + { + $samples = [[60], [61], [62], [63], [65]]; + $targets = [3.1, 3.6, 3.8, 4, 4.1]; + + $regression = new DecisionTreeRegressor(4); + $regression->train($samples, $targets); + + $testSamples = [[9300], [10565], [15000]]; + $predicted = $regression->predict($testSamples); + + $filename = 'least-squares-test-'.random_int(100, 999).'-'.uniqid('', false); + $filepath = (string) tempnam(sys_get_temp_dir(), $filename); + $modelManager = new ModelManager(); + $modelManager->saveToFile($regression, $filepath); + + $restoredRegression = $modelManager->restoreFromFile($filepath); + self::assertEquals($regression, $restoredRegression); + self::assertEquals($predicted, $restoredRegression->predict($testSamples)); + } +} diff --git a/tests/Tree/Node/BinaryNodeTest.php b/tests/Tree/Node/BinaryNodeTest.php new file mode 100644 index 0000000..43db418 --- /dev/null +++ b/tests/Tree/Node/BinaryNodeTest.php @@ -0,0 +1,47 @@ +height()); + self::assertEquals(0, $node->balance()); + } + + public function testAttachDetachLeft(): void + { + $node = new BinaryNode(); + $node->attachLeft(new BinaryNode()); + + self::assertEquals(2, $node->height()); + self::assertEquals(-1, $node->balance()); + + $node->detachLeft(); + + self::assertEquals(1, $node->height()); + self::assertEquals(0, $node->balance()); + } + + public function testAttachDetachRight(): void + { + $node = new BinaryNode(); + $node->attachRight(new BinaryNode()); + + self::assertEquals(2, $node->height()); + self::assertEquals(1, $node->balance()); + + $node->detachRight(); + + self::assertEquals(1, $node->height()); + self::assertEquals(0, $node->balance()); + } +} diff --git a/tests/Tree/Node/DecisionNodeTest.php b/tests/Tree/Node/DecisionNodeTest.php new file mode 100644 index 0000000..2db3482 --- /dev/null +++ b/tests/Tree/Node/DecisionNodeTest.php @@ -0,0 +1,57 @@ +column()); + self::assertEquals(2, $node->samplesCount()); + } + + public function testImpurityIncrease(): void + { + $node = new DecisionNode(2, 4, [ + [[[1, 2, 3]], [1]], + [[[2, 3, 4]], [2]], + ], 400); + + $node->attachRight(new DecisionNode(2, 4, [ + [[[1, 2, 3]], [1]], + [[[2, 3, 4]], [2]], + ], 200)); + + $node->attachLeft(new DecisionNode(2, 4, [ + [[[1, 2, 3]], [1]], + [[[2, 3, 4]], [2]], + ], 100)); + + self::assertEquals(100, $node->purityIncrease()); + } + + public function testThrowExceptionOnInvalidGroupsCount(): void + { + $this->expectException(InvalidArgumentException::class); + + new DecisionNode(2, 3, [], 200); + } + + public function testThrowExceptionOnInvalidImpurity(): void + { + $this->expectException(InvalidArgumentException::class); + + new DecisionNode(2, 3, [[], []], -2); + } +}