Neural networks partial training and persistency (#91)

* Neural networks partial training and persistency

* cs fixes

* Add partialTrain to nn docs

* Test for invalid partial training classes provided
This commit is contained in:
David Monllaó 2017-05-23 15:03:05 +08:00 committed by Arkadiusz Kondas
parent 3dff40ea1d
commit de50490154
7 changed files with 175 additions and 25 deletions

View File

@ -29,6 +29,19 @@ $mlp->train(
$samples = [[1, 0, 0, 0], [0, 1, 1, 0], [1, 1, 1, 1], [0, 0, 0, 0]], $samples = [[1, 0, 0, 0], [0, 1, 1, 0], [1, 1, 1, 1], [0, 0, 0, 0]],
$targets = ['a', 'a', 'b', 'c'] $targets = ['a', 'a', 'b', 'c']
); );
```
Use partialTrain method to train in batches. Example:
```
$mlp->partialTrain(
$samples = [[1, 0, 0, 0], [0, 1, 1, 0]],
$targets = ['a', 'a']
);
$mlp->partialTrain(
$samples = [[1, 1, 1, 1], [0, 0, 0, 0]],
$targets = ['b', 'c']
);
``` ```

View File

@ -4,17 +4,8 @@ declare(strict_types=1);
namespace Phpml\Classification; namespace Phpml\Classification;
use Phpml\Classification\Classifier;
use Phpml\Exception\InvalidArgumentException; use Phpml\Exception\InvalidArgumentException;
use Phpml\NeuralNetwork\Network\MultilayerPerceptron; use Phpml\NeuralNetwork\Network\MultilayerPerceptron;
use Phpml\NeuralNetwork\Training\Backpropagation;
use Phpml\NeuralNetwork\ActivationFunction;
use Phpml\NeuralNetwork\Layer;
use Phpml\NeuralNetwork\Node\Bias;
use Phpml\NeuralNetwork\Node\Input;
use Phpml\NeuralNetwork\Node\Neuron;
use Phpml\NeuralNetwork\Node\Neuron\Synapse;
use Phpml\Helper\Predictable;
class MLPClassifier extends MultilayerPerceptron implements Classifier class MLPClassifier extends MultilayerPerceptron implements Classifier
{ {

View File

@ -108,4 +108,8 @@ class InvalidArgumentException extends \Exception
return new self('Provide at least 2 different classes'); return new self('Provide at least 2 different classes');
} }
public static function inconsistentClasses()
{
return new self('The provided classes don\'t match the classes provided in the constructor');
}
} }

View File

@ -32,6 +32,14 @@ abstract class LayeredNetwork implements Network
return $this->layers; return $this->layers;
} }
/**
* @return void
*/
public function removeLayers()
{
unset($this->layers);
}
/** /**
* @return Layer * @return Layer
*/ */

View File

@ -5,6 +5,7 @@ declare(strict_types=1);
namespace Phpml\NeuralNetwork\Network; namespace Phpml\NeuralNetwork\Network;
use Phpml\Estimator; use Phpml\Estimator;
use Phpml\IncrementalEstimator;
use Phpml\Exception\InvalidArgumentException; use Phpml\Exception\InvalidArgumentException;
use Phpml\NeuralNetwork\Training\Backpropagation; use Phpml\NeuralNetwork\Training\Backpropagation;
use Phpml\NeuralNetwork\ActivationFunction; use Phpml\NeuralNetwork\ActivationFunction;
@ -15,10 +16,20 @@ use Phpml\NeuralNetwork\Node\Neuron;
use Phpml\NeuralNetwork\Node\Neuron\Synapse; use Phpml\NeuralNetwork\Node\Neuron\Synapse;
use Phpml\Helper\Predictable; use Phpml\Helper\Predictable;
abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator, IncrementalEstimator
{ {
use Predictable; use Predictable;
/**
* @var int
*/
private $inputLayerFeatures;
/**
* @var array
*/
private $hiddenLayers;
/** /**
* @var array * @var array
*/ */
@ -29,6 +40,16 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator
*/ */
private $iterations; private $iterations;
/**
* @var ActivationFunction
*/
protected $activationFunction;
/**
* @var int
*/
private $theta;
/** /**
* @var Backpropagation * @var Backpropagation
*/ */
@ -50,22 +71,33 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator
throw InvalidArgumentException::invalidLayersNumber(); throw InvalidArgumentException::invalidLayersNumber();
} }
$nClasses = count($classes); if (count($classes) < 2) {
if ($nClasses < 2) {
throw InvalidArgumentException::invalidClassesNumber(); throw InvalidArgumentException::invalidClassesNumber();
} }
$this->classes = array_values($classes); $this->classes = array_values($classes);
$this->iterations = $iterations; $this->iterations = $iterations;
$this->inputLayerFeatures = $inputLayerFeatures;
$this->hiddenLayers = $hiddenLayers;
$this->activationFunction = $activationFunction;
$this->theta = $theta;
$this->addInputLayer($inputLayerFeatures); $this->initNetwork();
$this->addNeuronLayers($hiddenLayers, $activationFunction); }
$this->addNeuronLayers([$nClasses], $activationFunction);
/**
* @return void
*/
private function initNetwork()
{
$this->addInputLayer($this->inputLayerFeatures);
$this->addNeuronLayers($this->hiddenLayers, $this->activationFunction);
$this->addNeuronLayers([count($this->classes)], $this->activationFunction);
$this->addBiasNodes(); $this->addBiasNodes();
$this->generateSynapses(); $this->generateSynapses();
$this->backpropagation = new Backpropagation($theta); $this->backpropagation = new Backpropagation($this->theta);
} }
/** /**
@ -74,6 +106,22 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator
*/ */
public function train(array $samples, array $targets) public function train(array $samples, array $targets)
{ {
$this->reset();
$this->initNetwork();
$this->partialTrain($samples, $targets, $this->classes);
}
/**
* @param array $samples
* @param array $targets
*/
public function partialTrain(array $samples, array $targets, array $classes = [])
{
if (!empty($classes) && array_values($classes) !== $this->classes) {
// We require the list of classes in the constructor.
throw InvalidArgumentException::inconsistentClasses();
}
for ($i = 0; $i < $this->iterations; ++$i) { for ($i = 0; $i < $this->iterations; ++$i) {
$this->trainSamples($samples, $targets); $this->trainSamples($samples, $targets);
} }
@ -83,13 +131,21 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator
* @param array $sample * @param array $sample
* @param mixed $target * @param mixed $target
*/ */
protected abstract function trainSample(array $sample, $target); abstract protected function trainSample(array $sample, $target);
/** /**
* @param array $sample * @param array $sample
* @return mixed * @return mixed
*/ */
protected abstract function predictSample(array $sample); abstract protected function predictSample(array $sample);
/**
* @return void
*/
protected function reset()
{
$this->removeLayers();
}
/** /**
* @param int $nodes * @param int $nodes

View File

@ -17,12 +17,12 @@ class Backpropagation
/** /**
* @var array * @var array
*/ */
private $sigmas; private $sigmas = null;
/** /**
* @var array * @var array
*/ */
private $prevSigmas; private $prevSigmas = null;
/** /**
* @param int $theta * @param int $theta
@ -38,14 +38,12 @@ class Backpropagation
*/ */
public function backpropagate(array $layers, $targetClass) public function backpropagate(array $layers, $targetClass)
{ {
$layersNumber = count($layers); $layersNumber = count($layers);
// Backpropagation. // Backpropagation.
for ($i = $layersNumber; $i > 1; --$i) { for ($i = $layersNumber; $i > 1; --$i) {
$this->sigmas = []; $this->sigmas = [];
foreach ($layers[$i - 1]->getNodes() as $key => $neuron) { foreach ($layers[$i - 1]->getNodes() as $key => $neuron) {
if ($neuron instanceof Neuron) { if ($neuron instanceof Neuron) {
$sigma = $this->getSigma($neuron, $targetClass, $key, $i == $layersNumber); $sigma = $this->getSigma($neuron, $targetClass, $key, $i == $layersNumber);
foreach ($neuron->getSynapses() as $synapse) { foreach ($neuron->getSynapses() as $synapse) {
@ -55,6 +53,10 @@ class Backpropagation
} }
$this->prevSigmas = $this->sigmas; $this->prevSigmas = $this->sigmas;
} }
// Clean some memory (also it helps make MLP persistency & children more maintainable).
$this->sigmas = null;
$this->prevSigmas = null;
} }
/** /**

View File

@ -5,8 +5,8 @@ declare(strict_types=1);
namespace tests\Phpml\Classification; namespace tests\Phpml\Classification;
use Phpml\Classification\MLPClassifier; use Phpml\Classification\MLPClassifier;
use Phpml\NeuralNetwork\Training\Backpropagation;
use Phpml\NeuralNetwork\Node\Neuron; use Phpml\NeuralNetwork\Node\Neuron;
use Phpml\ModelManager;
use PHPUnit\Framework\TestCase; use PHPUnit\Framework\TestCase;
class MLPClassifierTest extends TestCase class MLPClassifierTest extends TestCase
@ -53,7 +53,7 @@ class MLPClassifierTest extends TestCase
public function testBackpropagationLearning() public function testBackpropagationLearning()
{ {
// Single layer 2 classes. // Single layer 2 classes.
$network = new MLPClassifier(2, [2], ['a', 'b'], 1000); $network = new MLPClassifier(2, [2], ['a', 'b']);
$network->train( $network->train(
[[1, 0], [0, 1], [1, 1], [0, 0]], [[1, 0], [0, 1], [1, 1], [0, 0]],
['a', 'b', 'a', 'b'] ['a', 'b', 'a', 'b']
@ -65,6 +65,50 @@ class MLPClassifierTest extends TestCase
$this->assertEquals('b', $network->predict([0, 0])); $this->assertEquals('b', $network->predict([0, 0]));
} }
public function testBackpropagationTrainingReset()
{
// Single layer 2 classes.
$network = new MLPClassifier(2, [2], ['a', 'b'], 1000);
$network->train(
[[1, 0], [0, 1]],
['a', 'b']
);
$this->assertEquals('a', $network->predict([1, 0]));
$this->assertEquals('b', $network->predict([0, 1]));
$network->train(
[[1, 0], [0, 1]],
['b', 'a']
);
$this->assertEquals('b', $network->predict([1, 0]));
$this->assertEquals('a', $network->predict([0, 1]));
}
public function testBackpropagationPartialTraining()
{
// Single layer 2 classes.
$network = new MLPClassifier(2, [2], ['a', 'b'], 1000);
$network->partialTrain(
[[1, 0], [0, 1]],
['a', 'b']
);
$this->assertEquals('a', $network->predict([1, 0]));
$this->assertEquals('b', $network->predict([0, 1]));
$network->partialTrain(
[[1, 1], [0, 0]],
['a', 'b']
);
$this->assertEquals('a', $network->predict([1, 0]));
$this->assertEquals('b', $network->predict([0, 1]));
$this->assertEquals('a', $network->predict([1, 1]));
$this->assertEquals('b', $network->predict([0, 0]));
}
public function testBackpropagationLearningMultilayer() public function testBackpropagationLearningMultilayer()
{ {
// Multi-layer 2 classes. // Multi-layer 2 classes.
@ -96,6 +140,26 @@ class MLPClassifierTest extends TestCase
$this->assertEquals(4, $network->predict([0, 0, 0, 0, 0])); $this->assertEquals(4, $network->predict([0, 0, 0, 0, 0]));
} }
public function testSaveAndRestore()
{
// Instantinate new Percetron trained for OR problem
$samples = [[0, 0], [1, 0], [0, 1], [1, 1]];
$targets = [0, 1, 1, 1];
$classifier = new MLPClassifier(2, [2], [0, 1]);
$classifier->train($samples, $targets);
$testSamples = [[0, 0], [1, 0], [0, 1], [1, 1]];
$predicted = $classifier->predict($testSamples);
$filename = 'perceptron-test-'.rand(100, 999).'-'.uniqid();
$filepath = tempnam(sys_get_temp_dir(), $filename);
$modelManager = new ModelManager();
$modelManager->saveToFile($classifier, $filepath);
$restoredClassifier = $modelManager->restoreFromFile($filepath);
$this->assertEquals($classifier, $restoredClassifier);
$this->assertEquals($predicted, $restoredClassifier->predict($testSamples));
}
/** /**
* @expectedException \Phpml\Exception\InvalidArgumentException * @expectedException \Phpml\Exception\InvalidArgumentException
*/ */
@ -104,6 +168,18 @@ class MLPClassifierTest extends TestCase
new MLPClassifier(2, [], [0, 1]); new MLPClassifier(2, [], [0, 1]);
} }
/**
* @expectedException \Phpml\Exception\InvalidArgumentException
*/
public function testThrowExceptionOnInvalidPartialTrainingClasses()
{
$classifier = new MLPClassifier(2, [2], [0, 1]);
$classifier->partialTrain(
[[0, 1], [1, 0]],
[0, 2],
[0, 1, 2]
);
}
/** /**
* @expectedException \Phpml\Exception\InvalidArgumentException * @expectedException \Phpml\Exception\InvalidArgumentException
*/ */