Skip to content

Commit

Permalink
Fix Optimizer initial theta randomization (#239)
Browse files Browse the repository at this point in the history
* Fix Optimizer initial theta randomization

* Add more tests for LUDecomposition and FuzzyCMeans
  • Loading branch information
akondas authored Feb 23, 2018
1 parent 83f3e8d commit a96f03e
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 23 deletions.
17 changes: 7 additions & 10 deletions src/Helper/Optimizer/Optimizer.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
namespace Phpml\Helper\Optimizer;

use Closure;
use Exception;
use Phpml\Exception\InvalidArgumentException;

abstract class Optimizer
{
public $initialTheta;

/**
* Unknown variables to be found
*
Expand All @@ -33,21 +35,16 @@ public function __construct(int $dimensions)
// Inits the weights randomly
$this->theta = [];
for ($i = 0; $i < $this->dimensions; ++$i) {
$this->theta[] = random_int(0, getrandmax()) / (float) getrandmax();
$this->theta[] = (random_int(0, PHP_INT_MAX) / PHP_INT_MAX) + 0.1;
}

$this->initialTheta = $this->theta;
}

/**
* Sets the weights manually
*
* @return $this
*
* @throws \Exception
*/
public function setInitialTheta(array $theta)
{
if (count($theta) != $this->dimensions) {
throw new Exception("Number of values in the weights array should be ${this}->dimensions");
throw new InvalidArgumentException(sprintf('Number of values in the weights array should be %s', $this->dimensions));
}

$this->theta = $theta;
Expand Down
12 changes: 12 additions & 0 deletions src/Helper/Optimizer/StochasticGD.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Phpml\Helper\Optimizer;

use Closure;
use Phpml\Exception\InvalidArgumentException;

/**
* Stochastic Gradient Descent optimization method
Expand Down Expand Up @@ -88,6 +89,17 @@ public function __construct(int $dimensions)
$this->dimensions = $dimensions;
}

public function setInitialTheta(array $theta)
{
if (count($theta) != $this->dimensions + 1) {
throw new InvalidArgumentException(sprintf('Number of values in the weights array should be %s', $this->dimensions + 1));
}

$this->theta = $theta;

return $this;
}

/**
* Sets minimum value for the change in the theta values
* between iterations to continue the iterations.<br>
Expand Down
15 changes: 2 additions & 13 deletions src/Math/LinearAlgebra/LUDecomposition.php
Original file line number Diff line number Diff line change
Expand Up @@ -225,25 +225,14 @@ public function isNonsingular(): bool
return true;
}

/**
* Count determinants
*
* @return float|int d matrix determinant
*
* @throws MatrixException
*/
public function det()
public function det(): float
{
if ($this->m !== $this->n) {
throw MatrixException::notSquareMatrix();
}

$d = $this->pivsign;
for ($j = 0; $j < $this->n; ++$j) {
$d *= $this->LU[$j][$j];
}

return $d;
return (float) $d;
}

/**
Expand Down
16 changes: 16 additions & 0 deletions tests/Clustering/FuzzyCMeansTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Phpml\Tests\Clustering;

use Phpml\Clustering\FuzzyCMeans;
use Phpml\Exception\InvalidArgumentException;
use PHPUnit\Framework\TestCase;

class FuzzyCMeansTest extends TestCase
Expand Down Expand Up @@ -45,4 +46,19 @@ public function testMembershipMatrix(): void
$this->assertEquals(1, array_sum($col));
}
}

/**
* @dataProvider invalidClusterNumberProvider
*/
public function testInvalidClusterNumber(int $clusters): void
{
$this->expectException(InvalidArgumentException::class);

new FuzzyCMeans($clusters);
}

public function invalidClusterNumberProvider(): array
{
return [[0], [-1]];
}
}
37 changes: 37 additions & 0 deletions tests/Helper/Optimizer/ConjugateGradientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

namespace Phpml\Tests\Helper\Optimizer;

use Phpml\Exception\InvalidArgumentException;
use Phpml\Helper\Optimizer\ConjugateGradient;
use PHPUnit\Framework\TestCase;

Expand Down Expand Up @@ -35,6 +36,34 @@ public function testRunOptimization(): void
$this->assertEquals([-1, 2], $theta, '', 0.1);
}

public function testRunOptimizationWithCustomInitialTheta(): void
{
// 200 samples from y = -1 + 2x (i.e. theta = [-1, 2])
$samples = [];
$targets = [];
for ($i = -100; $i <= 100; ++$i) {
$x = $i / 100;
$samples[] = [$x];
$targets[] = -1 + 2 * $x;
}

$callback = function ($theta, $sample, $target) {
$y = $theta[0] + $theta[1] * $sample[0];
$cost = ($y - $target) ** 2 / 2;
$grad = $y - $target;

return [$cost, $grad];
};

$optimizer = new ConjugateGradient(1);
// set very weak theta to trigger very bad result
$optimizer->setInitialTheta([0.0000001, 0.0000001]);

$theta = $optimizer->runOptimization($samples, $targets, $callback);

$this->assertEquals([-1.087708, 2.212034], $theta, '', 0.000001);
}

public function testRunOptimization2Dim(): void
{
// 100 samples from y = -1 + 2x0 - 3x1 (i.e. theta = [-1, 2, -3])
Expand Down Expand Up @@ -62,4 +91,12 @@ public function testRunOptimization2Dim(): void

$this->assertEquals([-1, 2, -3], $theta, '', 0.1);
}

public function testThrowExceptionOnInvalidTheta(): void
{
$opimizer = new ConjugateGradient(2);

$this->expectException(InvalidArgumentException::class);
$opimizer->setInitialTheta([0.15]);
}
}
31 changes: 31 additions & 0 deletions tests/Math/LinearAlgebra/LUDecompositionTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<?php

declare(strict_types=1);

namespace Phpml\Tests\Math\LinearAlgebra;

use Phpml\Exception\MatrixException;
use Phpml\Math\LinearAlgebra\LUDecomposition;
use Phpml\Math\Matrix;
use PHPUnit\Framework\TestCase;

/**
* LUDecomposition is used and tested in Matrix::inverse method so not all tests are required
*/
final class LUDecompositionTest extends TestCase
{
public function testNotSquareMatrix(): void
{
$this->expectException(MatrixException::class);

new LUDecomposition(new Matrix([1, 2, 3, 4, 5]));
}

public function testSolveWithInvalidMatrix(): void
{
$this->expectException(MatrixException::class);

$lu = new LUDecomposition(new Matrix([[1, 2], [3, 4]]));
$lu->solve(new Matrix([1, 2, 3]));
}
}

0 comments on commit a96f03e

Please sign in to comment.