diff --git a/docs/machine-learning/neural-network/multilayer-perceptron-classifier.md b/docs/machine-learning/neural-network/multilayer-perceptron-classifier.md index a6b060a4..72d0b4be 100644 --- a/docs/machine-learning/neural-network/multilayer-perceptron-classifier.md +++ b/docs/machine-learning/neural-network/multilayer-perceptron-classifier.md @@ -45,6 +45,12 @@ $mlp->partialTrain( ``` +You can update the learning rate between partialTrain runs: + +``` +$mlp->setLearningRate(0.1); +``` + ## Predict To predict sample label use predict method. You can provide one sample or array of samples: diff --git a/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php b/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php index a38e952a..5ace597c 100644 --- a/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php +++ b/src/Phpml/NeuralNetwork/Network/MultilayerPerceptron.php @@ -95,6 +95,12 @@ public function partialTrain(array $samples, array $targets, array $classes = [] } } + public function setLearningRate(float $learningRate): void + { + $this->learningRate = $learningRate; + $this->backpropagation->setLearningRate($this->learningRate); + } + /** * @param mixed $target */ diff --git a/src/Phpml/NeuralNetwork/Training/Backpropagation.php b/src/Phpml/NeuralNetwork/Training/Backpropagation.php index 8382a8e4..df515b21 100644 --- a/src/Phpml/NeuralNetwork/Training/Backpropagation.php +++ b/src/Phpml/NeuralNetwork/Training/Backpropagation.php @@ -25,6 +25,11 @@ class Backpropagation private $prevSigmas = null; public function __construct(float $learningRate) + { + $this->setLearningRate($learningRate); + } + + public function setLearningRate(float $learningRate): void { $this->learningRate = $learningRate; } diff --git a/tests/Phpml/NeuralNetwork/Network/MultilayerPerceptronTest.php b/tests/Phpml/NeuralNetwork/Network/MultilayerPerceptronTest.php new file mode 100644 index 00000000..c244c276 --- /dev/null +++ b/tests/Phpml/NeuralNetwork/Network/MultilayerPerceptronTest.php @@ -0,0 +1,28 @@ +getMockForAbstractClass( + MultilayerPerceptron::class, + [5, [3], [0, 1], 1000, null, 0.42] + ); + + $this->assertEquals(0.42, $this->readAttribute($mlp, 'learningRate')); + $backprop = $this->readAttribute($mlp, 'backpropagation'); + $this->assertEquals(0.42, $this->readAttribute($backprop, 'learningRate')); + + $mlp->setLearningRate(0.24); + $this->assertEquals(0.24, $this->readAttribute($mlp, 'learningRate')); + $backprop = $this->readAttribute($mlp, 'backpropagation'); + $this->assertEquals(0.24, $this->readAttribute($backprop, 'learningRate')); + } +}