Skip to content

Commit

Permalink
Add settings documentation
Browse files Browse the repository at this point in the history
Also fix link in quick_start

ISSUE=4611783



git-svn-id: https://svn.baidu.com/idl/trunk/paddle@1474 1ad973e4-5ce8-4261-8a94-b56d1f490c56
  • Loading branch information
yuyang18 committed Sep 1, 2016
1 parent 7f1d6c5 commit 200dfa1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
2 changes: 1 addition & 1 deletion doc_cn/demo/quick_start/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

## 安装(Install)

首先请参考<a href = "../../build/index.html">安装教程</a>安装PaddlePaddle。
首先请参考<a href = "../../build_and_install/install/index.html">安装教程</a>安装PaddlePaddle。

## 使用概述(Overview)

Expand Down
48 changes: 29 additions & 19 deletions python/paddle/trainer_config_helpers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@ class BaseSGDOptimizer(Optimizer):
w = w - \\eta \\nabla Q(w) = w - \\eta \\sum_{i}^{n} \\nabla Q_i(w)
where :math:`\\eta` is learning rate. And :math:`n` is batch size.
The SGD method is implemented by paddle with multiple extensions. Such as
momentum, adagrad, rmsprop, adam. Please use method 'use_xxx', such as
use_adam, to enhance the SGD method.
WARNING: IN PADDLE'S IMPLEMENTATION, BATCH_SIZE IS SET FOR ONE COMPUTE
PROCESS(NODE). IF YOU USE MULTIPLE MACHINE TO TRAIN YOUR NETWORK, THE GLOBAL
BATCH SIZE WILL BE (BATCH_SIZE * MACHINE_COUNT).
"""

def to_setting_kwargs(self):
Expand Down Expand Up @@ -352,17 +344,35 @@ def settings(batch_size,
gradient_clipping_threshold=None
):
"""
TODO(yuyang18): Complete docs.
:param batch_size:
:param learning_rate:
:param learning_method:
:param regularization:
:param is_async:
:param model_average:
:param gradient_clipping_threshold:
:return:
Set the optimization method, learning rate, batch size, and other training
settings. The currently supported algorithms are SGD and Async-SGD.
.. warning::
Note that the 'batch_size' in PaddlePaddle is not equal to global
training batch size. It represents the single training process's batch
size. If you use N processes to train one model, for example use three
GPU machines, the global batch size is N*'batch_size'.
:param batch_size: batch size for one training process.
:type batch_size: int
:param learning_rate: learning rate for SGD
:type learning_rate: float
:param learning_method: The extension optimization algorithms of gradient
descent, such as momentum, adagrad, rmsprop, etc.
Note that it should be instance with base type
BaseSGDOptimizer.
:type learning_method: BaseSGDOptimizer
:param regularization: The regularization method.
:type regularization: BaseRegularization
:param is_async: Is Async-SGD or not. Default value is False.
:type is_async: bool
:param model_average: Model Average Settings.
:type model_average: ModelAverage
:param gradient_clipping_threshold: gradient clipping threshold. If gradient
value larger than some value, will be
clipped.
:type gradient_clipping_threshold: float
"""
if isinstance(regularization, BaseRegularization):
regularization = [regularization]
Expand Down

0 comments on commit 200dfa1

Please sign in to comment.