mirror of
https://github.com/Llewellynvdm/php-ml.git
synced 2024-11-11 08:10:56 +00:00
Neural networks partial training and persistency (#91)
* Neural networks partial training and persistency * cs fixes * Add partialTrain to nn docs * Test for invalid partial training classes provided
This commit is contained in:
parent
3dff40ea1d
commit
de50490154
@ -29,6 +29,19 @@ $mlp->train(
|
|||||||
$samples = [[1, 0, 0, 0], [0, 1, 1, 0], [1, 1, 1, 1], [0, 0, 0, 0]],
|
$samples = [[1, 0, 0, 0], [0, 1, 1, 0], [1, 1, 1, 1], [0, 0, 0, 0]],
|
||||||
$targets = ['a', 'a', 'b', 'c']
|
$targets = ['a', 'a', 'b', 'c']
|
||||||
);
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
Use partialTrain method to train in batches. Example:
|
||||||
|
|
||||||
|
```
|
||||||
|
$mlp->partialTrain(
|
||||||
|
$samples = [[1, 0, 0, 0], [0, 1, 1, 0]],
|
||||||
|
$targets = ['a', 'a']
|
||||||
|
);
|
||||||
|
$mlp->partialTrain(
|
||||||
|
$samples = [[1, 1, 1, 1], [0, 0, 0, 0]],
|
||||||
|
$targets = ['b', 'c']
|
||||||
|
);
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -4,17 +4,8 @@ declare(strict_types=1);
|
|||||||
|
|
||||||
namespace Phpml\Classification;
|
namespace Phpml\Classification;
|
||||||
|
|
||||||
use Phpml\Classification\Classifier;
|
|
||||||
use Phpml\Exception\InvalidArgumentException;
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
use Phpml\NeuralNetwork\Network\MultilayerPerceptron;
|
use Phpml\NeuralNetwork\Network\MultilayerPerceptron;
|
||||||
use Phpml\NeuralNetwork\Training\Backpropagation;
|
|
||||||
use Phpml\NeuralNetwork\ActivationFunction;
|
|
||||||
use Phpml\NeuralNetwork\Layer;
|
|
||||||
use Phpml\NeuralNetwork\Node\Bias;
|
|
||||||
use Phpml\NeuralNetwork\Node\Input;
|
|
||||||
use Phpml\NeuralNetwork\Node\Neuron;
|
|
||||||
use Phpml\NeuralNetwork\Node\Neuron\Synapse;
|
|
||||||
use Phpml\Helper\Predictable;
|
|
||||||
|
|
||||||
class MLPClassifier extends MultilayerPerceptron implements Classifier
|
class MLPClassifier extends MultilayerPerceptron implements Classifier
|
||||||
{
|
{
|
||||||
|
@ -108,4 +108,8 @@ class InvalidArgumentException extends \Exception
|
|||||||
return new self('Provide at least 2 different classes');
|
return new self('Provide at least 2 different classes');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static function inconsistentClasses()
|
||||||
|
{
|
||||||
|
return new self('The provided classes don\'t match the classes provided in the constructor');
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,14 @@ abstract class LayeredNetwork implements Network
|
|||||||
return $this->layers;
|
return $this->layers;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return void
|
||||||
|
*/
|
||||||
|
public function removeLayers()
|
||||||
|
{
|
||||||
|
unset($this->layers);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @return Layer
|
* @return Layer
|
||||||
*/
|
*/
|
||||||
|
@ -5,6 +5,7 @@ declare(strict_types=1);
|
|||||||
namespace Phpml\NeuralNetwork\Network;
|
namespace Phpml\NeuralNetwork\Network;
|
||||||
|
|
||||||
use Phpml\Estimator;
|
use Phpml\Estimator;
|
||||||
|
use Phpml\IncrementalEstimator;
|
||||||
use Phpml\Exception\InvalidArgumentException;
|
use Phpml\Exception\InvalidArgumentException;
|
||||||
use Phpml\NeuralNetwork\Training\Backpropagation;
|
use Phpml\NeuralNetwork\Training\Backpropagation;
|
||||||
use Phpml\NeuralNetwork\ActivationFunction;
|
use Phpml\NeuralNetwork\ActivationFunction;
|
||||||
@ -15,10 +16,20 @@ use Phpml\NeuralNetwork\Node\Neuron;
|
|||||||
use Phpml\NeuralNetwork\Node\Neuron\Synapse;
|
use Phpml\NeuralNetwork\Node\Neuron\Synapse;
|
||||||
use Phpml\Helper\Predictable;
|
use Phpml\Helper\Predictable;
|
||||||
|
|
||||||
abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator
|
abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator, IncrementalEstimator
|
||||||
{
|
{
|
||||||
use Predictable;
|
use Predictable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
private $inputLayerFeatures;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var array
|
||||||
|
*/
|
||||||
|
private $hiddenLayers;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @var array
|
* @var array
|
||||||
*/
|
*/
|
||||||
@ -29,6 +40,16 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator
|
|||||||
*/
|
*/
|
||||||
private $iterations;
|
private $iterations;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var ActivationFunction
|
||||||
|
*/
|
||||||
|
protected $activationFunction;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @var int
|
||||||
|
*/
|
||||||
|
private $theta;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @var Backpropagation
|
* @var Backpropagation
|
||||||
*/
|
*/
|
||||||
@ -50,22 +71,33 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator
|
|||||||
throw InvalidArgumentException::invalidLayersNumber();
|
throw InvalidArgumentException::invalidLayersNumber();
|
||||||
}
|
}
|
||||||
|
|
||||||
$nClasses = count($classes);
|
if (count($classes) < 2) {
|
||||||
if ($nClasses < 2) {
|
|
||||||
throw InvalidArgumentException::invalidClassesNumber();
|
throw InvalidArgumentException::invalidClassesNumber();
|
||||||
}
|
}
|
||||||
|
|
||||||
$this->classes = array_values($classes);
|
$this->classes = array_values($classes);
|
||||||
|
|
||||||
$this->iterations = $iterations;
|
$this->iterations = $iterations;
|
||||||
|
$this->inputLayerFeatures = $inputLayerFeatures;
|
||||||
|
$this->hiddenLayers = $hiddenLayers;
|
||||||
|
$this->activationFunction = $activationFunction;
|
||||||
|
$this->theta = $theta;
|
||||||
|
|
||||||
$this->addInputLayer($inputLayerFeatures);
|
$this->initNetwork();
|
||||||
$this->addNeuronLayers($hiddenLayers, $activationFunction);
|
}
|
||||||
$this->addNeuronLayers([$nClasses], $activationFunction);
|
|
||||||
|
/**
|
||||||
|
* @return void
|
||||||
|
*/
|
||||||
|
private function initNetwork()
|
||||||
|
{
|
||||||
|
$this->addInputLayer($this->inputLayerFeatures);
|
||||||
|
$this->addNeuronLayers($this->hiddenLayers, $this->activationFunction);
|
||||||
|
$this->addNeuronLayers([count($this->classes)], $this->activationFunction);
|
||||||
|
|
||||||
$this->addBiasNodes();
|
$this->addBiasNodes();
|
||||||
$this->generateSynapses();
|
$this->generateSynapses();
|
||||||
|
|
||||||
$this->backpropagation = new Backpropagation($theta);
|
$this->backpropagation = new Backpropagation($this->theta);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -74,6 +106,22 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator
|
|||||||
*/
|
*/
|
||||||
public function train(array $samples, array $targets)
|
public function train(array $samples, array $targets)
|
||||||
{
|
{
|
||||||
|
$this->reset();
|
||||||
|
$this->initNetwork();
|
||||||
|
$this->partialTrain($samples, $targets, $this->classes);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param array $samples
|
||||||
|
* @param array $targets
|
||||||
|
*/
|
||||||
|
public function partialTrain(array $samples, array $targets, array $classes = [])
|
||||||
|
{
|
||||||
|
if (!empty($classes) && array_values($classes) !== $this->classes) {
|
||||||
|
// We require the list of classes in the constructor.
|
||||||
|
throw InvalidArgumentException::inconsistentClasses();
|
||||||
|
}
|
||||||
|
|
||||||
for ($i = 0; $i < $this->iterations; ++$i) {
|
for ($i = 0; $i < $this->iterations; ++$i) {
|
||||||
$this->trainSamples($samples, $targets);
|
$this->trainSamples($samples, $targets);
|
||||||
}
|
}
|
||||||
@ -83,13 +131,21 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator
|
|||||||
* @param array $sample
|
* @param array $sample
|
||||||
* @param mixed $target
|
* @param mixed $target
|
||||||
*/
|
*/
|
||||||
protected abstract function trainSample(array $sample, $target);
|
abstract protected function trainSample(array $sample, $target);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param array $sample
|
* @param array $sample
|
||||||
* @return mixed
|
* @return mixed
|
||||||
*/
|
*/
|
||||||
protected abstract function predictSample(array $sample);
|
abstract protected function predictSample(array $sample);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return void
|
||||||
|
*/
|
||||||
|
protected function reset()
|
||||||
|
{
|
||||||
|
$this->removeLayers();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param int $nodes
|
* @param int $nodes
|
||||||
|
@ -17,12 +17,12 @@ class Backpropagation
|
|||||||
/**
|
/**
|
||||||
* @var array
|
* @var array
|
||||||
*/
|
*/
|
||||||
private $sigmas;
|
private $sigmas = null;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @var array
|
* @var array
|
||||||
*/
|
*/
|
||||||
private $prevSigmas;
|
private $prevSigmas = null;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param int $theta
|
* @param int $theta
|
||||||
@ -38,14 +38,12 @@ class Backpropagation
|
|||||||
*/
|
*/
|
||||||
public function backpropagate(array $layers, $targetClass)
|
public function backpropagate(array $layers, $targetClass)
|
||||||
{
|
{
|
||||||
|
|
||||||
$layersNumber = count($layers);
|
$layersNumber = count($layers);
|
||||||
|
|
||||||
// Backpropagation.
|
// Backpropagation.
|
||||||
for ($i = $layersNumber; $i > 1; --$i) {
|
for ($i = $layersNumber; $i > 1; --$i) {
|
||||||
$this->sigmas = [];
|
$this->sigmas = [];
|
||||||
foreach ($layers[$i - 1]->getNodes() as $key => $neuron) {
|
foreach ($layers[$i - 1]->getNodes() as $key => $neuron) {
|
||||||
|
|
||||||
if ($neuron instanceof Neuron) {
|
if ($neuron instanceof Neuron) {
|
||||||
$sigma = $this->getSigma($neuron, $targetClass, $key, $i == $layersNumber);
|
$sigma = $this->getSigma($neuron, $targetClass, $key, $i == $layersNumber);
|
||||||
foreach ($neuron->getSynapses() as $synapse) {
|
foreach ($neuron->getSynapses() as $synapse) {
|
||||||
@ -55,6 +53,10 @@ class Backpropagation
|
|||||||
}
|
}
|
||||||
$this->prevSigmas = $this->sigmas;
|
$this->prevSigmas = $this->sigmas;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Clean some memory (also it helps make MLP persistency & children more maintainable).
|
||||||
|
$this->sigmas = null;
|
||||||
|
$this->prevSigmas = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -5,8 +5,8 @@ declare(strict_types=1);
|
|||||||
namespace tests\Phpml\Classification;
|
namespace tests\Phpml\Classification;
|
||||||
|
|
||||||
use Phpml\Classification\MLPClassifier;
|
use Phpml\Classification\MLPClassifier;
|
||||||
use Phpml\NeuralNetwork\Training\Backpropagation;
|
|
||||||
use Phpml\NeuralNetwork\Node\Neuron;
|
use Phpml\NeuralNetwork\Node\Neuron;
|
||||||
|
use Phpml\ModelManager;
|
||||||
use PHPUnit\Framework\TestCase;
|
use PHPUnit\Framework\TestCase;
|
||||||
|
|
||||||
class MLPClassifierTest extends TestCase
|
class MLPClassifierTest extends TestCase
|
||||||
@ -53,7 +53,7 @@ class MLPClassifierTest extends TestCase
|
|||||||
public function testBackpropagationLearning()
|
public function testBackpropagationLearning()
|
||||||
{
|
{
|
||||||
// Single layer 2 classes.
|
// Single layer 2 classes.
|
||||||
$network = new MLPClassifier(2, [2], ['a', 'b'], 1000);
|
$network = new MLPClassifier(2, [2], ['a', 'b']);
|
||||||
$network->train(
|
$network->train(
|
||||||
[[1, 0], [0, 1], [1, 1], [0, 0]],
|
[[1, 0], [0, 1], [1, 1], [0, 0]],
|
||||||
['a', 'b', 'a', 'b']
|
['a', 'b', 'a', 'b']
|
||||||
@ -65,6 +65,50 @@ class MLPClassifierTest extends TestCase
|
|||||||
$this->assertEquals('b', $network->predict([0, 0]));
|
$this->assertEquals('b', $network->predict([0, 0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function testBackpropagationTrainingReset()
|
||||||
|
{
|
||||||
|
// Single layer 2 classes.
|
||||||
|
$network = new MLPClassifier(2, [2], ['a', 'b'], 1000);
|
||||||
|
$network->train(
|
||||||
|
[[1, 0], [0, 1]],
|
||||||
|
['a', 'b']
|
||||||
|
);
|
||||||
|
|
||||||
|
$this->assertEquals('a', $network->predict([1, 0]));
|
||||||
|
$this->assertEquals('b', $network->predict([0, 1]));
|
||||||
|
|
||||||
|
$network->train(
|
||||||
|
[[1, 0], [0, 1]],
|
||||||
|
['b', 'a']
|
||||||
|
);
|
||||||
|
|
||||||
|
$this->assertEquals('b', $network->predict([1, 0]));
|
||||||
|
$this->assertEquals('a', $network->predict([0, 1]));
|
||||||
|
}
|
||||||
|
|
||||||
|
public function testBackpropagationPartialTraining()
|
||||||
|
{
|
||||||
|
// Single layer 2 classes.
|
||||||
|
$network = new MLPClassifier(2, [2], ['a', 'b'], 1000);
|
||||||
|
$network->partialTrain(
|
||||||
|
[[1, 0], [0, 1]],
|
||||||
|
['a', 'b']
|
||||||
|
);
|
||||||
|
|
||||||
|
$this->assertEquals('a', $network->predict([1, 0]));
|
||||||
|
$this->assertEquals('b', $network->predict([0, 1]));
|
||||||
|
|
||||||
|
$network->partialTrain(
|
||||||
|
[[1, 1], [0, 0]],
|
||||||
|
['a', 'b']
|
||||||
|
);
|
||||||
|
|
||||||
|
$this->assertEquals('a', $network->predict([1, 0]));
|
||||||
|
$this->assertEquals('b', $network->predict([0, 1]));
|
||||||
|
$this->assertEquals('a', $network->predict([1, 1]));
|
||||||
|
$this->assertEquals('b', $network->predict([0, 0]));
|
||||||
|
}
|
||||||
|
|
||||||
public function testBackpropagationLearningMultilayer()
|
public function testBackpropagationLearningMultilayer()
|
||||||
{
|
{
|
||||||
// Multi-layer 2 classes.
|
// Multi-layer 2 classes.
|
||||||
@ -96,6 +140,26 @@ class MLPClassifierTest extends TestCase
|
|||||||
$this->assertEquals(4, $network->predict([0, 0, 0, 0, 0]));
|
$this->assertEquals(4, $network->predict([0, 0, 0, 0, 0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public function testSaveAndRestore()
|
||||||
|
{
|
||||||
|
// Instantinate new Percetron trained for OR problem
|
||||||
|
$samples = [[0, 0], [1, 0], [0, 1], [1, 1]];
|
||||||
|
$targets = [0, 1, 1, 1];
|
||||||
|
$classifier = new MLPClassifier(2, [2], [0, 1]);
|
||||||
|
$classifier->train($samples, $targets);
|
||||||
|
$testSamples = [[0, 0], [1, 0], [0, 1], [1, 1]];
|
||||||
|
$predicted = $classifier->predict($testSamples);
|
||||||
|
|
||||||
|
$filename = 'perceptron-test-'.rand(100, 999).'-'.uniqid();
|
||||||
|
$filepath = tempnam(sys_get_temp_dir(), $filename);
|
||||||
|
$modelManager = new ModelManager();
|
||||||
|
$modelManager->saveToFile($classifier, $filepath);
|
||||||
|
|
||||||
|
$restoredClassifier = $modelManager->restoreFromFile($filepath);
|
||||||
|
$this->assertEquals($classifier, $restoredClassifier);
|
||||||
|
$this->assertEquals($predicted, $restoredClassifier->predict($testSamples));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @expectedException \Phpml\Exception\InvalidArgumentException
|
* @expectedException \Phpml\Exception\InvalidArgumentException
|
||||||
*/
|
*/
|
||||||
@ -104,6 +168,18 @@ class MLPClassifierTest extends TestCase
|
|||||||
new MLPClassifier(2, [], [0, 1]);
|
new MLPClassifier(2, [], [0, 1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @expectedException \Phpml\Exception\InvalidArgumentException
|
||||||
|
*/
|
||||||
|
public function testThrowExceptionOnInvalidPartialTrainingClasses()
|
||||||
|
{
|
||||||
|
$classifier = new MLPClassifier(2, [2], [0, 1]);
|
||||||
|
$classifier->partialTrain(
|
||||||
|
[[0, 1], [1, 0]],
|
||||||
|
[0, 2],
|
||||||
|
[0, 1, 2]
|
||||||
|
);
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* @expectedException \Phpml\Exception\InvalidArgumentException
|
* @expectedException \Phpml\Exception\InvalidArgumentException
|
||||||
*/
|
*/
|
||||||
|
Loading…
Reference in New Issue
Block a user