Skip to content

Commit

Permalink
refine sparse momentum api and unittest (PaddlePaddle#126)
Browse files Browse the repository at this point in the history
* refine sparse momentum api and unittest
* fix unittests bug
  • Loading branch information
luotao1 authored and reyoung committed Sep 29, 2016
1 parent 6decbdf commit 1fc4352
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 55 deletions.
1 change: 1 addition & 0 deletions doc/algorithm/rnn/rnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ We also project the encoder vector to :code:`decoder_size` dimensional space, ge
The decoder uses :code:`recurrent_group` to define the recurrent neural network. The step and output functions are defined in :code:`gru_decoder_with_attention`:

.. code-block:: python
group_inputs=[StaticInput(input=encoded_vector,is_seq=True),
StaticInput(input=encoded_proj,is_seq=True)]
trg_embedding = embedding_layer(
Expand Down
6 changes: 6 additions & 0 deletions doc/ui/api/trainer_config_helpers/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ BaseSGDOptimizer
:members: BaseSGDOptimizer
:noindex:

MomentumOptimizer
=================
.. automodule:: paddle.trainer_config_helpers.optimizers
:members: MomentumOptimizer
:noindex:

AdamOptimizer
=============
.. automodule:: paddle.trainer_config_helpers.optimizers
Expand Down
2 changes: 1 addition & 1 deletion paddle/trainer/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ add_test(NAME test_CompareTwoOpts
COMMAND ${PROJ_ROOT}/paddle/.set_python_path.sh -d ${PROJ_ROOT}/python/
${CMAKE_CURRENT_BINARY_DIR}/test_CompareTwoOpts
--config_file_a=trainer/tests/sample_trainer_config_opt_a.conf --config_file_b=trainer/tests/sample_trainer_config_opt_b.conf
--num_passes=1 --need_high_accuracy=1
--num_passes=1 --need_high_accuracy=0
WORKING_DIRECTORY ${PROJ_ROOT}/paddle/)

################# test_CompareSparse ##################
Expand Down
1 change: 1 addition & 0 deletions paddle/trainer/tests/mnist.list
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
trainer/tests/mnist_bin_part
Binary file added paddle/trainer/tests/mnist_bin_part
Binary file not shown.
47 changes: 22 additions & 25 deletions paddle/trainer/tests/sample_trainer_config_opt_a.conf
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#Todo(luotao02) This config is only used for unitest. It is out of date now, and will be updated later.
from paddle.trainer_config_helpers import *

################################### Data Configuration ###################################
TrainData(ProtoData(files = "train.list"))
TrainData(ProtoData(files = "trainer/tests/mnist.list"))
################################### Algorithm Configuration ###################################
Settings(
learning_rate_decay_a = 0.0,
learning_rate_decay_b = 0.0,
learning_rate = 1e-03,
batch_size = 1000,
algorithm = 'sgd',
num_batches_per_send_parameter = 1,
num_batches_per_get_parameter = 1,
learning_method='sparse_momentum',
)
default_momentum(0.5)
settings(batch_size = 1000,
learning_method = MomentumOptimizer(momentum=0.5, sparse=False))
################################### Network Configuration ###################################
Layer(type = "data", name = "input", size = 784)
Layer(inputs = [Input("input", parameter_name = "_layer1.w")], name = "layer1", bias = Bias(parameter_name = "_layer1.bias"), active_type = "sigmoid", type = "fc", size = 800)
Layer(inputs = [Input("layer1", parameter_name = "_layer2.w")], name = "layer2", bias = Bias(parameter_name = "_layer2.bias"), active_type = "sigmoid", type = "fc", size = 800)
#Layer(inputs = [Input("layer2", parameter_name = "_layer_output.w", decay_rate = 0.02)], name = "output", bias = Bias(parameter_name = "_layer_output.bias"), active_type = "margin", type = "fc", size = 10)
#Layer(inputs = [Input("layer2", parameter_name = "_layer_output.w", decay_rate = 0.02)], name = "output", bias = Bias(parameter_name = "_layer_output.bias"), type = "fc", size = 10)
Layer(inputs = [Input("layer2", parameter_name = "_layer_output.w")], name = "output", bias = Bias(parameter_name = "_layer_output.bias"), active_type = "softmax", type = "fc", size = 10)
Layer(type = "data", name = "label", size = 1)
Layer(inputs = [Input("output"), Input("label")], type = "multi-class-cross-entropy", name = "cost")
#Layer(inputs = [Input("output"), Input("label")], type = "huber", name = "cost")
Evaluator(inputs=["output", "label"], type = "classification_error", name = "classification_error")
Inputs("input", "label")
Outputs("cost")
data = data_layer(name ="input", size=784)

fc1 = fc_layer(input=data, size=800,
bias_attr=True,
act=SigmoidActivation())

fc2 = fc_layer(input=fc1, size=800,
bias_attr=True,
act=SigmoidActivation())

output = fc_layer(input=[fc1, fc2], size=10,
bias_attr=True,
act=SoftmaxActivation())

lbl = data_layer(name ="label", size=1)

cost = classification_cost(input=output, label=lbl)
outputs(cost)
47 changes: 22 additions & 25 deletions paddle/trainer/tests/sample_trainer_config_opt_b.conf
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#Todo(luotao02) This config is only used for unitest. It is out of date now, and will be updated later.
from paddle.trainer_config_helpers import *

################################### Data Configuration ###################################
TrainData(ProtoData(files = "train.list"))
TrainData(ProtoData(files = "trainer/tests/mnist.list"))
################################### Algorithm Configuration ###################################
Settings(
learning_rate_decay_a = 0.0,
learning_rate_decay_b = 0.0,
learning_rate = 1e-03,
batch_size = 1000,
algorithm = 'sgd',
num_batches_per_send_parameter = 1,
num_batches_per_get_parameter = 1,
learning_method='momentum',
)
default_momentum(0.5)
settings(batch_size = 1000,
learning_method = MomentumOptimizer(momentum=0.5, sparse=False))
################################### Network Configuration ###################################
Layer(type = "data", name = "input", size = 784)
Layer(inputs = [Input("input", parameter_name = "_layer1.w")], name = "layer1", bias = Bias(parameter_name = "_layer1.bias"), active_type = "sigmoid", type = "fc", size = 800)
Layer(inputs = [Input("layer1", parameter_name = "_layer2.w")], name = "layer2", bias = Bias(parameter_name = "_layer2.bias"), active_type = "sigmoid", type = "fc", size = 800)
#Layer(inputs = [Input("layer2", parameter_name = "_layer_output.w", decay_rate = 0.02)], name = "output", bias = Bias(parameter_name = "_layer_output.bias"), active_type = "margin", type = "fc", size = 10)
#Layer(inputs = [Input("layer2", parameter_name = "_layer_output.w", decay_rate = 0.02)], name = "output", bias = Bias(parameter_name = "_layer_output.bias"), type = "fc", size = 10)
Layer(inputs = [Input("layer2", parameter_name = "_layer_output.w")], name = "output", bias = Bias(parameter_name = "_layer_output.bias"), active_type = "softmax", type = "fc", size = 10)
Layer(type = "data", name = "label", size = 1)
Layer(inputs = [Input("output"), Input("label")], type = "multi-class-cross-entropy", name = "cost")
#Layer(inputs = [Input("output"), Input("label")], type = "huber", name = "cost")
Evaluator(inputs=["output", "label"], type = "classification_error", name = "classification_error")
Inputs("input", "label")
Outputs("cost")
data = data_layer(name ="input", size=784)

fc1 = fc_layer(input=data, size=800,
bias_attr=True,
act=SigmoidActivation())

fc2 = fc_layer(input=fc1, size=800,
bias_attr=True,
act=SigmoidActivation())

output = fc_layer(input=[fc1, fc2], size=10,
bias_attr=True,
act=SoftmaxActivation())

lbl = data_layer(name ="label", size=1)

cost = classification_cost(input=output, label=lbl)
outputs(cost)
33 changes: 29 additions & 4 deletions python/paddle/trainer_config_helpers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,41 @@ def to_setting_kwargs(self):


class MomentumOptimizer(BaseSGDOptimizer):
"""
MomentumOptimizer.
When sparse=True, the update scheme:
.. math::
\\alpha_t &= \\alpha_{t-1} / k \\\\
\\beta_t &= \\beta_{t-1} / (1 + \\lambda \\gamma_t) \\\\
u_t &= u_{t-1} - \\alpha_t \\gamma_t g_t \\\\
v_t &= v_{t-1} + \\tau_{t-1} \\alpha_t \\gamma_t g_t \\\\
\\tau_t &= \\tau_{t-1} + \\beta_t / \\alpha_t
where :math:`k` is momentum, :math:`\\lambda` is decay rate,
:math:`\\gamma_t` is learning rate at the t'th step.
:param sparse: with sparse support or not.
:type sparse: bool
"""
def extra_settings(self):
default_momentum(self.momentum)

def to_setting_kwargs(self):
return {
'learning_method': 'momentum'
}
if self.sparse:
return {
'learning_method': 'sparse_momentum'
}
else:
return {
'learning_method': 'momentum'
}

def __init__(self, momentum=None):
def __init__(self, momentum=None, sparse=False):
self.momentum = momentum
self.sparse = sparse


class AdamOptimizer(BaseSGDOptimizer):
Expand Down

0 comments on commit 1fc4352

Please sign in to comment.