Skip to content

Commit

Permalink
fix more lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Apr 12, 2016
1 parent 5362a0a commit 6565aeb
Showing 1 changed file with 87 additions and 13 deletions.
100 changes: 87 additions & 13 deletions python/mxnet/module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# pylint: disable=too-many-lines, too-many-arguments
"""A module is like a FeedForward model. but we would like to make it
easier to be composed. So it is more like the Torch modules.
"""

import logging
import time
import numpy as np

from . import context as ctx
from . import ndarray as nd
from . import optimizer as opt
Expand All @@ -13,10 +18,6 @@
from .model import BatchEndParam
from .initializer import Uniform

import logging
import time
import numpy as np

def _as_list(obj):
"""A utility function that treat the argument as a list.
Expand Down Expand Up @@ -166,8 +167,8 @@ def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for cb in _as_list(batch_end_callback):
cb(batch_end_params)
for callback in _as_list(batch_end_callback):
callback(batch_end_params)

def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True,
always_output_list=False):
Expand Down Expand Up @@ -210,18 +211,19 @@ def predict(self, eval_data, num_batch=None, merge_batches=True, reset=True,
break
self.forward(eval_batch, is_train=False)
pad = eval_batch.pad
outputs = [out[0:out.shape[0]-pad] for out in self.get_outputs()]
output_list.append(self.get_outputs())

if len(output_list) == 0:
return output_list

if merge_batches:
num_outputs = len(output_list[0])
for o in output_list:
assert len(o) == num_outputs, \
for out in output_list:
assert len(out) == num_outputs, \
'Cannot merge batches, as num of outputs is not the same ' + \
'in mini-batches. Maybe bucketing is used?'
output_list2 = [np.concatenate([o[i] for o in output_list])
output_list2 = [np.concatenate([out[i] for out in output_list])
for i in range(num_outputs)]

if num_outputs == 1:
Expand Down Expand Up @@ -309,8 +311,8 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for cb in _as_list(batch_end_callback):
cb(batch_end_params)
for callback in _as_list(batch_end_callback):
callback(batch_end_params)

# one epoch of training is finished
for name, val in eval_metric.get_name_value():
Expand All @@ -323,8 +325,8 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
self.sync_params_from_devices()

if epoch_end_callback is not None:
for cb in _as_list(epoch_end_callback):
cb(epoch, self.symbol, self.arg_params, self.aux_params)
for callback in _as_list(epoch_end_callback):
callback(epoch, self.symbol, self.arg_params, self.aux_params)

#----------------------------------------
# evaluation on validation set
Expand Down Expand Up @@ -398,6 +400,73 @@ def update_metric(self, eval_metric, labels):
"""
raise NotImplementedError()

def bind(self, data_shapes, label_shapes=None, for_training=True,
inputs_need_grad=False, force_rebind=False, shared_module=None):
"""Bind the symbols to construct executors. This is necessary before one
can perform computation with the module.
Parameters
----------
data_shapes : list of (str, tuple)
Typically is `data_iter.provide_data`.
label_shapes : list of (str, tuple)
Typically is `data_iter.provide_label`.
for_training : bool
Default is `True`. Whether the executors should be bind for training.
inputs_need_grad : bool
Default is `False`. Whether the gradients to the input data need to be computed.
Typically this is not needed. But this might be needed when implementing composition
of modules.
force_rebind : bool
Default is `False`. This function does nothing if the executors are already
binded. But with this `True`, the executors will be forced to rebind.
shared_module : Module
Default is `None`. This is used in bucketing. When not `None`, the shared module
essentially corresponds to a different bucket -- a module with different symbol
but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
"""
raise NotImplementedError()

def init_params(self, initializer=Uniform(0.01), arg_params=None, aux_params=None,
allow_missing=False, force_init=False):
"""Initialize the parameters and auxiliary states.
Parameters
----------
initializer : Initializer
Called to initialize parameters if needed.
arg_params : dict
If not None, should be a dictionary of existing arg_params. Initialization
will be copied from that.
aux_params : dict
If not None, should be a dictionary of existing aux_params. Initialization
will be copied from that.
allow_missing : bool
If true, params could contain missing values, and the initializer will be
called to fill those missing params.
force_init : bool
If true, will force re-initialize even if already initialized.
"""
raise NotImplementedError()

def init_optimizer(self, kvstore='local', optimizer='sgd', optimizer_params={},
force_init=False):
"""Install and initialize optimizers.
Parameters
----------
kvstore : str or KVStore
Default `'local'`.
optimizer : str or Optimizer
Default `'sgd'`
optimizer_params : dict
Default `{}`
force_init : bool
Default `False`, indicating whether we should force re-initializing the
optimizer in the case an optimizer is already installed.
"""
raise NotImplementedError()


class Module(BaseModule):
"""Module is a basic module that wrap a `Symbol`. It is functionally the same
Expand Down Expand Up @@ -437,6 +506,11 @@ def __init__(self, symbol, input_names=['data', 'softmax_label'], logger=logging
self.arg_params = None
self.aux_params = None

self.optimizer = None
self.kvstore = None
self.update_on_kvstore = None
self.updater = None

self._reset_bind()

def _reset_bind(self):
Expand Down

0 comments on commit 6565aeb

Please sign in to comment.