Skip to content

Commit

Permalink
drcn
Browse files Browse the repository at this point in the history
  • Loading branch information
ghif committed Mar 2, 2016
1 parent 0eb8aef commit 0ecb608
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 37 deletions.
10 changes: 6 additions & 4 deletions check_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
src = 'svhn'
tgt = 'mnist'
datapair = src+'-'+tgt
model = 'drcn'
# model = 'convnet'
model = 'drcn-st'
dropout_rate = 0.5
is_aug = 0
denoising = 0.4
is_aug = 1
denoising = 0.5

RESFILE = 'results/'+datapair+'_'+model+'_results_drop%.1f_aug%d_denoise%.1f.pkl.gz' % (dropout_rate, is_aug, denoising)
# RESFILE = 'results/'+datapair+'_'+model+'_results_drop%.1f_aug%d.pkl.gz' % (dropout_rate, is_aug)
# RESFILE = 'results/'+datapair+'_'+model+'_results_drop%.1f_aug%d_h300.pkl.gz' % (dropout_rate, is_aug)

print(RESFILE)
res = pickle.load(gzip.open(RESFILE,'rb'))

Expand Down
10 changes: 7 additions & 3 deletions loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,18 @@ def load_mnist(mode=0):
train_set_x = train_set_x.reshape(train_set_x.shape[0], 1, 28, 28).astype('float32')
valid_set_x = valid_set_x.reshape(valid_set_x.shape[0], 1, 28, 28).astype('float32')
test_set_x = test_set_x.reshape(test_set_x.shape[0], 1, 28, 28).astype('float32')
train_set_y = train_set_y.astype('uint8')
valid_set_y = valid_set_y.astype('uint8')
test_set_y = test_set_y.astype('uint8')

rval = [(train_set_x, train_set_y), (valid_set_x, valid_set_y),
(test_set_x, test_set_y)]
rval = (train_set_x, train_set_y), (valid_set_x, valid_set_y), (test_set_x, test_set_y)
return rval

def load_mnist32x32(mode=0):
# dataset='I:\Data\PhD Life\Tutorial\Python\data\mnist.pkl.gz' # linux
# dataset='/u/students/gif/Desktop/PhD/Tutorial/dataset/MNIST/mnist32x32.pkl.gz' # linux
dataset = '/local/scratch/gif/dataset/MNIST/mnist32x32.pkl.gz' #the-villa
# dataset = 'I:\Data\PhD Life\Tutorial\dataset\MNIST\mnist32x32.pkl.gz' # laptop
f = gzip.open(dataset,'rb')
train_set, valid_set, test_set = pickle.load(f)
f.close()
Expand Down Expand Up @@ -222,12 +225,13 @@ def load_usps(mode=0):
train_set_x = train_set_x.reshape(train_set_x.shape[0], 1, 28, 28).astype('float32')
test_set_x = test_set_x.reshape(test_set_x.shape[0], 1, 28, 28).astype('float32')

rval = [(train_set_x, train_set_y), (test_set_x, test_set_y)]
rval = (train_set_x, train_set_y), (test_set_x, test_set_y)
return rval

def load_svhn():
# dataset = '/u/students/gif/Desktop/PhD/Tutorial/dataset/SVHN/svhn_gray.pkl.gz' #linux
dataset = '/local/scratch/gif/dataset/SVHN/svhn_gray.pkl.gz' #the-villa
# dataset = 'I:\Data\PhD Life\Tutorial\dataset\SVHN\svhn_gray.pkl.gz' # laptop
f = gzip.open(dataset,'rb')
(X_train, y_train), (X_test, y_test) = pickle.load(f)
f.close()
Expand Down
19 changes: 15 additions & 4 deletions main_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,38 @@

# hyperparameters
learning_rate = 1e-4
batch_size = 128
batch_size = 32
nb_epoch = 50
augmentation = True
dropout = 0.5
dense_dim = 300


src = 'svhn'
src = 'usps'
tgt = 'mnist'

RESFILE = 'results/'+src+'-'+tgt+'_convnet_results_drop%.1f_aug%d_h%d.pkl.gz' % (dropout, augmentation, dense_dim)
PARAMFILE = 'results/'+src+'-'+tgt+'_convnet_weights_drop%.1f_aug%d_h%d.pkl.gz' % (dropout, augmentation, dense_dim)
print(PARAMFILE)

# Load data

if src == 'svhn':
(X_train, Y_train), (X_test, Y_test) = load_svhn()
(_, _), (_, _), (X_tgt_test, Y_tgt_test) = load_mnist32x32()
elif src == 'mnist':
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist32x32()
(_, _), (X_tgt_test, Y_tgt_test) = load_svhn()
if tgt == 'svhn':
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist32x32()
(_, _), (X_tgt_test, Y_tgt_test) = load_svhn()
else:
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist()
(_, _), (X_tgt_test, Y_tgt_test) = load_usps()

elif src == 'usps':
(X_train, Y_train), (X_test, Y_test) = load_usps()
(_, _), (_, _), (X_tgt_test, Y_tgt_test) = load_mnist()




print('Preprocess data ...')
Expand Down
28 changes: 19 additions & 9 deletions main_drcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,25 @@

net_config = {
'lr': 1e-4,
'batch_size': 128,
'batch_size': 32,
'nb_epoch': 50,
'augmentation': False,
'augmentation': True,
'dropout': 0.5,
'shuffle': False,
'shuffle': True,
'dense_dim': 300,
'loss': 'categorical_crossentropy'
}

ae_config = {
'lr': 1e-4,
'batch_size': 128,
'batch_size': 32,
'denoising': 0.4,
'shuffle': False,
'loss': 'squared_error'
'shuffle': True,
'loss': 'squared_error',
'input': 't'
}

src = 'svhn'
src = 'usps'
tgt = 'mnist'

RESFILE = 'results/'+src+'-'+tgt+'_drcn_results_drop%.1f_aug%d_denoise%.1f.pkl.gz' % (net_config['dropout'], net_config['augmentation'], ae_config['denoising'])
Expand All @@ -34,8 +35,17 @@
(X_train, Y_train), (X_test, Y_test) = load_svhn()
(_, _), (_, _), (X_tgt_test, Y_tgt_test) = load_mnist32x32()
elif src == 'mnist':
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist32x32()
(_, _), (X_tgt_test, Y_tgt_test) = load_svhn()
if tgt == 'svhn':
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist32x32()
(_, _), (X_tgt_test, Y_tgt_test) = load_svhn()
else:
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist()
(_, _), (X_tgt_test, Y_tgt_test) = load_usps()

elif src == 'usps':
(X_train, Y_train), (X_test, Y_test) = load_usps()
(_, _), (_, _), (X_tgt_test, Y_tgt_test) = load_mnist()



print('Preprocess data ...')
Expand Down
22 changes: 16 additions & 6 deletions main_drcn_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
'nb_epoch': 50,
'augmentation': True,
'dropout': 0.5,
'shuffle': False,
'shuffle': True,
'dense_dim': 300,
'loss': 'categorical_crossentropy'
}
Expand All @@ -17,26 +17,36 @@
'lr': 1e-4,
'batch_size': 128,
'denoising': 0.5,
'shuffle': False,
'shuffle': True,
'loss': 'squared_error',
'input': 'src'
'input': 'st'
}

src = 'svhn'
tgt = 'mnist'
model = 'drcn-s'
model = 'drcn-'+ae_config['input']

RESFILE = 'results/'+src+'-'+tgt+'_'+model+'_results_drop%.1f_aug%d_denoise%.1f.pkl.gz' % (net_config['dropout'], net_config['augmentation'], ae_config['denoising'])
PARAMFILE = 'results/'+src+'-'+tgt+'_'+model+'_weights_drop%.1f_aug%d_denoise%.1f.pkl.gz' % (net_config['dropout'], net_config['augmentation'], ae_config['denoising'])
PREDICTPREFIX = src+'-'+tgt+'_'+model+'_'
print(PARAMFILE)


# Load data
if src == 'svhn':
(X_train, Y_train), (X_test, Y_test) = load_svhn()
(_, _), (_, _), (X_tgt_test, Y_tgt_test) = load_mnist32x32()
elif src == 'mnist':
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist32x32()
(_, _), (X_tgt_test, Y_tgt_test) = load_svhn()
if tgt == 'svhn':
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist32x32()
(_, _), (X_tgt_test, Y_tgt_test) = load_svhn()
else:
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist()
(_, _), (X_tgt_test, Y_tgt_test) = load_usps()

elif src == 'usps':
(X_train, Y_train), (X_test, Y_test) = load_usps()
(_, _), (_, _), (X_tgt_test, Y_tgt_test) = load_mnist()


print('Preprocess data ...')
Expand Down
13 changes: 11 additions & 2 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,23 @@ def train(self, X, Y,

total_batches_ae = X_tgt.shape[0] / self.ae_config['batch_size']

if self.ae_config['input'] == 's':
X_ae = np.copy(X)
elif self.ae_config['input'] == 't':
X_ae = np.copy(X_tgt)
elif self.ae_config['input'] == 'st':
X_ae = np.concatenate([X, X_tgt])


for epoch in range(self.net_config['nb_epoch']):

# ========== CONVAE =======
start_time = time.time()
loss_ae = 0
nbatch = 0
for X_batch, Y_batch in gdatagen.flow(X_tgt, np.copy(X_tgt), batch_size=self.ae_config['batch_size'], shuffle=self.ae_config['shuffle']):

# print('AE training : ',X_ae.shape)
for X_batch, Y_batch in gdatagen.flow(X_ae, np.copy(X_ae), batch_size=self.ae_config['batch_size'], shuffle=self.ae_config['shuffle']):
if self.ae_config['denoising'] > 0.:
X_batch = get_corrupted_output(X_batch, corruption_level=self.ae_config['denoising']).astype('float32')
else:
Expand Down Expand Up @@ -601,7 +610,7 @@ def train(self, X, Y,
if RESFILE is not None:
pickle.dump(self.res, gzip.open(RESFILE,'wb'))

if epoch % 5 == 0:
if epoch % 10 == 0:
print('=== > Save weights !')
self.save_weights(PARAMFILE)
# end epoch
Expand Down
28 changes: 19 additions & 9 deletions test_drcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,34 @@

src = 'svhn'
tgt = 'mnist'

RESFILE = 'results/'+src+'-'+tgt+'_drcn_results_drop0.5_aug0_denoise0.4.pkl.gz'
PARAMFILE = 'results/'+src+'-'+tgt+'_drcn_weights_drop0.5_aug0_denoise0.4.pkl.gz'
model = 'drcn-st'
RESFILE = 'results/'+src+'-'+tgt+'_'+model+'_results_drop0.5_aug1_denoise0.5.pkl.gz'
PARAMFILE = 'results/'+src+'-'+tgt+'_'+model+'_weights_drop0.5_aug1_denoise0.5.pkl.gz'


print('Load data...')
inds = pickle.load(gzip.open('data_indices100.pkl.gz','rb'))


if src == 'svhn':
(X_train, Y_train), (X_test, Y_test) = load_svhn()
(_, _), (_, _), (X_tgt_test, Y_tgt_test) = load_mnist32x32()
idx_src = inds['svhn_train']
idx_tgt = inds['mnist_test']

elif src == 'mnist':
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist32x32()
(_, _), (X_tgt_test, Y_tgt_test) = load_svhn()
idx_src = inds['mnist_train']
idx_tgt = inds['svhn_test']
if tgt == 'svhn':
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist32x32()
(_, _), (X_tgt_test, Y_tgt_test) = load_svhn()
idx_src = inds['mnist_train']
idx_tgt = inds['svhn_test']
else:
(X_train, Y_train), (_, _), (X_test, Y_test) = load_mnist()
(_, _), (X_tgt_test, Y_tgt_test) = load_usps()

elif src == 'usps':
(X_train, Y_train), (X_test, Y_test) = load_usps()
(_, _), (_, _), (X_tgt_test, Y_tgt_test) = load_mnist()

print('Preprocess data ...')
# X_train, scaler = remove_mean(X_train)
Expand All @@ -46,7 +56,7 @@
model.load_weights(PARAMFILE)


show_filter(X_train[idx_src], grayscale=True, filename='viz/'+src+'_'+tgt+'_drcn-fixed_X100-src-orig_denoise0.4.png')
show_filter(model.convae_.predict(X_train[idx_src]), grayscale=True, filename='viz/'+src+'_'+tgt+'_drcn-fixed_X100-src-pred_denoise0.4.png')
show_filter(X_train[idx_src], grayscale=True, filename='viz/'+src+'_'+tgt+'_'+model+'_X100-src-orig.png')
show_filter(model.convae_.predict(X_train[idx_src]), grayscale=True, filename='viz/'+src+'_'+tgt+'_'+model+'_X100-src-pred.png')
# show_filter(X_tgt_test[idx_tgt], grayscale=True, filename='viz/'+src+'_'+tgt+'_drcn-fixed_X100-tgt-orig.png')
# show_filter(model.convae_.predict(X_tgt_test[idx_tgt]), grayscale=True, filename='viz/'+src+'_'+tgt+'_drcn-fixed_X100-tgt-pred.png')

0 comments on commit 0ecb608

Please sign in to comment.