Skip to content

Commit

Permalink
Support probability estimation in SVC (#218)
Browse files Browse the repository at this point in the history
* Add test for svm model with probability estimation

* Extract buildPredictCommand method

* Fix test to use PHP_EOL

* Add predictProbability method (not completed)

* Add test for DataTransformer::predictions

* Fix SVM to use PHP_EOL

* Support probability estimation in SVM

* Add documentation

* Add InvalidOperationException class

* Throw InvalidOperationException before executing libsvm if probability estimation is not supported
  • Loading branch information
y-uti authored and akondas committed Feb 6, 2018
1 parent ed775fb commit ec091b5
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 11 deletions.
39 changes: 39 additions & 0 deletions docs/machine-learning/classification/svc.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,42 @@ $classifier->predict([3, 2]);
$classifier->predict([[3, 2], [1, 5]]);
// return ['b', 'a']
```

### Probability estimation

To predict probabilities you must build a classifier with `$probabilityEstimates` set to true. Example:

```
use Phpml\Classification\SVC;
use Phpml\SupportVectorMachine\Kernel;
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];
$classifier = new SVC(
Kernel::LINEAR, // $kernel
1.0, // $cost
3, // $degree
null, // $gamma
0.0, // $coef0
0.001, // $tolerance
100, // $cacheSize
true, // $shrinking
true // $probabilityEstimates, set to true
);
$classifier->train($samples, $labels);
```

Then use `predictProbability` method instead of `predict`:

```
$classifier->predictProbability([3, 2]);
// return ['a' => 0.349833, 'b' => 0.650167]
$classifier->predictProbability([[3, 2], [1, 5]]);
// return [
// ['a' => 0.349833, 'b' => 0.650167],
// ['a' => 0.922664, 'b' => 0.0773364],
// ]
```
11 changes: 11 additions & 0 deletions src/Phpml/Exception/InvalidOperationException.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
<?php

declare(strict_types=1);

namespace Phpml\Exception;

use Exception;

class InvalidOperationException extends Exception
{
}
31 changes: 31 additions & 0 deletions src/Phpml/SupportVectorMachine/DataTransformer.php
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,37 @@ public static function predictions(string $rawPredictions, array $labels): array
return $results;
}

public static function probabilities(string $rawPredictions, array $labels): array
{
$numericLabels = self::numericLabels($labels);

$predictions = explode(PHP_EOL, trim($rawPredictions));

$header = array_shift($predictions);
$headerColumns = explode(' ', $header);
array_shift($headerColumns);

$columnLabels = [];
foreach ($headerColumns as $numericLabel) {
$columnLabels[] = array_search($numericLabel, $numericLabels);
}

$results = [];
foreach ($predictions as $rawResult) {
$probabilities = explode(' ', $rawResult);
array_shift($probabilities);

$result = [];
foreach ($probabilities as $i => $prob) {
$result[$columnLabels[$i]] = (float) $prob;
}

$results[] = $result;
}

return $results;
}

public static function numericLabels(array $labels): array
{
$numericLabels = [];
Expand Down
78 changes: 67 additions & 11 deletions src/Phpml/SupportVectorMachine/SupportVectorMachine.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Phpml\SupportVectorMachine;

use Phpml\Exception\InvalidArgumentException;
use Phpml\Exception\InvalidOperationException;
use Phpml\Exception\LibsvmCommandException;
use Phpml\Helper\Trainable;

Expand Down Expand Up @@ -178,13 +179,61 @@ public function getModel(): string
* @throws LibsvmCommandException
*/
public function predict(array $samples)
{
$predictions = $this->runSvmPredict($samples, false);

if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
$predictions = DataTransformer::predictions($predictions, $this->targets);
} else {
$predictions = explode(PHP_EOL, trim($predictions));
}

if (!is_array($samples[0])) {
return $predictions[0];
}

return $predictions;
}

/**
* @return array|string
*
* @throws LibsvmCommandException
*/
public function predictProbability(array $samples)
{
if (!$this->probabilityEstimates) {
throw new InvalidOperationException('Model does not support probabiliy estimates');
}

$predictions = $this->runSvmPredict($samples, true);

if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
$predictions = DataTransformer::probabilities($predictions, $this->targets);
} else {
$predictions = explode(PHP_EOL, trim($predictions));
}

if (!is_array($samples[0])) {
return $predictions[0];
}

return $predictions;
}

private function runSvmPredict(array $samples, bool $probabilityEstimates): string
{
$testSet = DataTransformer::testSet($samples);
file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
$outputFileName = $testSetFileName.'-output';

$command = sprintf('%ssvm-predict%s %s %s %s', $this->binPath, $this->getOSExtension(), $testSetFileName, $modelFileName, $outputFileName);
$command = $this->buildPredictCommand(
$testSetFileName,
$modelFileName,
$outputFileName,
$probabilityEstimates
);
$output = [];
exec(escapeshellcmd($command).' 2>&1', $output, $return);

Expand All @@ -198,16 +247,6 @@ public function predict(array $samples)
throw LibsvmCommandException::failedToRun($command, array_pop($output));
}

if (in_array($this->type, [Type::C_SVC, Type::NU_SVC])) {
$predictions = DataTransformer::predictions($predictions, $this->targets);
} else {
$predictions = explode(PHP_EOL, trim($predictions));
}

if (!is_array($samples[0])) {
return $predictions[0];
}

return $predictions;
}

Expand Down Expand Up @@ -246,6 +285,23 @@ private function buildTrainCommand(string $trainingSetFileName, string $modelFil
);
}

private function buildPredictCommand(
string $testSetFileName,
string $modelFileName,
string $outputFileName,
bool $probabilityEstimates
): string {
return sprintf(
'%ssvm-predict%s -b %d %s %s %s',
$this->binPath,
$this->getOSExtension(),
$probabilityEstimates ? 1 : 0,
escapeshellarg($testSetFileName),
escapeshellarg($modelFileName),
escapeshellarg($outputFileName)
);
}

private function ensureDirectorySeparator(string &$path): void
{
if (substr($path, -1) !== DIRECTORY_SEPARATOR) {
Expand Down
41 changes: 41 additions & 0 deletions tests/Phpml/SupportVectorMachine/DataTransformerTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,45 @@ public function testTransformSamplesToTestSet(): void

$this->assertEquals($testSet, DataTransformer::testSet($samples));
}

public function testPredictions(): void
{
$labels = ['a', 'a', 'b', 'b'];
$rawPredictions = implode(PHP_EOL, [0, 1, 0, 0]);

$predictions = ['a', 'b', 'a', 'a'];

$this->assertEquals($predictions, DataTransformer::predictions($rawPredictions, $labels));
}

public function testProbabilities(): void
{
$labels = ['a', 'b', 'c'];
$rawPredictions = implode(PHP_EOL, [
'labels 0 1 2',
'1 0.1 0.7 0.2',
'2 0.2 0.3 0.5',
'0 0.6 0.1 0.3',
]);

$probabilities = [
[
'a' => 0.1,
'b' => 0.7,
'c' => 0.2,
],
[
'a' => 0.2,
'b' => 0.3,
'c' => 0.5,
],
[
'a' => 0.6,
'b' => 0.1,
'c' => 0.3,
],
];

$this->assertEquals($probabilities, DataTransformer::probabilities($rawPredictions, $labels));
}
}
79 changes: 79 additions & 0 deletions tests/Phpml/SupportVectorMachine/SupportVectorMachineTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Phpml\Tests\SupportVectorMachine;

use Phpml\Exception\InvalidArgumentException;
use Phpml\Exception\InvalidOperationException;
use Phpml\Exception\LibsvmCommandException;
use Phpml\SupportVectorMachine\Kernel;
use Phpml\SupportVectorMachine\SupportVectorMachine;
Expand Down Expand Up @@ -37,6 +38,31 @@ public function testTrainCSVCModelWithLinearKernel(): void
$this->assertEquals($model, $svm->getModel());
}

public function testTrainCSVCModelWithProbabilityEstimate(): void
{
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];

$svm = new SupportVectorMachine(
Type::C_SVC,
Kernel::LINEAR,
100.0,
0.5,
3,
null,
0.0,
0.1,
0.01,
100,
true,
true
);
$svm->train($samples, $labels);

$this->assertContains(PHP_EOL.'probA ', $svm->getModel());
$this->assertContains(PHP_EOL.'probB ', $svm->getModel());
}

public function testPredictSampleWithLinearKernel(): void
{
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
Expand Down Expand Up @@ -83,6 +109,41 @@ public function testPredictSampleFromMultipleClassWithRbfKernel(): void
$this->assertEquals('c', $predictions[2]);
}

public function testPredictProbability(): void
{
$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];

$svm = new SupportVectorMachine(
Type::C_SVC,
Kernel::LINEAR,
100.0,
0.5,
3,
null,
0.0,
0.1,
0.01,
100,
true,
true
);
$svm->train($samples, $labels);

$predictions = $svm->predictProbability([
[3, 2],
[2, 3],
[4, -5],
]);

$this->assertTrue($predictions[0]['a'] < $predictions[0]['b']);
$this->assertTrue($predictions[1]['a'] > $predictions[1]['b']);
$this->assertTrue($predictions[2]['a'] < $predictions[2]['b']);

// Should be true because the latter is farther from the decision boundary
$this->assertTrue($predictions[0]['b'] < $predictions[2]['b']);
}

public function testThrowExceptionWhenVarPathIsNotWritable(): void
{
$this->expectException(InvalidArgumentException::class);
Expand Down Expand Up @@ -124,4 +185,22 @@ public function testThrowExceptionWhenLibsvmFailsDuringPredict(): void
$svm = new SupportVectorMachine(Type::C_SVC, Kernel::RBF);
$svm->predict([1]);
}

public function testThrowExceptionWhenPredictProbabilityCalledWithoutProperModel(): void
{
$this->expectException(InvalidOperationException::class);
$this->expectExceptionMessage('Model does not support probabiliy estimates');

$samples = [[1, 3], [1, 4], [2, 4], [3, 1], [4, 1], [4, 2]];
$labels = ['a', 'a', 'a', 'b', 'b', 'b'];

$svm = new SupportVectorMachine(Type::C_SVC, Kernel::LINEAR, 100.0);
$svm->train($samples, $labels);

$predictions = $svm->predictProbability([
[3, 2],
[2, 3],
[4, -5],
]);
}
}

0 comments on commit ec091b5

Please sign in to comment.