Skip to content


a basic SymbolModule
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Apr 12, 2016
1 parent 34c9e3f commit 64e7a54
Showing 1 changed file with 286 additions and 0 deletions.
286 changes: 286 additions & 0 deletions python/mxnet/
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# A module is like a FeedForward, but we would like to make it
# easier to be composed. So it is more like the Torch modules.

from . import context as ctx
from . import symbol as sym
from . import ndarray as nd

from .executor_manager import _split_input_slice, _load_data, _load_label
from .base import mx_real_t
from .initializer import Uniform

from collections import namedtuple
import logging

class DataParallelExecutorGroup(object):
def __init__(self, symbol, context, workload, data_shapes, label_shapes, param_names,
for_training, inputs_need_grad, shared_group=None, input_types=None):
self.param_names = param_names
self.arg_names = symbol.list_arguments()
self.aux_names = symbol.list_auxiliary_states()

self.symbol = symbol
self.context = context
self.workload = workload

self.for_training = for_training
self.inputs_need_grad = inputs_need_grad

self.input_types = input_types

if shared_group is not None:
self.shared_data_arrays = shared_group.shared_data_arrays
self.shared_data_arrays = [{} for _ in context]

self.decide_slices(data_shapes, label_shapes)
self.bind_exec(data_shapes, label_shapes, shared_group)

def decide_slices(self, data_shapes, label_shapes):
assert len(data_shapes) > 0
self.batch_size = data_shapes[0][1][0]
for s in data_shapes:
assert s[1][0] == self.batch_size, "all the data must have the same batch size"

self.slices = _split_input_slice(self.batch_size, self.workload)

def bind_exec(self, data_shapes, label_shapes, shared_group):
self.execs = []
for i in range(len(self.context)):
self.execs.append(self._bind_ith_exec(i, data_shapes, label_shapes, shared_group))

# convenient data structures
self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)]
for name, _ in data_shapes]
if label_shapes is not None:
self.label_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)]
for name, _ in label_shapes]
self.label_arrays = None

self.param_arrays = [[e.arg_arrays[i] for e in self.execs]
for i, name in enumerate(self.arg_names)
if name in self.param_names]
self.grad_arrays = [[e.grad_arrays[i] for e in self.execs]
for i, name in enumerate(self.arg_names)
if name in self.param_names]

self.aux_arrays = [[e.aux_arrays[i] for e in self.execs]
for i in range(len(self.aux_names))]

def set_params(self, arg_params, aux_params):
for texec in self.execs:
texec.copy_params_from(arg_params, aux_params)

def forward(self, data_batch):
_load_data(data_batch, self.data_arrays)
if self.for_training:
_load_label(data_batch, self.label_arrays)
for texec in self.execs:

def backward(self):
assert self.for_training, 're-bind with for_training=True to run backward'
for texec in self.execs:

def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
data_shapes = self._sliced_shape(data_shapes, i)
if label_shapes is not None:
label_shapes = self._sliced_shape(label_shapes, i)
shared_exec = None if shared_group is None else shared_group.execs[i]
context = self.context[i]
shared_data_arrays = self.shared_data_arrays[i]

input_shapes = dict(data_shapes)
if label_shapes is not None:

arg_shapes, _, aux_shapes = self.symbol.infer_shape(**input_shapes)
assert arg_shapes is not None, "shape inference failed"

if self.input_types is None:
input_types = {k: mx_real_t for k in input_shapes.keys()}
input_types = self.input_types
arg_types, _, aux_types = self.symbol.infer_type(**input_types)
assert arg_types is not None, "type inference failed"

data_names = [x[0] for x in data_shapes]

arg_arrays = []
grad_arrays = {} if self.for_training else None
grad_req = {}
for name in self.arg_names:
if self.for_training:
if name in self.param_names:
grad_req[name] = 'write'
elif name in data_names:
grad_req[name] = 'write' if self.inputs_need_grad else 'null'
grad_req[name] = 'null'
grad_req[name] == 'null'

# create or borrow arguments and gradients
for j in range(len(self.arg_names)):
name = self.arg_names[j]
if name in self.param_names: # model parameter
if shared_exec is None:
arg_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
if grad_req[name] != 'null':
grad_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
grad_arrays[name] = grad_arr
arg_arr = shared_exec.arg_dict[name]
assert arg_arr.shape == arg_shape[j]
assert arg_arr.dtype == arg_types[j]
if grad_req[name] != 'null':
grad_arrays[name] = shared_exec.grad_dict[name]
else: # data or label
if name in shared_data_arrays:
arg_arr = shared_data_arrays[name]
assert arg_arr.shape == arg_shape[j]
assert arg_arr.dtype == arg_types[j]
arg_arr = nd.zeros(arg_shapes[j], context, dtype=arg_types[j])
shared_data_arrays[name] = arg_arr


# create or borrow aux variables
if shared_exec is None:
aux_arrays = [nd.zeros(s, ctx, dtype=t) for s, t in zip(aux_shapes, aux_types)]
for j, a in enumerate(shared_exec.aux_arrays):
assert aux_shapes[j] == a.shape
assert aux_types[j] == a.dtype
aux_arrays = shared_exec.aux_arrays[:]

executor = self.symbol.bind(ctx=context, args=arg_arrays,
args_grad=grad_arrays, aux_states=aux_arrays,
grad_req=grad_req, shared_exec=shared_exec)
return executor

def _sliced_shape(self, shapes, i):
return [(k, tuple([self.slices[i].stop-self.slices[i].start] + list(v[1:])))
for k, v in shapes]

class BaseModule(object):

class SymbolModule(BaseModule):
def __init__(self, symbol, input_names, context=ctx.cpu(), work_load_list=None):
if isinstance(context, ctx.Context):
context = [context]
self.context = context
if work_load_list is None:
work_load_list = [1] * len(self.context)
assert len(work_load_list) == len(self.context)
self.work_load_list = work_load_list

self.symbol = symbol

arg_names = symbol.list_arguments()
self.param_names = [x for x in arg_names if x not in input_names]
self.aux_names = symbol.list_auxiliary_states()

self.param_initialized = False

def _reset_bind(self):
self.binded = False
self.exec_group = None

# === bind the module ===
# binding a module allocate the memory required to carry out the computation
# on the specific devices (contexts) specified.
def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_group=None):
# force rebinding is typically used when one want to switch from
# training to prediction phase.
if force_rebind:

if self.binded:
logging.warning('Already binded, ignoring bind()')

self.for_training = for_training
self.binded = True

if not for_training:
assert not inputs_need_grad
assert label_shapes is not None

self.exec_group = DataParallelExecutorGroup(self.symbol, self.context, self.work_load_list,
data_shapes, label_shapes, self.param_names,
for_training, inputs_need_grad, shared_group)
if self.param_initialized:
# if the parameters are already initialized, we are re-binding
# so automatically copy the already initialized params
self.exec_group.set_params(self.arg_params, self.aux_params)

def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
allow_missing=False, force_init=False):
"""Initialize the parameters and auxiliary states.
initializer : Initializer
Called to initialize parameters if needed.
arg_params : dict
If not None, should be a dictionary of existing arg_params. Initialization
will be copied from that.
aux_params : dict
If not None, should be a dictionary of existing aux_params. Initialization
will be copied from that.
allow_missing : bool
If true, params could contain missing values, and the initializer will be
called to fill those missing params.
force_init : bool
If true, will force re-initialize even if already initialized.
if self.param_initialized and not force_init:
assert self.binded, 'call bind before initializing the parameters'

param_arrays = [nd.zeros(x[0].shape) for x in self.exec_group.param_arrays]
self.arg_params = {name:arr for name, arr in zip(self.param_names, param_arrays)}

aux_arrays = [nd.zeros(x[0].shape) for x in self.exec_group.aux_arrays]
self.aux_params = {name:arr for name, arr in zip(self.aux_names, aux_arrays)}

def _impl(name, arr, cache):
if cache is not None:
if cache.has_key(name):
assert allow_missing
initializer(name, arr)
initializer(name, arr)

for name, arr in self.arg_params.iteritems():
_impl(name, arr, arg_params)

for name, arr in self.aux_params.iteritems():
_impl(name, arr, aux_params)

self.param_initialized = True

# copy the initialized parameters to devices
self.exec_group.set_params(self.arg_params, self.aux_params)

0 comments on commit 64e7a54

Please sign in to comment.