Skip to content

Commit

Permalink
an (untested) example of lstm-bucketing using module
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Apr 12, 2016
1 parent bd0c9d6 commit 05ec8fb
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 6 deletions.
74 changes: 74 additions & 0 deletions example/module/lstm_bucketing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
sys.path.insert(0, "../../python")
sys.path.insert(0, "../rnn")
import numpy as np
import mxnet as mx

from lstm import lstm_unroll
from bucket_io import BucketSentenceIter, default_build_vocab

def Perplexity(label, pred):
loss = 0.
for i in range(pred.shape[0]):
loss += -np.log(max(1e-10, pred[i][int(label[i])]))
return np.exp(loss / label.size)

if __name__ == '__main__':
batch_size = 32
buckets = [10, 20, 30, 40, 50, 60]
#buckets = [32]
num_hidden = 200
num_embed = 200
num_lstm_layer = 2

num_epoch = 25
learning_rate = 0.01
momentum = 0.0

# dummy data is used to test speed without IO
dummy_data = False

contexts = [mx.context.gpu(i) for i in range(1)]

vocab = default_build_vocab("./data/ptb.train.txt")

def sym_gen(seq_len):
return lstm_unroll(num_lstm_layer, seq_len, len(vocab),
num_hidden=num_hidden, num_embed=num_embed,
num_label=len(vocab))

init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
init_states = init_c + init_h

data_train = BucketSentenceIter("./data/ptb.train.txt", vocab,
buckets, batch_size, init_states)
data_val = BucketSentenceIter("./data/ptb.valid.txt", vocab,
buckets, batch_size, init_states)

if dummy_data:
data_train = DummyIter(data_train)
data_val = DummyIter(data_val)

default_input_names = [x[0] for x in (data_train.provide_data + data_train.provide_label)]
if len(buckets) == 1:
mod = mx.mod.Module(sym_gen(buckets[0]), input_names=default_input_names,
context=contexts)
else:
mod = mx.mod.BucketModule(sym_gen, default_bucket_key=buckets[0],
default_input_names=default_input_names,
context=contexts)

import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

mod.fit(data_train, eval_data=data_val,
eval_metric=mx.metric.np(Perplexity),
batch_end_callback=mx.callback.Speedometer(batch_size, 50),
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
optimizer='sgd',
optimizer_params={'learning_rate':0.01, 'momentum', 0.9, 'wd': 0.00001})

25 changes: 19 additions & 6 deletions python/mxnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',


class Module(BaseModule):
def __init__(self, symbol, input_names, logger=logging, context=ctx.cpu(), work_load_list=None):
def __init__(self, symbol, input_names=['data', 'softmax_label'], logger=logging,
context=ctx.cpu(), work_load_list=None):
super(Module, self).__init__(logger=logger)

if isinstance(context, ctx.Context):
Expand Down Expand Up @@ -544,12 +545,14 @@ def sync_params_from_devices(self):


class BucketingModule(BaseModule):
def __init__(self, sym_gen, default_bucket_key=None, logger=logging,
context=ctx.cpu(), work_load_list=None):
def __init__(self, sym_gen, default_bucket_key=None, default_input_names=None,
logger=logging, context=ctx.cpu(), work_load_list=None):
super(BucketModule, self).__init__(logger=logger)

assert default_bucket_key is not None
assert default_input_names is not None, 'please specify input names for the default bucket'
self.default_bucket_key = default_bucket_key
self.default_input_names = default_input_names

self.sym_gen = sym_gen
self.context = context
Expand All @@ -562,6 +565,16 @@ def _reset_bind(self):
self.buckets = {}
self.curr_module = None

def _gen_symbol(self, key):
assert self.binded
symbol = self.sym_gen(self.default_bucket_key)
arg_names = symbol.list_arguments()

# we assume in the bucketing case, all symbols have the same set of parameters,
# and all the rest of the arguments are considered as input names
input_names = [x for x in arg_names if not x in self.curr_module.param_names]
return symbol, input_names

def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False):
# force rebinding is typically used when one want to switch from
Expand All @@ -574,8 +587,8 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
return

self.binded = True
symbol, input_names = self.sym_gen(self.default_bucket_key)
module = Module(symbol, input_names, logger=self.logger, context=self.context,
symbol = self.sym_gen(self.default_bucket_key)
module = Module(symbol, self.default_input_names, logger=self.logger, context=self.context,
work_load_list=self.work_load_list)
module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
force_rebind=False, shared_module=None)
Expand All @@ -595,7 +608,7 @@ def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=Non
def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
assert self.binded, 'call bind before switching bucket'
if not self.buckets.has_key(bucket_key):
symbol, input_names = self.sym_gen(bucket_key)
symbol, input_names = self._gen_symbol(bucket_key)
module = Module(symbol, input_names, logger=self.logger, context=self.context,
work_load_list=self.work_load_list)
module.bind(data_shapes, label_shapes, self.curr_module.for_training,
Expand Down

0 comments on commit 05ec8fb

Please sign in to comment.