Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanyang committed Oct 10, 2016
0 parents commit 93fb66c
Show file tree
Hide file tree
Showing 7 changed files with 519 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
result
.DS_Store
data
*.pyc
log.txt
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# center_loss_mxnet
# mxnet_center_loss
123 changes: 123 additions & 0 deletions center_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import os

# MXNET_CPU_WORKER_NTHREADS must be greater than 1 for custom op to work on CPU
os.environ['MXNET_CPU_WORKER_NTHREADS'] = '2'
import mxnet as mx

# define metric of accuracy
class accuracy(mx.metric.EvalMetric):
def __init__(self, num=None):
super(accuracy, self).__init__('accuracy', num)

def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)

if self.num != None:
assert len(labels) == self.num

pred_label = mx.nd.argmax_channel(preds[0]).asnumpy().astype('int32')
label = labels[0].asnumpy().astype('int32')

mx.metric.check_label_shapes(label, pred_label)

self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)


# define some metric of center_loss
class center_loss_metric(mx.metric.EvalMetric):
def __init__(self):
super(center_loss_metric, self).__init__('center_loss')

def update(self, labels, preds):
self.sum_metric = + preds[1].asnumpy()[0]
self.num_inst = 1


# see details:
# <A Discriminative Feature Learning Approach for Deep Face Recogfnition>
class CenterLoss(mx.operator.CustomOp):
def __init__(self, ctx, shapes, dtypes, num_class, alpha, scale=1.0):
if not len(shapes[0]) == 2:
raise ValueError('dim for input_data shoudl be 2 for CenterLoss')

self.alpha = alpha
self.batch_size = shapes[0][0]
self.num_class = num_class
self.scale = scale

def forward(self, is_train, req, in_data, out_data, aux):
# can not access ndarray using (i,j)
labels = in_data[1].asnumpy()
diff = aux[0]
center = aux[1]

# store x_i - c_yi
for i in range(self.batch_size):
diff[i] = in_data[0][i] - center[int(labels[i])]

loss = mx.nd.sum(mx.nd.square(diff)) / self.batch_size / 2
self.assign(out_data[0], req[0], loss)

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
diff = aux[0]
center = aux[1]
sum_ = aux[2]

# back grad is just scale * ( x_i - c_yi)
grad_scale = float(self.scale)
self.assign(in_grad[0], req[0], diff * grad_scale)

# update the center
labels = in_data[1].asnumpy()
label_occur = dict()
for i, label in enumerate(labels):
label_occur.setdefault(int(label), []).append(i)

for label, sample_index in label_occur.items():
sum_[:] = 0
for i in sample_index:
sum_ = sum_ + diff[i]
delta_c = sum_ / (1 + len(sample_index))
center[label] += self.alpha * delta_c


@mx.operator.register("centerloss")
class CenterLossProp(mx.operator.CustomOpProp):
def __init__(self, num_class, alpha, scale=1.0, batchsize=64):
super(CenterLossProp, self).__init__(need_top_grad=False)

# convert it to numbers
self.num_class = int(num_class)
self.alpha = float(alpha)
self.scale = float(scale)
self.batchsize = int(batchsize)

def list_arguments(self):
return ['data', 'label']

def list_outputs(self):
return ['output']

def list_auxiliary_states(self):
# call them jsut bias for zero initialization
return ['diff_bias', 'center_bias', 'sum_bias']

def infer_shape(self, in_shape):
data_shape = in_shape[0]
label_shape = (in_shape[0][0],)

# store diff , same shape as input batch
diff_shape = [self.batchsize, data_shape[1]]

# store the center of each class , should be ( num_class, d )
center_shape = [self.num_class, diff_shape[1]]

# computation buf
sum_shape = [diff_shape[1],]

output_shape = [1, ]
return [data_shape, label_shape], [output_shape], [diff_shape, center_shape, sum_shape]

def create_operator(self, ctx, shapes, dtypes):
return CenterLoss(ctx, shapes, dtypes, self.num_class, self.alpha, self.scale)
64 changes: 64 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# pylint: skip-file
""" data iterator for mnist """
import sys
import os
# code to automatically download dataset
sys.path.append(os.path.join("/Users/yuanyang/Build/mxnet/tests/python/common"))
import get_data
import mxnet as mx


class custom_mnist_iter(mx.io.DataIter):
def __init__(self, mnist_iter):
super(custom_mnist_iter,self).__init__()
self.data_iter = mnist_iter
self.batch_size = self.data_iter.batch_size

@property
def provide_data(self):
return self.data_iter.provide_data

@property
def provide_label(self):
provide_label = self.data_iter.provide_label[0]
# Different labels should be used here for actual application
return [('softmax_label', provide_label[1]), \
('center_label', provide_label[1])]

def hard_reset(self):
self.data_iter.hard_reset()

def reset(self):
self.data_iter.reset()

def next(self):
batch = self.data_iter.next()
label = batch.label[0]

return mx.io.DataBatch(data=batch.data, label=[label,label], \
pad=batch.pad, index=batch.index)



def mnist_iterator(batch_size, input_shape):
"""return train and val iterators for mnist"""
# download data
get_data.GetMNIST_ubyte()
flat = False if len(input_shape) == 3 else True

train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
shuffle=True,
flat=flat)

val_dataiter = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
input_shape=input_shape,
batch_size=batch_size,
flat=flat)

return (custom_mnist_iter(train_dataiter), custom_mnist_iter(val_dataiter))
97 changes: 97 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import mxnet as mx
import numpy as np
from center_loss import *
from data import mnist_iterator
import logging
import train_model
import argparse

parser = argparse.ArgumentParser(description='train mnist use softmax and centerloss')
parser.add_argument('--gpus', type=str, default='2,3',
help='the gpus will be used, e.g "0,1,2,3"')
parser.add_argument('--batch-size', type=int, default=100,
help='the batch size')
parser.add_argument('--num-examples', type=int, default=60000,
help='the number of training examples')
parser.add_argument('--lr', type=float, default=.01,
help='the initial learning rate')
parser.add_argument('--lr-factor', type=float, default=0.5,
help='times the lr with a factor for every lr-factor-epoch epoch')
parser.add_argument('--lr-factor-epoch', type=float, default=20,
help='the number of epoch to factor the lr, could be .5')
parser.add_argument('--model-prefix', type=str,
help='the prefix of the model to load')
parser.add_argument('--save-model-prefix', type=str,default='center_loss',
help='the prefix of the model to save')
parser.add_argument('--num-epochs', type=int, default=20,
help='the number of training epochs')
parser.add_argument('--load-epoch', type=int,
help="load the model on an epoch using the model-prefix")
parser.add_argument('--kv-store', type=str, default='local',
help='the kvstore type')
parser.add_argument('--log_file', type=str, default='log.txt',
help='log file')
parser.add_argument('--log_dir', type=str, default='.',
help='log dir')
args = parser.parse_args()

# mnist input shape
data_shape = (1,28,28)

def get_symbol(batchsize=64):
"""
LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
Haffner. "Gradient-based learning applied to document recognition."
Proceedings of the IEEE (1998)
"""
data = mx.symbol.Variable('data')

softmax_label = mx.symbol.Variable('softmax_label')
center_label = mx.symbol.Variable('center_label')

# first conv
conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh")
pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",kernel=(2,2), stride=(2,2))

# second conv
conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh")
pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",kernel=(2,2), stride=(2,2))

# first fullc
flatten = mx.symbol.Flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh")

embedding = mx.symbol.FullyConnected(data=tanh3, num_hidden=2, name='embedding')

# second fullc
fc2 = mx.symbol.FullyConnected(data=embedding, num_hidden=10, name='fc2')


ce_loss = mx.symbol.SoftmaxOutput(data=fc2, label=softmax_label, name='softmax')
center_loss_ = mx.symbol.Custom(data=fc2, label=center_label, name='center_loss_', op_type='centerloss',\
num_class=10, alpha=0.5, scale=0.01, batchsize=batchsize)
center_loss = mx.symbol.MakeLoss(name='center_loss', data=center_loss_)
mlp = mx.symbol.Group([ce_loss, center_loss])

return mlp

def main():
batchsize = args.batch_size if args.gpus is '' else \
args.batch_size / len(args.gpus.split(','))
print 'batchsize is ', batchsize

# define network structure
net = get_symbol(batchsize)

# load data
train, val = mnist_iterator(batch_size=args.batch_size, input_shape=data_shape)

# train
print 'training model ...'
train_model.fit(args, net, (train, val), data_shape)

if __name__ == "__main__":
main()
Loading

0 comments on commit 93fb66c

Please sign in to comment.