Skip to content

Commit

Permalink
Add tests for LogisticRegression (#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
y-uti authored and akondas committed Mar 3, 2018
1 parent 9c19555 commit af9ccfe
Showing 1 changed file with 110 additions and 0 deletions.
110 changes: 110 additions & 0 deletions tests/Classification/Linear/LogisticRegressionTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,49 @@
use PHPUnit\Framework\TestCase;
use ReflectionMethod;
use ReflectionProperty;
use Throwable;

class LogisticRegressionTest extends TestCase
{
public function testConstructorThrowWhenInvalidTrainingType(): void
{
$this->expectException(Throwable::class);

$classifier = new LogisticRegression(
500,
true,
-1,
'log',
'L2'
);
}

public function testConstructorThrowWhenInvalidCost(): void
{
$this->expectException(Throwable::class);

$classifier = new LogisticRegression(
500,
true,
LogisticRegression::CONJUGATE_GRAD_TRAINING,
'invalid',
'L2'
);
}

public function testConstructorThrowWhenInvalidPenalty(): void
{
$this->expectException(Throwable::class);

$classifier = new LogisticRegression(
500,
true,
LogisticRegression::CONJUGATE_GRAD_TRAINING,
'log',
'invalid'
);
}

public function testPredictSingleSample(): void
{
// AND problem
Expand All @@ -22,6 +62,76 @@ public function testPredictSingleSample(): void
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}

public function testPredictSingleSampleWithBatchTraining(): void
{
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];

// $maxIterations is set to 10000 as batch training needs more
// iteration to converge than CG method in general.
$classifier = new LogisticRegression(
10000,
true,
LogisticRegression::BATCH_TRAINING,
'log',
'L2'
);
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}

public function testPredictSingleSampleWithOnlineTraining(): void
{
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];

// $penalty is set to empty (no penalty) because L2 penalty seems to
// prevent convergence in online training for this dataset.
$classifier = new LogisticRegression(
10000,
true,
LogisticRegression::ONLINE_TRAINING,
'log',
''
);
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}

public function testPredictSingleSampleWithSSECost(): void
{
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];
$classifier = new LogisticRegression(
500,
true,
LogisticRegression::CONJUGATE_GRAD_TRAINING,
'sse',
'L2'
);
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}

public function testPredictSingleSampleWithoutPenalty(): void
{
$samples = [[0, 0], [1, 0], [0, 1], [1, 1], [0.4, 0.4], [0.6, 0.6]];
$targets = [0, 0, 0, 1, 0, 1];
$classifier = new LogisticRegression(
500,
true,
LogisticRegression::CONJUGATE_GRAD_TRAINING,
'log',
''
);
$classifier->train($samples, $targets);
$this->assertEquals(0, $classifier->predict([0.1, 0.1]));
$this->assertEquals(1, $classifier->predict([0.9, 0.9]));
}

public function testPredictMultiClassSample(): void
{
// By use of One-v-Rest, Perceptron can perform multi-class classification
Expand Down

0 comments on commit af9ccfe

Please sign in to comment.