forked from keras-team/tf-keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_lib.py
593 lines (515 loc) · 23.1 KB
/
export_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library for exporting inference-only TF-Keras models/layers."""
import tensorflow.compat.v2 as tf
from tensorflow.python.util.tf_export import keras_export
from tf_keras.engine import base_layer
from tf_keras.engine import functional
from tf_keras.engine import sequential
from tf_keras.utils import io_utils
@keras_export("keras.export.ExportArchive")
class ExportArchive(tf.__internal__.tracking.AutoTrackable):
"""ExportArchive is used to write SavedModel artifacts (e.g. for inference).
If you have a TF-Keras model or layer that you want to export as SavedModel
for serving (e.g. via TensorFlow-Serving), you can use `ExportArchive`
to configure the different serving endpoints you need to make available,
as well as their signatures. Simply instantiate an `ExportArchive`,
use `track()` to register the layer(s) or model(s) to be used,
then use the `add_endpoint()` method to register a new serving endpoint.
When done, use the `write_out()` method to save the artifact.
The resulting artifact is a SavedModel and can be reloaded via
`tf.saved_model.load`.
Examples:
Here's how to export a model for inference.
```python
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
# Elsewhere, we can reload the artifact and serve it.
# The endpoint we added is available as a method:
serving_model = tf.saved_model.load("path/to/location")
outputs = serving_model.serve(inputs)
```
Here's how to export a model with one endpoint for inference and one
endpoint for a training-mode forward pass (e.g. with dropout on).
```python
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="call_inference",
fn=lambda x: model.call(x, training=False),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.add_endpoint(
name="call_training",
fn=lambda x: model.call(x, training=True),
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
export_archive.write_out("path/to/location")
```
**Note on resource tracking:**
`ExportArchive` is able to automatically track all `tf.Variables` used
by its endpoints, so most of the time calling `.track(model)`
is not strictly required. However, if your model uses lookup layers such
as `IntegerLookup`, `StringLookup`, or `TextVectorization`,
it will need to be tracked explicitly via `.track(model)`.
Explicit tracking is also required if you need to be able to access
the properties `variables`, `trainable_variables`, or
`non_trainable_variables` on the revived archive.
"""
def __init__(self):
self._endpoint_names = []
self._endpoint_signatures = {}
self.tensorflow_version = tf.__version__
self.variables = []
self.trainable_variables = []
self.non_trainable_variables = []
@tf.__internal__.tracking.no_automatic_dependency_tracking
def track(self, resource):
"""Track the variables (and other assets) of a layer or model."""
if not isinstance(resource, tf.__internal__.tracking.Trackable):
raise ValueError(
"Invalid resource type. Expected an instance of a "
"TensorFlow `Trackable` (such as a TF-Keras "
"`Layer` or `Model`). "
f"Received instead an object of type '{type(resource)}'. "
f"Object received: {resource}"
)
if isinstance(resource, base_layer.Layer):
if not resource.built:
raise ValueError(
"The layer provided has not yet been built. "
"It must be built before export."
)
# Layers in `_tracked` are not part of the trackables that get saved,
# because we're creating the attribute in a
# no_automatic_dependency_tracking scope.
if not hasattr(self, "_tracked"):
self._tracked = []
self._tracked.append(resource)
if isinstance(resource, base_layer.Layer):
# Variables in the lists below are actually part of the trackables
# that get saved, because the lists are created in __init__.
self.variables += resource.variables
self.trainable_variables += resource.trainable_variables
self.non_trainable_variables += resource.non_trainable_variables
def add_endpoint(self, name, fn, input_signature=None):
"""Register a new serving endpoint.
Arguments:
name: Str, name of the endpoint.
fn: A function. It should only leverage resources
(e.g. `tf.Variable` objects or `tf.lookup.StaticHashTable`
objects) that are available on the models/layers
tracked by the `ExportArchive` (you can call `.track(model)`
to track a new model).
The shape and dtype of the inputs to the function must be
known. For that purpose, you can either 1) make sure that
`fn` is a `tf.function` that has been called at least once, or
2) provide an `input_signature` argument that specifies the
shape and dtype of the inputs (see below).
input_signature: Used to specify the shape and dtype of the
inputs to `fn`. List of `tf.TensorSpec` objects (one
per positional input argument of `fn`). Nested arguments are
allowed (see below for an example showing a Functional model
with 2 input arguments).
Example:
Adding an endpoint using the `input_signature` argument when the
model has a single input argument:
```python
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
```
Adding an endpoint using the `input_signature` argument when the
model has two positional input arguments:
```python
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
],
)
```
Adding an endpoint using the `input_signature` argument when the
model has one input argument that is a list of 2 tensors (e.g.
a Functional model with 2 inputs):
```python
model = keras.Model(inputs=[x1, x2], outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
[
tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
],
],
)
```
This also works with dictionary inputs:
```python
model = keras.Model(inputs={"x1": x1, "x2": x2}, outputs=outputs)
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[
{
"x1": tf.TensorSpec(shape=(None, 3), dtype=tf.float32),
"x2": tf.TensorSpec(shape=(None, 4), dtype=tf.float32),
},
],
)
```
Adding an endpoint that is a `tf.function`:
```python
@tf.function()
def serving_fn(x):
return model(x)
# The function must be traced, i.e. it must be called at least once.
serving_fn(tf.random.normal(shape=(2, 3)))
export_archive = ExportArchive()
export_archive.track(model)
export_archive.add_endpoint(name="serve", fn=serving_fn)
```
"""
if name in self._endpoint_names:
raise ValueError(f"Endpoint name '{name}' is already taken.")
if input_signature:
decorated_fn = tf.function(fn, input_signature=input_signature)
self._endpoint_signatures[name] = input_signature
else:
if isinstance(fn, tf.types.experimental.GenericFunction):
if not fn._list_all_concrete_functions():
raise ValueError(
f"The provided tf.function '{fn}' "
"has never been called. "
"To specify the expected shape and dtype "
"of the function's arguments, "
"you must either provide a function that "
"has been called at least once, or alternatively pass "
"an `input_signature` argument in `add_endpoint()`."
)
decorated_fn = fn
else:
raise ValueError(
"If the `fn` argument provided is not a `tf.function`, "
"you must provide an `input_signature` argument to "
"specify the shape and dtype of the function arguments. "
"Example:\n\n"
"export_archive.add_endpoint(\n"
" name='call',\n"
" fn=model.call,\n"
" input_signature=[\n"
" tf.TensorSpec(\n"
" shape=(None, 224, 224, 3),\n"
" dtype=tf.float32,\n"
" )\n"
" ],\n"
")"
)
setattr(self, name, decorated_fn)
self._endpoint_names.append(name)
def add_variable_collection(self, name, variables):
"""Register a set of variables to be retrieved after reloading.
Arguments:
name: The string name for the collection.
variables: A tuple/list/set of `tf.Variable` instances.
Example:
```python
export_archive = ExportArchive()
export_archive.track(model)
# Register an endpoint
export_archive.add_endpoint(
name="serve",
fn=model.call,
input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],
)
# Save a variable collection
export_archive.add_variable_collection(
name="optimizer_variables", variables=model.optimizer.variables)
export_archive.write_out("path/to/location")
# Reload the object
revived_object = tf.saved_model.load("path/to/location")
# Retrieve the variables
optimizer_variables = revived_object.optimizer_variables
```
"""
if not isinstance(variables, (list, tuple, set)):
raise ValueError(
"Expected `variables` to be a list/tuple/set. "
f"Received instead object of type '{type(variables)}'."
)
if not all(isinstance(v, tf.Variable) for v in variables):
raise ValueError(
"Expected all elements in `variables` to be "
"`tf.Variable` instances. Found instead the following types: "
f"{list(set(type(v) for v in variables))}"
)
setattr(self, name, list(variables))
def write_out(self, filepath, options=None):
"""Write the corresponding SavedModel to disk.
Arguments:
filepath: `str` or `pathlib.Path` object.
Path where to save the artifact.
options: `tf.saved_model.SaveOptions` object that specifies
SavedModel saving options.
**Note on TF-Serving**: all endpoints registered via `add_endpoint()`
are made visible for TF-Serving in the SavedModel artifact. In addition,
the first endpoint registered is made visible under the alias
`"serving_default"` (unless an endpoint with the name
`"serving_default"` was already registered manually),
since TF-Serving requires this endpoint to be set.
"""
if not self._endpoint_names:
raise ValueError(
"No endpoints have been set yet. Call add_endpoint()."
)
self._filter_and_track_resources()
signatures = {}
for name in self._endpoint_names:
signatures[name] = self._get_concrete_fn(name)
# Add "serving_default" signature key for TFServing
if "serving_default" not in self._endpoint_names:
signatures["serving_default"] = self._get_concrete_fn(
self._endpoint_names[0]
)
tf.saved_model.save(
self, filepath, options=options, signatures=signatures
)
# Print out available endpoints
endpoints = "\n\n".join(
_print_signature(getattr(self, name), name)
for name in self._endpoint_names
)
io_utils.print_msg(
f"Saved artifact at '{filepath}'. "
"The following endpoints are available:\n\n"
f"{endpoints}"
)
def _get_concrete_fn(self, endpoint):
"""Workaround for some SavedModel quirks."""
if endpoint in self._endpoint_signatures:
return getattr(self, endpoint)
else:
traces = getattr(self, endpoint)._trackable_children("saved_model")
return list(traces.values())[0]
def _get_variables_used_by_endpoints(self):
fns = [self._get_concrete_fn(name) for name in self._endpoint_names]
return _list_variables_used_by_fns(fns)
def _filter_and_track_resources(self):
"""Track resources used by endpoints / referenced in `track()` calls."""
# Start by extracting variables from endpoints.
fns = [self._get_concrete_fn(name) for name in self._endpoint_names]
tvs, ntvs = _list_variables_used_by_fns(fns)
self._all_variables = list(tvs + ntvs)
# Next, track lookup tables.
# Hopefully, one day this will be automated at the tf.function level.
self._misc_assets = []
# isort: off
from tensorflow.python.ops.lookup_ops import (
LookupInterface,
)
from tensorflow.python.trackable.resource import (
RestoredResource,
)
from tf_keras.layers.preprocessing.index_lookup import IndexLookup
if hasattr(self, "_tracked"):
for root in self._tracked:
descendants = tf.train.TrackableView(root).descendants()
for trackable in descendants:
if isinstance(
trackable,
(IndexLookup, LookupInterface, RestoredResource),
):
self._misc_assets.append(trackable)
def export_model(model, filepath):
export_archive = ExportArchive()
export_archive.track(model)
if isinstance(model, (functional.Functional, sequential.Sequential)):
input_signature = tf.nest.map_structure(_make_tensor_spec, model.inputs)
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
export_archive.add_endpoint("serve", model.__call__, input_signature)
else:
save_spec = model._get_save_spec()
if not save_spec:
raise ValueError(
"The model provided has never called. "
"It must be called at least once before export."
)
input_signature = [save_spec]
export_archive.add_endpoint("serve", model.__call__, input_signature)
export_archive.write_out(filepath)
class ReloadedLayer(base_layer.Layer):
"""Reload a TF-Keras model/layer saved via SavedModel / ExportArchive.
Arguments:
filepath: `str` or `pathlib.Path` object. The path to the SavedModel.
call_endpoint: Name of the endpoint to use as the `call()` method
of the reloaded layer. If the SavedModel was created
via `model.export()`,
then the default endpoint name is `'serve'`. In other cases
it may be named `'serving_default'`.
Example:
```python
model.export("path/to/artifact")
reloaded_layer = ReloadedLayer("path/to/artifact")
outputs = reloaded_layer(inputs)
```
The reloaded object can be used like a regular TF-Keras layer, and supports
training/fine-tuning of its trainable weights. Note that the reloaded
object retains none of the internal structure or custom methods of the
original object -- it's a brand new layer created around the saved
function.
**Limitations:**
* Only call endpoints with a single `inputs` tensor argument
(which may optionally be a dict/tuple/list of tensors) are supported.
For endpoints with multiple separate input tensor arguments, consider
subclassing `ReloadedLayer` and implementing a `call()` method with a
custom signature.
* If you need training-time behavior to differ from inference-time behavior
(i.e. if you need the reloaded object to support a `training=True` argument
in `__call__()`), make sure that the training-time call function is
saved as a standalone endpoint in the artifact, and provide its name
to the `ReloadedLayer` via the `call_training_endpoint` argument.
"""
def __init__(
self,
filepath,
call_endpoint="serve",
call_training_endpoint=None,
trainable=True,
name=None,
dtype=None,
):
# Initialize an empty layer, then add_weight() etc. as needed.
super().__init__(trainable=trainable, name=name, dtype=dtype)
self._reloaded_obj = tf.saved_model.load(filepath)
self.filepath = filepath
self.call_endpoint = call_endpoint
self.call_training_endpoint = call_training_endpoint
# Resolve the call function.
if hasattr(self._reloaded_obj, call_endpoint):
# Case 1: it's set as an attribute.
self.call_endpoint_fn = getattr(self._reloaded_obj, call_endpoint)
elif call_endpoint in self._reloaded_obj.signatures:
# Case 2: it's listed in the `signatures` field.
self.call_endpoint_fn = self._reloaded_obj.signatures[call_endpoint]
else:
raise ValueError(
f"The endpoint '{call_endpoint}' is neither an "
"attribute of the reloaded SavedModel, nor an entry "
"in the `signatures` field of the reloaded SavedModel. "
)
# Resolving the training function.
if call_training_endpoint:
if hasattr(self._reloaded_obj, call_training_endpoint):
self.call_training_endpoint_fn = getattr(
self._reloaded_obj, call_training_endpoint
)
elif call_training_endpoint in self._reloaded_obj.signatures:
self.call_training_endpoint_fn = self._reloaded_obj.signatures[
call_training_endpoint
]
else:
raise ValueError(
f"The endpoint '{call_training_endpoint}' is "
"neither an attribute of the reloaded SavedModel, "
"nor an entry in the `signatures` field of "
"the reloaded SavedModel. "
)
# Add trainable and non-trainable weights from the call_endpoint_fn.
all_fns = [self.call_endpoint_fn]
if call_training_endpoint:
all_fns.append(self.call_training_endpoint_fn)
tvs, ntvs = _list_variables_used_by_fns(all_fns)
for v in tvs:
self._add_existing_weight(v, trainable=True)
for v in ntvs:
self._add_existing_weight(v, trainable=False)
self.built = True
def _add_existing_weight(self, weight, trainable):
"""Calls add_weight() to register but not create an existing weight."""
self.add_weight(
name=weight.name,
shape=weight.shape,
dtype=weight.dtype,
trainable=trainable,
getter=lambda *_, **__: weight,
)
def call(self, inputs, training=False, **kwargs):
if training:
if self.call_training_endpoint:
return self.call_training_endpoint_fn(inputs, **kwargs)
return self.call_endpoint_fn(inputs, **kwargs)
def get_config(self):
base_config = super().get_config()
config = {
# Note: this is not intended to be portable.
"filepath": self.filepath,
"call_endpoint": self.call_endpoint,
"call_training_endpoint": self.call_training_endpoint,
}
return {**base_config, **config}
def _make_tensor_spec(x):
return tf.TensorSpec(x.shape, dtype=x.dtype, name=x.name)
def _print_signature(fn, name):
concrete_fn = fn._list_all_concrete_functions()[0]
pprinted_signature = concrete_fn.pretty_printed_signature(verbose=True)
lines = pprinted_signature.split("\n")
lines = [f"* Endpoint '{name}'"] + lines[1:]
endpoint = "\n".join(lines)
return endpoint
def _list_variables_used_by_fns(fns):
trainable_variables = []
non_trainable_variables = []
trainable_variables_ids = set()
non_trainable_variables_ids = set()
for fn in fns:
if hasattr(fn, "concrete_functions"):
concrete_functions = fn.concrete_functions
elif hasattr(fn, "get_concrete_function"):
concrete_functions = [fn.get_concrete_function()]
else:
concrete_functions = [fn]
for concrete_fn in concrete_functions:
for v in concrete_fn.trainable_variables:
if id(v) not in trainable_variables_ids:
trainable_variables.append(v)
trainable_variables_ids.add(id(v))
for v in concrete_fn.variables:
if (
id(v) not in trainable_variables_ids
and id(v) not in non_trainable_variables_ids
):
non_trainable_variables.append(v)
non_trainable_variables_ids.add(id(v))
return trainable_variables, non_trainable_variables