Skip to content

Commit

Permalink
fix update of CenterLossMetric
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyetx committed Dec 14, 2016
1 parent 93b629b commit 7e9dfc6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions center_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def __init__(self):
super(CenterLossMetric, self).__init__('center_loss')

def update(self, labels, preds):
self.sum_metric = + preds[1].asnumpy()[0]
self.num_inst = 1
self.sum_metric += preds[1].asnumpy()[0]
self.num_inst += 1


# see details:
Expand All @@ -41,7 +41,7 @@ class CenterLoss(mx.operator.CustomOp):
def __init__(self, ctx, shapes, dtypes, num_class, alpha, scale=1.0):
if not len(shapes[0]) == 2:
raise ValueError('dim for input_data shoudl be 2 for CenterLoss')

self.alpha = alpha
self.batch_size = shapes[0][0]
self.num_class = num_class
Expand All @@ -51,7 +51,7 @@ def forward(self, is_train, req, in_data, out_data, aux):
labels = in_data[1].asnumpy()
diff = aux[0]
center = aux[1]

# store x_i - c_yi
for i in range(self.batch_size):
diff[i] = in_data[0][i] - center[int(labels[i])]
Expand Down

0 comments on commit 7e9dfc6

Please sign in to comment.