Skip to content

Commit

Permalink
Cleanup examples
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Dec 9, 2015
1 parent 96f3404 commit 81787dd
Show file tree
Hide file tree
Showing 15 changed files with 227 additions and 248 deletions.
26 changes: 13 additions & 13 deletions examples/addition_rnn.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
from keras.models import Sequential, slice_X
from keras.layers.core import Activation, TimeDistributedDense, RepeatVector
from keras.layers import recurrent
import numpy as np
from six.moves import range

"""
An implementation of sequence to sequence learning for performing addition
'''An implementation of sequence to sequence learning for performing addition
Input: "535+61"
Output: "596"
Padding is handled by using a repeated sentinel character (space)
Expand All @@ -32,16 +24,23 @@
Five digits inverted:
+ One layer LSTM (128 HN), 550k training examples = 99% train/test accuracy in 30 epochs
"""
'''

from __future__ import print_function
from keras.models import Sequential, slice_X
from keras.layers.core import Activation, TimeDistributedDense, RepeatVector
from keras.layers import recurrent
import numpy as np
from six.moves import range


class CharacterTable(object):
"""
'''
Given a set of characters:
+ Encode them to a one hot integer representation
+ Decode the one hot integer representation to their character output
+ Decode a vector of probabilties to their character output
"""
'''
def __init__(self, chars, maxlen):
self.chars = sorted(set(chars))
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
Expand Down Expand Up @@ -150,7 +149,8 @@ class colors:
print()
print('-' * 50)
print('Iteration', iteration)
model.fit(X_train, y_train, batch_size=BATCH_SIZE, nb_epoch=1, validation_data=(X_val, y_val), show_accuracy=True)
model.fit(X_train, y_train, batch_size=BATCH_SIZE, nb_epoch=1,
validation_data=(X_val, y_val), show_accuracy=True)
###
# Select 10 samples from the validation set at random so we can visualize errors
for i in range(10):
Expand Down
30 changes: 14 additions & 16 deletions examples/babi_memnn.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,4 @@
from __future__ import print_function
from keras.models import Sequential
from keras.layers.embeddings import Embedding
from keras.layers.core import Activation, Dense, Merge, Permute, Dropout
from keras.layers.recurrent import LSTM
from keras.datasets.data_utils import get_file
from keras.preprocessing.sequence import pad_sequences
from functools import reduce
import tarfile
import numpy as np
import re

"""
Train a memory network on the bAbI dataset.
'''Train a memory network on the bAbI dataset.
References:
- Jason Weston, Antoine Bordes, Sumit Chopra, Tomas Mikolov, Alexander M. Rush,
Expand All @@ -24,7 +11,19 @@
Reaches 93% accuracy on task 'single_supporting_fact_10k' after 70 epochs.
Time per epoch: 3s on CPU (core i7).
"""
'''

from __future__ import print_function
from keras.models import Sequential
from keras.layers.embeddings import Embedding
from keras.layers.core import Activation, Dense, Merge, Permute, Dropout
from keras.layers.recurrent import LSTM
from keras.datasets.data_utils import get_file
from keras.preprocessing.sequence import pad_sequences
from functools import reduce
import tarfile
import numpy as np
import re


def tokenize(sent):
Expand Down Expand Up @@ -200,4 +199,3 @@ def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
nb_epoch=70,
show_accuracy=True,
validation_data=([inputs_test, queries_test, inputs_test], answers_test))

34 changes: 16 additions & 18 deletions examples/babi_rnn.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,4 @@
from __future__ import absolute_import
from __future__ import print_function
from functools import reduce
import re
import tarfile

import numpy as np
np.random.seed(1337) # for reproducibility

from keras.datasets.data_utils import get_file
from keras.layers.embeddings import Embedding
from keras.layers.core import Dense, Merge
from keras.layers import recurrent
from keras.models import Sequential
from keras.preprocessing.sequence import pad_sequences

'''
Trains two recurrent neural networks based upon a story and a question.
'''Trains two recurrent neural networks based upon a story and a question.
The resulting merged vector is then queried to answer a range of bAbI tasks.
The results are comparable to those for an LSTM model provided in Weston et al.:
Expand Down Expand Up @@ -73,6 +56,21 @@
This becomes especially obvious on QA2 and QA3, both far longer than QA1.
'''

from __future__ import print_function
from functools import reduce
import re
import tarfile

import numpy as np
np.random.seed(1337) # for reproducibility

from keras.datasets.data_utils import get_file
from keras.layers.embeddings import Embedding
from keras.layers.core import Dense, Merge
from keras.layers import recurrent
from keras.models import Sequential
from keras.preprocessing.sequence import pad_sequences


def tokenize(sent):
'''Return the tokens of a sentence including punctuation.
Expand Down
44 changes: 21 additions & 23 deletions examples/cifar10_cnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
from __future__ import absolute_import
'''Train a simple deep CNN on the CIFAR10 small images dataset.
GPU run command:
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cifar10_cnn.py
It gets down to 0.65 test logloss in 25 epochs, and down to 0.55 after 50 epochs.
(it's still underfitting at that point, though).
Note: the data was pickled with Python 2, and some encoding issues might prevent you
from loading it in Python 3. You might have to load it in Python 2,
save it in a different format, load it in Python 3 and repickle it.
'''

from __future__ import print_function
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
Expand All @@ -9,20 +21,6 @@
from keras.utils import np_utils, generic_utils
from six.moves import range

'''
Train a (fairly simple) deep CNN on the CIFAR10 small images dataset.
GPU run command:
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python cifar10_cnn.py
It gets down to 0.65 test logloss in 25 epochs, and down to 0.55 after 50 epochs.
(it's still underfitting at that point, though).
Note: the data was pickled with Python 2, and some encoding issues might prevent you
from loading it in Python 3. You might have to load it in Python 2,
save it in a different format, load it in Python 3 and repickle it.
'''

batch_size = 32
nb_classes = 10
nb_epoch = 200
Expand Down Expand Up @@ -71,19 +69,19 @@
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd)

X_train = X_train.astype("float32")
X_test = X_test.astype("float32")
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

if not data_augmentation:
print("Not using data augmentation or normalization")
print('Not using data augmentation or normalization')
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch)
score = model.evaluate(X_test, Y_test, batch_size=batch_size)
print('Test score:', score)

else:
print("Using real time data augmentation")
print('Using real time data augmentation')

# this will do preprocessing and realtime data augmentation
datagen = ImageDataGenerator(
Expand All @@ -106,16 +104,16 @@
print('-'*40)
print('Epoch', e)
print('-'*40)
print("Training...")
print('Training...')
# batch train with realtime data augmentation
progbar = generic_utils.Progbar(X_train.shape[0])
for X_batch, Y_batch in datagen.flow(X_train, Y_train):
loss = model.train_on_batch(X_batch, Y_batch)
progbar.add(X_batch.shape[0], values=[("train loss", loss[0])])
progbar.add(X_batch.shape[0], values=[('train loss', loss[0])])

print("Testing...")
print('Testing...')
# test time!
progbar = generic_utils.Progbar(X_test.shape[0])
for X_batch, Y_batch in datagen.flow(X_test, Y_test):
score = model.test_on_batch(X_batch, Y_batch)
progbar.add(X_batch.shape[0], values=[("test loss", score[0])])
progbar.add(X_batch.shape[0], values=[('test loss', score[0])])
23 changes: 11 additions & 12 deletions examples/imdb_bidirectional_lstm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
from __future__ import absolute_import
'''Train a Bidirectional LSTM on the IMDB sentiment classification task.
GPU command:
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python imdb_bidirectional_lstm.py
Output after 4 epochs on CPU: ~0.8146
Time per epoch on CPU (Core i7): ~150s.
'''

from __future__ import print_function
import numpy as np
np.random.seed(1337) # for reproducibility
Expand All @@ -11,21 +19,12 @@
from keras.layers.recurrent import LSTM
from keras.datasets import imdb

'''
Train a Bidirectional LSTM on the IMDB sentiment classification task.
GPU command:
THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python imdb_bidirectional_lstm.py
Output after 4 epochs on CPU: ~0.8146
Time per epoch on CPU (Core i7): ~150s.
'''

max_features = 20000
maxlen = 100 # cut texts after this number of words (among top max_features most common words)
batch_size = 32

print("Loading data...")
print('Loading data...')
(X_train, y_train), (X_test, y_test) = imdb.load_data(nb_words=max_features,
test_split=0.2)
print(len(X_train), 'train sequences')
Expand Down Expand Up @@ -53,7 +52,7 @@
# try using different optimizers and different optimizer configs
model.compile('adam', {'output': 'binary_crossentropy'})

print("Train...")
print('Train...')
model.fit({'input': X_train, 'output': y_train},
batch_size=batch_size,
nb_epoch=4)
Expand Down
26 changes: 12 additions & 14 deletions examples/imdb_cnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from __future__ import absolute_import
'''This example demonstrates the use of Convolution1D for text classification.
Run on GPU: THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python imdb_cnn.py
Get to 0.835 test accuracy after 2 epochs. 100s/epoch on K520 GPU.
'''

from __future__ import print_function
import numpy as np
np.random.seed(1337) # for reproducibility
Expand All @@ -10,14 +16,6 @@
from keras.layers.convolutional import Convolution1D, MaxPooling1D
from keras.datasets import imdb

'''
This example demonstrates the use of Convolution1D
for text classification.
Run on GPU: THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python imdb_cnn.py
Get to 0.835 test accuracy after 2 epochs. 100s/epoch on K520 GPU.
'''

# set parameters:
max_features = 5000
Expand All @@ -29,13 +27,13 @@
hidden_dims = 250
nb_epoch = 2

print("Loading data...")
print('Loading data...')
(X_train, y_train), (X_test, y_test) = imdb.load_data(nb_words=max_features,
test_split=0.2)
print(len(X_train), 'train sequences')
print(len(X_test), 'test sequences')

print("Pad sequences (samples x time)")
print('Pad sequences (samples x time)')
X_train = sequence.pad_sequences(X_train, maxlen=maxlen)
X_test = sequence.pad_sequences(X_test, maxlen=maxlen)
print('X_train shape:', X_train.shape)
Expand All @@ -53,8 +51,8 @@
# word group filters of size filter_length:
model.add(Convolution1D(nb_filter=nb_filter,
filter_length=filter_length,
border_mode="valid",
activation="relu",
border_mode='valid',
activation='relu',
subsample_length=1))
# we use standard max pooling (halving the output of the previous layer):
model.add(MaxPooling1D(pool_length=2))
Expand All @@ -74,7 +72,7 @@

model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
class_mode="binary")
class_mode='binary')
model.fit(X_train, y_train, batch_size=batch_size,
nb_epoch=nb_epoch, show_accuracy=True,
validation_data=(X_test, y_test))
Loading

0 comments on commit 81787dd

Please sign in to comment.