forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
380 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# pylint: skip-file | ||
import mxnet as mx | ||
import numpy as np | ||
import data | ||
import model | ||
import logging | ||
from solver import Solver, Monitor | ||
try: | ||
import cPickle as pickle | ||
except: | ||
import pickle | ||
|
||
class AutoEncoderModel(model.MXModel): | ||
def setup(self, dims, pt_dropout=None, ft_dropout=None, input_act=None, internal_act='relu', output_act=None): | ||
self.N = len(dims) - 1 | ||
self.dims = dims | ||
self.stacks = [] | ||
self.pt_dropout = pt_dropout | ||
self.ft_dropout = ft_dropout | ||
self.input_act = input_act | ||
self.internal_act = internal_act | ||
self.output_act = output_act | ||
|
||
self.data = mx.symbol.Variable('data') | ||
for i in range(self.N): | ||
if i == 0: | ||
decoder_act = input_act | ||
idropout = None | ||
else: | ||
decoder_act = internal_act | ||
idropout = pt_dropout | ||
if i == self.N-1: | ||
encoder_act = output_act | ||
odropout = None | ||
else: | ||
encoder_act = internal_act | ||
odropout = pt_dropout | ||
istack, iargs, iargs_grad, iargs_mult = self.make_stack(i, self.data, dims[i], dims[i+1], | ||
idropout, odropout, encoder_act, decoder_act) | ||
self.stacks.append(istack) | ||
self.args.update(iargs) | ||
self.args_grad.update(iargs_grad) | ||
self.args_mult.update(iargs_mult) | ||
|
||
self.encoder, self.internals = self.make_encoder(self.data, dims, ft_dropout, internal_act, output_act) | ||
self.decoder = self.make_decoder(self.encoder, dims, ft_dropout, internal_act, input_act) | ||
if input_act == 'softmax': | ||
self.loss = self.decoder | ||
else: | ||
self.loss = mx.symbol.LinearRegressionOutput(data=self.decoder, label=self.data) | ||
|
||
def make_stack(self, istack, data, num_input, num_hidden, idropout=None, | ||
odropout=None, encoder_act='relu', decoder_act='relu'): | ||
x = data | ||
if idropout: | ||
x = mx.symbol.Dropout(data=x, p=idropout) | ||
x = mx.symbol.FullyConnected(name='encoder_%d'%istack, data=x, num_hidden=num_hidden) | ||
if encoder_act: | ||
x = mx.symbol.Activation(data=x, act_type=encoder_act) | ||
if odropout: | ||
x = mx.symbol.Dropout(data=x, p=odropout) | ||
x = mx.symbol.FullyConnected(name='decoder_%d'%istack, data=x, num_hidden=num_input) | ||
if decoder_act == 'softmax': | ||
x = mx.symbol.Softmax(data=x, label=data, prob_label=True, act_type=decoder_act) | ||
elif decoder_act: | ||
x = mx.symbol.Activation(data=x, act_type=decoder_act) | ||
x = mx.symbol.LinearRegressionOutput(data=x, label=data) | ||
else: | ||
x = mx.symbol.LinearRegressionOutput(data=x, label=data) | ||
|
||
args = {'encoder_%d_weight'%istack: mx.nd.empty((num_hidden, num_input), self.xpu), | ||
'encoder_%d_bias'%istack: mx.nd.empty((num_hidden,), self.xpu), | ||
'decoder_%d_weight'%istack: mx.nd.empty((num_input, num_hidden), self.xpu), | ||
'decoder_%d_bias'%istack: mx.nd.empty((num_input,), self.xpu),} | ||
args_grad = {'encoder_%d_weight'%istack: mx.nd.empty((num_hidden, num_input), self.xpu), | ||
'encoder_%d_bias'%istack: mx.nd.empty((num_hidden,), self.xpu), | ||
'decoder_%d_weight'%istack: mx.nd.empty((num_input, num_hidden), self.xpu), | ||
'decoder_%d_bias'%istack: mx.nd.empty((num_input,), self.xpu),} | ||
args_mult = {'encoder_%d_weight'%istack: 1.0, | ||
'encoder_%d_bias'%istack: 2.0, | ||
'decoder_%d_weight'%istack: 1.0, | ||
'decoder_%d_bias'%istack: 2.0,} | ||
init = mx.initializer.Normal(0.01) | ||
for k,v in args.items(): | ||
init(k,v) | ||
|
||
return x, args, args_grad, args_mult | ||
|
||
def make_encoder(self, data, dims, dropout=None, internal_act='relu', output_act=None): | ||
x = data | ||
internals = [] | ||
N = len(dims) - 1 | ||
for i in range(N): | ||
x = mx.symbol.FullyConnected(name='encoder_%d'%i, data=x, num_hidden=dims[i+1]) | ||
if internal_act and i < N-1: | ||
x = mx.symbol.Activation(data=x, act_type=internal_act) | ||
elif output_act and i == N-1: | ||
x = mx.symbol.Activation(data=x, act_type=output_act) | ||
if dropout: | ||
x = mx.symbol.Dropout(data=x, p=dropout) | ||
internals.append(x) | ||
return x, internals | ||
|
||
def make_decoder(self, feature, dims, dropout=None, internal_act='relu', input_act=None): | ||
x = feature | ||
N = len(dims) - 1 | ||
for i in reversed(range(N)): | ||
x = mx.symbol.FullyConnected(name='decoder_%d'%i, data=x, num_hidden=dims[i]) | ||
if internal_act and i > 0: | ||
x = mx.symbol.Activation(data=x, act_type=internal_act) | ||
elif input_act and i == 0: | ||
x = mx.symbol.Activation(data=x, act_type=input_act) | ||
if dropout and i > 0: | ||
x = mx.symbol.Dropout(data=x, p = dropout) | ||
return x | ||
|
||
def layerwise_pretrain(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None): | ||
def l2_norm(label, pred): | ||
return np.mean(np.square(label-pred))/2.0 | ||
solver = Solver('sgd', momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler) | ||
solver.set_metric(mx.metric.CustomMetric(l2_norm)) | ||
solver.set_monitor(Monitor(1000)) | ||
data_iter = mx.io.NDArrayIter([X], batch_size=batch_size, shuffle=False, | ||
last_batch_handle='roll_over') | ||
for i in range(self.N): | ||
if i == 0: | ||
data_iter_i = data_iter | ||
else: | ||
X_i = model.extract_feature(self.internals[i-1], self.args, ['data'], | ||
data_iter, X.shape[0], self.xpu).values()[0] | ||
data_iter_i = mx.io.NDArrayIter([X_i], batch_size=batch_size, | ||
last_batch_handle='roll_over') | ||
solver.solve(self.xpu, self.stacks[i], self.args, self.args_grad, ['data'], data_iter_i, | ||
0, n_iter, self.args_mult) | ||
|
||
def finetune(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None): | ||
def l2_norm(label, pred): | ||
return np.mean(np.square(label-pred))/2.0 | ||
solver = Solver('sgd', momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler) | ||
solver.set_metric(mx.metric.CustomMetric(l2_norm)) | ||
solver.set_monitor(Monitor(1000)) | ||
data_iter = mx.io.NDArrayIter([X], batch_size=batch_size, shuffle=False, | ||
last_batch_handle='roll_over') | ||
solver.solve(self.xpu, self.loss, self.args, self.args_grad, ['data'], data_iter, | ||
0, n_iter, self.args_mult) | ||
|
||
def eval(self, X): | ||
batch_size = 100 | ||
data_iter = mx.io.NDArrayIter([X], batch_size=batch_size, shuffle=False, | ||
last_batch_handle='pad') | ||
Y = model.extract_feature(self.loss, self.args, ['data'], data_iter, | ||
X.shape[0], self.xpu).values()[0] | ||
return np.mean(np.square(Y-X))/2.0 | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
# set to INFO to see less information during training | ||
logging.basicConfig(level=logging.DEBUG) | ||
ae_model = AutoEncoderModel(mx.gpu(0), [784,500,500,2000,10], pt_dropout=0.2) | ||
|
||
X, _ = data.get_mnist() | ||
train_X = X[:60000] | ||
val_X = X[60000:] | ||
|
||
ae_model.layerwise_pretrain(train_X, 256, 50000, 'sgd', l_rate=0.1, decay=0.0, | ||
lr_scheduler=mx.misc.FactorScheduler(20000,0.1)) | ||
ae_model.finetune(train_X, 256, 100000, 'sgd', l_rate=0.1, decay=0.0, | ||
lr_scheduler=mx.misc.FactorScheduler(20000,0.1)) | ||
ae_model.save('mnist_pt.arg') | ||
ae_model.load('mnist_pt.arg') | ||
print "Training error:", ae_model.eval(train_X) | ||
print "Validation error:", ae_model.eval(val_X) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import numpy as np | ||
from sklearn.datasets import fetch_mldata | ||
|
||
def get_mnist(): | ||
np.random.seed(1234) # set seed for deterministic ordering | ||
mnist = fetch_mldata('MNIST original', data_home='../../data') | ||
p = np.random.permutation(mnist.data.shape[0]) | ||
X = mnist.data[p].astype(np.float32)*0.02 | ||
Y = mnist.target[p] | ||
return X, Y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# pylint: skip-file | ||
import mxnet as mx | ||
import numpy as np | ||
import logging | ||
from solver import Solver, Monitor | ||
try: | ||
import cPickle as pickle | ||
except: | ||
import pickle | ||
|
||
|
||
def extract_feature(sym, args, input_names, data_iter, N, xpu=mx.cpu()): | ||
data_iter.reset() | ||
input_buffs = [mx.nd.empty(i.shape, ctx=xpu) for i in data_iter.next()] | ||
args = dict(args, **dict(zip(input_names, input_buffs))) | ||
exe = sym.bind(xpu, args=args) | ||
outputs = [[] for i in exe.outputs] | ||
output_buffs = None | ||
|
||
data_iter.hard_reset() | ||
for datas in data_iter: | ||
for data, buff in zip(datas, input_buffs): | ||
data.copyto(buff) | ||
exe.forward(is_train=False) | ||
if output_buffs is None: | ||
output_buffs = [mx.nd.empty(i.shape, ctx=mx.cpu()) for i in exe.outputs] | ||
else: | ||
for out, buff in zip(outputs, output_buffs): | ||
out.append(buff.asnumpy()) | ||
for out, buff in zip(exe.outputs, output_buffs): | ||
out.copyto(buff) | ||
for out, buff in zip(outputs, output_buffs): | ||
out.append(buff.asnumpy()) | ||
outputs = [np.concatenate(i, axis=0)[:N] for i in outputs] | ||
return dict(zip(sym.list_outputs(), outputs)) | ||
|
||
class MXModel(object): | ||
def __init__(self, xpu=mx.cpu(), *args, **kwargs): | ||
self.xpu = xpu | ||
self.loss = None | ||
self.args = {} | ||
self.args_grad = {} | ||
self.args_mult = {} | ||
self.setup(*args, **kwargs) | ||
|
||
def save(self, fname): | ||
args_save = {key: v.asnumpy() for key, v in self.args.items()} | ||
with open(fname, 'w') as fout: | ||
pickle.dump(args_save, fout) | ||
|
||
def load(self, fname): | ||
with open(fname) as fin: | ||
args_save = pickle.load(fin) | ||
for key, v in args_save.items(): | ||
if key in self.args: | ||
self.args[key][:] = v | ||
|
||
def setup(self, *args, **kwargs): | ||
raise NotImplementedError("must override this") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# pylint: skip-file | ||
import mxnet as mx | ||
import numpy as np | ||
import logging | ||
|
||
class Monitor(object): | ||
def __init__(self, interval, level=logging.DEBUG, stat=None): | ||
self.interval = interval | ||
self.level = level | ||
if stat is None: | ||
def mean_abs(x): | ||
return np.fabs(x).mean() | ||
self.stat = mean_abs | ||
else: | ||
self.stat = stat | ||
|
||
def forward_end(self, i, internals): | ||
if i%self.interval == 0 and logging.getLogger().isEnabledFor(self.level): | ||
for key in sorted(internals.keys()): | ||
arr = internals[key] | ||
logging.log(self.level, 'iter:%d param:%s\t\tstat(%s):%s'%(i, key, self.stat.__name__, str(self.stat(arr.asnumpy())))) | ||
|
||
def backward_end(self, i, weights, grads, metric=None): | ||
if i%self.interval == 0 and logging.getLogger().isEnabledFor(self.level): | ||
for key in sorted(grads.keys()): | ||
arr = grads[key] | ||
logging.log(self.level, 'iter:%d param:%s\t\tstat(%s):%s\t\tgrad_stat:%s'%(i, key, self.stat.__name__, str(self.stat(weights[key].asnumpy())), str(self.stat(arr.asnumpy())))) | ||
if metric is not None: | ||
logging.info('Iter:%d metric:%f'%(i, metric.get()[1])) | ||
|
||
class Solver(object): | ||
def __init__(self, optimizer, **kwargs): | ||
if isinstance(optimizer, str): | ||
self.optimizer = mx.optimizer.create(optimizer, **kwargs) | ||
else: | ||
self.optimizer = optimizer | ||
self.updater = mx.optimizer.get_updater(self.optimizer) | ||
self.monitor = None | ||
self.metric = None | ||
self.iter_end_callback = None | ||
self.iter_start_callback = None | ||
|
||
def set_metric(self, metric): | ||
self.metric = metric | ||
|
||
def set_monitor(self, monitor): | ||
self.monitor = monitor | ||
|
||
def set_iter_end_callback(self, callback): | ||
self.iter_end_callback = callback | ||
|
||
def set_iter_start_callback(self, callback): | ||
self.iter_start_callback = callback | ||
|
||
def solve(self, xpu, sym, args, args_grad, input_names, | ||
data_iter, begin_epoch, end_epoch, debug = False, args_lrmult=None): | ||
if args_lrmult is None: | ||
args_lrmult = {} | ||
|
||
data_iter.reset() | ||
input_dict = {key: mx.nd.empty(arr.shape, ctx=xpu) for key, arr in zip(input_names, data_iter.next())} | ||
batch_size = input_dict.values()[0].shape[0] | ||
self.optimizer.rescale_grad = 1.0/batch_size | ||
args = dict(args, **input_dict) | ||
|
||
output_names = sym.list_outputs() | ||
if debug: | ||
sym = sym.get_internals() | ||
blob_names = sym.list_outputs() | ||
sym_group = [] | ||
for i in range(len(blob_names)): | ||
if blob_names[i] not in args: | ||
x = sym[i] | ||
if blob_names[i] not in output_names: | ||
x = mx.symbol.BlockGrad(x, name=blob_names[i]) | ||
sym_group.append(x) | ||
sym = mx.symbol.Group(sym_group) | ||
exe = sym.bind(xpu, args=args, args_grad=args_grad) | ||
|
||
update_dict = {name: args_grad[name] for name, nd in zip(sym.list_arguments(), exe.grad_arrays) if nd} | ||
|
||
output_dict = {} | ||
output_buff = {} | ||
internal_dict = {} | ||
for key, arr in zip(sym.list_outputs(), exe.outputs): | ||
if key in output_names: | ||
output_dict[key] = arr | ||
output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu()) | ||
else: | ||
internal_dict[key] = arr | ||
|
||
data_iter.reset() | ||
for i in range(begin_epoch, end_epoch): | ||
if self.iter_start_callback is not None: | ||
self.iter_start_callback(i) | ||
try: | ||
data_list = data_iter.next() | ||
except: | ||
data_iter.reset() | ||
for data, key in zip(data_list, input_names): | ||
data.copyto(input_dict[key]) | ||
exe.forward(is_train=True) | ||
if self.monitor is not None: | ||
self.monitor.forward_end(i, internal_dict) | ||
for key in output_dict: | ||
output_dict[key].copyto(output_buff[key]) | ||
|
||
exe.backward() | ||
self.optimizer.begin_epoch(i) | ||
for key, arr in update_dict.items(): | ||
self.updater(key, arr, args[key], args_lrmult.get(key, 1.0)) | ||
|
||
exe.outputs[0].wait_to_read() | ||
if self.metric is not None: | ||
self.metric.update(input_dict[input_names[-1]].asnumpy(), | ||
output_buff[output_names[0]].asnumpy()) | ||
|
||
if self.monitor is not None: | ||
self.monitor.backward_end(i, args, update_dict, self.metric) | ||
|
||
if self.iter_end_callback is not None: | ||
self.iter_end_callback(i) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|