Skip to content

Commit

Permalink
Minor updates to flatbuffer utilities
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 307732210
Change-Id: I6b97ccdff0323dbf0fd20fc20d6bc7e49d5e08ad
  • Loading branch information
MeghnaNatraj authored and tensorflower-gardener committed Apr 22, 2020
1 parent 4b70075 commit 47ea7ee
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
6 changes: 3 additions & 3 deletions tensorflow/lite/tools/flatbuffer_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class WriteReadModelTest(test_util.TensorFlowTestCase):
def testWriteReadModel(self):
# 1. SETUP
# Define the initial model
initial_model = test_utils.build_mock_model_python_object()
initial_model = test_utils.build_mock_model()
# Define temporary files
tmp_dir = self.get_temp_dir()
model_filename = os.path.join(tmp_dir, 'model.tflite')
Expand Down Expand Up @@ -76,7 +76,7 @@ class StripStringsTest(test_util.TensorFlowTestCase):
def testStripStrings(self):
# 1. SETUP
# Define the initial model
initial_model = test_utils.build_mock_model_python_object()
initial_model = test_utils.build_mock_model()
final_model = copy.deepcopy(initial_model)

# 2. INVOKE
Expand Down Expand Up @@ -121,7 +121,7 @@ class RandomizeWeightsTest(test_util.TensorFlowTestCase):
def testRandomizeWeights(self):
# 1. SETUP
# Define the initial model
initial_model = test_utils.build_mock_model_python_object()
initial_model = test_utils.build_mock_model()
final_model = copy.deepcopy(initial_model)

# 2. INVOKE
Expand Down
20 changes: 12 additions & 8 deletions tensorflow/lite/tools/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Utility functions that support testing.
All functions that can be commonly used by various tests are in this file.
All functions that can be commonly used by various tests.
"""

from __future__ import absolute_import
Expand All @@ -25,7 +25,7 @@
from tensorflow.lite.python import schema_py_generated as schema_fb


def build_mock_model():
def build_mock_flatbuffer_model():
"""Creates a flatbuffer containing an example model."""
builder = flatbuffers.Builder(1024)

Expand Down Expand Up @@ -205,10 +205,14 @@ def build_mock_model():
return model


def build_mock_model_python_object():
"""Creates a python flatbuffer object containing an example model."""
model_mock = build_mock_model()
model_obj = schema_fb.Model.GetRootAsModel(model_mock, 0)
model = schema_fb.ModelT.InitFromObj(model_obj)

def load_model_from_flatbuffer(flatbuffer_model):
"""Loads a model as a python object from a flatbuffer model."""
model = schema_fb.Model.GetRootAsModel(flatbuffer_model, 0)
model = schema_fb.ModelT.InitFromObj(model)
return model


def build_mock_model():
"""Creates an object containing an example model."""
model = build_mock_flatbuffer_model()
return load_model_from_flatbuffer(model)
9 changes: 4 additions & 5 deletions tensorflow/lite/tools/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def testBuiltinCodeToName(self):
self.assertEqual('HASHTABLE_LOOKUP', visualize.BuiltinCodeToName(10))

def testFlatbufferToDict(self):
model_data = test_utils.build_mock_model()
model_dict = visualize.CreateDictFromFlatbuffer(model_data)
model = test_utils.build_mock_flatbuffer_model()
model_dict = visualize.CreateDictFromFlatbuffer(model)
self.assertEqual(0, model_dict['version'])
self.assertEqual(1, len(model_dict['subgraphs']))
self.assertEqual(1, len(model_dict['operator_codes']))
Expand All @@ -45,12 +45,11 @@ def testFlatbufferToDict(self):
self.assertEqual(0, model_dict['subgraphs'][0]['tensors'][0]['buffer'])

def testVisualize(self):
model_data = test_utils.build_mock_model()

model = test_utils.build_mock_flatbuffer_model()
tmp_dir = self.get_temp_dir()
model_filename = os.path.join(tmp_dir, 'model.tflite')
with open(model_filename, 'wb') as model_file:
model_file.write(model_data)
model_file.write(model)
html_filename = os.path.join(tmp_dir, 'visualization.html')

visualize.CreateHtmlFile(model_filename, html_filename)
Expand Down

0 comments on commit 47ea7ee

Please sign in to comment.