forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Example: Sequence to sequence learning for addition using RNNs
- Loading branch information
Showing
1 changed file
with
152 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# -*- coding: utf-8 -*- | ||
from __future__ import print_function | ||
from keras.models import Sequential, slice_X | ||
from keras.layers.core import Activation, Dense, RepeatVector | ||
from keras.layers import recurrent | ||
from sklearn.utils import shuffle | ||
import numpy as np | ||
|
||
""" | ||
An implementation of sequence to sequence learning for performing addition | ||
Input: "535+61" | ||
Output: "596" | ||
Padding is handled by using a repeated sentinel character (space) | ||
By default, the JZS1 recurrent neural network is used | ||
JZS1 was an "evolved" recurrent neural network performing well on arithmetic benchmark in: | ||
"An Empirical Exploration of Recurrent Network Architectures" | ||
http://jmlr.org/proceedings/papers/v37/jozefowicz15.pdf | ||
Input may optionally be inverted, shown to increase performance in many tasks in: | ||
"Learning to Execute" | ||
http://arxiv.org/abs/1410.4615 | ||
and | ||
"Sequence to Sequence Learning with Neural Networks" | ||
http://papers.nips.cc/paper/5346-sequence-to-sequence-learning-with-neural-networks.pdf | ||
Theoretically it introduces shorter term dependencies between source and target. | ||
Two digits inverted: | ||
+ One layer JZS1 (128 HN) with 55 iterations = 99% train/test accuracy | ||
Three digits inverted: | ||
+ One layer JZS1 (128 HN) with 19 iterations = 99% train/test accuracy | ||
Four digits inverted: | ||
+ One layer JZS1 (128 HN) with 20 iterations = 99% train/test accuracy | ||
Five digits inverted: | ||
+ One layer JZS1 (128 HN) with 28 iterations = 99% train/test accuracy | ||
""" | ||
|
||
|
||
class CharacterTable(object): | ||
""" | ||
Given a set of characters: | ||
+ Encode them to a one hot integer representation | ||
+ Decode the one hot integer representation to their character output | ||
+ Decode a vector of probabilties to their character output | ||
""" | ||
def __init__(self, chars, maxlen): | ||
self.chars = sorted(set(chars)) | ||
self.char_indices = dict((c, i) for i, c in enumerate(self.chars)) | ||
self.indices_char = dict((i, c) for i, c in enumerate(self.chars)) | ||
self.maxlen = maxlen | ||
|
||
def encode(self, C, maxlen=None): | ||
maxlen = maxlen if maxlen else self.maxlen | ||
X = np.zeros((maxlen, len(self.chars))) | ||
for i, c in enumerate(C): | ||
X[i, self.char_indices[c]] = 1 | ||
return X | ||
|
||
def decode(self, X, calc_argmax=True): | ||
if calc_argmax: | ||
X = X.argmax(axis=-1) | ||
return ''.join(self.indices_char[x] for x in X) | ||
|
||
# Parameters for the model and dataset | ||
# Note: Training size is number of queries to generate, not final number of unique queries | ||
TRAINING_SIZE = 800000 | ||
DIGITS = 3 | ||
INVERT = True | ||
# Try replacing JZS1 with LSTM, GRU, or SimpleRNN | ||
RNN = recurrent.JZS1 | ||
HIDDEN_SIZE = 128 | ||
BATCH_SIZE = 128 | ||
LAYERS = 1 | ||
MAXLEN = DIGITS + 1 + DIGITS | ||
|
||
chars = '0123456789+ ' | ||
ctable = CharacterTable(chars, MAXLEN) | ||
|
||
questions = [] | ||
expected = [] | ||
seen = set() | ||
print('Generating data...') | ||
for i in xrange(TRAINING_SIZE): | ||
f = lambda: int(''.join(np.random.choice(list('0123456789')) for i in xrange(np.random.randint(1, DIGITS + 1)))) | ||
a, b = f(), f() | ||
# Skip any addition questions we've already seen | ||
# Also skip any such that X+Y == Y+X (hence the sorting) | ||
key = tuple(sorted((a, b))) | ||
if key in seen: | ||
continue | ||
seen.add(key) | ||
# Pad the data with spaces such that it is always MAXLEN | ||
q = '{}+{}'.format(a, b) | ||
query = q + ' ' * (MAXLEN - len(q)) | ||
ans = str(a + b) | ||
# Answers can be of maximum size DIGITS + 1 | ||
ans += ' ' * (DIGITS + 1 - len(ans)) | ||
if INVERT: | ||
query = query[::-1] | ||
questions.append(query) | ||
expected.append(ans) | ||
print('Total addition questions:', len(questions)) | ||
|
||
print('Vectorization...') | ||
X = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool) | ||
y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool) | ||
for i, sentence in enumerate(questions): | ||
X[i] = ctable.encode(sentence, maxlen=MAXLEN) | ||
for i, sentence in enumerate(expected): | ||
y[i] = ctable.encode(sentence, maxlen=DIGITS + 1) | ||
|
||
# Shuffle (X, y) in unison as the later parts of X will almost all be larger digits | ||
X, y = shuffle(X, y) | ||
# Explicitly set apart 10% for validation data that we never train over | ||
split_at = len(X) - len(X) / 10 | ||
(X_train, X_val) = (slice_X(X, 0, split_at), slice_X(X, split_at)) | ||
(y_train, y_val) = (y[:split_at], y[split_at:]) | ||
|
||
print('Build model...') | ||
model = Sequential() | ||
# "Encode" the input sequence using an RNN, producing an output of HIDDEN_SIZE | ||
model.add(RNN(len(chars), HIDDEN_SIZE)) | ||
# For the decoder's input, we repeat the encoded input for each time step | ||
model.add(RepeatVector(DIGITS + 1)) | ||
# The decoder RNN could be multiple layers stacked or a single layer | ||
for _ in xrange(LAYERS): | ||
model.add(RNN(HIDDEN_SIZE, HIDDEN_SIZE, return_sequences=True)) | ||
# For each of step of the output sequence, decide which character should be chosen | ||
model.add(Dense(HIDDEN_SIZE, len(chars))) | ||
model.add(Activation('softmax')) | ||
|
||
model.compile(loss='categorical_crossentropy', optimizer='adam') | ||
|
||
# Train the model each generation and show predictions against the validation dataset | ||
for iteration in range(1, 60): | ||
print() | ||
print('-' * 50) | ||
print('Iteration', iteration) | ||
model.fit(X, y, batch_size=BATCH_SIZE, nb_epoch=1, validation_data=(X_val, y_val), show_accuracy=True) | ||
### | ||
# Select 10 samples from the validation set at random so we can visualize errors | ||
for i in xrange(10): | ||
ind = np.random.randint(0, len(X_val)) | ||
rowX, rowy = X_val[np.array([ind])], y_val[np.array([ind])] | ||
preds = model.predict_classes(rowX, verbose=0) | ||
q = ctable.decode(rowX[0]) | ||
correct = ctable.decode(rowy[0]) | ||
guess = ctable.decode(preds[0], calc_argmax=False) | ||
print('Q', q[::-1] if INVERT else q) | ||
print('T', correct) | ||
print('☑' if correct == guess else '☒', guess) | ||
print('---') |