Skip to content

Commit

Permalink
sort dataset by length to speed up error computation
Browse files Browse the repository at this point in the history
  • Loading branch information
nouiz committed Jan 31, 2015
1 parent ec112fa commit 6a32f03
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
27 changes: 25 additions & 2 deletions code/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ def get_dataset_file(dataset, default_dataset, origin):
return dataset


def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1, maxlen=None):
''' Loads the dataset
def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1, maxlen=None,
sort_by_len=True):
'''Loads the dataset
:type path: String
:param path: The path to the dataset (here IMDB)
Expand All @@ -87,6 +88,12 @@ def load_data(path="imdb.pkl", n_words=100000, valid_portion=0.1, maxlen=None):
the validation set.
:type maxlen: None or positive int
:param maxlen: the max sequence length we use in the train/valid set.
:type sort_by_len: bool
:name sort_by_len: Sort by the sequence lenght for the train,
valid and test set. This allow faster execution as it cause
less padding per minibatch. Another mechanism must be used to
shuffle the train set at each epoch.
'''

#############
Expand Down Expand Up @@ -140,6 +147,22 @@ def remove_unk(x):
valid_set_x = remove_unk(valid_set_x)
test_set_x = remove_unk(test_set_x)

def len_argsort(seq):
return sorted(range(len(seq)), key=lambda x: len(seq[x]))

if sort_by_len:
sorted_index = len_argsort(test_set_x)
test_set_x = [test_set_x[i] for i in sorted_index]
test_set_y = [test_set_y[i] for i in sorted_index]

sorted_index = len_argsort(valid_set_x)
valid_set_x = [valid_set_x[i] for i in sorted_index]
valid_set_y = [valid_set_y[i] for i in sorted_index]

sorted_index = len_argsort(train_set_x)
train_set_x = [train_set_x[i] for i in sorted_index]
train_set_y = [train_set_y[i] for i in sorted_index]

train = (train_set_x, train_set_y)
valid = (valid_set_x, valid_set_y)
test = (test_set_x, test_set_y)
Expand Down
9 changes: 4 additions & 5 deletions code/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,8 @@ def train_lstm(

print 'Optimization'

kf_valid = get_minibatches_idx(len(valid[0]), valid_batch_size,
shuffle=True)
kf_test = get_minibatches_idx(len(test[0]), valid_batch_size,
shuffle=True)
kf_valid = get_minibatches_idx(len(valid[0]), valid_batch_size)
kf_test = get_minibatches_idx(len(test[0]), valid_batch_size)

print "%d train examples" % len(train[0])
print "%d valid examples" % len(valid[0])
Expand Down Expand Up @@ -561,7 +559,8 @@ def train_lstm(
best_p = unzip(tparams)

use_noise.set_value(0.)
train_err = pred_error(f_pred, prepare_data, train, kf)
kf_train_sorted = get_minibatches_idx(len(train[0]), batch_size)
train_err = pred_error(f_pred, prepare_data, train, kf_train_sorted)
valid_err = pred_error(f_pred, prepare_data, valid, kf_valid)
test_err = pred_error(f_pred, prepare_data, test, kf_test)

Expand Down

0 comments on commit 6a32f03

Please sign in to comment.