mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-21 20:45:10 +00:00
Implement DecisionTreeRegressor (#375)
This commit is contained in:
parent
8544cf7083
commit
91812f4c4a
144
src/Regression/DecisionTreeRegressor.php
Normal file
144
src/Regression/DecisionTreeRegressor.php
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Regression;
|
||||||
|
|
||||||
|
use Phpml\Exception\InvalidOperationException;
|
||||||
|
use Phpml\Math\Statistic\Mean;
|
||||||
|
use Phpml\Math\Statistic\Variance;
|
||||||
|
use Phpml\Tree\CART;
|
||||||
|
use Phpml\Tree\Node\AverageNode;
|
||||||
|
use Phpml\Tree\Node\BinaryNode;
|
||||||
|
use Phpml\Tree\Node\DecisionNode;
|
||||||
|
|
||||||
|
final class DecisionTreeRegressor extends CART implements Regression
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @var int|null
|
||||||
|
*/
|
||||||
|
protected $maxFeatures;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var float
|
||||||
|
*/
|
||||||
|
protected $tolerance;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
protected $columns = [];
|
||||||
|
|
||||||
|
public function train(array $samples, array $targets): void
|
||||||
|
{
|
||||||
|
$features = count($samples[0]);
|
||||||
|
|
||||||
|
$this->columns = range(0, $features - 1);
|
||||||
|
$this->maxFeatures = $this->maxFeatures ?? (int) round(sqrt($features));
|
||||||
|
|
||||||
|
$this->grow($samples, $targets);
|
||||||
|
|
||||||
|
$this->columns = [];
|
||||||
|
}
|
||||||
|
|
||||||
|
public function predict(array $samples)
|
||||||
|
{
|
||||||
|
if ($this->bare()) {
|
||||||
|
throw new InvalidOperationException('Regressor must be trained first');
|
||||||
|
}
|
||||||
|
|
||||||
|
$predictions = [];
|
||||||
|
|
||||||
|
foreach ($samples as $sample) {
|
||||||
|
$node = $this->search($sample);
|
||||||
|
|
||||||
|
$predictions[] = $node instanceof AverageNode
|
||||||
|
? $node->outcome()
|
||||||
|
: null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return $predictions;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected function split(array $samples, array $targets): DecisionNode
|
||||||
|
{
|
||||||
|
$bestVariance = INF;
|
||||||
|
$bestColumn = $bestValue = null;
|
||||||
|
$bestGroups = [];
|
||||||
|
|
||||||
|
shuffle($this->columns);
|
||||||
|
|
||||||
|
foreach (array_slice($this->columns, 0, $this->maxFeatures) as $column) {
|
||||||
|
$values = array_unique(array_column($samples, $column));
|
||||||
|
|
||||||
|
foreach ($values as $value) {
|
||||||
|
$groups = $this->partition($column, $value, $samples, $targets);
|
||||||
|
|
||||||
|
$variance = $this->splitImpurity($groups);
|
||||||
|
|
||||||
|
if ($variance < $bestVariance) {
|
||||||
|
$bestColumn = $column;
|
||||||
|
$bestValue = $value;
|
||||||
|
$bestGroups = $groups;
|
||||||
|
$bestVariance = $variance;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ($variance <= $this->tolerance) {
|
||||||
|
break 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new DecisionNode($bestColumn, $bestValue, $bestGroups, $bestVariance);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected function terminate(array $targets): BinaryNode
|
||||||
|
{
|
||||||
|
return new AverageNode(Mean::arithmetic($targets), Variance::population($targets), count($targets));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected function splitImpurity(array $groups): float
|
||||||
|
{
|
||||||
|
$samplesCount = (int) array_sum(array_map(static function (array $group) {
|
||||||
|
return count($group[0]);
|
||||||
|
}, $groups));
|
||||||
|
|
||||||
|
$impurity = 0.;
|
||||||
|
|
||||||
|
foreach ($groups as $group) {
|
||||||
|
$k = count($group[1]);
|
||||||
|
|
||||||
|
if ($k < 2) {
|
||||||
|
continue 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
$variance = Variance::population($group[1]);
|
||||||
|
|
||||||
|
$impurity += ($k / $samplesCount) * $variance;
|
||||||
|
}
|
||||||
|
|
||||||
|
return $impurity;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param int|float $value
|
||||||
|
*/
|
||||||
|
private function partition(int $column, $value, array $samples, array $targets): array
|
||||||
|
{
|
||||||
|
$leftSamples = $leftTargets = $rightSamples = $rightTargets = [];
|
||||||
|
foreach ($samples as $index => $sample) {
|
||||||
|
if ($sample[$column] < $value) {
|
||||||
|
$leftSamples[] = $sample;
|
||||||
|
$leftTargets[] = $targets[$index];
|
||||||
|
} else {
|
||||||
|
$rightSamples[] = $sample;
|
||||||
|
$rightTargets[] = $targets[$index];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return [
|
||||||
|
[$leftSamples, $leftTargets],
|
||||||
|
[$rightSamples, $rightTargets],
|
||||||
|
];
|
||||||
|
}
|
||||||
|
}
|
176
src/Tree/CART.php
Normal file
176
src/Tree/CART.php
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Tree;
|
||||||
|
|
||||||
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
|
use Phpml\Tree\Node\BinaryNode;
|
||||||
|
use Phpml\Tree\Node\DecisionNode;
|
||||||
|
use Phpml\Tree\Node\LeafNode;
|
||||||
|
|
||||||
|
abstract class CART
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @var DecisionNode|null
|
||||||
|
*/
|
||||||
|
protected $root;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
protected $maxDepth;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
protected $maxLeafSize;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var float
|
||||||
|
*/
|
||||||
|
protected $minPurityIncrease;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
protected $featureCount;
|
||||||
|
|
||||||
|
public function __construct(int $maxDepth = PHP_INT_MAX, int $maxLeafSize = 3, float $minPurityIncrease = 0.)
|
||||||
|
{
|
||||||
|
if ($maxDepth < 1) {
|
||||||
|
throw new InvalidArgumentException('Max depth must be greater than 0');
|
||||||
|
}
|
||||||
|
|
||||||
|
if ($maxLeafSize < 1) {
|
||||||
|
throw new InvalidArgumentException('Max leaf size must be greater than 0');
|
||||||
|
}
|
||||||
|
|
||||||
|
if ($minPurityIncrease < 0.) {
|
||||||
|
throw new InvalidArgumentException('Min purity increase must be equal or greater than 0');
|
||||||
|
}
|
||||||
|
|
||||||
|
$this->maxDepth = $maxDepth;
|
||||||
|
$this->maxLeafSize = $maxLeafSize;
|
||||||
|
$this->minPurityIncrease = $minPurityIncrease;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function root(): ?DecisionNode
|
||||||
|
{
|
||||||
|
return $this->root;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function height(): int
|
||||||
|
{
|
||||||
|
return $this->root !== null ? $this->root->height() : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function balance(): int
|
||||||
|
{
|
||||||
|
return $this->root !== null ? $this->root->balance() : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function bare(): bool
|
||||||
|
{
|
||||||
|
return $this->root === null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function grow(array $samples, array $targets): void
|
||||||
|
{
|
||||||
|
$this->featureCount = count($samples[0]);
|
||||||
|
$depth = 1;
|
||||||
|
$this->root = $this->split($samples, $targets);
|
||||||
|
$stack = [[$this->root, $depth]];
|
||||||
|
|
||||||
|
while ($stack) {
|
||||||
|
[$current, $depth] = array_pop($stack) ?? [];
|
||||||
|
|
||||||
|
[$left, $right] = $current->groups();
|
||||||
|
|
||||||
|
$current->cleanup();
|
||||||
|
|
||||||
|
$depth++;
|
||||||
|
|
||||||
|
if ($left === [] || $right === []) {
|
||||||
|
$node = $this->terminate(array_merge($left[1], $right[1]));
|
||||||
|
|
||||||
|
$current->attachLeft($node);
|
||||||
|
$current->attachRight($node);
|
||||||
|
|
||||||
|
continue 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ($depth >= $this->maxDepth) {
|
||||||
|
$current->attachLeft($this->terminate($left[1]));
|
||||||
|
$current->attachRight($this->terminate($right[1]));
|
||||||
|
|
||||||
|
continue 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (count($left[1]) > $this->maxLeafSize) {
|
||||||
|
$node = $this->split($left[0], $left[1]);
|
||||||
|
|
||||||
|
if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) {
|
||||||
|
$current->attachLeft($node);
|
||||||
|
|
||||||
|
$stack[] = [$node, $depth];
|
||||||
|
} else {
|
||||||
|
$current->attachLeft($this->terminate($left[1]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
$current->attachLeft($this->terminate($left[1]));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (count($right[1]) > $this->maxLeafSize) {
|
||||||
|
$node = $this->split($right[0], $right[1]);
|
||||||
|
|
||||||
|
if ($node->purityIncrease() + 1e-8 > $this->minPurityIncrease) {
|
||||||
|
$current->attachRight($node);
|
||||||
|
|
||||||
|
$stack[] = [$node, $depth];
|
||||||
|
} else {
|
||||||
|
$current->attachRight($this->terminate($right[1]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
$current->attachRight($this->terminate($right[1]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public function search(array $sample): ?BinaryNode
|
||||||
|
{
|
||||||
|
$current = $this->root;
|
||||||
|
|
||||||
|
while ($current) {
|
||||||
|
if ($current instanceof DecisionNode) {
|
||||||
|
$value = $current->value();
|
||||||
|
|
||||||
|
if (is_string($value)) {
|
||||||
|
if ($sample[$current->column()] === $value) {
|
||||||
|
$current = $current->left();
|
||||||
|
} else {
|
||||||
|
$current = $current->right();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ($sample[$current->column()] < $value) {
|
||||||
|
$current = $current->left();
|
||||||
|
} else {
|
||||||
|
$current = $current->right();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
continue 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ($current instanceof LeafNode) {
|
||||||
|
break 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return $current;
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract protected function split(array $samples, array $targets): DecisionNode;
|
||||||
|
|
||||||
|
abstract protected function terminate(array $targets): BinaryNode;
|
||||||
|
}
|
9
src/Tree/Node.php
Normal file
9
src/Tree/Node.php
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Tree;
|
||||||
|
|
||||||
|
interface Node
|
||||||
|
{
|
||||||
|
}
|
45
src/Tree/Node/AverageNode.php
Normal file
45
src/Tree/Node/AverageNode.php
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Tree\Node;
|
||||||
|
|
||||||
|
class AverageNode extends BinaryNode implements PurityNode, LeafNode
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @var float
|
||||||
|
*/
|
||||||
|
private $outcome;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var float
|
||||||
|
*/
|
||||||
|
private $impurity;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
private $samplesCount;
|
||||||
|
|
||||||
|
public function __construct(float $outcome, float $impurity, int $samplesCount)
|
||||||
|
{
|
||||||
|
$this->outcome = $outcome;
|
||||||
|
$this->impurity = $impurity;
|
||||||
|
$this->samplesCount = $samplesCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function outcome(): float
|
||||||
|
{
|
||||||
|
return $this->outcome;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function impurity(): float
|
||||||
|
{
|
||||||
|
return $this->impurity;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function samplesCount(): int
|
||||||
|
{
|
||||||
|
return $this->samplesCount;
|
||||||
|
}
|
||||||
|
}
|
83
src/Tree/Node/BinaryNode.php
Normal file
83
src/Tree/Node/BinaryNode.php
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Tree\Node;
|
||||||
|
|
||||||
|
use Phpml\Tree\Node;
|
||||||
|
|
||||||
|
class BinaryNode implements Node
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @var self|null
|
||||||
|
*/
|
||||||
|
private $parent;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var self|null
|
||||||
|
*/
|
||||||
|
private $left;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var self|null
|
||||||
|
*/
|
||||||
|
private $right;
|
||||||
|
|
||||||
|
public function parent(): ?self
|
||||||
|
{
|
||||||
|
return $this->parent;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function left(): ?self
|
||||||
|
{
|
||||||
|
return $this->left;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function right(): ?self
|
||||||
|
{
|
||||||
|
return $this->right;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function height(): int
|
||||||
|
{
|
||||||
|
return 1 + max($this->left !== null ? $this->left->height() : 0, $this->right !== null ? $this->right->height() : 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function balance(): int
|
||||||
|
{
|
||||||
|
return ($this->right !== null ? $this->right->height() : 0) - ($this->left !== null ? $this->left->height() : 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function setParent(?self $node = null): void
|
||||||
|
{
|
||||||
|
$this->parent = $node;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function attachLeft(self $node): void
|
||||||
|
{
|
||||||
|
$node->setParent($this);
|
||||||
|
$this->left = $node;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function detachLeft(): void
|
||||||
|
{
|
||||||
|
if ($this->left !== null) {
|
||||||
|
$this->left->setParent();
|
||||||
|
$this->left = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public function attachRight(self $node): void
|
||||||
|
{
|
||||||
|
$node->setParent($this);
|
||||||
|
$this->right = $node;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function detachRight(): void
|
||||||
|
{
|
||||||
|
if ($this->right !== null) {
|
||||||
|
$this->right->setParent();
|
||||||
|
$this->right = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
107
src/Tree/Node/DecisionNode.php
Normal file
107
src/Tree/Node/DecisionNode.php
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Tree\Node;
|
||||||
|
|
||||||
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
|
|
||||||
|
class DecisionNode extends BinaryNode implements PurityNode
|
||||||
|
{
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
private $column;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var mixed
|
||||||
|
*/
|
||||||
|
private $value;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
private $groups = [];
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var float
|
||||||
|
*/
|
||||||
|
private $impurity;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
private $samplesCount;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param mixed $value
|
||||||
|
*/
|
||||||
|
public function __construct(int $column, $value, array $groups, float $impurity)
|
||||||
|
{
|
||||||
|
if (count($groups) !== 2) {
|
||||||
|
throw new InvalidArgumentException('The number of groups must be exactly two.');
|
||||||
|
}
|
||||||
|
|
||||||
|
if ($impurity < 0.) {
|
||||||
|
throw new InvalidArgumentException('Impurity cannot be less than 0.');
|
||||||
|
}
|
||||||
|
|
||||||
|
$this->column = $column;
|
||||||
|
$this->value = $value;
|
||||||
|
$this->groups = $groups;
|
||||||
|
$this->impurity = $impurity;
|
||||||
|
$this->samplesCount = (int) array_sum(array_map(function (array $group) {
|
||||||
|
return count($group[0]);
|
||||||
|
}, $groups));
|
||||||
|
}
|
||||||
|
|
||||||
|
public function column(): int
|
||||||
|
{
|
||||||
|
return $this->column;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return mixed
|
||||||
|
*/
|
||||||
|
public function value()
|
||||||
|
{
|
||||||
|
return $this->value;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function groups(): array
|
||||||
|
{
|
||||||
|
return $this->groups;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function impurity(): float
|
||||||
|
{
|
||||||
|
return $this->impurity;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function samplesCount(): int
|
||||||
|
{
|
||||||
|
return $this->samplesCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function purityIncrease(): float
|
||||||
|
{
|
||||||
|
$impurity = $this->impurity;
|
||||||
|
|
||||||
|
if ($this->left() instanceof PurityNode) {
|
||||||
|
$impurity -= $this->left()->impurity()
|
||||||
|
* ($this->left()->samplesCount() / $this->samplesCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
if ($this->right() instanceof PurityNode) {
|
||||||
|
$impurity -= $this->right()->impurity()
|
||||||
|
* ($this->right()->samplesCount() / $this->samplesCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
return $impurity;
|
||||||
|
}
|
||||||
|
|
||||||
|
public function cleanup(): void
|
||||||
|
{
|
||||||
|
$this->groups = [[], []];
|
||||||
|
}
|
||||||
|
}
|
9
src/Tree/Node/LeafNode.php
Normal file
9
src/Tree/Node/LeafNode.php
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Tree\Node;
|
||||||
|
|
||||||
|
interface LeafNode
|
||||||
|
{
|
||||||
|
}
|
14
src/Tree/Node/PurityNode.php
Normal file
14
src/Tree/Node/PurityNode.php
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Tree\Node;
|
||||||
|
|
||||||
|
use Phpml\Tree\Node;
|
||||||
|
|
||||||
|
interface PurityNode extends Node
|
||||||
|
{
|
||||||
|
public function impurity(): float;
|
||||||
|
|
||||||
|
public function samplesCount(): int;
|
||||||
|
}
|
68
tests/Regression/DecisionTreeRegressorTest.php
Normal file
68
tests/Regression/DecisionTreeRegressorTest.php
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Tests\Regression;
|
||||||
|
|
||||||
|
use Phpml\Exception\InvalidOperationException;
|
||||||
|
use Phpml\ModelManager;
|
||||||
|
use Phpml\Regression\DecisionTreeRegressor;
|
||||||
|
use PHPUnit\Framework\TestCase;
|
||||||
|
|
||||||
|
class DecisionTreeRegressorTest extends TestCase
|
||||||
|
{
|
||||||
|
public function testPredictSingleFeatureSamples(): void
|
||||||
|
{
|
||||||
|
$delta = 0.01;
|
||||||
|
|
||||||
|
$samples = [[60], [61], [62], [63], [65]];
|
||||||
|
$targets = [3.1, 3.6, 3.8, 4, 4.1];
|
||||||
|
|
||||||
|
$regression = new DecisionTreeRegressor(4);
|
||||||
|
$regression->train($samples, $targets);
|
||||||
|
|
||||||
|
self::assertEqualsWithDelta([4.05], $regression->predict([[64]]), $delta);
|
||||||
|
|
||||||
|
$samples = [[9300], [10565], [15000], [15000], [17764], [57000], [65940], [73676], [77006], [93739], [146088], [153260]];
|
||||||
|
$targets = [7100, 15500, 4400, 4400, 5900, 4600, 8800, 2000, 2750, 2550, 960, 1025];
|
||||||
|
|
||||||
|
$regression = new DecisionTreeRegressor();
|
||||||
|
$regression->train($samples, $targets);
|
||||||
|
|
||||||
|
self::assertEqualsWithDelta([11300.0], $regression->predict([[9300]]), $delta);
|
||||||
|
self::assertEqualsWithDelta([5250.0], $regression->predict([[57000]]), $delta);
|
||||||
|
self::assertEqualsWithDelta([2433.33], $regression->predict([[77006]]), $delta);
|
||||||
|
self::assertEqualsWithDelta([11300.0], $regression->predict([[9300]]), $delta);
|
||||||
|
self::assertEqualsWithDelta([992.5], $regression->predict([[153260]]), $delta);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testPreventPredictWhenNotTrained(): void
|
||||||
|
{
|
||||||
|
$regression = new DecisionTreeRegressor();
|
||||||
|
|
||||||
|
$this->expectException(InvalidOperationException::class);
|
||||||
|
|
||||||
|
$regression->predict([[1]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testSaveAndRestore(): void
|
||||||
|
{
|
||||||
|
$samples = [[60], [61], [62], [63], [65]];
|
||||||
|
$targets = [3.1, 3.6, 3.8, 4, 4.1];
|
||||||
|
|
||||||
|
$regression = new DecisionTreeRegressor(4);
|
||||||
|
$regression->train($samples, $targets);
|
||||||
|
|
||||||
|
$testSamples = [[9300], [10565], [15000]];
|
||||||
|
$predicted = $regression->predict($testSamples);
|
||||||
|
|
||||||
|
$filename = 'least-squares-test-'.random_int(100, 999).'-'.uniqid('', false);
|
||||||
|
$filepath = (string) tempnam(sys_get_temp_dir(), $filename);
|
||||||
|
$modelManager = new ModelManager();
|
||||||
|
$modelManager->saveToFile($regression, $filepath);
|
||||||
|
|
||||||
|
$restoredRegression = $modelManager->restoreFromFile($filepath);
|
||||||
|
self::assertEquals($regression, $restoredRegression);
|
||||||
|
self::assertEquals($predicted, $restoredRegression->predict($testSamples));
|
||||||
|
}
|
||||||
|
}
|
47
tests/Tree/Node/BinaryNodeTest.php
Normal file
47
tests/Tree/Node/BinaryNodeTest.php
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Tests\Tree\Node;
|
||||||
|
|
||||||
|
use Phpml\Tree\Node\BinaryNode;
|
||||||
|
use PHPUnit\Framework\TestCase;
|
||||||
|
|
||||||
|
final class BinaryNodeTest extends TestCase
|
||||||
|
{
|
||||||
|
public function testSimpleNode(): void
|
||||||
|
{
|
||||||
|
$node = new BinaryNode();
|
||||||
|
|
||||||
|
self::assertEquals(1, $node->height());
|
||||||
|
self::assertEquals(0, $node->balance());
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testAttachDetachLeft(): void
|
||||||
|
{
|
||||||
|
$node = new BinaryNode();
|
||||||
|
$node->attachLeft(new BinaryNode());
|
||||||
|
|
||||||
|
self::assertEquals(2, $node->height());
|
||||||
|
self::assertEquals(-1, $node->balance());
|
||||||
|
|
||||||
|
$node->detachLeft();
|
||||||
|
|
||||||
|
self::assertEquals(1, $node->height());
|
||||||
|
self::assertEquals(0, $node->balance());
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testAttachDetachRight(): void
|
||||||
|
{
|
||||||
|
$node = new BinaryNode();
|
||||||
|
$node->attachRight(new BinaryNode());
|
||||||
|
|
||||||
|
self::assertEquals(2, $node->height());
|
||||||
|
self::assertEquals(1, $node->balance());
|
||||||
|
|
||||||
|
$node->detachRight();
|
||||||
|
|
||||||
|
self::assertEquals(1, $node->height());
|
||||||
|
self::assertEquals(0, $node->balance());
|
||||||
|
}
|
||||||
|
}
|
57
tests/Tree/Node/DecisionNodeTest.php
Normal file
57
tests/Tree/Node/DecisionNodeTest.php
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
<?php
|
||||||
|
|
||||||
|
declare(strict_types=1);
|
||||||
|
|
||||||
|
namespace Phpml\Tests\Tree\Node;
|
||||||
|
|
||||||
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
|
use Phpml\Tree\Node\DecisionNode;
|
||||||
|
use PHPUnit\Framework\TestCase;
|
||||||
|
|
||||||
|
final class DecisionNodeTest extends TestCase
|
||||||
|
{
|
||||||
|
public function testSimpleNode(): void
|
||||||
|
{
|
||||||
|
$node = new DecisionNode(2, 4, [
|
||||||
|
[[[1, 2, 3]], [1]],
|
||||||
|
[[[2, 3, 4]], [2]],
|
||||||
|
], 400);
|
||||||
|
|
||||||
|
self::assertEquals(2, $node->column());
|
||||||
|
self::assertEquals(2, $node->samplesCount());
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testImpurityIncrease(): void
|
||||||
|
{
|
||||||
|
$node = new DecisionNode(2, 4, [
|
||||||
|
[[[1, 2, 3]], [1]],
|
||||||
|
[[[2, 3, 4]], [2]],
|
||||||
|
], 400);
|
||||||
|
|
||||||
|
$node->attachRight(new DecisionNode(2, 4, [
|
||||||
|
[[[1, 2, 3]], [1]],
|
||||||
|
[[[2, 3, 4]], [2]],
|
||||||
|
], 200));
|
||||||
|
|
||||||
|
$node->attachLeft(new DecisionNode(2, 4, [
|
||||||
|
[[[1, 2, 3]], [1]],
|
||||||
|
[[[2, 3, 4]], [2]],
|
||||||
|
], 100));
|
||||||
|
|
||||||
|
self::assertEquals(100, $node->purityIncrease());
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testThrowExceptionOnInvalidGroupsCount(): void
|
||||||
|
{
|
||||||
|
$this->expectException(InvalidArgumentException::class);
|
||||||
|
|
||||||
|
new DecisionNode(2, 3, [], 200);
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testThrowExceptionOnInvalidImpurity(): void
|
||||||
|
{
|
||||||
|
$this->expectException(InvalidArgumentException::class);
|
||||||
|
|
||||||
|
new DecisionNode(2, 3, [[], []], -2);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user