Skip to content

Commit

Permalink
rename mod to module
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Apr 12, 2016
1 parent 2253cfa commit 94482b2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
3 changes: 2 additions & 1 deletion python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from . import torch
from . import torch as th

from . import mod
from . import module
from . import module as mod

__version__ = base.__version__
11 changes: 9 additions & 2 deletions python/mxnet/mod.py → python/mxnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def backward(self):
for texec in self.execs:
texec.backward()

def update_metric(self, metric, labels):
for texec, islice in zip(self.execs, self.slices):
labels_slice = [label[islice] for label in labels]
metric.update(labels_slice, texec.outputs)

def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
data_shapes = self._sliced_shape(data_shapes, i)
Expand Down Expand Up @@ -182,9 +186,9 @@ def train(self, train_data, valid_data):
pass


class SymbolModule(BaseModule):
class Module(BaseModule):
def __init__(self, symbol, input_names, context=ctx.cpu(), work_load_list=None):
super(SymbolModule, self).__init__()
super(Module, self).__init__()

if isinstance(context, ctx.Context):
context = [context]
Expand Down Expand Up @@ -347,3 +351,6 @@ def update(self):
updater=self.updater,
num_device=len(self.context),
kvstore=self.kvstore)

def update_metric(self, metric, labels):
self.exec_group.update_metric(metric, labels)

0 comments on commit 94482b2

Please sign in to comment.