-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
_working_with_rnns.py
592 lines (439 loc) · 19.3 KB
/
_working_with_rnns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
"""
Title: Working with RNNs
Authors: Scott Zhu, Francois Chollet
Date created: 2019/07/08
Last modified: 2023/07/10
Description: Complete guide to using & customizing RNN layers.
Accelerator: GPU
"""
"""
## Introduction
Recurrent neural networks (RNN) are a class of neural networks that is powerful for
modeling sequence data such as time series or natural language.
Schematically, a RNN layer uses a `for` loop to iterate over the timesteps of a
sequence, while maintaining an internal state that encodes information about the
timesteps it has seen so far.
The Keras RNN API is designed with a focus on:
- **Ease of use**: the built-in `keras.layers.RNN`, `keras.layers.LSTM`,
`keras.layers.GRU` layers enable you to quickly build recurrent models without
having to make difficult configuration choices.
- **Ease of customization**: You can also define your own RNN cell layer (the inner
part of the `for` loop) with custom behavior, and use it with the generic
`keras.layers.RNN` layer (the `for` loop itself). This allows you to quickly
prototype different research ideas in a flexible way with minimal code.
"""
"""
## Setup
"""
import numpy as np
import tensorflow as tf
import keras
from keras import layers
"""
## Built-in RNN layers: a simple example
"""
"""
There are three built-in RNN layers in Keras:
1. `keras.layers.SimpleRNN`, a fully-connected RNN where the output from previous
timestep is to be fed to next timestep.
2. `keras.layers.GRU`, first proposed in
[Cho et al., 2014](https://arxiv.org/abs/1406.1078).
3. `keras.layers.LSTM`, first proposed in
[Hochreiter & Schmidhuber, 1997](https://www.bioinf.jku.at/publications/older/2604.pdf).
In early 2015, Keras had the first reusable open-source Python implementations of LSTM
and GRU.
Here is a simple example of a `Sequential` model that processes sequences of integers,
embeds each integer into a 64-dimensional vector, then processes the sequence of
vectors using a `LSTM` layer.
"""
model = keras.Sequential()
# Add an Embedding layer expecting input vocab of size 1000, and
# output embedding dimension of size 64.
model.add(layers.Embedding(input_dim=1000, output_dim=64))
# Add a LSTM layer with 128 internal units.
model.add(layers.LSTM(128))
# Add a Dense layer with 10 units.
model.add(layers.Dense(10))
model.summary()
"""
Built-in RNNs support a number of useful features:
- Recurrent dropout, via the `dropout` and `recurrent_dropout` arguments
- Ability to process an input sequence in reverse, via the `go_backwards` argument
- Loop unrolling (which can lead to a large speedup when processing short sequences on
CPU), via the `unroll` argument
- ...and more.
For more information, see the
[RNN API documentation](https://keras.io/api/layers/recurrent_layers/).
"""
"""
## Outputs and states
By default, the output of a RNN layer contains a single vector per sample. This vector
is the RNN cell output corresponding to the last timestep, containing information
about the entire input sequence. The shape of this output is `(batch_size, units)`
where `units` corresponds to the `units` argument passed to the layer's constructor.
A RNN layer can also return the entire sequence of outputs for each sample (one vector
per timestep per sample), if you set `return_sequences=True`. The shape of this output
is `(batch_size, timesteps, units)`.
"""
model = keras.Sequential()
model.add(layers.Embedding(input_dim=1000, output_dim=64))
# The output of GRU will be a 3D tensor of shape (batch_size, timesteps, 256)
model.add(layers.GRU(256, return_sequences=True))
# The output of SimpleRNN will be a 2D tensor of shape (batch_size, 128)
model.add(layers.SimpleRNN(128))
model.add(layers.Dense(10))
model.summary()
"""
In addition, a RNN layer can return its final internal state(s). The returned states
can be used to resume the RNN execution later, or
[to initialize another RNN](https://arxiv.org/abs/1409.3215).
This setting is commonly used in the
encoder-decoder sequence-to-sequence model, where the encoder final state is used as
the initial state of the decoder.
To configure a RNN layer to return its internal state, set the `return_state` parameter
to `True` when creating the layer. Note that `LSTM` has 2 state tensors, but `GRU`
only has one.
To configure the initial state of the layer, just call the layer with additional
keyword argument `initial_state`.
Note that the shape of the state needs to match the unit size of the layer, like in the
example below.
"""
encoder_vocab = 1000
decoder_vocab = 2000
encoder_input = layers.Input(shape=(None,))
encoder_embedded = layers.Embedding(input_dim=encoder_vocab, output_dim=64)(
encoder_input
)
# Return states in addition to output
output, state_h, state_c = layers.LSTM(64, return_state=True, name="encoder")(
encoder_embedded
)
encoder_state = [state_h, state_c]
decoder_input = layers.Input(shape=(None,))
decoder_embedded = layers.Embedding(input_dim=decoder_vocab, output_dim=64)(
decoder_input
)
# Pass the 2 states to a new LSTM layer, as initial state
decoder_output = layers.LSTM(64, name="decoder")(
decoder_embedded, initial_state=encoder_state
)
output = layers.Dense(10)(decoder_output)
model = keras.Model([encoder_input, decoder_input], output)
model.summary()
"""
## RNN layers and RNN cells
In addition to the built-in RNN layers, the RNN API also provides cell-level APIs.
Unlike RNN layers, which processes whole batches of input sequences, the RNN cell only
processes a single timestep.
The cell is the inside of the `for` loop of a RNN layer. Wrapping a cell inside a
`keras.layers.RNN` layer gives you a layer capable of processing batches of
sequences, e.g. `RNN(LSTMCell(10))`.
Mathematically, `RNN(LSTMCell(10))` produces the same result as `LSTM(10)`. In fact,
the implementation of this layer in TF v1.x was just creating the corresponding RNN
cell and wrapping it in a RNN layer. However using the built-in `GRU` and `LSTM`
layers enable the use of CuDNN and you may see better performance.
There are three built-in RNN cells, each of them corresponding to the matching RNN
layer.
- `keras.layers.SimpleRNNCell` corresponds to the `SimpleRNN` layer.
- `keras.layers.GRUCell` corresponds to the `GRU` layer.
- `keras.layers.LSTMCell` corresponds to the `LSTM` layer.
The cell abstraction, together with the generic `keras.layers.RNN` class, make it
very easy to implement custom RNN architectures for your research.
"""
"""
## Cross-batch statefulness
When processing very long sequences (possibly infinite), you may want to use the
pattern of **cross-batch statefulness**.
Normally, the internal state of a RNN layer is reset every time it sees a new batch
(i.e. every sample seen by the layer is assumed to be independent of the past). The
layer will only maintain a state while processing a given sample.
If you have very long sequences though, it is useful to break them into shorter
sequences, and to feed these shorter sequences sequentially into a RNN layer without
resetting the layer's state. That way, the layer can retain information about the
entirety of the sequence, even though it's only seeing one sub-sequence at a time.
You can do this by setting `stateful=True` in the constructor.
If you have a sequence `s = [t0, t1, ... t1546, t1547]`, you would split it into e.g.
```
s1 = [t0, t1, ... t100]
s2 = [t101, ... t201]
...
s16 = [t1501, ... t1547]
```
Then you would process it via:
```python
lstm_layer = layers.LSTM(64, stateful=True)
for s in sub_sequences:
output = lstm_layer(s)
```
When you want to clear the state, you can use `layer.reset_states()`.
> Note: In this setup, sample `i` in a given batch is assumed to be the continuation of
sample `i` in the previous batch. This means that all batches should contain the same
number of samples (batch size). E.g. if a batch contains `[sequence_A_from_t0_to_t100,
sequence_B_from_t0_to_t100]`, the next batch should contain
`[sequence_A_from_t101_to_t200, sequence_B_from_t101_to_t200]`.
Here is a complete example:
"""
paragraph1 = np.random.random((20, 10, 50)).astype(np.float32)
paragraph2 = np.random.random((20, 10, 50)).astype(np.float32)
paragraph3 = np.random.random((20, 10, 50)).astype(np.float32)
lstm_layer = layers.LSTM(64, stateful=True)
output = lstm_layer(paragraph1)
output = lstm_layer(paragraph2)
output = lstm_layer(paragraph3)
# reset_states() will reset the cached state to the original initial_state.
# If no initial_state was provided, zero-states will be used by default.
lstm_layer.reset_states()
"""
### RNN State Reuse
<a id="rnn_state_reuse"></a>
"""
"""
The recorded states of the RNN layer are not included in the `layer.weights()`. If you
would like to reuse the state from a RNN layer, you can retrieve the states value by
`layer.states` and use it as the
initial state for a new layer via the Keras functional API like `new_layer(inputs,
initial_state=layer.states)`, or model subclassing.
Please also note that sequential model might not be used in this case since it only
supports layers with single input and output, the extra input of initial state makes
it impossible to use here.
"""
paragraph1 = np.random.random((20, 10, 50)).astype(np.float32)
paragraph2 = np.random.random((20, 10, 50)).astype(np.float32)
paragraph3 = np.random.random((20, 10, 50)).astype(np.float32)
lstm_layer = layers.LSTM(64, stateful=True)
output = lstm_layer(paragraph1)
output = lstm_layer(paragraph2)
existing_state = lstm_layer.states
new_lstm_layer = layers.LSTM(64)
new_output = new_lstm_layer(paragraph3, initial_state=existing_state)
"""
## Bidirectional RNNs
For sequences other than time series (e.g. text), it is often the case that a RNN model
can perform better if it not only processes sequence from start to end, but also
backwards. For example, to predict the next word in a sentence, it is often useful to
have the context around the word, not only just the words that come before it.
Keras provides an easy API for you to build such bidirectional RNNs: the
`keras.layers.Bidirectional` wrapper.
"""
model = keras.Sequential()
model.add(
layers.Bidirectional(layers.LSTM(64, return_sequences=True), input_shape=(5, 10))
)
model.add(layers.Bidirectional(layers.LSTM(32)))
model.add(layers.Dense(10))
model.summary()
"""
Under the hood, `Bidirectional` will copy the RNN layer passed in, and flip the
`go_backwards` field of the newly copied layer, so that it will process the inputs in
reverse order.
The output of the `Bidirectional` RNN will be, by default, the concatenation of the forward layer
output and the backward layer output. If you need a different merging behavior, e.g.
concatenation, change the `merge_mode` parameter in the `Bidirectional` wrapper
constructor. For more details about `Bidirectional`, please check
[the API docs](https://keras.io/api/layers/recurrent_layers/bidirectional/).
"""
"""
## Performance optimization and CuDNN kernels
In TensorFlow 2.0, the built-in LSTM and GRU layers have been updated to leverage CuDNN
kernels by default when a GPU is available. With this change, the prior
`keras.layers.CuDNNLSTM/CuDNNGRU` layers have been deprecated, and you can build your
model without worrying about the hardware it will run on.
Since the CuDNN kernel is built with certain assumptions, this means the layer **will
not be able to use the CuDNN kernel if you change the defaults of the built-in LSTM or
GRU layers**. E.g.:
- Changing the `activation` function from `tanh` to something else.
- Changing the `recurrent_activation` function from `sigmoid` to something else.
- Using `recurrent_dropout` > 0.
- Setting `unroll` to True, which forces LSTM/GRU to decompose the inner
`tf.while_loop` into an unrolled `for` loop.
- Setting `use_bias` to False.
- Using masking when the input data is not strictly right padded (if the mask
corresponds to strictly right padded data, CuDNN can still be used. This is the most
common case).
For the detailed list of constraints, please see the documentation for the
[LSTM](https://keras.io/api/layers/recurrent_layers/lstm/) and
[GRU](https://keras.io/api/layers/recurrent_layers/gru/) layers.
"""
"""
### Using CuDNN kernels when available
Let's build a simple LSTM model to demonstrate the performance difference.
We'll use as input sequences the sequence of rows of MNIST digits (treating each row of
pixels as a timestep), and we'll predict the digit's label.
"""
batch_size = 64
# Each MNIST image batch is a tensor of shape (batch_size, 28, 28).
# Each input sequence will be of size (28, 28) (height is treated like time).
input_dim = 28
units = 64
output_size = 10 # labels are from 0 to 9
# Build the RNN model
def build_model(allow_cudnn_kernel=True):
# CuDNN is only available at the layer level, and not at the cell level.
# This means `LSTM(units)` will use the CuDNN kernel,
# while RNN(LSTMCell(units)) will run on non-CuDNN kernel.
if allow_cudnn_kernel:
# The LSTM layer with default options uses CuDNN.
lstm_layer = keras.layers.LSTM(units, input_shape=(None, input_dim))
else:
# Wrapping a LSTMCell in a RNN layer will not use CuDNN.
lstm_layer = keras.layers.RNN(
keras.layers.LSTMCell(units), input_shape=(None, input_dim)
)
model = keras.models.Sequential(
[
lstm_layer,
keras.layers.BatchNormalization(),
keras.layers.Dense(output_size),
]
)
return model
"""
Let's load the MNIST dataset:
"""
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
sample, sample_label = x_train[0], y_train[0]
"""
Let's create a model instance and train it.
We choose `sparse_categorical_crossentropy` as the loss function for the model. The
output of the model has shape of `[batch_size, 10]`. The target for the model is an
integer vector, each of the integer is in the range of 0 to 9.
"""
model = build_model(allow_cudnn_kernel=True)
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer="sgd",
metrics=["accuracy"],
)
model.fit(
x_train, y_train, validation_data=(x_test, y_test), batch_size=batch_size, epochs=1
)
"""
Now, let's compare to a model that does not use the CuDNN kernel:
"""
noncudnn_model = build_model(allow_cudnn_kernel=False)
noncudnn_model.set_weights(model.get_weights())
noncudnn_model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer="sgd",
metrics=["accuracy"],
)
noncudnn_model.fit(
x_train, y_train, validation_data=(x_test, y_test), batch_size=batch_size, epochs=1
)
"""
When running on a machine with a NVIDIA GPU and CuDNN installed,
the model built with CuDNN is much faster to train compared to the
model that uses the regular TensorFlow kernel.
The same CuDNN-enabled model can also be used to run inference in a CPU-only
environment. The `tf.device` annotation below is just forcing the device placement.
The model will run on CPU by default if no GPU is available.
You simply don't have to worry about the hardware you're running on anymore. Isn't that
pretty cool?
"""
import matplotlib.pyplot as plt
with tf.device("CPU:0"):
cpu_model = build_model(allow_cudnn_kernel=True)
cpu_model.set_weights(model.get_weights())
result = tf.argmax(cpu_model.predict_on_batch(tf.expand_dims(sample, 0)), axis=1)
print(
"Predicted result is: %s, target result is: %s" % (result.numpy(), sample_label)
)
plt.imshow(sample, cmap=plt.get_cmap("gray"))
"""
## RNNs with list/dict inputs, or nested inputs
Nested structures allow implementers to include more information within a single
timestep. For example, a video frame could have audio and video input at the same
time. The data shape in this case could be:
`[batch, timestep, {"video": [height, width, channel], "audio": [frequency]}]`
In another example, handwriting data could have both coordinates x and y for the
current position of the pen, as well as pressure information. So the data
representation could be:
`[batch, timestep, {"location": [x, y], "pressure": [force]}]`
The following code provides an example of how to build a custom RNN cell that accepts
such structured inputs.
"""
"""
### Define a custom cell that supports nested input/output
"""
"""
See [Making new Layers & Models via subclassing](/guides/making_new_layers_and_models_via_subclassing/)
for details on writing your own layers.
"""
@keras.saving.register_keras_serializable()
class NestedCell(keras.layers.Layer):
def __init__(self, unit_1, unit_2, unit_3, **kwargs):
self.unit_1 = unit_1
self.unit_2 = unit_2
self.unit_3 = unit_3
self.state_size = [tf.TensorShape([unit_1]), tf.TensorShape([unit_2, unit_3])]
self.output_size = [tf.TensorShape([unit_1]), tf.TensorShape([unit_2, unit_3])]
super().__init__(**kwargs)
def build(self, input_shapes):
# expect input_shape to contain 2 items, [(batch, i1), (batch, i2, i3)]
i1 = input_shapes[0][1]
i2 = input_shapes[1][1]
i3 = input_shapes[1][2]
self.kernel_1 = self.add_weight(
shape=(i1, self.unit_1), initializer="uniform", name="kernel_1"
)
self.kernel_2_3 = self.add_weight(
shape=(i2, i3, self.unit_2, self.unit_3),
initializer="uniform",
name="kernel_2_3",
)
def call(self, inputs, states):
# inputs should be in [(batch, input_1), (batch, input_2, input_3)]
# state should be in shape [(batch, unit_1), (batch, unit_2, unit_3)]
input_1, input_2 = tf.nest.flatten(inputs)
s1, s2 = states
output_1 = tf.matmul(input_1, self.kernel_1)
output_2_3 = tf.einsum("bij,ijkl->bkl", input_2, self.kernel_2_3)
state_1 = s1 + output_1
state_2_3 = s2 + output_2_3
output = (output_1, output_2_3)
new_states = (state_1, state_2_3)
return output, new_states
def get_config(self):
return {"unit_1": self.unit_1, "unit_2": self.unit_2, "unit_3": self.unit_3}
"""
### Build a RNN model with nested input/output
Let's build a Keras model that uses a `keras.layers.RNN` layer and the custom cell
we just defined.
"""
unit_1 = 10
unit_2 = 20
unit_3 = 30
i1 = 32
i2 = 64
i3 = 32
batch_size = 64
num_batches = 10
timestep = 50
cell = NestedCell(unit_1, unit_2, unit_3)
rnn = keras.layers.RNN(cell)
input_1 = keras.Input((None, i1))
input_2 = keras.Input((None, i2, i3))
outputs = rnn((input_1, input_2))
model = keras.models.Model([input_1, input_2], outputs)
model.compile(optimizer="adam", loss="mse", metrics=["accuracy"])
"""
### Train the model with randomly generated data
Since there isn't a good candidate dataset for this model, we use random Numpy data for
demonstration.
"""
input_1_data = np.random.random((batch_size * num_batches, timestep, i1))
input_2_data = np.random.random((batch_size * num_batches, timestep, i2, i3))
target_1_data = np.random.random((batch_size * num_batches, unit_1))
target_2_data = np.random.random((batch_size * num_batches, unit_2, unit_3))
input_data = [input_1_data, input_2_data]
target_data = [target_1_data, target_2_data]
model.fit(input_data, target_data, batch_size=batch_size)
"""
With the Keras `keras.layers.RNN` layer, You are only expected to define the math
logic for individual step within the sequence, and the `keras.layers.RNN` layer
will handle the sequence iteration for you. It's an incredibly powerful way to quickly
prototype new kinds of RNNs (e.g. a LSTM variant).
For more details, please visit the [API docs](https://keras.io/api/layers/recurrent_layers/rnn/).
"""