From 8f122fde90dd099ff9566c0cf9ae34d63b3d6b2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Monlla=C3=B3?= Date: Thu, 2 Feb 2017 09:03:09 +0100 Subject: [PATCH] Persistence class to save and restore models (#37) * Models manager with save/restore capabilities * Refactoring dataset exceptions * Persistency layer docs * New tests for serializable estimators * ModelManager static methods to instance methods --- docs/index.md | 2 + .../model-manager/persistency.md | 24 +++++++++ mkdocs.yml | 2 + src/Phpml/Dataset/CsvDataset.php | 8 +-- src/Phpml/Exception/DatasetException.php | 19 ------- src/Phpml/Exception/FileException.php | 39 ++++++++++++++ src/Phpml/Exception/SerializeException.php | 30 +++++++++++ src/Phpml/ModelManager.php | 52 +++++++++++++++++++ tests/Phpml/Association/AprioriTest.php | 19 +++++++ .../Phpml/Classification/DecisionTreeTest.php | 21 ++++++++ .../Classification/KNearestNeighborsTest.php | 24 +++++++++ tests/Phpml/Classification/NaiveBayesTest.php | 24 +++++++++ tests/Phpml/Classification/SVCTest.php | 23 ++++++++ tests/Phpml/Dataset/CsvDatasetTest.php | 2 +- tests/Phpml/ModelManagerTest.php | 47 +++++++++++++++++ tests/Phpml/Regression/LeastSquaresTest.php | 25 +++++++++ tests/Phpml/Regression/SVRTest.php | 24 +++++++++ 17 files changed, 361 insertions(+), 24 deletions(-) create mode 100644 docs/machine-learning/model-manager/persistency.md create mode 100644 src/Phpml/Exception/FileException.php create mode 100644 src/Phpml/Exception/SerializeException.php create mode 100644 src/Phpml/ModelManager.php create mode 100644 tests/Phpml/ModelManagerTest.php diff --git a/docs/index.md b/docs/index.md index 423877f..156acb2 100644 --- a/docs/index.md +++ b/docs/index.md @@ -84,6 +84,8 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples]( * [Iris](machine-learning/datasets/demo/iris/) * [Wine](machine-learning/datasets/demo/wine/) * [Glass](machine-learning/datasets/demo/glass/) +* Models management + * [Persistency](machine-learning/model-manager/persistency/) * Math * [Distance](math/distance/) * [Matrix](math/matrix/) diff --git a/docs/machine-learning/model-manager/persistency.md b/docs/machine-learning/model-manager/persistency.md new file mode 100644 index 0000000..626ae42 --- /dev/null +++ b/docs/machine-learning/model-manager/persistency.md @@ -0,0 +1,24 @@ +# Persistency + +You can save trained models for future use. Persistency across requests achieved by saving and restoring serialized estimators into files. + +### Example + +``` +use Phpml\Classification\KNearestNeighbors; +use Phpml\ModelManager; + +$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; +$labels = ['a', 'a', 'a', 'b', 'b', 'b']; + +$classifier = new KNearestNeighbors(); +$classifier->train($samples, $labels); + +$filepath = '/path/to/store/the/model'; +$modelManager = new ModelManager(); +$modelManager->saveToFile($classifier, $filepath); + +$restoredClassifier = $modelManager->restoreFromFile($filepath); +$restoredClassifier->predict([3, 2]); +// return 'b' +``` diff --git a/mkdocs.yml b/mkdocs.yml index b404e28..433cc3e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -40,6 +40,8 @@ pages: - Iris: machine-learning/datasets/demo/iris.md - Wine: machine-learning/datasets/demo/wine.md - Glass: machine-learning/datasets/demo/glass.md + - Models management: + - Persistency: machine-learning/model-manager/persistency.md - Math: - Distance: math/distance.md - Matrix: math/matrix.md diff --git a/src/Phpml/Dataset/CsvDataset.php b/src/Phpml/Dataset/CsvDataset.php index ab9a2b7..dd722d4 100644 --- a/src/Phpml/Dataset/CsvDataset.php +++ b/src/Phpml/Dataset/CsvDataset.php @@ -4,7 +4,7 @@ declare(strict_types=1); namespace Phpml\Dataset; -use Phpml\Exception\DatasetException; +use Phpml\Exception\FileException; class CsvDataset extends ArrayDataset { @@ -13,16 +13,16 @@ class CsvDataset extends ArrayDataset * @param int $features * @param bool $headingRow * - * @throws DatasetException + * @throws FileException */ public function __construct(string $filepath, int $features, bool $headingRow = true) { if (!file_exists($filepath)) { - throw DatasetException::missingFile(basename($filepath)); + throw FileException::missingFile(basename($filepath)); } if (false === $handle = fopen($filepath, 'rb')) { - throw DatasetException::cantOpenFile(basename($filepath)); + throw FileException::cantOpenFile(basename($filepath)); } if ($headingRow) { diff --git a/src/Phpml/Exception/DatasetException.php b/src/Phpml/Exception/DatasetException.php index 85f911f..6092053 100644 --- a/src/Phpml/Exception/DatasetException.php +++ b/src/Phpml/Exception/DatasetException.php @@ -6,15 +6,6 @@ namespace Phpml\Exception; class DatasetException extends \Exception { - /** - * @param string $filepath - * - * @return DatasetException - */ - public static function missingFile(string $filepath) - { - return new self(sprintf('Dataset file "%s" missing.', $filepath)); - } /** * @param string $path @@ -25,14 +16,4 @@ class DatasetException extends \Exception { return new self(sprintf('Dataset root folder "%s" missing.', $path)); } - - /** - * @param string $filepath - * - * @return DatasetException - */ - public static function cantOpenFile(string $filepath) - { - return new self(sprintf('Dataset file "%s" can\'t be open.', $filepath)); - } } diff --git a/src/Phpml/Exception/FileException.php b/src/Phpml/Exception/FileException.php new file mode 100644 index 0000000..558ae48 --- /dev/null +++ b/src/Phpml/Exception/FileException.php @@ -0,0 +1,39 @@ +invokeArgs($object, $params); } + + public function testSaveAndRestore() + { + $classifier = new Apriori(0.5, 0.5); + $classifier->train($this->sampleGreek, []); + + $testSamples = [['alpha', 'epsilon'], ['beta', 'theta']]; + $predicted = $classifier->predict($testSamples); + + $filename = 'apriori-test-'.rand(100, 999).'-'.uniqid(); + $filepath = tempnam(sys_get_temp_dir(), $filename); + $modelManager = new ModelManager(); + $modelManager->saveToFile($classifier, $filepath); + + $restoredClassifier = $modelManager->restoreFromFile($filepath); + $this->assertEquals($classifier, $restoredClassifier); + $this->assertEquals($predicted, $restoredClassifier->predict($testSamples)); + } } diff --git a/tests/Phpml/Classification/DecisionTreeTest.php b/tests/Phpml/Classification/DecisionTreeTest.php index 25fb94c..7ec25ba 100644 --- a/tests/Phpml/Classification/DecisionTreeTest.php +++ b/tests/Phpml/Classification/DecisionTreeTest.php @@ -5,6 +5,7 @@ declare(strict_types=1); namespace tests\Classification; use Phpml\Classification\DecisionTree; +use Phpml\ModelManager; class DecisionTreeTest extends \PHPUnit_Framework_TestCase { @@ -55,6 +56,26 @@ class DecisionTreeTest extends \PHPUnit_Framework_TestCase return $classifier; } + public function testSaveAndRestore() + { + list($data, $targets) = $this->getData($this->data); + $classifier = new DecisionTree(5); + $classifier->train($data, $targets); + + $testSamples = [['sunny', 78, 72, 'false'], ['overcast', 60, 60, 'false']]; + $predicted = $classifier->predict($testSamples); + + $filename = 'decision-tree-test-'.rand(100, 999).'-'.uniqid(); + $filepath = tempnam(sys_get_temp_dir(), $filename); + $modelManager = new ModelManager(); + $modelManager->saveToFile($classifier, $filepath); + + $restoredClassifier = $modelManager->restoreFromFile($filepath); + $this->assertEquals($classifier, $restoredClassifier); + $this->assertEquals($predicted, $restoredClassifier->predict($testSamples)); + + } + public function testTreeDepth() { list($data, $targets) = $this->getData($this->data); diff --git a/tests/Phpml/Classification/KNearestNeighborsTest.php b/tests/Phpml/Classification/KNearestNeighborsTest.php index 7c5a75a..9824e5d 100644 --- a/tests/Phpml/Classification/KNearestNeighborsTest.php +++ b/tests/Phpml/Classification/KNearestNeighborsTest.php @@ -6,6 +6,7 @@ namespace tests\Classification; use Phpml\Classification\KNearestNeighbors; use Phpml\Math\Distance\Chebyshev; +use Phpml\ModelManager; class KNearestNeighborsTest extends \PHPUnit_Framework_TestCase { @@ -57,4 +58,27 @@ class KNearestNeighborsTest extends \PHPUnit_Framework_TestCase $this->assertEquals($testLabels, $predicted); } + + public function testSaveAndRestore() + { + $trainSamples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; + $trainLabels = ['a', 'a', 'a', 'b', 'b', 'b']; + + $testSamples = [[3, 2], [5, 1], [4, 3], [4, -5], [2, 3], [1, 2], [1, 5], [3, 10]]; + $testLabels = ['b', 'b', 'b', 'b', 'a', 'a', 'a', 'a']; + + // Using non-default constructor parameters to check that their values are restored. + $classifier = new KNearestNeighbors(3, new Chebyshev()); + $classifier->train($trainSamples, $trainLabels); + $predicted = $classifier->predict($testSamples); + + $filename = 'knearest-neighbors-test-'.rand(100, 999).'-'.uniqid(); + $filepath = tempnam(sys_get_temp_dir(), $filename); + $modelManager = new ModelManager(); + $modelManager->saveToFile($classifier, $filepath); + + $restoredClassifier = $modelManager->restoreFromFile($filepath); + $this->assertEquals($classifier, $restoredClassifier); + $this->assertEquals($predicted, $restoredClassifier->predict($testSamples)); + } } diff --git a/tests/Phpml/Classification/NaiveBayesTest.php b/tests/Phpml/Classification/NaiveBayesTest.php index 1a0aa1f..1f66ec8 100644 --- a/tests/Phpml/Classification/NaiveBayesTest.php +++ b/tests/Phpml/Classification/NaiveBayesTest.php @@ -5,6 +5,7 @@ declare(strict_types=1); namespace tests\Classification; use Phpml\Classification\NaiveBayes; +use Phpml\ModelManager; class NaiveBayesTest extends \PHPUnit_Framework_TestCase { @@ -45,4 +46,27 @@ class NaiveBayesTest extends \PHPUnit_Framework_TestCase $this->assertEquals($testLabels, $classifier->predict($testSamples)); } + + public function testSaveAndRestore() + { + $trainSamples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]]; + $trainLabels = ['a', 'b', 'c']; + + $testSamples = [[3, 1, 1], [5, 1, 1], [4, 3, 8]]; + $testLabels = ['a', 'a', 'c']; + + $classifier = new NaiveBayes(); + $classifier->train($trainSamples, $trainLabels); + $predicted = $classifier->predict($testSamples); + + $filename = 'naive-bayes-test-'.rand(100, 999).'-'.uniqid(); + $filepath = tempnam(sys_get_temp_dir(), $filename); + $modelManager = new ModelManager(); + $modelManager->saveToFile($classifier, $filepath); + + $restoredClassifier = $modelManager->restoreFromFile($filepath); + $this->assertEquals($classifier, $restoredClassifier); + $this->assertEquals($predicted, $restoredClassifier->predict($testSamples)); + } + } diff --git a/tests/Phpml/Classification/SVCTest.php b/tests/Phpml/Classification/SVCTest.php index 8b17b53..34111b4 100644 --- a/tests/Phpml/Classification/SVCTest.php +++ b/tests/Phpml/Classification/SVCTest.php @@ -6,6 +6,7 @@ namespace tests\Classification; use Phpml\Classification\SVC; use Phpml\SupportVectorMachine\Kernel; +use Phpml\ModelManager; class SVCTest extends \PHPUnit_Framework_TestCase { @@ -42,4 +43,26 @@ class SVCTest extends \PHPUnit_Framework_TestCase $this->assertEquals($testLabels, $predictions); } + + public function testSaveAndRestore() + { + $trainSamples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]]; + $trainLabels = ['a', 'a', 'a', 'b', 'b', 'b']; + + $testSamples = [[3, 2], [5, 1], [4, 3]]; + $testLabels = ['b', 'b', 'b']; + + $classifier = new SVC(Kernel::LINEAR, $cost = 1000); + $classifier->train($trainSamples, $trainLabels); + $predicted = $classifier->predict($testSamples); + + $filename = 'svc-test-'.rand(100, 999).'-'.uniqid(); + $filepath = tempnam(sys_get_temp_dir(), $filename); + $modelManager = new ModelManager(); + $modelManager->saveToFile($classifier, $filepath); + + $restoredClassifier = $modelManager->restoreFromFile($filepath); + $this->assertEquals($classifier, $restoredClassifier); + $this->assertEquals($predicted, $restoredClassifier->predict($testSamples)); + } } diff --git a/tests/Phpml/Dataset/CsvDatasetTest.php b/tests/Phpml/Dataset/CsvDatasetTest.php index 44e745a..65c589a 100644 --- a/tests/Phpml/Dataset/CsvDatasetTest.php +++ b/tests/Phpml/Dataset/CsvDatasetTest.php @@ -9,7 +9,7 @@ use Phpml\Dataset\CsvDataset; class CsvDatasetTest extends \PHPUnit_Framework_TestCase { /** - * @expectedException \Phpml\Exception\DatasetException + * @expectedException \Phpml\Exception\FileException */ public function testThrowExceptionOnMissingFile() { diff --git a/tests/Phpml/ModelManagerTest.php b/tests/Phpml/ModelManagerTest.php new file mode 100644 index 0000000..75044e9 --- /dev/null +++ b/tests/Phpml/ModelManagerTest.php @@ -0,0 +1,47 @@ +saveToFile($obj, $filepath); + + $restored = $modelManager->restoreFromFile($filepath); + $this->assertEquals($obj, $restored); + } + + /** + * @expectedException \Phpml\Exception\FileException + */ + public function testSaveToWrongFile() + { + $filepath = sys_get_temp_dir() . DIRECTORY_SEPARATOR . 'unexisting'; + + $obj = new LeastSquares(); + $modelManager = new ModelManager(); + $modelManager->saveToFile($obj, $filepath); + } + + /** + * @expectedException \Phpml\Exception\FileException + */ + public function testRestoreWrongFile() + { + $filepath = sys_get_temp_dir() . DIRECTORY_SEPARATOR . 'unexisting'; + $modelManager = new ModelManager(); + $modelManager->restoreFromFile($filepath); + } +} diff --git a/tests/Phpml/Regression/LeastSquaresTest.php b/tests/Phpml/Regression/LeastSquaresTest.php index c668b88..2cd3885 100644 --- a/tests/Phpml/Regression/LeastSquaresTest.php +++ b/tests/Phpml/Regression/LeastSquaresTest.php @@ -5,6 +5,7 @@ declare(strict_types=1); namespace tests\Regression; use Phpml\Regression\LeastSquares; +use Phpml\ModelManager; class LeastSquaresTest extends \PHPUnit_Framework_TestCase { @@ -65,4 +66,28 @@ class LeastSquaresTest extends \PHPUnit_Framework_TestCase $this->assertEquals(4094.82, $regression->predict([60000, 1996]), '', $delta); $this->assertEquals(5711.40, $regression->predict([60000, 2000]), '', $delta); } + + public function testSaveAndRestore() + { + //https://www.easycalculation.com/analytical/learn-least-square-regression.php + $samples = [[60], [61], [62], [63], [65]]; + $targets = [[3.1], [3.6], [3.8], [4], [4.1]]; + + $regression = new LeastSquares(); + $regression->train($samples, $targets); + + //http://www.stat.wmich.edu/s216/book/node127.html + $testSamples = [[9300], [10565], [15000]]; + $predicted = $regression->predict($testSamples); + + $filename = 'least-squares-test-'.rand(100, 999).'-'.uniqid(); + $filepath = tempnam(sys_get_temp_dir(), $filename); + $modelManager = new ModelManager(); + $modelManager->saveToFile($regression, $filepath); + + $restoredRegression = $modelManager->restoreFromFile($filepath); + $this->assertEquals($regression, $restoredRegression); + $this->assertEquals($predicted, $restoredRegression->predict($testSamples)); + } + } diff --git a/tests/Phpml/Regression/SVRTest.php b/tests/Phpml/Regression/SVRTest.php index 5f8bec9..17bb66e 100644 --- a/tests/Phpml/Regression/SVRTest.php +++ b/tests/Phpml/Regression/SVRTest.php @@ -6,6 +6,7 @@ namespace tests\Regression; use Phpml\Regression\SVR; use Phpml\SupportVectorMachine\Kernel; +use Phpml\ModelManager; class SVRTest extends \PHPUnit_Framework_TestCase { @@ -34,4 +35,27 @@ class SVRTest extends \PHPUnit_Framework_TestCase $this->assertEquals([4109.82, 4112.28], $regression->predict([[60000, 1996], [60000, 2000]]), '', $delta); } + + public function testSaveAndRestore() + { + + $samples = [[60], [61], [62], [63], [65]]; + $targets = [3.1, 3.6, 3.8, 4, 4.1]; + + $regression = new SVR(Kernel::LINEAR); + $regression->train($samples, $targets); + + $testSamples = [64]; + $predicted = $regression->predict($testSamples); + + $filename = 'svr-test'.rand(100, 999).'-'.uniqid(); + $filepath = tempnam(sys_get_temp_dir(), $filename); + $modelManager = new ModelManager(); + $modelManager->saveToFile($regression, $filepath); + + $restoredRegression = $modelManager->restoreFromFile($filepath); + $this->assertEquals($regression, $restoredRegression); + $this->assertEquals($predicted, $restoredRegression->predict($testSamples)); + } + }