Skip to content

Commit

Permalink
Persistence class to save and restore models (#37)
Browse files Browse the repository at this point in the history
* Models manager with save/restore capabilities

* Refactoring dataset exceptions

* Persistency layer docs

* New tests for serializable estimators

* ModelManager static methods to instance methods
  • Loading branch information
dmonllao authored and akondas committed Feb 2, 2017
1 parent c1b1a5d commit 8f122fd
Show file tree
Hide file tree
Showing 17 changed files with 361 additions and 24 deletions.
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ Example scripts are available in a separate repository [php-ai/php-ml-examples](
* [Iris](machine-learning/datasets/demo/iris/)
* [Wine](machine-learning/datasets/demo/wine/)
* [Glass](machine-learning/datasets/demo/glass/)
* Models management
* [Persistency](machine-learning/model-manager/persistency/)
* Math
* [Distance](math/distance/)
* [Matrix](math/matrix/)
Expand Down
24 changes: 24 additions & 0 deletions docs/machine-learning/model-manager/persistency.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Persistency

You can save trained models for future use. Persistency across requests achieved by saving and restoring serialized estimators into files.

### Example

```
use Phpml\Classification\KNearestNeighbors;
use Phpml\ModelManager;
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$classifier = new KNearestNeighbors();
$classifier->train($samples, $labels);
$filepath = '/path/to/store/the/model';
$modelManager = new ModelManager();
$modelManager->saveToFile($classifier, $filepath);
$restoredClassifier = $modelManager->restoreFromFile($filepath);
$restoredClassifier->predict([3, 2]);
// return 'b'
```
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ pages:
- Iris: machine-learning/datasets/demo/iris.md
- Wine: machine-learning/datasets/demo/wine.md
- Glass: machine-learning/datasets/demo/glass.md
- Models management:
- Persistency: machine-learning/model-manager/persistency.md
- Math:
- Distance: math/distance.md
- Matrix: math/matrix.md
Expand Down
8 changes: 4 additions & 4 deletions src/Phpml/Dataset/CsvDataset.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Phpml\Dataset;

use Phpml\Exception\DatasetException;
use Phpml\Exception\FileException;

class CsvDataset extends ArrayDataset
{
Expand All @@ -13,16 +13,16 @@ class CsvDataset extends ArrayDataset
* @param int $features
* @param bool $headingRow
*
* @throws DatasetException
* @throws FileException
*/
public function __construct(string $filepath, int $features, bool $headingRow = true)
{
if (!file_exists($filepath)) {
throw DatasetException::missingFile(basename($filepath));
throw FileException::missingFile(basename($filepath));
}

if (false === $handle = fopen($filepath, 'rb')) {
throw DatasetException::cantOpenFile(basename($filepath));
throw FileException::cantOpenFile(basename($filepath));
}

if ($headingRow) {
Expand Down
19 changes: 0 additions & 19 deletions src/Phpml/Exception/DatasetException.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,6 @@

class DatasetException extends \Exception
{
/**
* @param string $filepath
*
* @return DatasetException
*/
public static function missingFile(string $filepath)
{
return new self(sprintf('Dataset file "%s" missing.', $filepath));
}

/**
* @param string $path
Expand All @@ -25,14 +16,4 @@ public static function missingFolder(string $path)
{
return new self(sprintf('Dataset root folder "%s" missing.', $path));
}

/**
* @param string $filepath
*
* @return DatasetException
*/
public static function cantOpenFile(string $filepath)
{
return new self(sprintf('Dataset file "%s" can\'t be open.', $filepath));
}
}
39 changes: 39 additions & 0 deletions src/Phpml/Exception/FileException.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
<?php

declare(strict_types=1);

namespace Phpml\Exception;

class FileException extends \Exception
{

/**
* @param string $filepath
*
* @return FileException
*/
public static function missingFile(string $filepath)
{
return new self(sprintf('File "%s" missing.', $filepath));
}

/**
* @param string $filepath
*
* @return FileException
*/
public static function cantOpenFile(string $filepath)
{
return new self(sprintf('File "%s" can\'t be open.', $filepath));
}

/**
* @param string $filepath
*
* @return FileException
*/
public static function cantSaveFile(string $filepath)
{
return new self(sprintf('File "%s" can\'t be saved.', $filepath));
}
}
30 changes: 30 additions & 0 deletions src/Phpml/Exception/SerializeException.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<?php

declare(strict_types=1);

namespace Phpml\Exception;

class SerializeException extends \Exception
{

/**
* @param string $filepath
*
* @return SerializeException
*/
public static function cantUnserialize(string $filepath)
{
return new self(sprintf('"%s" can not be unserialized.', $filepath));
}

/**
* @param string $classname
*
* @return SerializeException
*/
public static function cantSerialize(string $classname)
{
return new self(sprintf('Class "%s" can not be serialized.', $classname));
}

}
52 changes: 52 additions & 0 deletions src/Phpml/ModelManager.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
<?php

declare(strict_types=1);

namespace Phpml;

use Phpml\Estimator;
use Phpml\Exception\SerializeException;
use Phpml\Exception\FileException;

class ModelManager
{
/**
* @param Estimator $object
* @param string $filepath
*/
public function saveToFile(Estimator $object, string $filepath)
{
if (!file_exists($filepath) || !is_writable(dirname($filepath))) {
throw FileException::cantSaveFile(basename($filepath));
}

$serialized = serialize($object);
if (empty($serialized)) {
throw SerializeException::cantSerialize(get_type($object));
}

$result = file_put_contents($filepath, $serialized, LOCK_EX);
if ($result === false) {
throw FileException::cantSaveFile(basename($filepath));
}
}

/**
* @param string $filepath
*
* @return Estimator
*/
public function restoreFromFile(string $filepath)
{
if (!file_exists($filepath) || !is_readable($filepath)) {
throw FileException::cantOpenFile(basename($filepath));
}

$object = unserialize(file_get_contents($filepath));
if ($object === false) {
throw SerializeException::cantUnserialize(basename($filepath));
}

return $object;
}
}
19 changes: 19 additions & 0 deletions tests/Phpml/Association/AprioriTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace tests\Classification;

use Phpml\Association\Apriori;
use Phpml\ModelManager;

class AprioriTest extends \PHPUnit_Framework_TestCase
{
Expand Down Expand Up @@ -184,4 +185,22 @@ public function invoke(&$object, $method, array $params = [])

return $method->invokeArgs($object, $params);
}

public function testSaveAndRestore()
{
$classifier = new Apriori(0.5, 0.5);
$classifier->train($this->sampleGreek, []);

$testSamples = [['alpha', 'epsilon'], ['beta', 'theta']];
$predicted = $classifier->predict($testSamples);

$filename = 'apriori-test-'.rand(100, 999).'-'.uniqid();
$filepath = tempnam(sys_get_temp_dir(), $filename);
$modelManager = new ModelManager();
$modelManager->saveToFile($classifier, $filepath);

$restoredClassifier = $modelManager->restoreFromFile($filepath);
$this->assertEquals($classifier, $restoredClassifier);
$this->assertEquals($predicted, $restoredClassifier->predict($testSamples));
}
}
21 changes: 21 additions & 0 deletions tests/Phpml/Classification/DecisionTreeTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace tests\Classification;

use Phpml\Classification\DecisionTree;
use Phpml\ModelManager;

class DecisionTreeTest extends \PHPUnit_Framework_TestCase
{
Expand Down Expand Up @@ -55,6 +56,26 @@ public function testPredictSingleSample()
return $classifier;
}

public function testSaveAndRestore()
{
list($data, $targets) = $this->getData($this->data);
$classifier = new DecisionTree(5);
$classifier->train($data, $targets);

$testSamples = [['sunny', 78, 72, 'false'], ['overcast', 60, 60, 'false']];
$predicted = $classifier->predict($testSamples);

$filename = 'decision-tree-test-'.rand(100, 999).'-'.uniqid();
$filepath = tempnam(sys_get_temp_dir(), $filename);
$modelManager = new ModelManager();
$modelManager->saveToFile($classifier, $filepath);

$restoredClassifier = $modelManager->restoreFromFile($filepath);
$this->assertEquals($classifier, $restoredClassifier);
$this->assertEquals($predicted, $restoredClassifier->predict($testSamples));

}

public function testTreeDepth()
{
list($data, $targets) = $this->getData($this->data);
Expand Down
24 changes: 24 additions & 0 deletions tests/Phpml/Classification/KNearestNeighborsTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

use Phpml\Classification\KNearestNeighbors;
use Phpml\Math\Distance\Chebyshev;
use Phpml\ModelManager;

class KNearestNeighborsTest extends \PHPUnit_Framework_TestCase
{
Expand Down Expand Up @@ -57,4 +58,27 @@ public function testPredictArrayOfSamplesUsingChebyshevDistanceMetric()

$this->assertEquals($testLabels, $predicted);
}

public function testSaveAndRestore()
{
$trainSamples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$trainLabels = ['a', 'a', 'a', 'b', 'b', 'b'];

$testSamples = [[3, 2], [5, 1], [4, 3], [4, -5], [2, 3], [1, 2], [1, 5], [3, 10]];
$testLabels = ['b', 'b', 'b', 'b', 'a', 'a', 'a', 'a'];

// Using non-default constructor parameters to check that their values are restored.
$classifier = new KNearestNeighbors(3, new Chebyshev());
$classifier->train($trainSamples, $trainLabels);
$predicted = $classifier->predict($testSamples);

$filename = 'knearest-neighbors-test-'.rand(100, 999).'-'.uniqid();
$filepath = tempnam(sys_get_temp_dir(), $filename);
$modelManager = new ModelManager();
$modelManager->saveToFile($classifier, $filepath);

$restoredClassifier = $modelManager->restoreFromFile($filepath);
$this->assertEquals($classifier, $restoredClassifier);
$this->assertEquals($predicted, $restoredClassifier->predict($testSamples));
}
}
24 changes: 24 additions & 0 deletions tests/Phpml/Classification/NaiveBayesTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace tests\Classification;

use Phpml\Classification\NaiveBayes;
use Phpml\ModelManager;

class NaiveBayesTest extends \PHPUnit_Framework_TestCase
{
Expand Down Expand Up @@ -45,4 +46,27 @@ public function testPredictArrayOfSamples()
$this->assertEquals($testLabels, $classifier->predict($testSamples));

}

public function testSaveAndRestore()
{
$trainSamples = [[5, 1, 1], [1, 5, 1], [1, 1, 5]];
$trainLabels = ['a', 'b', 'c'];

$testSamples = [[3, 1, 1], [5, 1, 1], [4, 3, 8]];
$testLabels = ['a', 'a', 'c'];

$classifier = new NaiveBayes();
$classifier->train($trainSamples, $trainLabels);
$predicted = $classifier->predict($testSamples);

$filename = 'naive-bayes-test-'.rand(100, 999).'-'.uniqid();
$filepath = tempnam(sys_get_temp_dir(), $filename);
$modelManager = new ModelManager();
$modelManager->saveToFile($classifier, $filepath);

$restoredClassifier = $modelManager->restoreFromFile($filepath);
$this->assertEquals($classifier, $restoredClassifier);
$this->assertEquals($predicted, $restoredClassifier->predict($testSamples));
}

}
23 changes: 23 additions & 0 deletions tests/Phpml/Classification/SVCTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

use Phpml\Classification\SVC;
use Phpml\SupportVectorMachine\Kernel;
use Phpml\ModelManager;

class SVCTest extends \PHPUnit_Framework_TestCase
{
Expand Down Expand Up @@ -42,4 +43,26 @@ public function testPredictArrayOfSamplesWithLinearKernel()

$this->assertEquals($testLabels, $predictions);
}

public function testSaveAndRestore()
{
$trainSamples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$trainLabels = ['a', 'a', 'a', 'b', 'b', 'b'];

$testSamples = [[3, 2], [5, 1], [4, 3]];
$testLabels = ['b', 'b', 'b'];

$classifier = new SVC(Kernel::LINEAR, $cost = 1000);
$classifier->train($trainSamples, $trainLabels);
$predicted = $classifier->predict($testSamples);

$filename = 'svc-test-'.rand(100, 999).'-'.uniqid();
$filepath = tempnam(sys_get_temp_dir(), $filename);
$modelManager = new ModelManager();
$modelManager->saveToFile($classifier, $filepath);

$restoredClassifier = $modelManager->restoreFromFile($filepath);
$this->assertEquals($classifier, $restoredClassifier);
$this->assertEquals($predicted, $restoredClassifier->predict($testSamples));
}
}
Loading

0 comments on commit 8f122fd

Please sign in to comment.