Skip to content

Commit

Permalink
auto encoder example
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Nov 7, 2015
1 parent a8bafe8 commit 729e2b5
Show file tree
Hide file tree
Showing 4 changed files with 380 additions and 0 deletions.
182 changes: 182 additions & 0 deletions example/autoencoder/autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# pylint: skip-file
import mxnet as mx
import numpy as np
import data
import model
import logging
from solver import Solver, Monitor
try:
import cPickle as pickle
except:
import pickle

class AutoEncoderModel(model.MXModel):
def setup(self, dims, pt_dropout=None, ft_dropout=None, input_act=None, internal_act='relu', output_act=None):
self.N = len(dims) - 1
self.dims = dims
self.stacks = []
self.pt_dropout = pt_dropout
self.ft_dropout = ft_dropout
self.input_act = input_act
self.internal_act = internal_act
self.output_act = output_act

self.data = mx.symbol.Variable('data')
for i in range(self.N):
if i == 0:
decoder_act = input_act
idropout = None
else:
decoder_act = internal_act
idropout = pt_dropout
if i == self.N-1:
encoder_act = output_act
odropout = None
else:
encoder_act = internal_act
odropout = pt_dropout
istack, iargs, iargs_grad, iargs_mult = self.make_stack(i, self.data, dims[i], dims[i+1],
idropout, odropout, encoder_act, decoder_act)
self.stacks.append(istack)
self.args.update(iargs)
self.args_grad.update(iargs_grad)
self.args_mult.update(iargs_mult)

self.encoder, self.internals = self.make_encoder(self.data, dims, ft_dropout, internal_act, output_act)
self.decoder = self.make_decoder(self.encoder, dims, ft_dropout, internal_act, input_act)
if input_act == 'softmax':
self.loss = self.decoder
else:
self.loss = mx.symbol.LinearRegressionOutput(data=self.decoder, label=self.data)

def make_stack(self, istack, data, num_input, num_hidden, idropout=None,
odropout=None, encoder_act='relu', decoder_act='relu'):
x = data
if idropout:
x = mx.symbol.Dropout(data=x, p=idropout)
x = mx.symbol.FullyConnected(name='encoder_%d'%istack, data=x, num_hidden=num_hidden)
if encoder_act:
x = mx.symbol.Activation(data=x, act_type=encoder_act)
if odropout:
x = mx.symbol.Dropout(data=x, p=odropout)
x = mx.symbol.FullyConnected(name='decoder_%d'%istack, data=x, num_hidden=num_input)
if decoder_act == 'softmax':
x = mx.symbol.Softmax(data=x, label=data, prob_label=True, act_type=decoder_act)
elif decoder_act:
x = mx.symbol.Activation(data=x, act_type=decoder_act)
x = mx.symbol.LinearRegressionOutput(data=x, label=data)
else:
x = mx.symbol.LinearRegressionOutput(data=x, label=data)

args = {'encoder_%d_weight'%istack: mx.nd.empty((num_hidden, num_input), self.xpu),
'encoder_%d_bias'%istack: mx.nd.empty((num_hidden,), self.xpu),
'decoder_%d_weight'%istack: mx.nd.empty((num_input, num_hidden), self.xpu),
'decoder_%d_bias'%istack: mx.nd.empty((num_input,), self.xpu),}
args_grad = {'encoder_%d_weight'%istack: mx.nd.empty((num_hidden, num_input), self.xpu),
'encoder_%d_bias'%istack: mx.nd.empty((num_hidden,), self.xpu),
'decoder_%d_weight'%istack: mx.nd.empty((num_input, num_hidden), self.xpu),
'decoder_%d_bias'%istack: mx.nd.empty((num_input,), self.xpu),}
args_mult = {'encoder_%d_weight'%istack: 1.0,
'encoder_%d_bias'%istack: 2.0,
'decoder_%d_weight'%istack: 1.0,
'decoder_%d_bias'%istack: 2.0,}
init = mx.initializer.Normal(0.01)
for k,v in args.items():
init(k,v)

return x, args, args_grad, args_mult

def make_encoder(self, data, dims, dropout=None, internal_act='relu', output_act=None):
x = data
internals = []
N = len(dims) - 1
for i in range(N):
x = mx.symbol.FullyConnected(name='encoder_%d'%i, data=x, num_hidden=dims[i+1])
if internal_act and i < N-1:
x = mx.symbol.Activation(data=x, act_type=internal_act)
elif output_act and i == N-1:
x = mx.symbol.Activation(data=x, act_type=output_act)
if dropout:
x = mx.symbol.Dropout(data=x, p=dropout)
internals.append(x)
return x, internals

def make_decoder(self, feature, dims, dropout=None, internal_act='relu', input_act=None):
x = feature
N = len(dims) - 1
for i in reversed(range(N)):
x = mx.symbol.FullyConnected(name='decoder_%d'%i, data=x, num_hidden=dims[i])
if internal_act and i > 0:
x = mx.symbol.Activation(data=x, act_type=internal_act)
elif input_act and i == 0:
x = mx.symbol.Activation(data=x, act_type=input_act)
if dropout and i > 0:
x = mx.symbol.Dropout(data=x, p = dropout)
return x

def layerwise_pretrain(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None):
def l2_norm(label, pred):
return np.mean(np.square(label-pred))/2.0
solver = Solver('sgd', momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler)
solver.set_metric(mx.metric.CustomMetric(l2_norm))
solver.set_monitor(Monitor(1000))
data_iter = mx.io.NDArrayIter([X], batch_size=batch_size, shuffle=False,
last_batch_handle='roll_over')
for i in range(self.N):
if i == 0:
data_iter_i = data_iter
else:
X_i = model.extract_feature(self.internals[i-1], self.args, ['data'],
data_iter, X.shape[0], self.xpu).values()[0]
data_iter_i = mx.io.NDArrayIter([X_i], batch_size=batch_size,
last_batch_handle='roll_over')
solver.solve(self.xpu, self.stacks[i], self.args, self.args_grad, ['data'], data_iter_i,
0, n_iter, self.args_mult)

def finetune(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None):
def l2_norm(label, pred):
return np.mean(np.square(label-pred))/2.0
solver = Solver('sgd', momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler)
solver.set_metric(mx.metric.CustomMetric(l2_norm))
solver.set_monitor(Monitor(1000))
data_iter = mx.io.NDArrayIter([X], batch_size=batch_size, shuffle=False,
last_batch_handle='roll_over')
solver.solve(self.xpu, self.loss, self.args, self.args_grad, ['data'], data_iter,
0, n_iter, self.args_mult)

def eval(self, X):
batch_size = 100
data_iter = mx.io.NDArrayIter([X], batch_size=batch_size, shuffle=False,
last_batch_handle='pad')
Y = model.extract_feature(self.loss, self.args, ['data'], data_iter,
X.shape[0], self.xpu).values()[0]
return np.mean(np.square(Y-X))/2.0



if __name__ == '__main__':
# set to INFO to see less information during training
logging.basicConfig(level=logging.DEBUG)
ae_model = AutoEncoderModel(mx.gpu(0), [784,500,500,2000,10], pt_dropout=0.2)

X, _ = data.get_mnist()
train_X = X[:60000]
val_X = X[60000:]

ae_model.layerwise_pretrain(train_X, 256, 50000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
ae_model.finetune(train_X, 256, 100000, 'sgd', l_rate=0.1, decay=0.0,
lr_scheduler=mx.misc.FactorScheduler(20000,0.1))
ae_model.save('mnist_pt.arg')
ae_model.load('mnist_pt.arg')
print "Training error:", ae_model.eval(train_X)
print "Validation error:", ae_model.eval(val_X)









10 changes: 10 additions & 0 deletions example/autoencoder/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import numpy as np
from sklearn.datasets import fetch_mldata

def get_mnist():
np.random.seed(1234) # set seed for deterministic ordering
mnist = fetch_mldata('MNIST original', data_home='../../data')
p = np.random.permutation(mnist.data.shape[0])
X = mnist.data[p].astype(np.float32)*0.02
Y = mnist.target[p]
return X, Y
59 changes: 59 additions & 0 deletions example/autoencoder/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# pylint: skip-file
import mxnet as mx
import numpy as np
import logging
from solver import Solver, Monitor
try:
import cPickle as pickle
except:
import pickle


def extract_feature(sym, args, input_names, data_iter, N, xpu=mx.cpu()):
data_iter.reset()
input_buffs = [mx.nd.empty(i.shape, ctx=xpu) for i in data_iter.next()]
args = dict(args, **dict(zip(input_names, input_buffs)))
exe = sym.bind(xpu, args=args)
outputs = [[] for i in exe.outputs]
output_buffs = None

data_iter.hard_reset()
for datas in data_iter:
for data, buff in zip(datas, input_buffs):
data.copyto(buff)
exe.forward(is_train=False)
if output_buffs is None:
output_buffs = [mx.nd.empty(i.shape, ctx=mx.cpu()) for i in exe.outputs]
else:
for out, buff in zip(outputs, output_buffs):
out.append(buff.asnumpy())
for out, buff in zip(exe.outputs, output_buffs):
out.copyto(buff)
for out, buff in zip(outputs, output_buffs):
out.append(buff.asnumpy())
outputs = [np.concatenate(i, axis=0)[:N] for i in outputs]
return dict(zip(sym.list_outputs(), outputs))

class MXModel(object):
def __init__(self, xpu=mx.cpu(), *args, **kwargs):
self.xpu = xpu
self.loss = None
self.args = {}
self.args_grad = {}
self.args_mult = {}
self.setup(*args, **kwargs)

def save(self, fname):
args_save = {key: v.asnumpy() for key, v in self.args.items()}
with open(fname, 'w') as fout:
pickle.dump(args_save, fout)

def load(self, fname):
with open(fname) as fin:
args_save = pickle.load(fin)
for key, v in args_save.items():
if key in self.args:
self.args[key][:] = v

def setup(self, *args, **kwargs):
raise NotImplementedError("must override this")
129 changes: 129 additions & 0 deletions example/autoencoder/solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# pylint: skip-file
import mxnet as mx
import numpy as np
import logging

class Monitor(object):
def __init__(self, interval, level=logging.DEBUG, stat=None):
self.interval = interval
self.level = level
if stat is None:
def mean_abs(x):
return np.fabs(x).mean()
self.stat = mean_abs
else:
self.stat = stat

def forward_end(self, i, internals):
if i%self.interval == 0 and logging.getLogger().isEnabledFor(self.level):
for key in sorted(internals.keys()):
arr = internals[key]
logging.log(self.level, 'iter:%d param:%s\t\tstat(%s):%s'%(i, key, self.stat.__name__, str(self.stat(arr.asnumpy()))))

def backward_end(self, i, weights, grads, metric=None):
if i%self.interval == 0 and logging.getLogger().isEnabledFor(self.level):
for key in sorted(grads.keys()):
arr = grads[key]
logging.log(self.level, 'iter:%d param:%s\t\tstat(%s):%s\t\tgrad_stat:%s'%(i, key, self.stat.__name__, str(self.stat(weights[key].asnumpy())), str(self.stat(arr.asnumpy()))))
if metric is not None:
logging.info('Iter:%d metric:%f'%(i, metric.get()[1]))

class Solver(object):
def __init__(self, optimizer, **kwargs):
if isinstance(optimizer, str):
self.optimizer = mx.optimizer.create(optimizer, **kwargs)
else:
self.optimizer = optimizer
self.updater = mx.optimizer.get_updater(self.optimizer)
self.monitor = None
self.metric = None
self.iter_end_callback = None
self.iter_start_callback = None

def set_metric(self, metric):
self.metric = metric

def set_monitor(self, monitor):
self.monitor = monitor

def set_iter_end_callback(self, callback):
self.iter_end_callback = callback

def set_iter_start_callback(self, callback):
self.iter_start_callback = callback

def solve(self, xpu, sym, args, args_grad, input_names,
data_iter, begin_epoch, end_epoch, debug = False, args_lrmult=None):
if args_lrmult is None:
args_lrmult = {}

data_iter.reset()
input_dict = {key: mx.nd.empty(arr.shape, ctx=xpu) for key, arr in zip(input_names, data_iter.next())}
batch_size = input_dict.values()[0].shape[0]
self.optimizer.rescale_grad = 1.0/batch_size
args = dict(args, **input_dict)

output_names = sym.list_outputs()
if debug:
sym = sym.get_internals()
blob_names = sym.list_outputs()
sym_group = []
for i in range(len(blob_names)):
if blob_names[i] not in args:
x = sym[i]
if blob_names[i] not in output_names:
x = mx.symbol.BlockGrad(x, name=blob_names[i])
sym_group.append(x)
sym = mx.symbol.Group(sym_group)
exe = sym.bind(xpu, args=args, args_grad=args_grad)

update_dict = {name: args_grad[name] for name, nd in zip(sym.list_arguments(), exe.grad_arrays) if nd}

output_dict = {}
output_buff = {}
internal_dict = {}
for key, arr in zip(sym.list_outputs(), exe.outputs):
if key in output_names:
output_dict[key] = arr
output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu())
else:
internal_dict[key] = arr

data_iter.reset()
for i in range(begin_epoch, end_epoch):
if self.iter_start_callback is not None:
self.iter_start_callback(i)
try:
data_list = data_iter.next()
except:
data_iter.reset()
for data, key in zip(data_list, input_names):
data.copyto(input_dict[key])
exe.forward(is_train=True)
if self.monitor is not None:
self.monitor.forward_end(i, internal_dict)
for key in output_dict:
output_dict[key].copyto(output_buff[key])

exe.backward()
self.optimizer.begin_epoch(i)
for key, arr in update_dict.items():
self.updater(key, arr, args[key], args_lrmult.get(key, 1.0))

exe.outputs[0].wait_to_read()
if self.metric is not None:
self.metric.update(input_dict[input_names[-1]].asnumpy(),
output_buff[output_names[0]].asnumpy())

if self.monitor is not None:
self.monitor.backward_end(i, args, update_dict, self.metric)

if self.iter_end_callback is not None:
self.iter_end_callback(i)







0 comments on commit 729e2b5

Please sign in to comment.