Skip to content

Commit

Permalink
Metadata writers for tokenizers and InputTextTensor
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 363753312
  • Loading branch information
lu-wang-g authored and tflite-support-robot committed Mar 18, 2021
1 parent 676dce1 commit 3613853
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
_MIN_UINT8 = 0
_MAX_UINT8 = 255

# Default description for vocabulary files.
_VOCAB_FILE_DESCRIPTION = ("Vocabulary file to convert natural language "
"words to embedding vectors.")


class GeneralMd:
"""A container for common metadata information of a model.
Expand Down Expand Up @@ -79,7 +83,7 @@ class AssociatedFileMd:

def __init__(
self,
file_path: Optional[str] = None,
file_path: str,
description: Optional[str] = None,
file_type: Optional[_metadata_fb.AssociatedFileType] = _metadata_fb
.AssociatedFileType.UNKNOWN,
Expand Down Expand Up @@ -110,9 +114,7 @@ class LabelFileMd(AssociatedFileMd):
"recognize.")
_FILE_TYPE = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS

def __init__(self,
file_path: Optional[str] = None,
locale: Optional[str] = None):
def __init__(self, file_path: str, locale: Optional[str] = None):
"""Creates a LabelFileMd object.
Args:
Expand All @@ -125,6 +127,125 @@ def __init__(self,
locale)


class RegexTokenizerMd:
"""A container for the Regex tokenizer [1] metadata information.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L459
"""

def __init__(self, delim_regex_pattern: str, vocab_file_path: str):
"""Initializes a RegexTokenizerMd object.
Args:
delim_regex_pattern: the regular expression to segment strings and create
tokens.
vocab_file_path: path to the vocabulary file.
"""
self._delim_regex_pattern = delim_regex_pattern
self._vocab_file_path = vocab_file_path

def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the Bert tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the Bert tokenizer metadata.
"""
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = _VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY

# Create the RegexTokenizer.
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = (
_metadata_fb.ProcessUnitOptions.RegexTokenizerOptions)
tokenizer.options = _metadata_fb.RegexTokenizerOptionsT()
tokenizer.options.delimRegexPattern = self._delim_regex_pattern
tokenizer.options.vocabFile = [vocab]
return tokenizer


class BertTokenizerMd:
"""A container for the Bert tokenizer [1] metadata information.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
"""

def __init__(self, vocab_file_path: str):
"""Initializes a BertTokenizerMd object.
Args:
vocab_file_path: path to the vocabulary file.
"""
self._vocab_file_path = vocab_file_path

def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the Bert tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the Bert tokenizer metadata.
"""
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = _VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
tokenizer.options.vocabFile = [vocab]
return tokenizer


class SentencePieceTokenizerMd:
"""A container for the sentence piece tokenizer [1] metadata information.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
"""

_SP_MODEL_DESCRIPTION = "The sentence piece model file."
_SP_VOCAB_FILE_DESCRIPTION = _VOCAB_FILE_DESCRIPTION + (
" This file is optional during tokenization, while the sentence piece "
"model is mandatory.")

def __init__(self,
sentence_piece_model_path: str,
vocab_file_path: Optional[str] = None):
"""Initializes a SentencePieceTokenizerMd object.
Args:
sentence_piece_model_path: path to the sentence piece model file.
vocab_file_path: path to the vocabulary file.
"""
self._sentence_piece_model_path = sentence_piece_model_path
self._vocab_file_path = vocab_file_path

def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the sentence piece tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the sentence piece tokenizer metadata.
"""
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = (
_metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()

sp_model = _metadata_fb.AssociatedFileT()
sp_model.name = self._sentence_piece_model_path
sp_model.description = self._SP_MODEL_DESCRIPTION
tokenizer.options.sentencePieceModel = [sp_model]
if self._vocab_file_path:
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = self._SP_VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer.options.vocabFile = [vocab]
return tokenizer


class TensorMd:
"""A container for common tensor metadata information.
Expand Down Expand Up @@ -189,7 +310,7 @@ def create_metadata(self) -> _metadata_fb.TensorMetadataT:


class InputImageTensorMd(TensorMd):
"""A container for input tensor metadata information.
"""A container for input image tensor metadata information.
Attributes:
norm_mean: the mean value used in tensor normalization [1].
Expand Down Expand Up @@ -284,6 +405,56 @@ def create_metadata(self) -> _metadata_fb.TensorMetadataT:
return tensor_metadata


class InputTextTensorMd(TensorMd):
"""A container for the input text tensor metadata information.
Attributes:
tokenizer_md: information of the tokenizer in the input text tensor, if any.
"""

def __init__(self,
name: Optional[str] = None,
description: Optional[str] = None,
tokenizer_md: Optional[RegexTokenizerMd] = None):
"""Initializes the instance of InputTextTensorMd.
Args:
name: name of the tensor.
description: description of what the tensor is.
tokenizer_md: information of the tokenizer in the input text tensor, if
any. Only `RegexTokenizer` [1] is currenly supported. If the tokenizer
is `BertTokenizer` [2] or `SentencePieceTokenizer` [3], refer to
`bert_nl_classifier.MetadataWriter`.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475
[2]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
[3]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
"""
super().__init__(name, description)
self.tokenizer_md = tokenizer_md

def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the input text metadata based on the information.
Returns:
A Flatbuffers Python object of the input text metadata.
Raises:
ValueError: if the type of tokenizer_md is unsupported.
"""
if not isinstance(self.tokenizer_md, (type(None), RegexTokenizerMd)):
raise ValueError(
"The type of tokenizer_options, {}, is unsupported".format(
type(self.tokenizer_md)))

tensor_metadata = super().create_metadata()
if self.tokenizer_md:
tensor_metadata.processUnits = [self.tokenizer_md.create_metadata()]
return tensor_metadata


class ClassificationTensorMd(TensorMd):
"""A container for the classification tensor metadata information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_create_metadata_should_succeed(self, content_type, golden_json):
tensor_metadata = tensor_md.create_metadata()

metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata(tensor_metadata))
_create_dummy_model_metadata_with_tensor(tensor_metadata))
expected_json = test_utils.load_file(golden_json, "r")
self.assertEqual(metadata_json, expected_json)

Expand Down Expand Up @@ -158,7 +158,7 @@ def test_create_metadata_should_succeed(self, tensor_type, golden_json):
tensor_metadata = tesnor_md.create_metadata()

metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata(tensor_metadata))
_create_dummy_model_metadata_with_tensor(tensor_metadata))
expected_json = test_utils.load_file(golden_json, "r")
self.assertEqual(metadata_json, expected_json)

Expand All @@ -173,6 +173,54 @@ def test_init_should_throw_exception_with_incompatible_mean_and_std(self):
"{} and {}".format(len(norm_mean), len(norm_std)), str(error.exception))


class InputTextTensorMdTest(tf.test.TestCase):

_NAME = "input text"
_DESCRIPTION = "The input string."
_VOCAB_FILE = "vocab.txt"
_DELIM_REGEX_PATTERN = r"[^\w\']+"
_EXPECTED_TENSOR_JSON = "../testdata/input_text_tesnor_meta.json"
_EXPECTED_TENSOR_DEFAULT_JSON = "../testdata/input_text_tesnor_default_meta.json"

def test_create_metadata_should_succeed(self):
regex_tokenizer_md = metadata_info.RegexTokenizerMd(
self._DELIM_REGEX_PATTERN, self._VOCAB_FILE)

text_tensor_md = metadata_info.InputTextTensorMd(self._NAME,
self._DESCRIPTION,
regex_tokenizer_md)

metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_tensor(
text_tensor_md.create_metadata()))
expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r")
self.assertEqual(metadata_json, expected_json)

def test_create_metadata_by_default_should_succeed(self):
text_tensor_md = metadata_info.InputTextTensorMd()

metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_tensor(
text_tensor_md.create_metadata()))
expected_json = test_utils.load_file(self._EXPECTED_TENSOR_DEFAULT_JSON,
"r")
self.assertEqual(metadata_json, expected_json)

def test_create_metadata_throws_exception_with_unsupported_tokenizer(self):
invalid_tokenzier = metadata_info.BertTokenizerMd("vocab.txt")

with self.assertRaises(ValueError) as error:
tensor_md = metadata_info.InputTextTensorMd(
tokenizer_md=invalid_tokenzier)
tensor_md.create_metadata()

# TODO(b/175843689): f string cannot be used. Python version cannot be
# specified in Kokoro bazel test.
self.assertEqual(
"The type of tokenizer_options, {}, is unsupported".format(
type(invalid_tokenzier)), str(error.exception))


class ClassificationTensorMdTest(tf.test.TestCase, parameterized.TestCase):

_NAME = "probability"
Expand Down Expand Up @@ -210,7 +258,7 @@ def test_create_metadata_should_succeed(self, tensor_type, golden_json):
tensor_metadata = tesnor_md.create_metadata()

metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata(tensor_metadata))
_create_dummy_model_metadata_with_tensor(tensor_metadata))
expected_json = test_utils.load_file(golden_json, "r")
self.assertEqual(metadata_json, expected_json)

Expand All @@ -235,12 +283,61 @@ def test_create_metadata_should_succeed(self):
tensor_metadata = tesnor_md.create_metadata()

metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata(tensor_metadata))
_create_dummy_model_metadata_with_tensor(tensor_metadata))
expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r")
self.assertEqual(metadata_json, expected_json)


def _create_dummy_model_metadata(
class RegexTokenizerMdTest(tf.test.TestCase):

_VOCAB_FILE = "vocab.txt"
_DELIM_REGEX_PATTERN = r"[^\w\']+"
_EXPECTED_TENSOR_JSON = "../testdata/regex_tokenizer_meta.json"

def test_create_metadata_should_succeed(self):
tokenizer_md = metadata_info.RegexTokenizerMd(self._DELIM_REGEX_PATTERN,
self._VOCAB_FILE)
tokenizer_metadata = tokenizer_md.create_metadata()

metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r")
self.assertEqual(metadata_json, expected_json)


class BertTokenizerMdTest(tf.test.TestCase):

_VOCAB_FILE = "vocab.txt"
_EXPECTED_TENSOR_JSON = "../testdata/bert_tokenizer_meta.json"

def test_create_metadata_should_succeed(self):
tokenizer_md = metadata_info.BertTokenizerMd(self._VOCAB_FILE)
tokenizer_metadata = tokenizer_md.create_metadata()

metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r")
self.assertEqual(metadata_json, expected_json)


class SentencePieceTokenizerMd(tf.test.TestCase):

_VOCAB_FILE = "vocab.txt"
_SP_MODEL = "sp.model"
_EXPECTED_TENSOR_JSON = "../testdata/sentence_piece_tokenizer_meta.json"

def test_create_metadata_should_succeed(self):
tokenizer_md = metadata_info.SentencePieceTokenizerMd(
self._SP_MODEL, self._VOCAB_FILE)
tokenizer_metadata = tokenizer_md.create_metadata()

metadata_json = _metadata.convert_to_json(
_create_dummy_model_metadata_with_process_uint(tokenizer_metadata))
expected_json = test_utils.load_file(self._EXPECTED_TENSOR_JSON, "r")
self.assertEqual(metadata_json, expected_json)


def _create_dummy_model_metadata_with_tensor(
tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes:
# Create a dummy model using the tensor metadata.
subgraph_metadata = _metadata_fb.SubGraphMetadataT()
Expand All @@ -256,5 +353,21 @@ def _create_dummy_model_metadata(
return bytes(builder.Output())


def _create_dummy_model_metadata_with_process_uint(
process_unit_metadata: _metadata_fb.ProcessUnitT) -> bytes:
# Create a dummy model using the tensor metadata.
subgraph_metadata = _metadata_fb.SubGraphMetadataT()
subgraph_metadata.inputProcessUnits = [process_unit_metadata]
model_metadata = _metadata_fb.ModelMetadataT()
model_metadata.subgraphMetadata = [subgraph_metadata]

# Create the Flatbuffers object and convert it to the json format.
builder = flatbuffers.Builder(0)
builder.Finish(
model_metadata.Pack(builder),
_metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
return bytes(builder.Output())


if __name__ == "__main__":
tf.test.main()
Loading

0 comments on commit 3613853

Please sign in to comment.