Support for multiple training datasets (#38)

* Multiple training data sets allowed

* Tests with multiple training data sets

* Updating docs according to #38

Documenting all models which predictions will be based on all
training data provided.

Some models already supported multiple training data sets.
This commit is contained in:
David Monllaó 2017-02-01 19:06:38 +01:00 committed by Arkadiusz Kondas
parent 6281da280f
commit c1b1a5d6ac
13 changed files with 61 additions and 32 deletions

View File

@ -27,6 +27,8 @@ $associator = new Apriori($support = 0.5, $confidence = 0.5);
$associator->train($samples, $labels);
```
You can train the associator using multiple data sets, predictions will be based on all the training data.
### Predict
To predict sample label use `predict` method. You can provide one sample or array of samples:

View File

@ -24,6 +24,8 @@ $classifier = new KNearestNeighbors();
$classifier->train($samples, $labels);
```
You can train the classifier using multiple data sets, predictions will be based on all the training data.
## Predict
To predict sample label use `predict` method. You can provide one sample or array of samples:

View File

@ -14,6 +14,8 @@ $classifier = new NaiveBayes();
$classifier->train($samples, $labels);
```
You can train the classifier using multiple data sets, predictions will be based on all the training data.
### Predict
To predict sample label use `predict` method. You can provide one sample or array of samples:

View File

@ -34,6 +34,8 @@ $classifier = new SVC(Kernel::LINEAR, $cost = 1000);
$classifier->train($samples, $labels);
```
You can train the classifier using multiple data sets, predictions will be based on all the training data.
### Predict
To predict sample label use `predict` method. You can provide one sample or array of samples:

View File

@ -27,3 +27,4 @@ $training->train(
$maxIteraions = 30000
);
```
You can train the neural network using multiple data sets, predictions will be based on all the training data.

View File

@ -14,6 +14,8 @@ $regression = new LeastSquares();
$regression->train($samples, $targets);
```
You can train the model using multiple data sets, predictions will be based on all the training data.
### Predict
To predict sample target value use `predict` method with sample to check (as `array`). Example:

View File

@ -34,6 +34,8 @@ $regression = new SVR(Kernel::LINEAR);
$regression->train($samples, $targets);
```
You can train the model using multiple data sets, predictions will be based on all the training data.
### Predict
To predict sample target value use `predict` method. You can provide one sample or array of samples:

View File

@ -64,12 +64,13 @@ class DecisionTree implements Classifier
*/
public function train(array $samples, array $targets)
{
$this->featureCount = count($samples[0]);
$this->columnTypes = $this->getColumnTypes($samples);
$this->samples = $samples;
$this->targets = $targets;
$this->labels = array_keys(array_count_values($targets));
$this->tree = $this->getSplitLeaf(range(0, count($samples) - 1));
$this->samples = array_merge($this->samples, $samples);
$this->targets = array_merge($this->targets, $targets);
$this->featureCount = count($this->samples[0]);
$this->columnTypes = $this->getColumnTypes($this->samples);
$this->labels = array_keys(array_count_values($this->targets));
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
}
protected function getColumnTypes(array $samples)

View File

@ -63,12 +63,12 @@ class NaiveBayes implements Classifier
*/
public function train(array $samples, array $targets)
{
$this->samples = $samples;
$this->targets = $targets;
$this->sampleCount = count($samples);
$this->featureCount = count($samples[0]);
$this->samples = array_merge($this->samples, $samples);
$this->targets = array_merge($this->targets, $targets);
$this->sampleCount = count($this->samples);
$this->featureCount = count($this->samples[0]);
$labelCounts = array_count_values($targets);
$labelCounts = array_count_values($this->targets);
$this->labels = array_keys($labelCounts);
foreach ($this->labels as $label) {
$samples = $this->getSamplesByLabel($label);

View File

@ -9,12 +9,12 @@ trait Trainable
/**
* @var array
*/
private $samples;
private $samples = [];
/**
* @var array
*/
private $targets;
private $targets = [];
/**
* @param array $samples
@ -22,7 +22,7 @@ trait Trainable
*/
public function train(array $samples, array $targets)
{
$this->samples = $samples;
$this->targets = $targets;
$this->samples = array_merge($this->samples, $samples);
$this->targets = array_merge($this->targets, $targets);
}
}

View File

@ -13,12 +13,12 @@ class LeastSquares implements Regression
/**
* @var array
*/
private $samples;
private $samples = [];
/**
* @var array
*/
private $targets;
private $targets = [];
/**
* @var float
@ -36,8 +36,8 @@ class LeastSquares implements Regression
*/
public function train(array $samples, array $targets)
{
$this->samples = $samples;
$this->targets = $targets;
$this->samples = array_merge($this->samples, $samples);
$this->targets = array_merge($this->targets, $targets);
$this->computeCoefficients();
}

View File

@ -8,7 +8,7 @@ use Phpml\Classification\DecisionTree;
class DecisionTreeTest extends \PHPUnit_Framework_TestCase
{
public $data = [
private $data = [
['sunny', 85, 85, 'false', 'Dont_play' ],
['sunny', 80, 90, 'true', 'Dont_play' ],
['overcast', 83, 78, 'false', 'Play' ],
@ -25,34 +25,39 @@ class DecisionTreeTest extends \PHPUnit_Framework_TestCase
['rain', 71, 80, 'true', 'Dont_play' ]
];
public function getData()
private $extraData = [
['scorching', 90, 95, 'false', 'Dont_play'],
['scorching', 100, 93, 'true', 'Dont_play'],
];
private function getData($input)
{
static $data = null, $targets = null;
if ($data == null) {
$data = $this->data;
$targets = array_column($data, 4);
array_walk($data, function (&$v) {
array_splice($v, 4, 1);
});
}
return [$data, $targets];
$targets = array_column($input, 4);
array_walk($input, function (&$v) {
array_splice($v, 4, 1);
});
return [$input, $targets];
}
public function testPredictSingleSample()
{
list($data, $targets) = $this->getData();
list($data, $targets) = $this->getData($this->data);
$classifier = new DecisionTree(5);
$classifier->train($data, $targets);
$this->assertEquals('Dont_play', $classifier->predict(['sunny', 78, 72, 'false']));
$this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false']));
$this->assertEquals('Dont_play', $classifier->predict(['rain', 60, 60, 'true']));
list($data, $targets) = $this->getData($this->extraData);
$classifier->train($data, $targets);
$this->assertEquals('Dont_play', $classifier->predict(['scorching', 95, 90, 'true']));
$this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false']));
return $classifier;
}
public function testTreeDepth()
{
list($data, $targets) = $this->getData();
list($data, $targets) = $this->getData($this->data);
$classifier = new DecisionTree(5);
$classifier->train($data, $targets);
$this->assertTrue(5 >= $classifier->actualDepth);

View File

@ -34,5 +34,15 @@ class NaiveBayesTest extends \PHPUnit_Framework_TestCase
$predicted = $classifier->predict($testSamples);
$this->assertEquals($testLabels, $predicted);
// Feed an extra set of training data.
$samples = [[1, 1, 6]];
$labels = ['d'];
$classifier->train($samples, $labels);
$testSamples = [[1, 1, 6], [5, 1, 1]];
$testLabels = ['d', 'a'];
$this->assertEquals($testLabels, $classifier->predict($testSamples));
}
}