From c1b1a5d6ac368ad9b71240058596b356d8e71d55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Monlla=C3=B3?= Date: Wed, 1 Feb 2017 19:06:38 +0100 Subject: [PATCH] Support for multiple training datasets (#38) * Multiple training data sets allowed * Tests with multiple training data sets * Updating docs according to #38 Documenting all models which predictions will be based on all training data provided. Some models already supported multiple training data sets. --- docs/machine-learning/association/apriori.md | 2 ++ .../classification/k-nearest-neighbors.md | 2 ++ .../classification/naive-bayes.md | 2 ++ docs/machine-learning/classification/svc.md | 2 ++ .../neural-network/backpropagation.md | 1 + .../regression/least-squares.md | 2 ++ docs/machine-learning/regression/svr.md | 2 ++ src/Phpml/Classification/DecisionTree.php | 13 ++++---- src/Phpml/Classification/NaiveBayes.php | 10 +++--- src/Phpml/Helper/Trainable.php | 8 ++--- src/Phpml/Regression/LeastSquares.php | 8 ++--- .../Phpml/Classification/DecisionTreeTest.php | 31 +++++++++++-------- tests/Phpml/Classification/NaiveBayesTest.php | 10 ++++++ 13 files changed, 61 insertions(+), 32 deletions(-) diff --git a/docs/machine-learning/association/apriori.md b/docs/machine-learning/association/apriori.md index 544406e..e6685af 100644 --- a/docs/machine-learning/association/apriori.md +++ b/docs/machine-learning/association/apriori.md @@ -27,6 +27,8 @@ $associator = new Apriori($support = 0.5, $confidence = 0.5); $associator->train($samples, $labels); ``` +You can train the associator using multiple data sets, predictions will be based on all the training data. + ### Predict To predict sample label use `predict` method. You can provide one sample or array of samples: diff --git a/docs/machine-learning/classification/k-nearest-neighbors.md b/docs/machine-learning/classification/k-nearest-neighbors.md index 6e70c61..a4eb96c 100644 --- a/docs/machine-learning/classification/k-nearest-neighbors.md +++ b/docs/machine-learning/classification/k-nearest-neighbors.md @@ -24,6 +24,8 @@ $classifier = new KNearestNeighbors(); $classifier->train($samples, $labels); ``` +You can train the classifier using multiple data sets, predictions will be based on all the training data. + ## Predict To predict sample label use `predict` method. You can provide one sample or array of samples: diff --git a/docs/machine-learning/classification/naive-bayes.md b/docs/machine-learning/classification/naive-bayes.md index e990321..410fd45 100644 --- a/docs/machine-learning/classification/naive-bayes.md +++ b/docs/machine-learning/classification/naive-bayes.md @@ -14,6 +14,8 @@ $classifier = new NaiveBayes(); $classifier->train($samples, $labels); ``` +You can train the classifier using multiple data sets, predictions will be based on all the training data. + ### Predict To predict sample label use `predict` method. You can provide one sample or array of samples: diff --git a/docs/machine-learning/classification/svc.md b/docs/machine-learning/classification/svc.md index d502dac..62da509 100644 --- a/docs/machine-learning/classification/svc.md +++ b/docs/machine-learning/classification/svc.md @@ -34,6 +34,8 @@ $classifier = new SVC(Kernel::LINEAR, $cost = 1000); $classifier->train($samples, $labels); ``` +You can train the classifier using multiple data sets, predictions will be based on all the training data. + ### Predict To predict sample label use `predict` method. You can provide one sample or array of samples: diff --git a/docs/machine-learning/neural-network/backpropagation.md b/docs/machine-learning/neural-network/backpropagation.md index 8c9b560..0582351 100644 --- a/docs/machine-learning/neural-network/backpropagation.md +++ b/docs/machine-learning/neural-network/backpropagation.md @@ -27,3 +27,4 @@ $training->train( $maxIteraions = 30000 ); ``` +You can train the neural network using multiple data sets, predictions will be based on all the training data. diff --git a/docs/machine-learning/regression/least-squares.md b/docs/machine-learning/regression/least-squares.md index 4a00bcd..84a3279 100644 --- a/docs/machine-learning/regression/least-squares.md +++ b/docs/machine-learning/regression/least-squares.md @@ -14,6 +14,8 @@ $regression = new LeastSquares(); $regression->train($samples, $targets); ``` +You can train the model using multiple data sets, predictions will be based on all the training data. + ### Predict To predict sample target value use `predict` method with sample to check (as `array`). Example: diff --git a/docs/machine-learning/regression/svr.md b/docs/machine-learning/regression/svr.md index ed2d10f..ba6bd74 100644 --- a/docs/machine-learning/regression/svr.md +++ b/docs/machine-learning/regression/svr.md @@ -34,6 +34,8 @@ $regression = new SVR(Kernel::LINEAR); $regression->train($samples, $targets); ``` +You can train the model using multiple data sets, predictions will be based on all the training data. + ### Predict To predict sample target value use `predict` method. You can provide one sample or array of samples: diff --git a/src/Phpml/Classification/DecisionTree.php b/src/Phpml/Classification/DecisionTree.php index 45b6329..1a39cbe 100644 --- a/src/Phpml/Classification/DecisionTree.php +++ b/src/Phpml/Classification/DecisionTree.php @@ -64,12 +64,13 @@ class DecisionTree implements Classifier */ public function train(array $samples, array $targets) { - $this->featureCount = count($samples[0]); - $this->columnTypes = $this->getColumnTypes($samples); - $this->samples = $samples; - $this->targets = $targets; - $this->labels = array_keys(array_count_values($targets)); - $this->tree = $this->getSplitLeaf(range(0, count($samples) - 1)); + $this->samples = array_merge($this->samples, $samples); + $this->targets = array_merge($this->targets, $targets); + + $this->featureCount = count($this->samples[0]); + $this->columnTypes = $this->getColumnTypes($this->samples); + $this->labels = array_keys(array_count_values($this->targets)); + $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1)); } protected function getColumnTypes(array $samples) diff --git a/src/Phpml/Classification/NaiveBayes.php b/src/Phpml/Classification/NaiveBayes.php index 2596ada..af81b00 100644 --- a/src/Phpml/Classification/NaiveBayes.php +++ b/src/Phpml/Classification/NaiveBayes.php @@ -63,12 +63,12 @@ class NaiveBayes implements Classifier */ public function train(array $samples, array $targets) { - $this->samples = $samples; - $this->targets = $targets; - $this->sampleCount = count($samples); - $this->featureCount = count($samples[0]); + $this->samples = array_merge($this->samples, $samples); + $this->targets = array_merge($this->targets, $targets); + $this->sampleCount = count($this->samples); + $this->featureCount = count($this->samples[0]); - $labelCounts = array_count_values($targets); + $labelCounts = array_count_values($this->targets); $this->labels = array_keys($labelCounts); foreach ($this->labels as $label) { $samples = $this->getSamplesByLabel($label); diff --git a/src/Phpml/Helper/Trainable.php b/src/Phpml/Helper/Trainable.php index e58a1da..3d011ac 100644 --- a/src/Phpml/Helper/Trainable.php +++ b/src/Phpml/Helper/Trainable.php @@ -9,12 +9,12 @@ trait Trainable /** * @var array */ - private $samples; + private $samples = []; /** * @var array */ - private $targets; + private $targets = []; /** * @param array $samples @@ -22,7 +22,7 @@ trait Trainable */ public function train(array $samples, array $targets) { - $this->samples = $samples; - $this->targets = $targets; + $this->samples = array_merge($this->samples, $samples); + $this->targets = array_merge($this->targets, $targets); } } diff --git a/src/Phpml/Regression/LeastSquares.php b/src/Phpml/Regression/LeastSquares.php index 19609fb..1b664ed 100644 --- a/src/Phpml/Regression/LeastSquares.php +++ b/src/Phpml/Regression/LeastSquares.php @@ -13,12 +13,12 @@ class LeastSquares implements Regression /** * @var array */ - private $samples; + private $samples = []; /** * @var array */ - private $targets; + private $targets = []; /** * @var float @@ -36,8 +36,8 @@ class LeastSquares implements Regression */ public function train(array $samples, array $targets) { - $this->samples = $samples; - $this->targets = $targets; + $this->samples = array_merge($this->samples, $samples); + $this->targets = array_merge($this->targets, $targets); $this->computeCoefficients(); } diff --git a/tests/Phpml/Classification/DecisionTreeTest.php b/tests/Phpml/Classification/DecisionTreeTest.php index c6f307d..25fb94c 100644 --- a/tests/Phpml/Classification/DecisionTreeTest.php +++ b/tests/Phpml/Classification/DecisionTreeTest.php @@ -8,7 +8,7 @@ use Phpml\Classification\DecisionTree; class DecisionTreeTest extends \PHPUnit_Framework_TestCase { - public $data = [ + private $data = [ ['sunny', 85, 85, 'false', 'Dont_play' ], ['sunny', 80, 90, 'true', 'Dont_play' ], ['overcast', 83, 78, 'false', 'Play' ], @@ -25,34 +25,39 @@ class DecisionTreeTest extends \PHPUnit_Framework_TestCase ['rain', 71, 80, 'true', 'Dont_play' ] ]; - public function getData() + private $extraData = [ + ['scorching', 90, 95, 'false', 'Dont_play'], + ['scorching', 100, 93, 'true', 'Dont_play'], + ]; + + private function getData($input) { - static $data = null, $targets = null; - if ($data == null) { - $data = $this->data; - $targets = array_column($data, 4); - array_walk($data, function (&$v) { - array_splice($v, 4, 1); - }); - } - return [$data, $targets]; + $targets = array_column($input, 4); + array_walk($input, function (&$v) { + array_splice($v, 4, 1); + }); + return [$input, $targets]; } public function testPredictSingleSample() { - list($data, $targets) = $this->getData(); + list($data, $targets) = $this->getData($this->data); $classifier = new DecisionTree(5); $classifier->train($data, $targets); $this->assertEquals('Dont_play', $classifier->predict(['sunny', 78, 72, 'false'])); $this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false'])); $this->assertEquals('Dont_play', $classifier->predict(['rain', 60, 60, 'true'])); + list($data, $targets) = $this->getData($this->extraData); + $classifier->train($data, $targets); + $this->assertEquals('Dont_play', $classifier->predict(['scorching', 95, 90, 'true'])); + $this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false'])); return $classifier; } public function testTreeDepth() { - list($data, $targets) = $this->getData(); + list($data, $targets) = $this->getData($this->data); $classifier = new DecisionTree(5); $classifier->train($data, $targets); $this->assertTrue(5 >= $classifier->actualDepth); diff --git a/tests/Phpml/Classification/NaiveBayesTest.php b/tests/Phpml/Classification/NaiveBayesTest.php index f2edb02..1a0aa1f 100644 --- a/tests/Phpml/Classification/NaiveBayesTest.php +++ b/tests/Phpml/Classification/NaiveBayesTest.php @@ -34,5 +34,15 @@ class NaiveBayesTest extends \PHPUnit_Framework_TestCase $predicted = $classifier->predict($testSamples); $this->assertEquals($testLabels, $predicted); + + // Feed an extra set of training data. + $samples = [[1, 1, 6]]; + $labels = ['d']; + $classifier->train($samples, $labels); + + $testSamples = [[1, 1, 6], [5, 1, 1]]; + $testLabels = ['d', 'a']; + $this->assertEquals($testLabels, $classifier->predict($testSamples)); + } }