Add test for Pipeline save and restore with ModelManager (#191)

This commit is contained in:
Arkadiusz Kondas 2018-01-12 10:54:20 +01:00 committed by GitHub
parent d953ef6bfc
commit 7435bece34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,6 +7,7 @@ namespace Phpml\Tests;
use Phpml\Classification\SVC;
use Phpml\FeatureExtraction\TfIdfTransformer;
use Phpml\FeatureExtraction\TokenCountVectorizer;
use Phpml\ModelManager;
use Phpml\Pipeline;
use Phpml\Preprocessing\Imputer;
use Phpml\Preprocessing\Imputer\Strategy\MostFrequentStrategy;
@ -104,4 +105,40 @@ class PipelineTest extends TestCase
$this->assertEquals($expected, $predicted);
}
public function testSaveAndRestore(): void
{
$pipeline = new Pipeline([
new TokenCountVectorizer(new WordTokenizer()),
new TfIdfTransformer(),
], new SVC());
$pipeline->train([
'Hello Paul',
'Hello Martin',
'Goodbye Tom',
'Hello John',
'Goodbye Alex',
'Bye Tony',
], [
'greetings',
'greetings',
'farewell',
'greetings',
'farewell',
'farewell',
]);
$testSamples = ['Hello Max', 'Goodbye Mark'];
$predicted = $pipeline->predict($testSamples);
$filepath = tempnam(sys_get_temp_dir(), uniqid('pipeline-test', true));
$modelManager = new ModelManager();
$modelManager->saveToFile($pipeline, $filepath);
$restoredClassifier = $modelManager->restoreFromFile($filepath);
$this->assertEquals($pipeline, $restoredClassifier);
$this->assertEquals($predicted, $restoredClassifier->predict($testSamples));
unlink($filepath);
}
}