diff --git a/docs/machine-learning/neural-network/multilayer-perceptron-classifier.md b/docs/machine-learning/neural-network/multilayer-perceptron-classifier.md index 6f11f68..d2f746d 100644 --- a/docs/machine-learning/neural-network/multilayer-perceptron-classifier.md +++ b/docs/machine-learning/neural-network/multilayer-perceptron-classifier.md @@ -29,6 +29,19 @@ $mlp->train( $samples = [[1, 0, 0, 0], [0, 1, 1, 0], [1, 1, 1, 1], [0, 0, 0, 0]], $targets = ['a', 'a', 'b', 'c'] ); +``` + +Use partialTrain method to train in batches. Example: + +``` +$mlp->partialTrain( + $samples = [[1, 0, 0, 0], [0, 1, 1, 0]], + $targets = ['a', 'a'] +); +$mlp->partialTrain( + $samples = [[1, 1, 1, 1], [0, 0, 0, 0]], + $targets = ['b', 'c'] +); ``` diff --git a/src/Phpml/Classification/MLPClassifier.php b/src/Phpml/Classification/MLPClassifier.php index c5d00bf..bde49a2 100644 --- a/src/Phpml/Classification/MLPClassifier.php +++ b/src/Phpml/Classification/MLPClassifier.php @@ -4,17 +4,8 @@ declare(strict_types=1); namespace Phpml\Classification; -use Phpml\Classification\Classifier; use Phpml\Exception\InvalidArgumentException; use Phpml\NeuralNetwork\Network\MultilayerPerceptron; -use Phpml\NeuralNetwork\Training\Backpropagation; -use Phpml\NeuralNetwork\ActivationFunction; -use Phpml\NeuralNetwork\Layer; -use Phpml\NeuralNetwork\Node\Bias; -use Phpml\NeuralNetwork\Node\Input; -use Phpml\NeuralNetwork\Node\Neuron; -use Phpml\NeuralNetwork\Node\Neuron\Synapse; -use Phpml\Helper\Predictable; class MLPClassifier extends MultilayerPerceptron implements Classifier { diff --git a/src/Phpml/Exception/InvalidArgumentException.php b/src/Phpml/Exception/InvalidArgumentException.php index 3e2bff5..277aecd 100644 --- a/src/Phpml/Exception/InvalidArgumentException.php +++ b/src/Phpml/Exception/InvalidArgumentException.php @@ -108,4 +108,8 @@ class InvalidArgumentException extends \Exception return new self('Provide at least 2 different classes'); } + public static function inconsistentClasses() + { + return new self('The provided classes don\'t match the classes provided in the constructor'); + } } diff --git a/src/Phpml/NeuralNetwork/Network/LayeredNetwork.php b/src/Phpml/NeuralNetwork/Network/LayeredNetwork.php index cd90e3f..b20f6bb 100644 --- a/src/Phpml/NeuralNetwork/Network/LayeredNetwork.php +++ b/src/Phpml/NeuralNetwork/Network/LayeredNetwork.php @@ -32,6 +32,14 @@ abstract class LayeredNetwork implements Network return $this->layers; } + /** + * @return void + */ + public function removeLayers() + { + unset($this->layers); + } + /** * @return Layer */ diff --git a/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php b/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php index 5d7f94e..2503774 100644 --- a/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php +++ b/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php @@ -5,6 +5,7 @@ declare(strict_types=1); namespace Phpml\NeuralNetwork\Network; use Phpml\Estimator; +use Phpml\IncrementalEstimator; use Phpml\Exception\InvalidArgumentException; use Phpml\NeuralNetwork\Training\Backpropagation; use Phpml\NeuralNetwork\ActivationFunction; @@ -15,10 +16,20 @@ use Phpml\NeuralNetwork\Node\Neuron; use Phpml\NeuralNetwork\Node\Neuron\Synapse; use Phpml\Helper\Predictable; -abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator +abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator, IncrementalEstimator { use Predictable; + /** + * @var int + */ + private $inputLayerFeatures; + + /** + * @var array + */ + private $hiddenLayers; + /** * @var array */ @@ -29,6 +40,16 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator */ private $iterations; + /** + * @var ActivationFunction + */ + protected $activationFunction; + + /** + * @var int + */ + private $theta; + /** * @var Backpropagation */ @@ -50,22 +71,33 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator throw InvalidArgumentException::invalidLayersNumber(); } - $nClasses = count($classes); - if ($nClasses < 2) { + if (count($classes) < 2) { throw InvalidArgumentException::invalidClassesNumber(); } + $this->classes = array_values($classes); - $this->iterations = $iterations; + $this->inputLayerFeatures = $inputLayerFeatures; + $this->hiddenLayers = $hiddenLayers; + $this->activationFunction = $activationFunction; + $this->theta = $theta; - $this->addInputLayer($inputLayerFeatures); - $this->addNeuronLayers($hiddenLayers, $activationFunction); - $this->addNeuronLayers([$nClasses], $activationFunction); + $this->initNetwork(); + } + + /** + * @return void + */ + private function initNetwork() + { + $this->addInputLayer($this->inputLayerFeatures); + $this->addNeuronLayers($this->hiddenLayers, $this->activationFunction); + $this->addNeuronLayers([count($this->classes)], $this->activationFunction); $this->addBiasNodes(); $this->generateSynapses(); - $this->backpropagation = new Backpropagation($theta); + $this->backpropagation = new Backpropagation($this->theta); } /** @@ -74,6 +106,22 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator */ public function train(array $samples, array $targets) { + $this->reset(); + $this->initNetwork(); + $this->partialTrain($samples, $targets, $this->classes); + } + + /** + * @param array $samples + * @param array $targets + */ + public function partialTrain(array $samples, array $targets, array $classes = []) + { + if (!empty($classes) && array_values($classes) !== $this->classes) { + // We require the list of classes in the constructor. + throw InvalidArgumentException::inconsistentClasses(); + } + for ($i = 0; $i < $this->iterations; ++$i) { $this->trainSamples($samples, $targets); } @@ -83,13 +131,21 @@ abstract class MultilayerPerceptron extends LayeredNetwork implements Estimator * @param array $sample * @param mixed $target */ - protected abstract function trainSample(array $sample, $target); + abstract protected function trainSample(array $sample, $target); /** * @param array $sample * @return mixed */ - protected abstract function predictSample(array $sample); + abstract protected function predictSample(array $sample); + + /** + * @return void + */ + protected function reset() + { + $this->removeLayers(); + } /** * @param int $nodes diff --git a/src/Phpml/NeuralNetwork/Training/Backpropagation.php b/src/Phpml/NeuralNetwork/Training/Backpropagation.php index 741db2b..ba90b45 100644 --- a/src/Phpml/NeuralNetwork/Training/Backpropagation.php +++ b/src/Phpml/NeuralNetwork/Training/Backpropagation.php @@ -17,12 +17,12 @@ class Backpropagation /** * @var array */ - private $sigmas; + private $sigmas = null; /** * @var array */ - private $prevSigmas; + private $prevSigmas = null; /** * @param int $theta @@ -38,14 +38,12 @@ class Backpropagation */ public function backpropagate(array $layers, $targetClass) { - $layersNumber = count($layers); // Backpropagation. for ($i = $layersNumber; $i > 1; --$i) { $this->sigmas = []; foreach ($layers[$i - 1]->getNodes() as $key => $neuron) { - if ($neuron instanceof Neuron) { $sigma = $this->getSigma($neuron, $targetClass, $key, $i == $layersNumber); foreach ($neuron->getSynapses() as $synapse) { @@ -55,6 +53,10 @@ class Backpropagation } $this->prevSigmas = $this->sigmas; } + + // Clean some memory (also it helps make MLP persistency & children more maintainable). + $this->sigmas = null; + $this->prevSigmas = null; } /** diff --git a/tests/Phpml/Classification/MLPClassifierTest.php b/tests/Phpml/Classification/MLPClassifierTest.php index 9f8b3fc..3a009c3 100644 --- a/tests/Phpml/Classification/MLPClassifierTest.php +++ b/tests/Phpml/Classification/MLPClassifierTest.php @@ -5,8 +5,8 @@ declare(strict_types=1); namespace tests\Phpml\Classification; use Phpml\Classification\MLPClassifier; -use Phpml\NeuralNetwork\Training\Backpropagation; use Phpml\NeuralNetwork\Node\Neuron; +use Phpml\ModelManager; use PHPUnit\Framework\TestCase; class MLPClassifierTest extends TestCase @@ -53,7 +53,7 @@ class MLPClassifierTest extends TestCase public function testBackpropagationLearning() { // Single layer 2 classes. - $network = new MLPClassifier(2, [2], ['a', 'b'], 1000); + $network = new MLPClassifier(2, [2], ['a', 'b']); $network->train( [[1, 0], [0, 1], [1, 1], [0, 0]], ['a', 'b', 'a', 'b'] @@ -65,6 +65,50 @@ class MLPClassifierTest extends TestCase $this->assertEquals('b', $network->predict([0, 0])); } + public function testBackpropagationTrainingReset() + { + // Single layer 2 classes. + $network = new MLPClassifier(2, [2], ['a', 'b'], 1000); + $network->train( + [[1, 0], [0, 1]], + ['a', 'b'] + ); + + $this->assertEquals('a', $network->predict([1, 0])); + $this->assertEquals('b', $network->predict([0, 1])); + + $network->train( + [[1, 0], [0, 1]], + ['b', 'a'] + ); + + $this->assertEquals('b', $network->predict([1, 0])); + $this->assertEquals('a', $network->predict([0, 1])); + } + + public function testBackpropagationPartialTraining() + { + // Single layer 2 classes. + $network = new MLPClassifier(2, [2], ['a', 'b'], 1000); + $network->partialTrain( + [[1, 0], [0, 1]], + ['a', 'b'] + ); + + $this->assertEquals('a', $network->predict([1, 0])); + $this->assertEquals('b', $network->predict([0, 1])); + + $network->partialTrain( + [[1, 1], [0, 0]], + ['a', 'b'] + ); + + $this->assertEquals('a', $network->predict([1, 0])); + $this->assertEquals('b', $network->predict([0, 1])); + $this->assertEquals('a', $network->predict([1, 1])); + $this->assertEquals('b', $network->predict([0, 0])); + } + public function testBackpropagationLearningMultilayer() { // Multi-layer 2 classes. @@ -96,6 +140,26 @@ class MLPClassifierTest extends TestCase $this->assertEquals(4, $network->predict([0, 0, 0, 0, 0])); } + public function testSaveAndRestore() + { + // Instantinate new Percetron trained for OR problem + $samples = [[0, 0], [1, 0], [0, 1], [1, 1]]; + $targets = [0, 1, 1, 1]; + $classifier = new MLPClassifier(2, [2], [0, 1]); + $classifier->train($samples, $targets); + $testSamples = [[0, 0], [1, 0], [0, 1], [1, 1]]; + $predicted = $classifier->predict($testSamples); + + $filename = 'perceptron-test-'.rand(100, 999).'-'.uniqid(); + $filepath = tempnam(sys_get_temp_dir(), $filename); + $modelManager = new ModelManager(); + $modelManager->saveToFile($classifier, $filepath); + + $restoredClassifier = $modelManager->restoreFromFile($filepath); + $this->assertEquals($classifier, $restoredClassifier); + $this->assertEquals($predicted, $restoredClassifier->predict($testSamples)); + } + /** * @expectedException \Phpml\Exception\InvalidArgumentException */ @@ -104,6 +168,18 @@ class MLPClassifierTest extends TestCase new MLPClassifier(2, [], [0, 1]); } + /** + * @expectedException \Phpml\Exception\InvalidArgumentException + */ + public function testThrowExceptionOnInvalidPartialTrainingClasses() + { + $classifier = new MLPClassifier(2, [2], [0, 1]); + $classifier->partialTrain( + [[0, 1], [1, 0]], + [0, 2], + [0, 1, 2] + ); + } /** * @expectedException \Phpml\Exception\InvalidArgumentException */