Skip to content

Commit

Permalink
Ability to update learningRate in MLP (#160)
Browse files Browse the repository at this point in the history
* Allow people to update the learning rate

* Test for learning rate setter
  • Loading branch information
dmonllao authored and akondas committed Dec 5, 2017
1 parent c4f58f7 commit c4ad117
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ $mlp->partialTrain(
```

You can update the learning rate between partialTrain runs:

```
$mlp->setLearningRate(0.1);
```

## Predict

To predict sample label use predict method. You can provide one sample or array of samples:
Expand Down
6 changes: 6 additions & 0 deletions src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ public function partialTrain(array $samples, array $targets, array $classes = []
}
}

public function setLearningRate(float $learningRate): void
{
$this->learningRate = $learningRate;
$this->backpropagation->setLearningRate($this->learningRate);
}

/**
* @param mixed $target
*/
Expand Down
5 changes: 5 additions & 0 deletions src/Phpml/NeuralNetwork/Training/Backpropagation.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ class Backpropagation
private $prevSigmas = null;

public function __construct(float $learningRate)
{
$this->setLearningRate($learningRate);
}

public function setLearningRate(float $learningRate): void
{
$this->learningRate = $learningRate;
}
Expand Down
28 changes: 28 additions & 0 deletions tests/Phpml/NeuralNetwork/Network/MultilayerPerceptronTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
<?php

declare(strict_types=1);

namespace tests\Phpml\NeuralNetwork\Network;

use Phpml\NeuralNetwork\Network\MultilayerPerceptron;
use PHPUnit\Framework\TestCase;

class MultilayerPerceptronTest extends TestCase
{
public function testLearningRateSetter(): void
{
$mlp = $this->getMockForAbstractClass(
MultilayerPerceptron::class,
[5, [3], [0, 1], 1000, null, 0.42]
);

$this->assertEquals(0.42, $this->readAttribute($mlp, 'learningRate'));
$backprop = $this->readAttribute($mlp, 'backpropagation');
$this->assertEquals(0.42, $this->readAttribute($backprop, 'learningRate'));

$mlp->setLearningRate(0.24);
$this->assertEquals(0.24, $this->readAttribute($mlp, 'learningRate'));
$backprop = $this->readAttribute($mlp, 'backpropagation');
$this->assertEquals(0.24, $this->readAttribute($backprop, 'learningRate'));
}
}

0 comments on commit c4ad117

Please sign in to comment.