Skip to content

Commit

Permalink
lr mult update
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Nov 7, 2015
1 parent 729e2b5 commit 169a686
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
8 changes: 3 additions & 5 deletions example/autoencoder/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,12 @@ def set_iter_start_callback(self, callback):
self.iter_start_callback = callback

def solve(self, xpu, sym, args, args_grad, input_names,
data_iter, begin_epoch, end_epoch, debug = False, args_lrmult=None):
if args_lrmult is None:
args_lrmult = {}

data_iter, begin_epoch, end_epoch, debug = False, args_lrmult={}):
data_iter.reset()
input_dict = {key: mx.nd.empty(arr.shape, ctx=xpu) for key, arr in zip(input_names, data_iter.next())}
batch_size = input_dict.values()[0].shape[0]
self.optimizer.rescale_grad = 1.0/batch_size
self.optimizer.set_lr_mult(args_lrmult)
args = dict(args, **input_dict)

output_names = sym.list_outputs()
Expand Down Expand Up @@ -108,7 +106,7 @@ def solve(self, xpu, sym, args, args_grad, input_names,
exe.backward()
self.optimizer.begin_epoch(i)
for key, arr in update_dict.items():
self.updater(key, arr, args[key], args_lrmult.get(key, 1.0))
self.updater(key, arr, args[key])

exe.outputs[0].wait_to_read()
if self.metric is not None:
Expand Down
23 changes: 17 additions & 6 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def create_optimizer(name, rescale_grad=1, **kwargs):
def __init__(self, rescale_grad=1):
self.epoch = 0
self.rescale_grad = rescale_grad
self.lr_mult = {}

def begin_epoch(self, epoch):
"""Function called to notify beginning of epoch.
Expand All @@ -66,9 +67,19 @@ def create_state(self, index, weight):
"""Create additional optimizer state such as momentum.
override in implementations."""

def update(self, index, weight, grad, state, lr_mult=1.0):
def update(self, index, weight, grad, state):
"""Update the parameters. override in implementations"""

def set_lr_mult(self, args_lrmult):
"""Set individual learning rate multiplers for parameters
Parameters
----------
args_lrmult : dict of index to float
set the lr multipler for index to float
"""
self.lr_mult = args_lrmult.copy()

#convenience wrapper for Optimizer.Register
register = Optimizer.register

Expand Down Expand Up @@ -120,7 +131,7 @@ def create_state(self, index, weight):
else:
return zeros(weight.shape, weight.context)

def update(self, index, weight, grad, state, lr_mult=1.0):
def update(self, index, weight, grad, state):
"""Update the parameters.
Parameters
Expand All @@ -144,7 +155,7 @@ def update(self, index, weight, grad, state, lr_mult=1.0):
lr = self.lr_scheduler(self.epoch)
else:
lr = self.lr
lr *= lr_mult
lr *= self.lr_mult.get(index, default=1.0)

grad = grad * self.rescale_grad
if self.clip_gradient != None:
Expand All @@ -170,7 +181,7 @@ def create_state(self, index, weight):
"""Create a state to duplicate weight"""
return zeros(weight.shape, weight.context)

def update(self, index, weight, grad, state, lr_mult=1.0):
def update(self, index, weight, grad, state):
"""performs w += rescale_grad * grad"""
weight[:] += grad * self.rescale_grad
state[:] = weight
Expand All @@ -192,9 +203,9 @@ def get_updater(optimizer):
The clossure of the updater
"""
states = dict()
def updater(index, grad, weight, lr_mult=1.0):
def updater(index, grad, weight):
"""updater for kvstore"""
if index not in states:
states[index] = optimizer.create_state(index, weight)
optimizer.update(index, weight, grad, states[index], lr_mult=1.0)
optimizer.update(index, weight, grad, states[index])
return updater

0 comments on commit 169a686

Please sign in to comment.