From 7435bece34e29429adae5a09de10354511602d05 Mon Sep 17 00:00:00 2001 From: Arkadiusz Kondas Date: Fri, 12 Jan 2018 10:54:20 +0100 Subject: [PATCH] Add test for Pipeline save and restore with ModelManager (#191) --- tests/Phpml/PipelineTest.php | 37 ++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/Phpml/PipelineTest.php b/tests/Phpml/PipelineTest.php index e45675d4..86ff2a9f 100644 --- a/tests/Phpml/PipelineTest.php +++ b/tests/Phpml/PipelineTest.php @@ -7,6 +7,7 @@ use Phpml\Classification\SVC; use Phpml\FeatureExtraction\TfIdfTransformer; use Phpml\FeatureExtraction\TokenCountVectorizer; +use Phpml\ModelManager; use Phpml\Pipeline; use Phpml\Preprocessing\Imputer; use Phpml\Preprocessing\Imputer\Strategy\MostFrequentStrategy; @@ -104,4 +105,40 @@ public function testPipelineTransformers(): void $this->assertEquals($expected, $predicted); } + + public function testSaveAndRestore(): void + { + $pipeline = new Pipeline([ + new TokenCountVectorizer(new WordTokenizer()), + new TfIdfTransformer(), + ], new SVC()); + + $pipeline->train([ + 'Hello Paul', + 'Hello Martin', + 'Goodbye Tom', + 'Hello John', + 'Goodbye Alex', + 'Bye Tony', + ], [ + 'greetings', + 'greetings', + 'farewell', + 'greetings', + 'farewell', + 'farewell', + ]); + + $testSamples = ['Hello Max', 'Goodbye Mark']; + $predicted = $pipeline->predict($testSamples); + + $filepath = tempnam(sys_get_temp_dir(), uniqid('pipeline-test', true)); + $modelManager = new ModelManager(); + $modelManager->saveToFile($pipeline, $filepath); + + $restoredClassifier = $modelManager->restoreFromFile($filepath); + $this->assertEquals($pipeline, $restoredClassifier); + $this->assertEquals($predicted, $restoredClassifier->predict($testSamples)); + unlink($filepath); + } }