diff --git a/tests/Classification/MLPClassifierTest.php b/tests/Classification/MLPClassifierTest.php index c4c45c4..d3680b6 100644 --- a/tests/Classification/MLPClassifierTest.php +++ b/tests/Classification/MLPClassifierTest.php @@ -193,6 +193,35 @@ class MLPClassifierTest extends TestCase $this->assertEquals($predicted, $restoredClassifier->predict($testSamples)); } + public function testSaveAndRestoreWithPartialTraining(): void + { + $network = new MLPClassifier(2, [2], ['a', 'b'], 1000); + $network->partialTrain( + [[1, 0], [0, 1]], + ['a', 'b'] + ); + + $this->assertEquals('a', $network->predict([1, 0])); + $this->assertEquals('b', $network->predict([0, 1])); + + $filename = 'perceptron-test-'.random_int(100, 999).'-'.uniqid(); + $filepath = tempnam(sys_get_temp_dir(), $filename); + $modelManager = new ModelManager(); + $modelManager->saveToFile($network, $filepath); + + /** @var MLPClassifier $restoredNetwork */ + $restoredNetwork = $modelManager->restoreFromFile($filepath); + $restoredNetwork->partialTrain( + [[1, 1], [0, 0]], + ['a', 'b'] + ); + + $this->assertEquals('a', $restoredNetwork->predict([1, 0])); + $this->assertEquals('b', $restoredNetwork->predict([0, 1])); + $this->assertEquals('a', $restoredNetwork->predict([1, 1])); + $this->assertEquals('b', $restoredNetwork->predict([0, 0])); + } + public function testThrowExceptionOnInvalidLayersNumber(): void { $this->expectException(InvalidArgumentException::class);