mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-25 06:17:34 +00:00
Support for multiple training datasets (#38)
* Multiple training data sets allowed * Tests with multiple training data sets * Updating docs according to #38 Documenting all models which predictions will be based on all training data provided. Some models already supported multiple training data sets.
This commit is contained in:
parent
6281da280f
commit
c1b1a5d6ac
@ -27,6 +27,8 @@ $associator = new Apriori($support = 0.5, $confidence = 0.5);
|
|||||||
$associator->train($samples, $labels);
|
$associator->train($samples, $labels);
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can train the associator using multiple data sets, predictions will be based on all the training data.
|
||||||
|
|
||||||
### Predict
|
### Predict
|
||||||
|
|
||||||
To predict sample label use `predict` method. You can provide one sample or array of samples:
|
To predict sample label use `predict` method. You can provide one sample or array of samples:
|
||||||
|
@ -24,6 +24,8 @@ $classifier = new KNearestNeighbors();
|
|||||||
$classifier->train($samples, $labels);
|
$classifier->train($samples, $labels);
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can train the classifier using multiple data sets, predictions will be based on all the training data.
|
||||||
|
|
||||||
## Predict
|
## Predict
|
||||||
|
|
||||||
To predict sample label use `predict` method. You can provide one sample or array of samples:
|
To predict sample label use `predict` method. You can provide one sample or array of samples:
|
||||||
|
@ -14,6 +14,8 @@ $classifier = new NaiveBayes();
|
|||||||
$classifier->train($samples, $labels);
|
$classifier->train($samples, $labels);
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can train the classifier using multiple data sets, predictions will be based on all the training data.
|
||||||
|
|
||||||
### Predict
|
### Predict
|
||||||
|
|
||||||
To predict sample label use `predict` method. You can provide one sample or array of samples:
|
To predict sample label use `predict` method. You can provide one sample or array of samples:
|
||||||
|
@ -34,6 +34,8 @@ $classifier = new SVC(Kernel::LINEAR, $cost = 1000);
|
|||||||
$classifier->train($samples, $labels);
|
$classifier->train($samples, $labels);
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can train the classifier using multiple data sets, predictions will be based on all the training data.
|
||||||
|
|
||||||
### Predict
|
### Predict
|
||||||
|
|
||||||
To predict sample label use `predict` method. You can provide one sample or array of samples:
|
To predict sample label use `predict` method. You can provide one sample or array of samples:
|
||||||
|
@ -27,3 +27,4 @@ $training->train(
|
|||||||
$maxIteraions = 30000
|
$maxIteraions = 30000
|
||||||
);
|
);
|
||||||
```
|
```
|
||||||
|
You can train the neural network using multiple data sets, predictions will be based on all the training data.
|
||||||
|
@ -14,6 +14,8 @@ $regression = new LeastSquares();
|
|||||||
$regression->train($samples, $targets);
|
$regression->train($samples, $targets);
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can train the model using multiple data sets, predictions will be based on all the training data.
|
||||||
|
|
||||||
### Predict
|
### Predict
|
||||||
|
|
||||||
To predict sample target value use `predict` method with sample to check (as `array`). Example:
|
To predict sample target value use `predict` method with sample to check (as `array`). Example:
|
||||||
|
@ -34,6 +34,8 @@ $regression = new SVR(Kernel::LINEAR);
|
|||||||
$regression->train($samples, $targets);
|
$regression->train($samples, $targets);
|
||||||
```
|
```
|
||||||
|
|
||||||
|
You can train the model using multiple data sets, predictions will be based on all the training data.
|
||||||
|
|
||||||
### Predict
|
### Predict
|
||||||
|
|
||||||
To predict sample target value use `predict` method. You can provide one sample or array of samples:
|
To predict sample target value use `predict` method. You can provide one sample or array of samples:
|
||||||
|
@ -64,12 +64,13 @@ class DecisionTree implements Classifier
|
|||||||
*/
|
*/
|
||||||
public function train(array $samples, array $targets)
|
public function train(array $samples, array $targets)
|
||||||
{
|
{
|
||||||
$this->featureCount = count($samples[0]);
|
$this->samples = array_merge($this->samples, $samples);
|
||||||
$this->columnTypes = $this->getColumnTypes($samples);
|
$this->targets = array_merge($this->targets, $targets);
|
||||||
$this->samples = $samples;
|
|
||||||
$this->targets = $targets;
|
$this->featureCount = count($this->samples[0]);
|
||||||
$this->labels = array_keys(array_count_values($targets));
|
$this->columnTypes = $this->getColumnTypes($this->samples);
|
||||||
$this->tree = $this->getSplitLeaf(range(0, count($samples) - 1));
|
$this->labels = array_keys(array_count_values($this->targets));
|
||||||
|
$this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
protected function getColumnTypes(array $samples)
|
protected function getColumnTypes(array $samples)
|
||||||
|
@ -63,12 +63,12 @@ class NaiveBayes implements Classifier
|
|||||||
*/
|
*/
|
||||||
public function train(array $samples, array $targets)
|
public function train(array $samples, array $targets)
|
||||||
{
|
{
|
||||||
$this->samples = $samples;
|
$this->samples = array_merge($this->samples, $samples);
|
||||||
$this->targets = $targets;
|
$this->targets = array_merge($this->targets, $targets);
|
||||||
$this->sampleCount = count($samples);
|
$this->sampleCount = count($this->samples);
|
||||||
$this->featureCount = count($samples[0]);
|
$this->featureCount = count($this->samples[0]);
|
||||||
|
|
||||||
$labelCounts = array_count_values($targets);
|
$labelCounts = array_count_values($this->targets);
|
||||||
$this->labels = array_keys($labelCounts);
|
$this->labels = array_keys($labelCounts);
|
||||||
foreach ($this->labels as $label) {
|
foreach ($this->labels as $label) {
|
||||||
$samples = $this->getSamplesByLabel($label);
|
$samples = $this->getSamplesByLabel($label);
|
||||||
|
@ -9,12 +9,12 @@ trait Trainable
|
|||||||
/**
|
/**
|
||||||
* @var array
|
* @var array
|
||||||
*/
|
*/
|
||||||
private $samples;
|
private $samples = [];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @var array
|
* @var array
|
||||||
*/
|
*/
|
||||||
private $targets;
|
private $targets = [];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param array $samples
|
* @param array $samples
|
||||||
@ -22,7 +22,7 @@ trait Trainable
|
|||||||
*/
|
*/
|
||||||
public function train(array $samples, array $targets)
|
public function train(array $samples, array $targets)
|
||||||
{
|
{
|
||||||
$this->samples = $samples;
|
$this->samples = array_merge($this->samples, $samples);
|
||||||
$this->targets = $targets;
|
$this->targets = array_merge($this->targets, $targets);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -13,12 +13,12 @@ class LeastSquares implements Regression
|
|||||||
/**
|
/**
|
||||||
* @var array
|
* @var array
|
||||||
*/
|
*/
|
||||||
private $samples;
|
private $samples = [];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @var array
|
* @var array
|
||||||
*/
|
*/
|
||||||
private $targets;
|
private $targets = [];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @var float
|
* @var float
|
||||||
@ -36,8 +36,8 @@ class LeastSquares implements Regression
|
|||||||
*/
|
*/
|
||||||
public function train(array $samples, array $targets)
|
public function train(array $samples, array $targets)
|
||||||
{
|
{
|
||||||
$this->samples = $samples;
|
$this->samples = array_merge($this->samples, $samples);
|
||||||
$this->targets = $targets;
|
$this->targets = array_merge($this->targets, $targets);
|
||||||
|
|
||||||
$this->computeCoefficients();
|
$this->computeCoefficients();
|
||||||
}
|
}
|
||||||
|
@ -8,7 +8,7 @@ use Phpml\Classification\DecisionTree;
|
|||||||
|
|
||||||
class DecisionTreeTest extends \PHPUnit_Framework_TestCase
|
class DecisionTreeTest extends \PHPUnit_Framework_TestCase
|
||||||
{
|
{
|
||||||
public $data = [
|
private $data = [
|
||||||
['sunny', 85, 85, 'false', 'Dont_play' ],
|
['sunny', 85, 85, 'false', 'Dont_play' ],
|
||||||
['sunny', 80, 90, 'true', 'Dont_play' ],
|
['sunny', 80, 90, 'true', 'Dont_play' ],
|
||||||
['overcast', 83, 78, 'false', 'Play' ],
|
['overcast', 83, 78, 'false', 'Play' ],
|
||||||
@ -25,34 +25,39 @@ class DecisionTreeTest extends \PHPUnit_Framework_TestCase
|
|||||||
['rain', 71, 80, 'true', 'Dont_play' ]
|
['rain', 71, 80, 'true', 'Dont_play' ]
|
||||||
];
|
];
|
||||||
|
|
||||||
public function getData()
|
private $extraData = [
|
||||||
|
['scorching', 90, 95, 'false', 'Dont_play'],
|
||||||
|
['scorching', 100, 93, 'true', 'Dont_play'],
|
||||||
|
];
|
||||||
|
|
||||||
|
private function getData($input)
|
||||||
{
|
{
|
||||||
static $data = null, $targets = null;
|
$targets = array_column($input, 4);
|
||||||
if ($data == null) {
|
array_walk($input, function (&$v) {
|
||||||
$data = $this->data;
|
|
||||||
$targets = array_column($data, 4);
|
|
||||||
array_walk($data, function (&$v) {
|
|
||||||
array_splice($v, 4, 1);
|
array_splice($v, 4, 1);
|
||||||
});
|
});
|
||||||
}
|
return [$input, $targets];
|
||||||
return [$data, $targets];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public function testPredictSingleSample()
|
public function testPredictSingleSample()
|
||||||
{
|
{
|
||||||
list($data, $targets) = $this->getData();
|
list($data, $targets) = $this->getData($this->data);
|
||||||
$classifier = new DecisionTree(5);
|
$classifier = new DecisionTree(5);
|
||||||
$classifier->train($data, $targets);
|
$classifier->train($data, $targets);
|
||||||
$this->assertEquals('Dont_play', $classifier->predict(['sunny', 78, 72, 'false']));
|
$this->assertEquals('Dont_play', $classifier->predict(['sunny', 78, 72, 'false']));
|
||||||
$this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false']));
|
$this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false']));
|
||||||
$this->assertEquals('Dont_play', $classifier->predict(['rain', 60, 60, 'true']));
|
$this->assertEquals('Dont_play', $classifier->predict(['rain', 60, 60, 'true']));
|
||||||
|
|
||||||
|
list($data, $targets) = $this->getData($this->extraData);
|
||||||
|
$classifier->train($data, $targets);
|
||||||
|
$this->assertEquals('Dont_play', $classifier->predict(['scorching', 95, 90, 'true']));
|
||||||
|
$this->assertEquals('Play', $classifier->predict(['overcast', 60, 60, 'false']));
|
||||||
return $classifier;
|
return $classifier;
|
||||||
}
|
}
|
||||||
|
|
||||||
public function testTreeDepth()
|
public function testTreeDepth()
|
||||||
{
|
{
|
||||||
list($data, $targets) = $this->getData();
|
list($data, $targets) = $this->getData($this->data);
|
||||||
$classifier = new DecisionTree(5);
|
$classifier = new DecisionTree(5);
|
||||||
$classifier->train($data, $targets);
|
$classifier->train($data, $targets);
|
||||||
$this->assertTrue(5 >= $classifier->actualDepth);
|
$this->assertTrue(5 >= $classifier->actualDepth);
|
||||||
|
@ -34,5 +34,15 @@ class NaiveBayesTest extends \PHPUnit_Framework_TestCase
|
|||||||
$predicted = $classifier->predict($testSamples);
|
$predicted = $classifier->predict($testSamples);
|
||||||
|
|
||||||
$this->assertEquals($testLabels, $predicted);
|
$this->assertEquals($testLabels, $predicted);
|
||||||
|
|
||||||
|
// Feed an extra set of training data.
|
||||||
|
$samples = [[1, 1, 6]];
|
||||||
|
$labels = ['d'];
|
||||||
|
$classifier->train($samples, $labels);
|
||||||
|
|
||||||
|
$testSamples = [[1, 1, 6], [5, 1, 1]];
|
||||||
|
$testLabels = ['d', 'a'];
|
||||||
|
$this->assertEquals($testLabels, $classifier->predict($testSamples));
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user