Skip to content

Commit

Permalink
Add method to export the "final" metadata string in MetadataWriter
Browse files Browse the repository at this point in the history
Extra fields could be filled by MetadataPopulator, such as min_parser_version. Added the method `get_populated_metadata_json` to return the final metadata string.

PiperOrigin-RevId: 406744227
  • Loading branch information
lu-wang-g authored and tflite-support-robot committed Nov 1, 2021
1 parent 3e1653d commit c1cde40
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(self,
self._model_buffer = model_buffer
self._metadata_buffer = metadata_buffer
self._associated_files = associated_files if associated_files else []
self._populated_model_buffer = None

@classmethod
def create_from_metadata_info(
Expand Down Expand Up @@ -156,19 +157,45 @@ def populate(self) -> bytearray:
Returns:
A new model buffer with the metadata and associated files.
"""
if self._populated_model_buffer:
return self._populated_model_buffer

populator = _metadata.MetadataPopulator.with_model_buffer(
self._model_buffer)
if self._model_buffer is not None:
populator.load_metadata_buffer(self._metadata_buffer)
if self._associated_files:
populator.load_associated_files(self._associated_files)
populator.populate()
return populator.get_model_buffer()
self._populated_model_buffer = populator.get_model_buffer()
return self._populated_model_buffer

def get_metadata_json(self) -> str:
"""Gets the generated metadata string in JSON format."""
"""Gets the generated JSON metadata string before populated into model.
This method returns the metadata buffer before populated into the model.
More fields could be filled by MetadataPopulator, such as
min_parser_version. Use get_populated_metadata_json() if you want to get the
final metadata string.
Returns:
The generated JSON metadata string before populated into model.
"""
return _metadata.convert_to_json(bytes(self._metadata_buffer))

def get_populated_metadata_json(self) -> str:
"""Gets the generated JSON metadata string after populated into model.
More fields could be filled by MetadataPopulator, such as
min_parser_version. Use get_metadata_json() if you want to get the
original metadata string.
Returns:
The generated JSON metadata string after populated into model.
"""
displayer = _metadata.MetadataDisplayer.with_model_buffer(self.populate())
return displayer.get_metadata_json()


# If tensor name in metadata is empty, default to the tensor name saved in
# the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,19 @@ def test_get_metadata_json_should_succeed(self):
expected_json = test_utils.load_file(_EXPECTED_DUMMY_NO_VERSION_JSON, "r")
self.assertEqual(metadata_json, expected_json)

def test_get_populated_metadata_json_should_succeed(self):
model_buffer = test_utils.load_file(_MODEL)
model_metadata, input_metadata, output_metadata = (
self._create_dummy_metadata())

writer = metadata_writer.MetadataWriter.create_from_metadata(
model_buffer, model_metadata, [input_metadata], [output_metadata],
[_LABEL_FILE])
metadata_json = writer.get_populated_metadata_json()

expected_json = test_utils.load_file(_EXPECTED_DUMMY_JSON, "r")
self.assertEqual(metadata_json, expected_json)

def _assert_correct_metadata(self,
model_with_metadata,
expected_json_file,
Expand Down

0 comments on commit c1cde40

Please sign in to comment.