forked from YYuanAnyVision/mxnet_center_loss
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
yuanyang
committed
Oct 10, 2016
0 parents
commit 93fb66c
Showing
7 changed files
with
519 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
result | ||
.DS_Store | ||
data | ||
*.pyc | ||
log.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# center_loss_mxnet | ||
# mxnet_center_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import os | ||
|
||
# MXNET_CPU_WORKER_NTHREADS must be greater than 1 for custom op to work on CPU | ||
os.environ['MXNET_CPU_WORKER_NTHREADS'] = '2' | ||
import mxnet as mx | ||
|
||
# define metric of accuracy | ||
class accuracy(mx.metric.EvalMetric): | ||
def __init__(self, num=None): | ||
super(accuracy, self).__init__('accuracy', num) | ||
|
||
def update(self, labels, preds): | ||
mx.metric.check_label_shapes(labels, preds) | ||
|
||
if self.num != None: | ||
assert len(labels) == self.num | ||
|
||
pred_label = mx.nd.argmax_channel(preds[0]).asnumpy().astype('int32') | ||
label = labels[0].asnumpy().astype('int32') | ||
|
||
mx.metric.check_label_shapes(label, pred_label) | ||
|
||
self.sum_metric += (pred_label.flat == label.flat).sum() | ||
self.num_inst += len(pred_label.flat) | ||
|
||
|
||
# define some metric of center_loss | ||
class center_loss_metric(mx.metric.EvalMetric): | ||
def __init__(self): | ||
super(center_loss_metric, self).__init__('center_loss') | ||
|
||
def update(self, labels, preds): | ||
self.sum_metric = + preds[1].asnumpy()[0] | ||
self.num_inst = 1 | ||
|
||
|
||
# see details: | ||
# <A Discriminative Feature Learning Approach for Deep Face Recogfnition> | ||
class CenterLoss(mx.operator.CustomOp): | ||
def __init__(self, ctx, shapes, dtypes, num_class, alpha, scale=1.0): | ||
if not len(shapes[0]) == 2: | ||
raise ValueError('dim for input_data shoudl be 2 for CenterLoss') | ||
|
||
self.alpha = alpha | ||
self.batch_size = shapes[0][0] | ||
self.num_class = num_class | ||
self.scale = scale | ||
|
||
def forward(self, is_train, req, in_data, out_data, aux): | ||
# can not access ndarray using (i,j) | ||
labels = in_data[1].asnumpy() | ||
diff = aux[0] | ||
center = aux[1] | ||
|
||
# store x_i - c_yi | ||
for i in range(self.batch_size): | ||
diff[i] = in_data[0][i] - center[int(labels[i])] | ||
|
||
loss = mx.nd.sum(mx.nd.square(diff)) / self.batch_size / 2 | ||
self.assign(out_data[0], req[0], loss) | ||
|
||
def backward(self, req, out_grad, in_data, out_data, in_grad, aux): | ||
diff = aux[0] | ||
center = aux[1] | ||
sum_ = aux[2] | ||
|
||
# back grad is just scale * ( x_i - c_yi) | ||
grad_scale = float(self.scale) | ||
self.assign(in_grad[0], req[0], diff * grad_scale) | ||
|
||
# update the center | ||
labels = in_data[1].asnumpy() | ||
label_occur = dict() | ||
for i, label in enumerate(labels): | ||
label_occur.setdefault(int(label), []).append(i) | ||
|
||
for label, sample_index in label_occur.items(): | ||
sum_[:] = 0 | ||
for i in sample_index: | ||
sum_ = sum_ + diff[i] | ||
delta_c = sum_ / (1 + len(sample_index)) | ||
center[label] += self.alpha * delta_c | ||
|
||
|
||
@mx.operator.register("centerloss") | ||
class CenterLossProp(mx.operator.CustomOpProp): | ||
def __init__(self, num_class, alpha, scale=1.0, batchsize=64): | ||
super(CenterLossProp, self).__init__(need_top_grad=False) | ||
|
||
# convert it to numbers | ||
self.num_class = int(num_class) | ||
self.alpha = float(alpha) | ||
self.scale = float(scale) | ||
self.batchsize = int(batchsize) | ||
|
||
def list_arguments(self): | ||
return ['data', 'label'] | ||
|
||
def list_outputs(self): | ||
return ['output'] | ||
|
||
def list_auxiliary_states(self): | ||
# call them jsut bias for zero initialization | ||
return ['diff_bias', 'center_bias', 'sum_bias'] | ||
|
||
def infer_shape(self, in_shape): | ||
data_shape = in_shape[0] | ||
label_shape = (in_shape[0][0],) | ||
|
||
# store diff , same shape as input batch | ||
diff_shape = [self.batchsize, data_shape[1]] | ||
|
||
# store the center of each class , should be ( num_class, d ) | ||
center_shape = [self.num_class, diff_shape[1]] | ||
|
||
# computation buf | ||
sum_shape = [diff_shape[1],] | ||
|
||
output_shape = [1, ] | ||
return [data_shape, label_shape], [output_shape], [diff_shape, center_shape, sum_shape] | ||
|
||
def create_operator(self, ctx, shapes, dtypes): | ||
return CenterLoss(ctx, shapes, dtypes, self.num_class, self.alpha, self.scale) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# pylint: skip-file | ||
""" data iterator for mnist """ | ||
import sys | ||
import os | ||
# code to automatically download dataset | ||
sys.path.append(os.path.join("/Users/yuanyang/Build/mxnet/tests/python/common")) | ||
import get_data | ||
import mxnet as mx | ||
|
||
|
||
class custom_mnist_iter(mx.io.DataIter): | ||
def __init__(self, mnist_iter): | ||
super(custom_mnist_iter,self).__init__() | ||
self.data_iter = mnist_iter | ||
self.batch_size = self.data_iter.batch_size | ||
|
||
@property | ||
def provide_data(self): | ||
return self.data_iter.provide_data | ||
|
||
@property | ||
def provide_label(self): | ||
provide_label = self.data_iter.provide_label[0] | ||
# Different labels should be used here for actual application | ||
return [('softmax_label', provide_label[1]), \ | ||
('center_label', provide_label[1])] | ||
|
||
def hard_reset(self): | ||
self.data_iter.hard_reset() | ||
|
||
def reset(self): | ||
self.data_iter.reset() | ||
|
||
def next(self): | ||
batch = self.data_iter.next() | ||
label = batch.label[0] | ||
|
||
return mx.io.DataBatch(data=batch.data, label=[label,label], \ | ||
pad=batch.pad, index=batch.index) | ||
|
||
|
||
|
||
def mnist_iterator(batch_size, input_shape): | ||
"""return train and val iterators for mnist""" | ||
# download data | ||
get_data.GetMNIST_ubyte() | ||
flat = False if len(input_shape) == 3 else True | ||
|
||
train_dataiter = mx.io.MNISTIter( | ||
image="data/train-images-idx3-ubyte", | ||
label="data/train-labels-idx1-ubyte", | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
flat=flat) | ||
|
||
val_dataiter = mx.io.MNISTIter( | ||
image="data/t10k-images-idx3-ubyte", | ||
label="data/t10k-labels-idx1-ubyte", | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
flat=flat) | ||
|
||
return (custom_mnist_iter(train_dataiter), custom_mnist_iter(val_dataiter)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import mxnet as mx | ||
import numpy as np | ||
from center_loss import * | ||
from data import mnist_iterator | ||
import logging | ||
import train_model | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description='train mnist use softmax and centerloss') | ||
parser.add_argument('--gpus', type=str, default='2,3', | ||
help='the gpus will be used, e.g "0,1,2,3"') | ||
parser.add_argument('--batch-size', type=int, default=100, | ||
help='the batch size') | ||
parser.add_argument('--num-examples', type=int, default=60000, | ||
help='the number of training examples') | ||
parser.add_argument('--lr', type=float, default=.01, | ||
help='the initial learning rate') | ||
parser.add_argument('--lr-factor', type=float, default=0.5, | ||
help='times the lr with a factor for every lr-factor-epoch epoch') | ||
parser.add_argument('--lr-factor-epoch', type=float, default=20, | ||
help='the number of epoch to factor the lr, could be .5') | ||
parser.add_argument('--model-prefix', type=str, | ||
help='the prefix of the model to load') | ||
parser.add_argument('--save-model-prefix', type=str,default='center_loss', | ||
help='the prefix of the model to save') | ||
parser.add_argument('--num-epochs', type=int, default=20, | ||
help='the number of training epochs') | ||
parser.add_argument('--load-epoch', type=int, | ||
help="load the model on an epoch using the model-prefix") | ||
parser.add_argument('--kv-store', type=str, default='local', | ||
help='the kvstore type') | ||
parser.add_argument('--log_file', type=str, default='log.txt', | ||
help='log file') | ||
parser.add_argument('--log_dir', type=str, default='.', | ||
help='log dir') | ||
args = parser.parse_args() | ||
|
||
# mnist input shape | ||
data_shape = (1,28,28) | ||
|
||
def get_symbol(batchsize=64): | ||
""" | ||
LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick | ||
Haffner. "Gradient-based learning applied to document recognition." | ||
Proceedings of the IEEE (1998) | ||
""" | ||
data = mx.symbol.Variable('data') | ||
|
||
softmax_label = mx.symbol.Variable('softmax_label') | ||
center_label = mx.symbol.Variable('center_label') | ||
|
||
# first conv | ||
conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20) | ||
tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh") | ||
pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",kernel=(2,2), stride=(2,2)) | ||
|
||
# second conv | ||
conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50) | ||
tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh") | ||
pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",kernel=(2,2), stride=(2,2)) | ||
|
||
# first fullc | ||
flatten = mx.symbol.Flatten(data=pool2) | ||
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500) | ||
tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh") | ||
|
||
embedding = mx.symbol.FullyConnected(data=tanh3, num_hidden=2, name='embedding') | ||
|
||
# second fullc | ||
fc2 = mx.symbol.FullyConnected(data=embedding, num_hidden=10, name='fc2') | ||
|
||
|
||
ce_loss = mx.symbol.SoftmaxOutput(data=fc2, label=softmax_label, name='softmax') | ||
center_loss_ = mx.symbol.Custom(data=fc2, label=center_label, name='center_loss_', op_type='centerloss',\ | ||
num_class=10, alpha=0.5, scale=0.01, batchsize=batchsize) | ||
center_loss = mx.symbol.MakeLoss(name='center_loss', data=center_loss_) | ||
mlp = mx.symbol.Group([ce_loss, center_loss]) | ||
|
||
return mlp | ||
|
||
def main(): | ||
batchsize = args.batch_size if args.gpus is '' else \ | ||
args.batch_size / len(args.gpus.split(',')) | ||
print 'batchsize is ', batchsize | ||
|
||
# define network structure | ||
net = get_symbol(batchsize) | ||
|
||
# load data | ||
train, val = mnist_iterator(batch_size=args.batch_size, input_shape=data_shape) | ||
|
||
# train | ||
print 'training model ...' | ||
train_model.fit(args, net, (train, val), data_shape) | ||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.