diff --git a/src/Phpml/Classification/DecisionTree.php b/src/Phpml/Classification/DecisionTree.php
new file mode 100644
index 0000000..033b22b
--- /dev/null
+++ b/src/Phpml/Classification/DecisionTree.php
@@ -0,0 +1,274 @@
+maxDepth = $maxDepth;
+ }
+ /**
+ * @param array $samples
+ * @param array $targets
+ */
+ public function train(array $samples, array $targets)
+ {
+ $this->featureCount = count($samples[0]);
+ $this->columnTypes = $this->getColumnTypes($samples);
+ $this->samples = $samples;
+ $this->targets = $targets;
+ $this->labels = array_keys(array_count_values($targets));
+ $this->tree = $this->getSplitLeaf(range(0, count($samples) - 1));
+ }
+
+ protected function getColumnTypes(array $samples)
+ {
+ $types = [];
+ for ($i=0; $i<$this->featureCount; $i++) {
+ $values = array_column($samples, $i);
+ $isCategorical = $this->isCategoricalColumn($values);
+ $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOS;
+ }
+ return $types;
+ }
+
+ /**
+ * @param null|array $records
+ * @return DecisionTreeLeaf
+ */
+ protected function getSplitLeaf($records, $depth = 0)
+ {
+ $split = $this->getBestSplit($records);
+ $split->level = $depth;
+ if ($this->actualDepth < $depth) {
+ $this->actualDepth = $depth;
+ }
+ $leftRecords = [];
+ $rightRecords= [];
+ $remainingTargets = [];
+ $prevRecord = null;
+ $allSame = true;
+ foreach ($records as $recordNo) {
+ $record = $this->samples[$recordNo];
+ if ($prevRecord && $prevRecord != $record) {
+ $allSame = false;
+ }
+ $prevRecord = $record;
+ if ($split->evaluate($record)) {
+ $leftRecords[] = $recordNo;
+ } else {
+ $rightRecords[]= $recordNo;
+ }
+ $target = $this->targets[$recordNo];
+ if (! in_array($target, $remainingTargets)) {
+ $remainingTargets[] = $target;
+ }
+ }
+
+ if (count($remainingTargets) == 1 || $allSame || $depth >= $this->maxDepth) {
+ $split->isTerminal = 1;
+ $classes = array_count_values($remainingTargets);
+ arsort($classes);
+ $split->classValue = key($classes);
+ } else {
+ if ($leftRecords) {
+ $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
+ }
+ if ($rightRecords) {
+ $split->rightLeaf= $this->getSplitLeaf($rightRecords, $depth + 1);
+ }
+ }
+ return $split;
+ }
+
+ /**
+ * @param array $records
+ * @return DecisionTreeLeaf[]
+ */
+ protected function getBestSplit($records)
+ {
+ $targets = array_intersect_key($this->targets, array_flip($records));
+ $samples = array_intersect_key($this->samples, array_flip($records));
+ $samples = array_combine($records, $this->preprocess($samples));
+ $bestGiniVal = 1;
+ $bestSplit = null;
+ for ($i=0; $i<$this->featureCount; $i++) {
+ $colValues = [];
+ $baseValue = null;
+ foreach ($samples as $index => $row) {
+ $colValues[$index] = $row[$i];
+ if ($baseValue === null) {
+ $baseValue = $row[$i];
+ }
+ }
+ $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
+ if ($bestSplit == null || $bestGiniVal > $gini) {
+ $split = new DecisionTreeLeaf();
+ $split->value = $baseValue;
+ $split->giniIndex = $gini;
+ $split->columnIndex = $i;
+ $split->records = $records;
+ $bestSplit = $split;
+ $bestGiniVal = $gini;
+ }
+ }
+ return $bestSplit;
+ }
+
+ /**
+ * @param string $baseValue
+ * @param array $colValues
+ * @param array $targets
+ */
+ public function getGiniIndex($baseValue, $colValues, $targets)
+ {
+ $countMatrix = [];
+ foreach ($this->labels as $label) {
+ $countMatrix[$label] = [0, 0];
+ }
+ foreach ($colValues as $index => $value) {
+ $label = $targets[$index];
+ $rowIndex = $value == $baseValue ? 0 : 1;
+ $countMatrix[$label][$rowIndex]++;
+ }
+ $giniParts = [0, 0];
+ for ($i=0; $i<=1; $i++) {
+ $part = 0;
+ $sum = array_sum(array_column($countMatrix, $i));
+ if ($sum > 0) {
+ foreach ($this->labels as $label) {
+ $part += pow($countMatrix[$label][$i] / floatval($sum), 2);
+ }
+ }
+ $giniParts[$i] = (1 - $part) * $sum;
+ }
+ return array_sum($giniParts) / count($colValues);
+ }
+
+ /**
+ * @param array $samples
+ * @return array
+ */
+ protected function preprocess(array $samples)
+ {
+ // Detect and convert continuous data column values into
+ // discrete values by using the median as a threshold value
+ $columns = array();
+ for ($i=0; $i<$this->featureCount; $i++) {
+ $values = array_column($samples, $i);
+ if ($this->columnTypes[$i] == self::CONTINUOS) {
+ $median = Mean::median($values);
+ foreach ($values as &$value) {
+ if ($value <= $median) {
+ $value = "<= $median";
+ } else {
+ $value = "> $median";
+ }
+ }
+ }
+ $columns[] = $values;
+ }
+ // Below method is a strange yet very simple & efficient method
+ // to get the transpose of a 2D array
+ return array_map(null, ...$columns);
+ }
+
+ /**
+ * @param array $columnValues
+ * @return bool
+ */
+ protected function isCategoricalColumn(array $columnValues)
+ {
+ $count = count($columnValues);
+ // There are two main indicators that *may* show whether a
+ // column is composed of discrete set of values:
+ // 1- Column may contain string values
+ // 2- Number of unique values in the column is only a small fraction of
+ // all values in that column (Lower than or equal to %20 of all values)
+ $numericValues = array_filter($columnValues, 'is_numeric');
+ if (count($numericValues) != $count) {
+ return true;
+ }
+ $distinctValues = array_count_values($columnValues);
+ if (count($distinctValues) <= $count / 5) {
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * @return string
+ */
+ public function getHtml()
+ {
+ return $this->tree->__toString();
+ }
+
+ /**
+ * @param array $sample
+ * @return mixed
+ */
+ protected function predictSample(array $sample)
+ {
+ $node = $this->tree;
+ do {
+ if ($node->isTerminal) {
+ break;
+ }
+ if ($node->evaluate($sample)) {
+ $node = $node->leftLeaf;
+ } else {
+ $node = $node->rightLeaf;
+ }
+ } while ($node);
+ return $node->classValue;
+ }
+}
diff --git a/src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php b/src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
new file mode 100644
index 0000000..220f876
--- /dev/null
+++ b/src/Phpml/Classification/DecisionTree/DecisionTreeLeaf.php
@@ -0,0 +1,106 @@
+columnIndex];
+ if (preg_match("/^([<>=]{1,2})\s*(.*)/", $this->value, $matches)) {
+ $op = $matches[1];
+ $value= floatval($matches[2]);
+ $recordField = strval($recordField);
+ eval("\$result = $recordField $op $value;");
+ return $result;
+ }
+ return $recordField == $this->value;
+ }
+
+ public function __toString()
+ {
+ if ($this->isTerminal) {
+ $value = "$this->classValue";
+ } else {
+ $value = $this->value;
+ $col = "col_$this->columnIndex";
+ if (! preg_match("/^[<>=]{1,2}/", $value)) {
+ $value = "=$value";
+ }
+ $value = "$col $value
Gini: ". number_format($this->giniIndex, 2);
+ }
+ $str = "
+ $value | ||||
| Yes $this->leftLeaf | ";
+ } else {
+ $str .=''; + } + $str .=' | '; + if ($this->rightLeaf) { + $str .=" | No | $this->rightLeaf | ";
+ } else {
+ $str .=''; + } + $str .= ' |