Skip to content

Commit

Permalink
Update documentation to new API.
Browse files Browse the repository at this point in the history
  • Loading branch information
matsuyamax committed Oct 5, 2015
1 parent cb77f7d commit 0b8a52e
Show file tree
Hide file tree
Showing 17 changed files with 281 additions and 232 deletions.
128 changes: 75 additions & 53 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD

model = Sequential()
model.add(Dense(20, 64, init='uniform'))
# Dense(64) is a fully-connected layer with 64 hidden units.
# in the first layer, you must specify the expected input data shape:
# here, 20-dimensional vectors.
model.add(Dense(64, input_dim=20, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(64, 64, init='uniform'))
model.add(Dense(64, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(64, 2, init='uniform'))
model.add(Dense(2, init='uniform'))
model.add(Activation('softmax'))

sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
Expand All @@ -54,11 +57,11 @@ score = model.evaluate(X_test, y_test, batch_size=16)

```python
model = Sequential()
model.add(Dense(20, 64, init='uniform', activation='tanh'))
model.add(Dense(64, input_dim=20, init='uniform', activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(64, 64, init='uniform', activation='tanh'))
model.add(Dense(64, init='uniform', activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(64, 2, init='uniform', activation='softmax'))
model.add(Dense(2, init='uniform', activation='softmax'))

sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)
Expand All @@ -73,26 +76,29 @@ from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.optimizers import SGD

model = Sequential()
model.add(Convolution2D(32, 3, 3, 3, border_mode='full'))
# input: 100x100 images with 3 channels -> (3, 100, 100) tensors.
# this applies 32 convolution filters of size 3x3 each.
model.add(Convolution2D(32, 3, 3, border_mode='full', input_shape=(3, 100, 100)))
model.add(Activation('relu'))
model.add(Convolution2D(32, 32, 3, 3))
model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Dropout(0.25))

model.add(Convolution2D(64, 32, 3, 3, border_mode='full'))
model.add(Convolution2D(64, 3, 3, border_mode='valid'))
model.add(Activation('relu'))
model.add(Convolution2D(64, 64, 3, 3))
model.add(Convolution2D(64, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(64*8*8, 256))
# Note: Keras does automatic shape inference.
model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dropout(0.5))

model.add(Dense(256, 10))
model.add(Dense(10))
model.add(Activation('softmax'))

sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
Expand All @@ -112,9 +118,9 @@ from keras.layers.recurrent import LSTM

model = Sequential()
model.add(Embedding(max_features, 256))
model.add(LSTM(256, 128, activation='sigmoid', inner_activation='hard_sigmoid'))
model.add(LSTM(output_dim=128, activation='sigmoid', inner_activation='hard_sigmoid'))
model.add(Dropout(0.5))
model.add(Dense(128, 1))
model.add(Dense(1))
model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy', optimizer='rmsprop')
Expand All @@ -126,51 +132,67 @@ score = model.evaluate(X_test, Y_test, batch_size=16)
### Architecture for learning image captions with a convnet and a Gated Recurrent Unit:
(word-level embedding, caption of maximum length 16 words).

Note that getting this to actually "work" will require using a bigger convnet, initialized with pre-trained weights.
Displaying readable results will also require an embedding decoder.
Note that getting this to work well will require using a bigger convnet, initialized with pre-trained weights.

```python
max_caption_len = 16
vocab_size = 10000

# first, let's define an image model that
# will encode pictures into 128-dimensional vectors.
# it should be initialized with pre-trained weights.
image_model = Sequential()
image_model.add(Convolution2D(32, 3, 3, border_mode='full', input_shape=(3, 100, 100)))
image_model.add(Activation('relu'))
image_model.add(Convolution2D(32, 3, 3))
image_model.add(Activation('relu'))
image_model.add(MaxPooling2D(poolsize=(2, 2)))

image_model.add(Convolution2D(64, 3, 3, border_mode='full'))
image_model.add(Activation('relu'))
image_model.add(Convolution2D(64, 3, 3))
image_model.add(Activation('relu'))
image_model.add(MaxPooling2D(poolsize=(2, 2)))

image_model.add(Flatten())
image_model.add(Dense(128))

# let's load the weights from a save file.
image_model.load_weights('weight_file.h5')

# next, let's define a RNN model that encodes sequences of words
# into sequences of 128-dimensional word vectors.
language_model = Sequential()
language_model.add(Embedding(vocab_size, 256, input_length=max_caption_len))
language_model.add(GRU(output_dim=128, return_sequences=True))
language_model.add(Dense(128))

# let's repeat the image vector to turn it into a sequence.
image_model.add(RepeatVector(max_caption_len))

# the output of both models will be tensors of shape (samples, max_caption_len, 128).
# let's concatenate these 2 vector sequences.
model = Merge([image_model, language_model], mode='concat', concat_axis=-1)
# let's encode this vector sequence into a single vector
model.add(GRU(256, 256, return_sequences=False))
# which will be used to compute a probability
# distribution over what the next word in the caption should be!
model.add(Dense(vocab_size))
model.add(Activation('softmax'))

model = Sequential()
model.add(Convolution2D(32, 3, 3, 3, border_mode='full'))
model.add(Activation('relu'))
model.add(Convolution2D(32, 32, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2)))

model.add(Convolution2D(64, 32, 3, 3, border_mode='full'))
model.add(Activation('relu'))
model.add(Convolution2D(64, 64, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2)))

model.add(Convolution2D(128, 64, 3, 3, border_mode='full'))
model.add(Activation('relu'))
model.add(Convolution2D(128, 128, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(poolsize=(2, 2)))

model.add(Flatten())
model.add(Dense(128*4*4, 256))
model.add(Activation('relu'))
model.add(Dropout(0.5))

model.add(RepeatVector(max_caption_len))
# the GRU below returns sequences of max_caption_len vectors of size 256 (our word embedding size)
model.add(GRU(256, 256, return_sequences=True))

model.compile(loss='mean_squared_error', optimizer='rmsprop')
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

# "images" is a numpy array of shape (nb_samples, nb_channels=3, width, height)
# "captions" is a numpy array of shape (nb_samples, max_caption_len=16, embedding_dim=256)
# captions are supposed already embedded (dense vectors).
model.fit(images, captions, batch_size=16, nb_epoch=100)

# "images" is a numpy float array of shape (nb_samples, nb_channels=3, width, height).
# "captions" is a numpy integer array of shape (nb_samples, max_caption_len)
# containing word index sequences representing partial captions.
# "next_words" is a numpy float array of shape (nb_samples, vocab_size)
# containing a categorical encoding (0s and 1s) of the next word in the corresponding
# partial caption.
model.fit([images, partial_captions], next_words, batch_size=16, nb_epoch=100)
```

In the examples folder, you will find example models for real datasets:
- CIFAR10 small images classification: Convnet with realtime data augmentation
- CIFAR10 small images classification: Convolutional Neural Network (CNN) with realtime data augmentation
- IMDB movie review sentiment classification: LSTM over sequences of words
- Reuters newswires topic classification: Multilayer Perceptron (MLP)
- MNIST handwritten digits classification: MLP & CNN
Expand All @@ -183,7 +205,7 @@ In the examples folder, you will find example models for real datasets:

For complete coverage of the API, check out [the Keras documentation](http://keras.io).

A few highlights: convnets, LSTM, GRU, word2vec-style embeddings, PReLU, batch normalization...
A few highlights: convnets, LSTM, GRU, word2vec-style embeddings, PReLU, BatchNormalization...

## Installation

Expand All @@ -196,7 +218,7 @@ Keras uses the following dependencies:
- HDF5 and h5py (optional, required if you use model saving/loading functions)
- Optional but recommended if you use CNNs: cuDNN.

Once you have the dependencies installed, cd to the Keras folder and run the install command:
To install, `cd` to the Keras folder and run the install command:
```
sudo python setup.py install
```
Expand Down
6 changes: 3 additions & 3 deletions docs/sources/activations.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ Activations can either be used through an `Activation` layer, or through the `ac
```python
from keras.layers.core import Activation, Dense

model.add(Dense(64, 64, init='uniform'))
model.add(Dense(64))
model.add(Activation('tanh'))
```
is equivalent to:
```python
model.add(Dense(20, 64, init='uniform', activation='tanh'))
model.add(Dense(64, activation='tanh'))
```

You can also pass an element-wise Theano function as an activation:
Expand All @@ -20,7 +20,7 @@ You can also pass an element-wise Theano function as an activation:
def tanh(x):
return theano.tensor.tanh(x)

model.add(Dense(20, 64, init='uniform', activation=tanh))
model.add(Dense(64, activation=tanh))
model.add(Activation(tanh))
```

Expand Down
4 changes: 2 additions & 2 deletions docs/sources/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class LossHistory(keras.callbacks.Callback):
self.losses.append(logs.get('loss'))

model = Sequential()
model.add(Dense(784, 10, init='uniform'))
model.add(Dense(10, input_dim=784, init='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

Expand All @@ -97,7 +97,7 @@ print history.losses
from keras.callbacks import ModelCheckpoint

model = Sequential()
model.add(Dense(784, 10, init='uniform'))
model.add(Dense(10, input_dim=784, init='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')

Expand Down
2 changes: 1 addition & 1 deletion docs/sources/constraints.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ These layers expose 2 keyword arguments:

```python
from keras.constraints import maxnorm
model.add(Dense(64, 64, W_constraint = maxnorm(2)))
model.add(Dense(64, W_constraint = maxnorm(2)))
```

## Available constraints
Expand Down
Loading

0 comments on commit 0b8a52e

Please sign in to comment.