2017-01-31 19:27:15 +00:00
|
|
|
<?php
|
2017-01-31 19:33:08 +00:00
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
declare(strict_types=1);
|
|
|
|
|
|
|
|
namespace Phpml\Classification\DecisionTree;
|
|
|
|
|
2017-10-24 16:59:12 +00:00
|
|
|
use Phpml\Math\Comparison;
|
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
class DecisionTreeLeaf
|
|
|
|
{
|
|
|
|
/**
|
2018-01-06 12:09:33 +00:00
|
|
|
* @var string|int
|
2017-01-31 19:27:15 +00:00
|
|
|
*/
|
|
|
|
public $value;
|
|
|
|
|
2017-02-21 09:38:18 +00:00
|
|
|
/**
|
|
|
|
* @var float
|
|
|
|
*/
|
|
|
|
public $numericValue;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var string
|
|
|
|
*/
|
|
|
|
public $operator;
|
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
/**
|
|
|
|
* @var int
|
|
|
|
*/
|
|
|
|
public $columnIndex;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var DecisionTreeLeaf
|
|
|
|
*/
|
|
|
|
public $leftLeaf = null;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var DecisionTreeLeaf
|
|
|
|
*/
|
2017-08-17 06:50:37 +00:00
|
|
|
public $rightLeaf = null;
|
2017-01-31 19:27:15 +00:00
|
|
|
|
|
|
|
/**
|
|
|
|
* @var array
|
|
|
|
*/
|
|
|
|
public $records = [];
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Class value represented by the leaf, this value is non-empty
|
|
|
|
* only for terminal leaves
|
|
|
|
*
|
|
|
|
* @var string
|
|
|
|
*/
|
|
|
|
public $classValue = '';
|
|
|
|
|
|
|
|
/**
|
2018-01-06 12:09:33 +00:00
|
|
|
* @var bool|int
|
2017-01-31 19:27:15 +00:00
|
|
|
*/
|
|
|
|
public $isTerminal = false;
|
|
|
|
|
2017-02-13 20:23:18 +00:00
|
|
|
/**
|
|
|
|
* @var bool
|
|
|
|
*/
|
|
|
|
public $isContinuous = false;
|
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
/**
|
|
|
|
* @var float
|
|
|
|
*/
|
|
|
|
public $giniIndex = 0;
|
|
|
|
|
|
|
|
/**
|
|
|
|
* @var int
|
|
|
|
*/
|
|
|
|
public $level = 0;
|
|
|
|
|
2017-11-22 21:16:10 +00:00
|
|
|
/**
|
|
|
|
* HTML representation of the tree without column names
|
|
|
|
*/
|
|
|
|
public function __toString(): string
|
|
|
|
{
|
|
|
|
return $this->getHTML();
|
|
|
|
}
|
|
|
|
|
|
|
|
public function evaluate(array $record): bool
|
2017-01-31 19:27:15 +00:00
|
|
|
{
|
|
|
|
$recordField = $record[$this->columnIndex];
|
2017-02-21 09:38:18 +00:00
|
|
|
|
|
|
|
if ($this->isContinuous) {
|
2017-10-24 16:59:12 +00:00
|
|
|
return Comparison::compare((string) $recordField, $this->numericValue, $this->operator);
|
2017-01-31 19:27:15 +00:00
|
|
|
}
|
2017-11-06 07:56:37 +00:00
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
return $recordField == $this->value;
|
|
|
|
}
|
|
|
|
|
2017-02-13 20:23:18 +00:00
|
|
|
/**
|
|
|
|
* Returns Mean Decrease Impurity (MDI) in the node.
|
|
|
|
* For terminal nodes, this value is equal to 0
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
public function getNodeImpurityDecrease(int $parentRecordCount): float
|
2017-02-13 20:23:18 +00:00
|
|
|
{
|
|
|
|
if ($this->isTerminal) {
|
|
|
|
return 0.0;
|
|
|
|
}
|
|
|
|
|
2017-08-17 06:50:37 +00:00
|
|
|
$nodeSampleCount = (float) count($this->records);
|
2017-02-13 20:23:18 +00:00
|
|
|
$iT = $this->giniIndex;
|
|
|
|
|
|
|
|
if ($this->leftLeaf) {
|
2017-08-17 06:50:37 +00:00
|
|
|
$pL = count($this->leftLeaf->records) / $nodeSampleCount;
|
2017-02-13 20:23:18 +00:00
|
|
|
$iT -= $pL * $this->leftLeaf->giniIndex;
|
|
|
|
}
|
|
|
|
|
|
|
|
if ($this->rightLeaf) {
|
2017-08-17 06:50:37 +00:00
|
|
|
$pR = count($this->rightLeaf->records) / $nodeSampleCount;
|
2017-02-13 20:23:18 +00:00
|
|
|
$iT -= $pR * $this->rightLeaf->giniIndex;
|
|
|
|
}
|
|
|
|
|
|
|
|
return $iT * $nodeSampleCount / $parentRecordCount;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* Returns HTML representation of the node including children nodes
|
|
|
|
*/
|
2017-11-22 21:16:10 +00:00
|
|
|
public function getHTML($columnNames = null): string
|
2017-01-31 19:27:15 +00:00
|
|
|
{
|
|
|
|
if ($this->isTerminal) {
|
|
|
|
$value = "<b>$this->classValue</b>";
|
|
|
|
} else {
|
|
|
|
$value = $this->value;
|
2017-02-13 20:23:18 +00:00
|
|
|
if ($columnNames !== null) {
|
|
|
|
$col = $columnNames[$this->columnIndex];
|
|
|
|
} else {
|
|
|
|
$col = "col_$this->columnIndex";
|
|
|
|
}
|
2017-08-28 11:00:24 +00:00
|
|
|
|
2017-09-02 19:41:06 +00:00
|
|
|
if (!preg_match('/^[<>=]{1,2}/', (string) $value)) {
|
2017-01-31 19:27:15 +00:00
|
|
|
$value = "=$value";
|
|
|
|
}
|
2017-08-28 11:00:24 +00:00
|
|
|
|
2017-08-17 06:50:37 +00:00
|
|
|
$value = "<b>$col $value</b><br>Gini: ".number_format($this->giniIndex, 2);
|
2017-01-31 19:27:15 +00:00
|
|
|
}
|
2017-08-17 06:50:37 +00:00
|
|
|
|
|
|
|
$str = "<table ><tr><td colspan=3 align=center style='border:1px solid;'>$value</td></tr>";
|
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
if ($this->leftLeaf || $this->rightLeaf) {
|
2017-08-17 06:50:37 +00:00
|
|
|
$str .= '<tr>';
|
2017-01-31 19:27:15 +00:00
|
|
|
if ($this->leftLeaf) {
|
2017-08-17 06:50:37 +00:00
|
|
|
$str .= '<td valign=top><b>| Yes</b><br>'.$this->leftLeaf->getHTML($columnNames).'</td>';
|
2017-01-31 19:27:15 +00:00
|
|
|
} else {
|
2017-08-17 06:50:37 +00:00
|
|
|
$str .= '<td></td>';
|
2017-01-31 19:27:15 +00:00
|
|
|
}
|
2017-08-17 06:50:37 +00:00
|
|
|
|
|
|
|
$str .= '<td> </td>';
|
2017-01-31 19:27:15 +00:00
|
|
|
if ($this->rightLeaf) {
|
2017-08-17 06:50:37 +00:00
|
|
|
$str .= '<td valign=top align=right><b>No |</b><br>'.$this->rightLeaf->getHTML($columnNames).'</td>';
|
2017-01-31 19:27:15 +00:00
|
|
|
} else {
|
2017-08-17 06:50:37 +00:00
|
|
|
$str .= '<td></td>';
|
2017-01-31 19:27:15 +00:00
|
|
|
}
|
2017-08-17 06:50:37 +00:00
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
$str .= '</tr>';
|
|
|
|
}
|
2017-08-17 06:50:37 +00:00
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
$str .= '</table>';
|
2017-08-17 06:50:37 +00:00
|
|
|
|
2017-01-31 19:27:15 +00:00
|
|
|
return $str;
|
|
|
|
}
|
|
|
|
}
|