Skip to content

Commit

Permalink
[tflitefile_tools] Introduce tflite_parser (Samsung#8131)
Browse files Browse the repository at this point in the history
Let's introduce tflite_parser. Now model_parser doesn't have dependency
directly on tflite.

Signed-off-by: Yongseop Kim <[email protected]>
  • Loading branch information
YongseopKim authored Dec 9, 2021
1 parent d281975 commit 6e61a6b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 25 deletions.
32 changes: 7 additions & 25 deletions tools/tflitefile_tool/parser/model_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tflite.Model
import tflite.SubGraph
from ir import graph_stats
from .operator_parser import OperatorParser
from .tflite_parser import TFLiteParser
from printer.subgraph_printer import SubgraphPrinter
from printer.graph_stats_printer import PrintGraphStats
from saver.model_saver import ModelSaver


# TODO: Rename it as ModelParser
class TFLiteModelFileParser(object):
def __init__(self, option):
self.option = option
Expand All @@ -45,33 +43,17 @@ def SaveModel(self, model_name, op_parser):
saver.SaveConfigInfo(self.option.save_prefix)

def main(self):
# Generate Model: top structure of tflite model file
buf = self.option.model_file.read()
buf = bytearray(buf)
tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)

stats = graph_stats.GraphStats()
# Model file can have many models
for subgraph_index in range(tf_model.SubgraphsLength()):
tf_subgraph = tf_model.Subgraphs(subgraph_index)
model_name = "#{0} {1}".format(subgraph_index, tf_subgraph.Name())
# 0th subgraph is main subgraph
if (subgraph_index == 0):
model_name += " (MAIN)"

# Parse Operators
op_parser = OperatorParser(tf_model, tf_subgraph)
op_parser.Parse()

stats += graph_stats.CalcGraphStats(op_parser)
parser = TFLiteParser(self.option.model_file)
parser.parse()

for model_name, op_parser in parser.subg_list:
if self.option.save == False:
# print all of operators or requested objects
self.PrintModel(model_name, op_parser)
else:
# save all of operators in this model
self.SaveModel(model_name, op_parser)

print('==== Model Stats ({} Subgraphs) ===='.format(tf_model.SubgraphsLength()))
print('==== Model Stats ({} Subgraphs) ===='.format(len(parser.subg_list)))
print('')
PrintGraphStats(stats, self.option.print_level)
PrintGraphStats(parser.stats, self.option.print_level)
52 changes: 52 additions & 0 deletions tools/tflitefile_tool/parser/tflite_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env python

# Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tflite.Model
import tflite.SubGraph
from ir import graph_stats
from .operator_parser import OperatorParser


class TFLiteParser(object):
def __init__(self, tflite_file):
self.tflite_file = tflite_file
self.subg_list = list()

def parse(self):
# Generate Model: top structure of tflite model file
buf = self.tflite_file.read()
buf = bytearray(buf)
tf_model = tflite.Model.Model.GetRootAsModel(buf, 0)

stats = graph_stats.GraphStats()
# Model file can have many models
for subgraph_index in range(tf_model.SubgraphsLength()):
tf_subgraph = tf_model.Subgraphs(subgraph_index)
model_name = "#{0} {1}".format(subgraph_index, tf_subgraph.Name())
# 0th subgraph is main subgraph
if (subgraph_index == 0):
model_name += " (MAIN)"

# Parse Operators
op_parser = OperatorParser(tf_model, tf_subgraph)
op_parser.Parse()

stats += graph_stats.CalcGraphStats(op_parser)

subg = (model_name, op_parser)
self.subg_list.append(subg)

self.stats = stats

0 comments on commit 6e61a6b

Please sign in to comment.