-
Notifications
You must be signed in to change notification settings - Fork 19.6k
/
Copy pathmaking_new_layers_and_models_via_subclassing.py
679 lines (512 loc) · 19.8 KB
/
making_new_layers_and_models_via_subclassing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
"""
Title: Making new layers and models via subclassing
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2019/03/01
Last modified: 2023/06/25
Description: Complete guide to writing `Layer` and `Model` objects from scratch.
Accelerator: None
"""
"""
## Introduction
This guide will cover everything you need to know to build your own
subclassed layers and models. In particular, you'll learn about the following features:
- The `Layer` class
- The `add_weight()` method
- Trainable and non-trainable weights
- The `build()` method
- Making sure your layers can be used with any backend
- The `add_loss()` method
- The `training` argument in `call()`
- The `mask` argument in `call()`
- Making sure your layers can be serialized
Let's dive in.
"""
"""
## Setup
"""
import numpy as np
import keras
from keras import ops
from keras import layers
"""
## The `Layer` class: the combination of state (weights) and some computation
One of the central abstractions in Keras is the `Layer` class. A layer
encapsulates both a state (the layer's "weights") and a transformation from
inputs to outputs (a "call", the layer's forward pass).
Here's a densely-connected layer. It has two state variables:
the variables `w` and `b`.
"""
class Linear(keras.layers.Layer):
def __init__(self, units=32, input_dim=32):
super().__init__()
self.w = self.add_weight(
shape=(input_dim, units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(units,), initializer="zeros", trainable=True
)
def call(self, inputs):
return ops.matmul(inputs, self.w) + self.b
"""
You would use a layer by calling it on some tensor input(s), much like a Python
function.
"""
x = ops.ones((2, 2))
linear_layer = Linear(4, 2)
y = linear_layer(x)
print(y)
"""
Note that the weights `w` and `b` are automatically tracked by the layer upon
being set as layer attributes:
"""
assert linear_layer.weights == [linear_layer.w, linear_layer.b]
"""
## Layers can have non-trainable weights
Besides trainable weights, you can add non-trainable weights to a layer as
well. Such weights are meant not to be taken into account during
backpropagation, when you are training the layer.
Here's how to add and use a non-trainable weight:
"""
class ComputeSum(keras.layers.Layer):
def __init__(self, input_dim):
super().__init__()
self.total = self.add_weight(
initializer="zeros", shape=(input_dim,), trainable=False
)
def call(self, inputs):
self.total.assign_add(ops.sum(inputs, axis=0))
return self.total
x = ops.ones((2, 2))
my_sum = ComputeSum(2)
y = my_sum(x)
print(y.numpy())
y = my_sum(x)
print(y.numpy())
"""
It's part of `layer.weights`, but it gets categorized as a non-trainable weight:
"""
print("weights:", len(my_sum.weights))
print("non-trainable weights:", len(my_sum.non_trainable_weights))
# It's not included in the trainable weights:
print("trainable_weights:", my_sum.trainable_weights)
"""
## Best practice: deferring weight creation until the shape of the inputs is known
Our `Linear` layer above took an `input_dim` argument that was used to compute
the shape of the weights `w` and `b` in `__init__()`:
"""
class Linear(keras.layers.Layer):
def __init__(self, units=32, input_dim=32):
super().__init__()
self.w = self.add_weight(
shape=(input_dim, units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(units,), initializer="zeros", trainable=True
)
def call(self, inputs):
return ops.matmul(inputs, self.w) + self.b
"""
In many cases, you may not know in advance the size of your inputs, and you
would like to lazily create weights when that value becomes known, some time
after instantiating the layer.
In the Keras API, we recommend creating layer weights in the
`build(self, inputs_shape)` method of your layer. Like this:
"""
class Linear(keras.layers.Layer):
def __init__(self, units=32):
super().__init__()
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,), initializer="random_normal", trainable=True
)
def call(self, inputs):
return ops.matmul(inputs, self.w) + self.b
"""
The `__call__()` method of your layer will automatically run build the first time
it is called. You now have a layer that's lazy and thus easier to use:
"""
# At instantiation, we don't know on what inputs this is going to get called
linear_layer = Linear(32)
# The layer's weights are created dynamically the first time the layer is called
y = linear_layer(x)
"""
Implementing `build()` separately as shown above nicely separates creating weights
only once from using weights in every call.
"""
"""
## Layers are recursively composable
If you assign a Layer instance as an attribute of another Layer, the outer layer
will start tracking the weights created by the inner layer.
We recommend creating such sublayers in the `__init__()` method and leave it to
the first `__call__()` to trigger building their weights.
"""
class MLPBlock(keras.layers.Layer):
def __init__(self):
super().__init__()
self.linear_1 = Linear(32)
self.linear_2 = Linear(32)
self.linear_3 = Linear(1)
def call(self, inputs):
x = self.linear_1(inputs)
x = keras.activations.relu(x)
x = self.linear_2(x)
x = keras.activations.relu(x)
return self.linear_3(x)
mlp = MLPBlock()
y = mlp(
ops.ones(shape=(3, 64))
) # The first call to the `mlp` will create the weights
print("weights:", len(mlp.weights))
print("trainable weights:", len(mlp.trainable_weights))
"""
## Backend-agnostic layers and backend-specific layers
As long as a layer only uses APIs from the `keras.ops` namespace
(or other Keras namespaces such as `keras.activations`, `keras.random`, or `keras.layers`),
then it can be used with any backend -- TensorFlow, JAX, or PyTorch.
All layers you've seen so far in this guide work with all Keras backends.
The `keras.ops` namespace gives you access to:
- The NumPy API, e.g. `ops.matmul`, `ops.sum`, `ops.reshape`, `ops.stack`, etc.
- Neural networks-specific APIs such as `ops.softmax`, `ops.conv`, `ops.binary_crossentropy`, `ops.relu`, etc.
You can also use backend-native APIs in your layers (such as `tf.nn` functions),
but if you do this, then your layer will only be usable with the backend in question.
For instance, you could write the following JAX-specific layer using `jax.numpy`:
```python
import jax
class Linear(keras.layers.Layer):
...
def call(self, inputs):
return jax.numpy.matmul(inputs, self.w) + self.b
```
This would be the equivalent TensorFlow-specific layer:
```python
import tensorflow as tf
class Linear(keras.layers.Layer):
...
def call(self, inputs):
return tf.matmul(inputs, self.w) + self.b
```
And this would be the equivalent PyTorch-specific layer:
```python
import torch
class Linear(keras.layers.Layer):
...
def call(self, inputs):
return torch.matmul(inputs, self.w) + self.b
```
Because cross-backend compatibility is a tremendously useful property, we strongly
recommend that you seek to always make your layers backend-agnostic by leveraging
only Keras APIs.
"""
"""
## The `add_loss()` method
When writing the `call()` method of a layer, you can create loss tensors that
you will want to use later, when writing your training loop. This is doable by
calling `self.add_loss(value)`:
"""
# A layer that creates an activity regularization loss
class ActivityRegularizationLayer(keras.layers.Layer):
def __init__(self, rate=1e-2):
super().__init__()
self.rate = rate
def call(self, inputs):
self.add_loss(self.rate * ops.mean(inputs))
return inputs
"""
These losses (including those created by any inner layer) can be retrieved via
`layer.losses`. This property is reset at the start of every `__call__()` to
the top-level layer, so that `layer.losses` always contains the loss values
created during the last forward pass.
"""
class OuterLayer(keras.layers.Layer):
def __init__(self):
super().__init__()
self.activity_reg = ActivityRegularizationLayer(1e-2)
def call(self, inputs):
return self.activity_reg(inputs)
layer = OuterLayer()
assert (
len(layer.losses) == 0
) # No losses yet since the layer has never been called
_ = layer(ops.zeros((1, 1)))
assert len(layer.losses) == 1 # We created one loss value
# `layer.losses` gets reset at the start of each __call__
_ = layer(ops.zeros((1, 1)))
assert len(layer.losses) == 1 # This is the loss created during the call above
"""
In addition, the `loss` property also contains regularization losses created
for the weights of any inner layer:
"""
class OuterLayerWithKernelRegularizer(keras.layers.Layer):
def __init__(self):
super().__init__()
self.dense = keras.layers.Dense(
32, kernel_regularizer=keras.regularizers.l2(1e-3)
)
def call(self, inputs):
return self.dense(inputs)
layer = OuterLayerWithKernelRegularizer()
_ = layer(ops.zeros((1, 1)))
# This is `1e-3 * sum(layer.dense.kernel ** 2)`,
# created by the `kernel_regularizer` above.
print(layer.losses)
"""
These losses are meant to be taken into account when writing custom training loops.
They also work seamlessly with `fit()` (they get automatically summed and added to the main loss, if any):
"""
inputs = keras.Input(shape=(3,))
outputs = ActivityRegularizationLayer()(inputs)
model = keras.Model(inputs, outputs)
# If there is a loss passed in `compile`, the regularization
# losses get added to it
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
# It's also possible not to pass any loss in `compile`,
# since the model already has a loss to minimize, via the `add_loss`
# call during the forward pass!
model.compile(optimizer="adam")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
"""
## You can optionally enable serialization on your layers
If you need your custom layers to be serializable as part of a
[Functional model](/guides/functional_api/), you can optionally implement a `get_config()`
method:
"""
class Linear(keras.layers.Layer):
def __init__(self, units=32):
super().__init__()
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,), initializer="random_normal", trainable=True
)
def call(self, inputs):
return ops.matmul(inputs, self.w) + self.b
def get_config(self):
return {"units": self.units}
# Now you can recreate the layer from its config:
layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)
"""
Note that the `__init__()` method of the base `Layer` class takes some keyword
arguments, in particular a `name` and a `dtype`. It's good practice to pass
these arguments to the parent class in `__init__()` and to include them in the
layer config:
"""
class Linear(keras.layers.Layer):
def __init__(self, units=32, **kwargs):
super().__init__(**kwargs)
self.units = units
def build(self, input_shape):
self.w = self.add_weight(
shape=(input_shape[-1], self.units),
initializer="random_normal",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,), initializer="random_normal", trainable=True
)
def call(self, inputs):
return ops.matmul(inputs, self.w) + self.b
def get_config(self):
config = super().get_config()
config.update({"units": self.units})
return config
layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)
"""
If you need more flexibility when deserializing the layer from its config, you
can also override the `from_config()` class method. This is the base
implementation of `from_config()`:
```python
def from_config(cls, config):
return cls(**config)
```
To learn more about serialization and saving, see the complete
[guide to saving and serializing models](/guides/serialization_and_saving/).
"""
"""
## Privileged `training` argument in the `call()` method
Some layers, in particular the `BatchNormalization` layer and the `Dropout`
layer, have different behaviors during training and inference. For such
layers, it is standard practice to expose a `training` (boolean) argument in
the `call()` method.
By exposing this argument in `call()`, you enable the built-in training and
evaluation loops (e.g. `fit()`) to correctly use the layer in training and
inference.
"""
class CustomDropout(keras.layers.Layer):
def __init__(self, rate, **kwargs):
super().__init__(**kwargs)
self.rate = rate
def call(self, inputs, training=None):
if training:
return keras.random.dropout(inputs, rate=self.rate)
return inputs
"""
## Privileged `mask` argument in the `call()` method
The other privileged argument supported by `call()` is the `mask` argument.
You will find it in all Keras RNN layers. A mask is a boolean tensor (one
boolean value per timestep in the input) used to skip certain input timesteps
when processing timeseries data.
Keras will automatically pass the correct `mask` argument to `__call__()` for
layers that support it, when a mask is generated by a prior layer.
Mask-generating layers are the `Embedding`
layer configured with `mask_zero=True`, and the `Masking` layer.
"""
"""
## The `Model` class
In general, you will use the `Layer` class to define inner computation blocks,
and will use the `Model` class to define the outer model -- the object you
will train.
For instance, in a ResNet50 model, you would have several ResNet blocks
subclassing `Layer`, and a single `Model` encompassing the entire ResNet50
network.
The `Model` class has the same API as `Layer`, with the following differences:
- It exposes built-in training, evaluation, and prediction loops
(`model.fit()`, `model.evaluate()`, `model.predict()`).
- It exposes the list of its inner layers, via the `model.layers` property.
- It exposes saving and serialization APIs (`save()`, `save_weights()`...)
Effectively, the `Layer` class corresponds to what we refer to in the
literature as a "layer" (as in "convolution layer" or "recurrent layer") or as
a "block" (as in "ResNet block" or "Inception block").
Meanwhile, the `Model` class corresponds to what is referred to in the
literature as a "model" (as in "deep learning model") or as a "network" (as in
"deep neural network").
So if you're wondering, "should I use the `Layer` class or the `Model` class?",
ask yourself: will I need to call `fit()` on it? Will I need to call `save()`
on it? If so, go with `Model`. If not (either because your class is just a block
in a bigger system, or because you are writing training & saving code yourself),
use `Layer`.
For instance, we could take our mini-resnet example above, and use it to build
a `Model` that we could train with `fit()`, and that we could save with
`save_weights()`:
"""
"""
```python
class ResNet(keras.Model):
def __init__(self, num_classes=1000):
super().__init__()
self.block_1 = ResNetBlock()
self.block_2 = ResNetBlock()
self.global_pool = layers.GlobalAveragePooling2D()
self.classifier = Dense(num_classes)
def call(self, inputs):
x = self.block_1(inputs)
x = self.block_2(x)
x = self.global_pool(x)
return self.classifier(x)
resnet = ResNet()
dataset = ...
resnet.fit(dataset, epochs=10)
resnet.save(filepath.keras)
```
"""
"""
## Putting it all together: an end-to-end example
Here's what you've learned so far:
- A `Layer` encapsulate a state (created in `__init__()` or `build()`) and some
computation (defined in `call()`).
- Layers can be recursively nested to create new, bigger computation blocks.
- Layers are backend-agnostic as long as they only use Keras APIs. You can use
backend-native APIs (such as `jax.numpy`, `torch.nn` or `tf.nn`), but then
your layer will only be usable with that specific backend.
- Layers can create and track losses (typically regularization losses)
via `add_loss()`.
- The outer container, the thing you want to train, is a `Model`. A `Model` is
just like a `Layer`, but with added training and serialization utilities.
Let's put all of these things together into an end-to-end example: we're going
to implement a Variational AutoEncoder (VAE) in a backend-agnostic fashion
-- so that it runs the same with TensorFlow, JAX, and PyTorch.
We'll train it on MNIST digits.
Our VAE will be a subclass of `Model`, built as a nested composition of layers
that subclass `Layer`. It will feature a regularization loss (KL divergence).
"""
class Sampling(layers.Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = ops.shape(z_mean)[0]
dim = ops.shape(z_mean)[1]
epsilon = keras.random.normal(shape=(batch, dim))
return z_mean + ops.exp(0.5 * z_log_var) * epsilon
class Encoder(layers.Layer):
"""Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""
def __init__(
self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs
):
super().__init__(name=name, **kwargs)
self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
self.dense_mean = layers.Dense(latent_dim)
self.dense_log_var = layers.Dense(latent_dim)
self.sampling = Sampling()
def call(self, inputs):
x = self.dense_proj(inputs)
z_mean = self.dense_mean(x)
z_log_var = self.dense_log_var(x)
z = self.sampling((z_mean, z_log_var))
return z_mean, z_log_var, z
class Decoder(layers.Layer):
"""Converts z, the encoded digit vector, back into a readable digit."""
def __init__(
self, original_dim, intermediate_dim=64, name="decoder", **kwargs
):
super().__init__(name=name, **kwargs)
self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
self.dense_output = layers.Dense(original_dim, activation="sigmoid")
def call(self, inputs):
x = self.dense_proj(inputs)
return self.dense_output(x)
class VariationalAutoEncoder(keras.Model):
"""Combines the encoder and decoder into an end-to-end model for training."""
def __init__(
self,
original_dim,
intermediate_dim=64,
latent_dim=32,
name="autoencoder",
**kwargs,
):
super().__init__(name=name, **kwargs)
self.original_dim = original_dim
self.encoder = Encoder(
latent_dim=latent_dim, intermediate_dim=intermediate_dim
)
self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
# Add KL divergence regularization loss.
kl_loss = -0.5 * ops.mean(
z_log_var - ops.square(z_mean) - ops.exp(z_log_var) + 1
)
self.add_loss(kl_loss)
return reconstructed
"""
Let's train it on MNIST using the `fit()` API:
"""
(x_train, _), _ = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
original_dim = 784
vae = VariationalAutoEncoder(784, 64, 32)
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
vae.compile(optimizer, loss=keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=2, batch_size=64)