Skip to content

Commit

Permalink
move batchsize inside
Browse files Browse the repository at this point in the history
yuanyang committed Nov 22, 2016

Verified

This commit was signed with the committer’s verified signature.
tomlokhorst Tom Lokhorst
1 parent 5f3f565 commit 551bf25
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion center_loss.py
Original file line number Diff line number Diff line change
@@ -65,7 +65,7 @@ def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
sum_ = aux[2]

# back grad is just scale * ( x_i - c_yi)
grad_scale = float(self.scale)
grad_scale = float(self.scale/self.batch_size)
self.assign(in_grad[0], req[0], diff * grad_scale)

# update the center
2 changes: 1 addition & 1 deletion data.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import os

# code to automatically download dataset
mxnet_root = ''
mxnet_root = '/home/slu/build/mxnet'
sys.path.append(os.path.join( mxnet_root, 'tests/python/common'))
import get_data
import mxnet as mx
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
@@ -69,10 +69,10 @@ def get_symbol(batchsize=64):
# second fullc
fc2 = mx.symbol.FullyConnected(data=embedding, num_hidden=10, name='fc2')


ce_loss = mx.symbol.SoftmaxOutput(data=fc2, label=softmax_label, name='softmax')

center_loss_ = mx.symbol.Custom(data=embedding, label=center_label, name='center_loss_', op_type='centerloss',\
num_class=10, alpha=0.5, scale=0.01, batchsize=batchsize)
num_class=10, alpha=0.5, scale=1.0, batchsize=batchsize)
center_loss = mx.symbol.MakeLoss(name='center_loss', data=center_loss_)
mlp = mx.symbol.Group([ce_loss, center_loss])

0 comments on commit 551bf25

Please sign in to comment.