Skip to content

Commit

Permalink
Fix logistic regression implementation (#169)
Browse files Browse the repository at this point in the history
* Fix target value of LogisticRegression

* Fix probability calculation in LogisticRegression

* Change the default cost function to log-likelihood

* Remove redundant round function

* Fix for coding standard
  • Loading branch information
y-uti authored and akondas committed Dec 5, 2017
1 parent 946fbbc commit c4f58f7
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 9 deletions.
21 changes: 12 additions & 9 deletions src/Phpml/Classification/Linear/LogisticRegression.php
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class LogisticRegression extends Adaline
*
* @var string
*/
protected $costFunction = 'sse';
protected $costFunction = 'log';

/**
* Regularization term: only 'L2' is supported
Expand Down Expand Up @@ -67,7 +67,7 @@ public function __construct(
int $maxIterations = 500,
bool $normalizeInputs = true,
int $trainingType = self::CONJUGATE_GRAD_TRAINING,
string $cost = 'sse',
string $cost = 'log',
string $penalty = 'L2'
) {
$trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
Expand Down Expand Up @@ -190,6 +190,8 @@ protected function getCostFunction(): Closure
$hX = 1e-10;
}

$y = $y < 0 ? 0 : 1;

$error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
$gradient = $hX - $y;

Expand All @@ -213,6 +215,8 @@ protected function getCostFunction(): Closure
$this->weights = $weights;
$hX = $this->output($sample);

$y = $y < 0 ? 0 : 1;

$error = ($y - $hX) ** 2;
$gradient = -($y - $hX) * $hX * (1 - $hX);

Expand Down Expand Up @@ -243,7 +247,7 @@ protected function outputClass(array $sample): int
{
$output = $this->output($sample);

if (round($output) > 0.5) {
if ($output > 0.5) {
return 1;
}

Expand All @@ -260,14 +264,13 @@ protected function outputClass(array $sample): int
*/
protected function predictProbability(array $sample, $label): float
{
$predicted = $this->predictSampleBinary($sample);

if ((string) $predicted == (string) $label) {
$sample = $this->checkNormalizedSample($sample);
$sample = $this->checkNormalizedSample($sample);
$probability = $this->output($sample);

return (float) abs($this->output($sample) - 0.5);
if (array_search($label, $this->labels, true) > 0) {
return $probability;
}

return 0.0;
return 1 - $probability;
}
}
106 changes: 106 additions & 0 deletions tests/Phpml/Classification/Linear/LogisticRegressionTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
<?php

declare(strict_types=1);

namespace tests\Phpml\Classification\Linear;

use Phpml\Classification\Linear\LogisticRegression;
use PHPUnit\Framework\TestCase;
use ReflectionMethod;
use ReflectionProperty;

class LogisticRegressionTest extends TestCase
{
public function testPredictSingleSample(): void
{
// AND problem
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];
$classifier = new LogisticRegression();
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}

public function testPredictMultiClassSample(): void
{
// By use of One-v-Rest, Perceptron can perform multi-class classification
// The samples should be separable by lines perpendicular to the dimensions
$samples = [
[0, 0], [0, 1], [1, 0], [1, 1], // First group : a cluster at bottom-left corner in 2D
[5, 5], [6, 5], [5, 6], [7, 5], // Second group: another cluster at the middle-right
[3, 10], [3, 10], [3, 8], [3, 9], // Third group : cluster at the top-middle
];
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];

$classifier = new LogisticRegression();
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.5, 0.5]));
$this->assertEquals(1, $classifier->predict([6.0, 5.0]));
$this->assertEquals(2, $classifier->predict([3.0, 9.5]));
}

public function testPredictProbabilitySingleSample(): void
{
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];
$classifier = new LogisticRegression();
$classifier->train($samples, $targets);

$property = new ReflectionProperty($classifier, 'classifiers');
$property->setAccessible(true);
$predictor = $property->getValue($classifier)[0];
$method = new ReflectionMethod($predictor, 'predictProbability');
$method->setAccessible(true);

$zero = $method->invoke($predictor, [0.1, 0.1], 0);
$one = $method->invoke($predictor, [0.1, 0.1], 1);
$this->assertEquals(1, $zero + $one, null, 1e-6);
$this->assertTrue($zero > $one);

$zero = $method->invoke($predictor, [0.9, 0.9], 0);
$one = $method->invoke($predictor, [0.9, 0.9], 1);
$this->assertEquals(1, $zero + $one, null, 1e-6);
$this->assertTrue($zero < $one);
}

public function testPredictProbabilityMultiClassSample(): void
{
$samples = [
[0, 0], [0, 1], [1, 0], [1, 1],
[5, 5], [6, 5], [5, 6], [6, 6],
[3, 10], [3, 10], [3, 8], [3, 9],
];
$targets = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2];

$classifier = new LogisticRegression();
$classifier->train($samples, $targets);

$property = new ReflectionProperty($classifier, 'classifiers');
$property->setAccessible(true);

$predictor = $property->getValue($classifier)[0];
$method = new ReflectionMethod($predictor, 'predictProbability');
$method->setAccessible(true);
$zero = $method->invoke($predictor, [3.0, 9.5], 0);
$not_zero = $method->invoke($predictor, [3.0, 9.5], 'not_0');

$predictor = $property->getValue($classifier)[1];
$method = new ReflectionMethod($predictor, 'predictProbability');
$method->setAccessible(true);
$one = $method->invoke($predictor, [3.0, 9.5], 1);
$not_one = $method->invoke($predictor, [3.0, 9.5], 'not_1');

$predictor = $property->getValue($classifier)[2];
$method = new ReflectionMethod($predictor, 'predictProbability');
$method->setAccessible(true);
$two = $method->invoke($predictor, [3.0, 9.5], 2);
$not_two = $method->invoke($predictor, [3.0, 9.5], 'not_2');

$this->assertEquals(1, $zero + $not_zero, null, 1e-6);
$this->assertEquals(1, $one + $not_one, null, 1e-6);
$this->assertEquals(1, $two + $not_two, null, 1e-6);
$this->assertTrue($zero < $two);
$this->assertTrue($one < $two);
}
}

0 comments on commit c4f58f7

Please sign in to comment.