Skip to content

Commit

Permalink
ServingConfig.extra_trackable_resources is now supported in the Orb…
Browse files Browse the repository at this point in the history
…ax Model path when `support_tf_resources` is `True`.

PiperOrigin-RevId: 699239199
  • Loading branch information
wangpengmit authored and Orbax Authors committed Nov 22, 2024
1 parent 922d408 commit 1cb2786
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 0 deletions.
8 changes: 8 additions & 0 deletions export/orbax/export/modules/obm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,14 @@ def _convert_jax_functions_to_obm_functions(
"""Converts the JAX functions to OrbaxModel functions."""
if serving_config.input_signature is None:
raise ValueError('serving_config.input_signature is required.')
if (
not support_tf_resources
and serving_config.extra_trackable_resources is not None
):
raise ValueError(
'serving_config.extra_trackable_resources can only be set when'
' support_tf_resources is True.'
)

def _maybe_set_orbax_checkpoint_path(self, jax2obm_kwargs):
if constants.CHECKPOINT_PATH not in jax2obm_kwargs:
Expand Down
1 change: 1 addition & 0 deletions export/orbax/export/modules/obm_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from orbax.export import constants
from orbax.export import serving_config as osc
from orbax.export.modules import obm_module
import tensorflow as tf


class ObmModuleTest(parameterized.TestCase):
Expand Down

0 comments on commit 1cb2786

Please sign in to comment.