Implement DecisionTreeRegressor (#375)

This commit is contained in:
Arkadiusz Kondas 2019-05-12 20:04:39 +02:00 committed by GitHub
parent 8544cf7083
commit 91812f4c4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 759 additions and 0 deletions

View File

@ -0,0 +1,144 @@
<?php
declare(strict_types=1);
namespace Phpml\Regression;
use Phpml\Exception\InvalidOperationException;
use Phpml\Math\Statistic\Mean;
use Phpml\Math\Statistic\Variance;
use Phpml\Tree\CART;
use Phpml\Tree\Node\AverageNode;
use Phpml\Tree\Node\BinaryNode;
use Phpml\Tree\Node\DecisionNode;
final class DecisionTreeRegressor extends CART implements Regression
{
/**
* @var int|null
*/
protected $maxFeatures;
/**
* @var float
*/
protected $tolerance;
/**
* @var array
*/
protected $columns = [];
public function train(array $samples, array $targets): void
{
$features = count($samples[0]);
$this->columns = range(0, $features - 1);
$this->maxFeatures = $this->maxFeatures ?? (int) round(sqrt($features));
$this->grow($samples, $targets);
$this->columns = [];
}
public function predict(array $samples)
{
if ($this->bare()) {
throw new InvalidOperationException('Regressor must be trained first');
}
$predictions = [];
foreach ($samples as $sample) {
$node = $this->search($sample);
$predictions[] = $node instanceof AverageNode
? $node->outcome()
: null;
}
return $predictions;
}
protected function split(array $samples, array $targets): DecisionNode
{
$bestVariance = INF;
$bestColumn = $bestValue = null;
$bestGroups = [];
shuffle($this->columns);
foreach (array_slice($this->columns, 0, $this->maxFeatures) as $column) {
$values = array_unique(array_column($samples, $column));
foreach ($values as $value) {
$groups = $this->partition($column, $value, $samples, $targets);
$variance = $this->splitImpurity($groups);
if ($variance < $bestVariance) {
$bestColumn = $column;
$bestValue = $value;
$bestGroups = $groups;
$bestVariance = $variance;
}
if ($variance <= $this->tolerance) {
break 2;
}
}
}
return new DecisionNode($bestColumn, $bestValue, $bestGroups, $bestVariance);
}
protected function terminate(array $targets): BinaryNode
{
return new AverageNode(Mean::arithmetic($targets), Variance::population($targets), count($targets));
}
protected function splitImpurity(array $groups): float
{
$samplesCount = (int) array_sum(array_map(static function (array $group) {
return count($group[0]);
}, $groups));
$impurity = 0.;
foreach ($groups as $group) {
$k = count($group[1]);
if ($k < 2) {
continue 1;
}
$variance = Variance::population($group[1]);
$impurity += ($k / $samplesCount) * $variance;
}
return $impurity;
}
/**
* @param int|float $value
*/
private function partition(int $column, $value, array $samples, array $targets): array
{
$leftSamples = $leftTargets = $rightSamples = $rightTargets = [];
foreach ($samples as $index => $sample) {
if ($sample[$column] < $value) {
$leftSamples[] = $sample;
$leftTargets[] = $targets[$index];
} else {
$rightSamples[] = $sample;
$rightTargets[] = $targets[$index];
}
}
return [
[$leftSamples, $leftTargets],
[$rightSamples, $rightTargets],
];
}
}

176
src/Tree/CART.php Normal file
View File

@ -0,0 +1,176 @@
<?php
declare(strict_types=1);
namespace Phpml\Tree;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Tree\Node\BinaryNode;
use Phpml\Tree\Node\DecisionNode;
use Phpml\Tree\Node\LeafNode;
abstract class CART
{
/**
* @var DecisionNode|null
*/
protected $root;
/**
* @var int
*/
protected $maxDepth;
/**
* @var int
*/
protected $maxLeafSize;
/**
* @var float
*/
protected $minPurityIncrease;
/**
* @var int
*/
protected $featureCount;
public function __construct(int $maxDepth = PHP_INT_MAX, int $maxLeafSize = 3, float $minPurityIncrease = 0.)
{
if ($maxDepth < 1) {
throw new InvalidArgumentException('Max depth must be greater than 0');
}
if ($maxLeafSize < 1) {
throw new InvalidArgumentException('Max leaf size must be greater than 0');
}
if ($minPurityIncrease < 0.) {
throw new InvalidArgumentException('Min purity increase must be equal or greater than 0');
}
$this->maxDepth = $maxDepth;
$this->maxLeafSize = $maxLeafSize;
$this->minPurityIncrease = $minPurityIncrease;
}
public function root(): ?DecisionNode
{
return $this->root;
}
public function height(): int
{
return $this->root !== null ? $this->root->height() : 0;
}
public function balance(): int
{
return $this->root !== null ? $this->root->balance() : 0;
}
public function bare(): bool
{
return $this->root === null;
}
public function grow(array $samples, array $targets): void
{
$this->featureCount = count($samples[0]);
$depth = 1;
$this->root = $this->split($samples, $targets);
$stack = [[$this->root, $depth]];
while ($stack) {
[$current, $depth] = array_pop($stack) ?? [];
[$left, $right] = $current->groups();
$current->cleanup();
$depth++;
if ($left === [] || $right === []) {
$node = $this->terminate(array_merge($left[1], $right[1]));
$current->attachLeft($node);
$current->attachRight($node);
continue 1;
}
if ($depth >= $this->maxDepth) {
$current->attachLeft($this->terminate($left[1]));
$current->attachRight($this->terminate($right[1]));
continue 1;
}
if (count($left[1]) > $this->maxLeafSize) {
$node = $this->split($left[0], $left[1]);
if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) {
$current->attachLeft($node);
$stack[] = [$node, $depth];
} else {
$current->attachLeft($this->terminate($left[1]));
}
} else {
$current->attachLeft($this->terminate($left[1]));
}
if (count($right[1]) > $this->maxLeafSize) {
$node = $this->split($right[0], $right[1]);
if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) {
$current->attachRight($node);
$stack[] = [$node, $depth];
} else {
$current->attachRight($this->terminate($right[1]));
}
} else {
$current->attachRight($this->terminate($right[1]));
}
}
}
public function search(array $sample): ?BinaryNode
{
$current = $this->root;
while ($current) {
if ($current instanceof DecisionNode) {
$value = $current->value();
if (is_string($value)) {
if ($sample[$current->column()] === $value) {
$current = $current->left();
} else {
$current = $current->right();
}
} else {
if ($sample[$current->column()] < $value) {
$current = $current->left();
} else {
$current = $current->right();
}
}
continue 1;
}
if ($current instanceof LeafNode) {
break 1;
}
}
return $current;
}
abstract protected function split(array $samples, array $targets): DecisionNode;
abstract protected function terminate(array $targets): BinaryNode;
}

9
src/Tree/Node.php Normal file
View File

@ -0,0 +1,9 @@
<?php
declare(strict_types=1);
namespace Phpml\Tree;
interface Node
{
}

View File

@ -0,0 +1,45 @@
<?php
declare(strict_types=1);
namespace Phpml\Tree\Node;
class AverageNode extends BinaryNode implements PurityNode, LeafNode
{
/**
* @var float
*/
private $outcome;
/**
* @var float
*/
private $impurity;
/**
* @var int
*/
private $samplesCount;
public function __construct(float $outcome, float $impurity, int $samplesCount)
{
$this->outcome = $outcome;
$this->impurity = $impurity;
$this->samplesCount = $samplesCount;
}
public function outcome(): float
{
return $this->outcome;
}
public function impurity(): float
{
return $this->impurity;
}
public function samplesCount(): int
{
return $this->samplesCount;
}
}

View File

@ -0,0 +1,83 @@
<?php
declare(strict_types=1);
namespace Phpml\Tree\Node;
use Phpml\Tree\Node;
class BinaryNode implements Node
{
/**
* @var self|null
*/
private $parent;
/**
* @var self|null
*/
private $left;
/**
* @var self|null
*/
private $right;
public function parent(): ?self
{
return $this->parent;
}
public function left(): ?self
{
return $this->left;
}
public function right(): ?self
{
return $this->right;
}
public function height(): int
{
return 1 + max($this->left !== null ? $this->left->height() : 0, $this->right !== null ? $this->right->height() : 0);
}
public function balance(): int
{
return ($this->right !== null ? $this->right->height() : 0) - ($this->left !== null ? $this->left->height() : 0);
}
public function setParent(?self $node = null): void
{
$this->parent = $node;
}
public function attachLeft(self $node): void
{
$node->setParent($this);
$this->left = $node;
}
public function detachLeft(): void
{
if ($this->left !== null) {
$this->left->setParent();
$this->left = null;
}
}
public function attachRight(self $node): void
{
$node->setParent($this);
$this->right = $node;
}
public function detachRight(): void
{
if ($this->right !== null) {
$this->right->setParent();
$this->right = null;
}
}
}

View File

@ -0,0 +1,107 @@
<?php
declare(strict_types=1);
namespace Phpml\Tree\Node;
use Phpml\Exception\InvalidArgumentException;
class DecisionNode extends BinaryNode implements PurityNode
{
/**
* @var int
*/
private $column;
/**
* @var mixed
*/
private $value;
/**
* @var array
*/
private $groups = [];
/**
* @var float
*/
private $impurity;
/**
* @var int
*/
private $samplesCount;
/**
* @param mixed $value
*/
public function __construct(int $column, $value, array $groups, float $impurity)
{
if (count($groups) !== 2) {
throw new InvalidArgumentException('The number of groups must be exactly two.');
}
if ($impurity < 0.) {
throw new InvalidArgumentException('Impurity cannot be less than 0.');
}
$this->column = $column;
$this->value = $value;
$this->groups = $groups;
$this->impurity = $impurity;
$this->samplesCount = (int) array_sum(array_map(function (array $group) {
return count($group[0]);
}, $groups));
}
public function column(): int
{
return $this->column;
}
/**
* @return mixed
*/
public function value()
{
return $this->value;
}
public function groups(): array
{
return $this->groups;
}
public function impurity(): float
{
return $this->impurity;
}
public function samplesCount(): int
{
return $this->samplesCount;
}
public function purityIncrease(): float
{
$impurity = $this->impurity;
if ($this->left() instanceof PurityNode) {
$impurity -= $this->left()->impurity()
* ($this->left()->samplesCount() / $this->samplesCount);
}
if ($this->right() instanceof PurityNode) {
$impurity -= $this->right()->impurity()
* ($this->right()->samplesCount() / $this->samplesCount);
}
return $impurity;
}
public function cleanup(): void
{
$this->groups = [[], []];
}
}

View File

@ -0,0 +1,9 @@
<?php
declare(strict_types=1);
namespace Phpml\Tree\Node;
interface LeafNode
{
}

View File

@ -0,0 +1,14 @@
<?php
declare(strict_types=1);
namespace Phpml\Tree\Node;
use Phpml\Tree\Node;
interface PurityNode extends Node
{
public function impurity(): float;
public function samplesCount(): int;
}

View File

@ -0,0 +1,68 @@
<?php
declare(strict_types=1);
namespace Phpml\Tests\Regression;
use Phpml\Exception\InvalidOperationException;
use Phpml\ModelManager;
use Phpml\Regression\DecisionTreeRegressor;
use PHPUnit\Framework\TestCase;
class DecisionTreeRegressorTest extends TestCase
{
public function testPredictSingleFeatureSamples(): void
{
$delta = 0.01;
$samples = [[60], [61], [62], [63], [65]];
$targets = [3.1, 3.6, 3.8, 4, 4.1];
$regression = new DecisionTreeRegressor(4);
$regression->train($samples, $targets);
self::assertEqualsWithDelta([4.05], $regression->predict([[64]]), $delta);
$samples = [[9300], [10565], [15000], [15000], [17764], [57000], [65940], [73676], [77006], [93739], [146088], [153260]];
$targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025];
$regression = new DecisionTreeRegressor();
$regression->train($samples, $targets);
self::assertEqualsWithDelta([11300.0], $regression->predict([[9300]]), $delta);
self::assertEqualsWithDelta([5250.0], $regression->predict([[57000]]), $delta);
self::assertEqualsWithDelta([2433.33], $regression->predict([[77006]]), $delta);
self::assertEqualsWithDelta([11300.0], $regression->predict([[9300]]), $delta);
self::assertEqualsWithDelta([992.5], $regression->predict([[153260]]), $delta);
}
public function testPreventPredictWhenNotTrained(): void
{
$regression = new DecisionTreeRegressor();
$this->expectException(InvalidOperationException::class);
$regression->predict([[1]]);
}
public function testSaveAndRestore(): void
{
$samples = [[60], [61], [62], [63], [65]];
$targets = [3.1, 3.6, 3.8, 4, 4.1];
$regression = new DecisionTreeRegressor(4);
$regression->train($samples, $targets);
$testSamples = [[9300], [10565], [15000]];
$predicted = $regression->predict($testSamples);
$filename = 'least-squares-test-'.random_int(100, 999).'-'.uniqid('', false);
$filepath = (string) tempnam(sys_get_temp_dir(), $filename);
$modelManager = new ModelManager();
$modelManager->saveToFile($regression, $filepath);
$restoredRegression = $modelManager->restoreFromFile($filepath);
self::assertEquals($regression, $restoredRegression);
self::assertEquals($predicted, $restoredRegression->predict($testSamples));
}
}

View File

@ -0,0 +1,47 @@
<?php
declare(strict_types=1);
namespace Phpml\Tests\Tree\Node;
use Phpml\Tree\Node\BinaryNode;
use PHPUnit\Framework\TestCase;
final class BinaryNodeTest extends TestCase
{
public function testSimpleNode(): void
{
$node = new BinaryNode();
self::assertEquals(1, $node->height());
self::assertEquals(0, $node->balance());
}
public function testAttachDetachLeft(): void
{
$node = new BinaryNode();
$node->attachLeft(new BinaryNode());
self::assertEquals(2, $node->height());
self::assertEquals(-1, $node->balance());
$node->detachLeft();
self::assertEquals(1, $node->height());
self::assertEquals(0, $node->balance());
}
public function testAttachDetachRight(): void
{
$node = new BinaryNode();
$node->attachRight(new BinaryNode());
self::assertEquals(2, $node->height());
self::assertEquals(1, $node->balance());
$node->detachRight();
self::assertEquals(1, $node->height());
self::assertEquals(0, $node->balance());
}
}

View File

@ -0,0 +1,57 @@
<?php
declare(strict_types=1);
namespace Phpml\Tests\Tree\Node;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Tree\Node\DecisionNode;
use PHPUnit\Framework\TestCase;
final class DecisionNodeTest extends TestCase
{
public function testSimpleNode(): void
{
$node = new DecisionNode(2, 4, [
[[[1, 2, 3]], [1]],
[[[2, 3, 4]], [2]],
], 400);
self::assertEquals(2, $node->column());
self::assertEquals(2, $node->samplesCount());
}
public function testImpurityIncrease(): void
{
$node = new DecisionNode(2, 4, [
[[[1, 2, 3]], [1]],
[[[2, 3, 4]], [2]],
], 400);
$node->attachRight(new DecisionNode(2, 4, [
[[[1, 2, 3]], [1]],
[[[2, 3, 4]], [2]],
], 200));
$node->attachLeft(new DecisionNode(2, 4, [
[[[1, 2, 3]], [1]],
[[[2, 3, 4]], [2]],
], 100));
self::assertEquals(100, $node->purityIncrease());
}
public function testThrowExceptionOnInvalidGroupsCount(): void
{
$this->expectException(InvalidArgumentException::class);
new DecisionNode(2, 3, [], 200);
}
public function testThrowExceptionOnInvalidImpurity(): void
{
$this->expectException(InvalidArgumentException::class);
new DecisionNode(2, 3, [[], []], -2);
}
}