Support for multiple training datasets (#38)

* Multiple training data sets allowed

* Tests with multiple training data sets

* Updating docs according to #38

Documenting all models which predictions will be based on all
training data provided.

Some models already supported multiple training data sets.
This commit is contained in:
David Monllaó 2017-02-01 19:06:38 +01:00 committed by Arkadiusz Kondas
parent 6281da280f
commit c1b1a5d6ac
13 changed files with 61 additions and 32 deletions

View File

@ -27,6 +27,8 @@ $associator = new Apriori($support = 0.5, $confidence = 0.5);
$associator->train($samples, $labels); $associator->train($samples, $labels);
``` ```
You can train the associator using multiple data sets, predictions will be based on all the training data.
### Predict ### Predict
To predict sample label use `predict` method. You can provide one sample or array of samples: To predict sample label use `predict` method. You can provide one sample or array of samples:

View File

@ -24,6 +24,8 @@ $classifier = new KNearestNeighbors();
$classifier->train($samples, $labels); $classifier->train($samples, $labels);
``` ```
You can train the classifier using multiple data sets, predictions will be based on all the training data.
## Predict ## Predict
To predict sample label use `predict` method. You can provide one sample or array of samples: To predict sample label use `predict` method. You can provide one sample or array of samples:

View File

@ -14,6 +14,8 @@ $classifier = new NaiveBayes();
$classifier->train($samples, $labels); $classifier->train($samples, $labels);
``` ```
You can train the classifier using multiple data sets, predictions will be based on all the training data.
### Predict ### Predict
To predict sample label use `predict` method. You can provide one sample or array of samples: To predict sample label use `predict` method. You can provide one sample or array of samples:

View File

@ -34,6 +34,8 @@ $classifier = new SVC(Kernel::LINEAR, $cost = 1000);
$classifier->train($samples, $labels); $classifier->train($samples, $labels);
``` ```
You can train the classifier using multiple data sets, predictions will be based on all the training data.
### Predict ### Predict
To predict sample label use `predict` method. You can provide one sample or array of samples: To predict sample label use `predict` method. You can provide one sample or array of samples:

View File

@ -27,3 +27,4 @@ $training->train(
$maxIteraions = 30000 $maxIteraions = 30000
); );
``` ```
You can train the neural network using multiple data sets, predictions will be based on all the training data.

View File

@ -14,6 +14,8 @@ $regression = new LeastSquares();
$regression->train($samples, $targets); $regression->train($samples, $targets);
``` ```
You can train the model using multiple data sets, predictions will be based on all the training data.
### Predict ### Predict
To predict sample target value use `predict` method with sample to check (as `array`). Example: To predict sample target value use `predict` method with sample to check (as `array`). Example:

View File

@ -34,6 +34,8 @@ $regression = new SVR(Kernel::LINEAR);
$regression->train($samples, $targets); $regression->train($samples, $targets);
``` ```
You can train the model using multiple data sets, predictions will be based on all the training data.
### Predict ### Predict
To predict sample target value use `predict` method. You can provide one sample or array of samples: To predict sample target value use `predict` method. You can provide one sample or array of samples:

View File

@ -64,12 +64,13 @@ class DecisionTree implements Classifier
*/ */
public function train(array $samples, array $targets) public function train(array $samples, array $targets)
{ {
$this->featureCount = count($samples[0]); $this->samples = array_merge($this->samples, $samples);
$this->columnTypes = $this->getColumnTypes($samples); $this->targets = array_merge($this->targets, $targets);
$this->samples = $samples;
$this->targets = $targets; $this->featureCount = count($this->samples[0]);
$this->labels = array_keys(array_count_values($targets)); $this->columnTypes = $this->getColumnTypes($this->samples);
$this->tree = $this->getSplitLeaf(range(0, count($samples) - 1)); $this->labels = array_keys(array_count_values($this->targets));
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
} }
protected function getColumnTypes(array $samples) protected function getColumnTypes(array $samples)

View File

@ -63,12 +63,12 @@ class NaiveBayes implements Classifier
*/ */
public function train(array $samples, array $targets) public function train(array $samples, array $targets)
{ {
$this->samples = $samples; $this->samples = array_merge($this->samples, $samples);
$this->targets = $targets; $this->targets = array_merge($this->targets, $targets);
$this->sampleCount = count($samples); $this->sampleCount = count($this->samples);
$this->featureCount = count($samples[0]); $this->featureCount = count($this->samples[0]);
$labelCounts = array_count_values($targets); $labelCounts = array_count_values($this->targets);
$this->labels = array_keys($labelCounts); $this->labels = array_keys($labelCounts);
foreach ($this->labels as $label) { foreach ($this->labels as $label) {
$samples = $this->getSamplesByLabel($label); $samples = $this->getSamplesByLabel($label);

View File

@ -9,12 +9,12 @@ trait Trainable
/** /**
* @var array * @var array
*/ */
private $samples; private $samples = [];
/** /**
* @var array * @var array
*/ */
private $targets; private $targets = [];
/** /**
* @param array $samples * @param array $samples
@ -22,7 +22,7 @@ trait Trainable
*/ */
public function train(array $samples, array $targets) public function train(array $samples, array $targets)
{ {
$this->samples = $samples; $this->samples = array_merge($this->samples, $samples);
$this->targets = $targets; $this->targets = array_merge($this->targets, $targets);
} }
} }

View File

@ -13,12 +13,12 @@ class LeastSquares implements Regression
/** /**
* @var array * @var array
*/ */
private $samples; private $samples = [];
/** /**
* @var array * @var array
*/ */
private $targets; private $targets = [];
/** /**
* @var float * @var float
@ -36,8 +36,8 @@ class LeastSquares implements Regression
*/ */
public function train(array $samples, array $targets) public function train(array $samples, array $targets)
{ {
$this->samples = $samples; $this->samples = array_merge($this->samples, $samples);
$this->targets = $targets; $this->targets = array_merge($this->targets, $targets);
$this->computeCoefficients(); $this->computeCoefficients();
} }

View File

@ -8,7 +8,7 @@ use Phpml\Classification\DecisionTree;
class DecisionTreeTest extends \PHPUnit_Framework_TestCase class DecisionTreeTest extends \PHPUnit_Framework_TestCase
{ {
public $data = [ private $data = [
['sunny', 85, 85, 'false', 'Dont_play' ], ['sunny', 85, 85, 'false', 'Dont_play' ],
['sunny', 80, 90, 'true', 'Dont_play' ], ['sunny', 80, 90, 'true', 'Dont_play' ],
['overcast', 83, 78, 'false', 'Play' ], ['overcast', 83, 78, 'false', 'Play' ],
@ -25,34 +25,39 @@ class DecisionTreeTest extends \PHPUnit_Framework_TestCase
['rain', 71, 80, 'true', 'Dont_play' ] ['rain', 71, 80, 'true', 'Dont_play' ]
]; ];
public function getData() private $extraData = [
['scorching', 90, 95, 'false', 'Dont_play'],
['scorching', 100, 93, 'true', 'Dont_play'],
];
private function getData($input)
{ {
static $data = null, $targets = null; $targets = array_column($input, 4);
if ($data == null) { array_walk($input, function (&$v) {
$data = $this->data;
$targets = array_column($data, 4);
array_walk($data, function (&$v) {
array_splice($v, 4, 1); array_splice($v, 4, 1);
}); });
} return [$input, $targets];
return [$data, $targets];
} }
public function testPredictSingleSample() public function testPredictSingleSample()
{ {
list($data, $targets) = $this->getData(); list($data, $targets) = $this->getData($this->data);
$classifier = new DecisionTree(5); $classifier = new DecisionTree(5);
$classifier->train($data, $targets); $classifier->train($data, $targets);
$this->assertEquals('Dont_play', $classifier->predict(['sunny', 78, 72, 'false'])); $this->assertEquals('Dont_play', $classifier->predict(['sunny', 78, 72, 'false']));
$this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false'])); $this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false']));
$this->assertEquals('Dont_play', $classifier->predict(['rain', 60, 60, 'true'])); $this->assertEquals('Dont_play', $classifier->predict(['rain', 60, 60, 'true']));
list($data, $targets) = $this->getData($this->extraData);
$classifier->train($data, $targets);
$this->assertEquals('Dont_play', $classifier->predict(['scorching', 95, 90, 'true']));
$this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false']));
return $classifier; return $classifier;
} }
public function testTreeDepth() public function testTreeDepth()
{ {
list($data, $targets) = $this->getData(); list($data, $targets) = $this->getData($this->data);
$classifier = new DecisionTree(5); $classifier = new DecisionTree(5);
$classifier->train($data, $targets); $classifier->train($data, $targets);
$this->assertTrue(5 >= $classifier->actualDepth); $this->assertTrue(5 >= $classifier->actualDepth);

View File

@ -34,5 +34,15 @@ class NaiveBayesTest extends \PHPUnit_Framework_TestCase
$predicted = $classifier->predict($testSamples); $predicted = $classifier->predict($testSamples);
$this->assertEquals($testLabels, $predicted); $this->assertEquals($testLabels, $predicted);
// Feed an extra set of training data.
$samples = [[1, 1, 6]];
$labels = ['d'];
$classifier->train($samples, $labels);
$testSamples = [[1, 1, 6], [5, 1, 1]];
$testLabels = ['d', 'a'];
$this->assertEquals($testLabels, $classifier->predict($testSamples));
} }
} }