-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtflite_model_loader.py
82 lines (69 loc) · 3.43 KB
/
tflite_model_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import logging
import tvm
from typing import Dict, List
from pathlib import Path
from tvm import relay
from tvm.error import OpError
from .convert_layout_mixin import ConvertLayoutMixin
from .abstract_model_loader import AbstractModelLoader
from .helpers.tflite_model_helper import TFLiteModelHelper
from ._base import GraphIR
logger = logging.getLogger(__name__)
class TFLiteModelLoader(AbstractModelLoader, ConvertLayoutMixin):
def __init__(self, model_artifacts: List[str], data_shape: Dict[str, List[int]]) -> None:
super(TFLiteModelLoader, self).__init__(model_artifacts, data_shape)
self.__model_file = None
self.__model = None
self.__data_types = None
@property
def ir_format(self) -> GraphIR:
return GraphIR.relay
@property
def model_objects(self) -> (tvm.IRModule, tvm.nd.NDArray):
return self._relay_module_object, self._params
@property
def aux_files(self) -> List[Path]:
return []
def __extract_model_file_from_model_artifacts(self) -> Path:
model_files = self._get_files_from_model_artifacts_with_extensions([".tflite"])
if len(model_files) > 1:
raise RuntimeError('InputConfiguration: Exactly one .tflite file is allowed for TFLite models.')
elif len(model_files) == 1:
self.__model_file = model_files[0]
else:
raise RuntimeError('InputConfiguration: No valid TFLite model found in input file. Please make sure the framework you selected is correct.')
def __extract_model_data_types_and_metadata_from_model_file(self) -> None:
tflite_model_helper = TFLiteModelHelper(self.__model_file.as_posix())
try:
tflite_model_helper.load_model()
except Exception as ex:
logger.exception("Failed to load TFLite model.%s" % repr(ex))
raise RuntimeError("InputConfiguration: Framework cannot load TFLite model: {}".format(ex))
else:
try:
tflite_model_helper.extract_input_and_output_tensors(user_shape_dict=self.data_shape)
self.__model = tflite_model_helper.tflite_model
self.__data_types = tflite_model_helper.input_dtypes_dict
self._metadata = tflite_model_helper.get_metadata()
except ImportError:
raise
except Exception as ex:
logging.exception("Unable to infer tensor data type for all inputs/outputs %s" % repr(ex))
raise RuntimeError("InputConfiguration: Framework cannot load TFLite model. Unable to infer tensor data type for all inputs/outputs: {}".format(ex))
def load_model(self) -> None:
self.__extract_model_file_from_model_artifacts()
self.__extract_model_data_types_and_metadata_from_model_file()
try:
self._relay_module_object, self._params = relay.frontend.from_tflite(
self.__model,
shape_dict=self.data_shape,
dtype_dict=self.__data_types
)
self._relay_module_object = self.convert_layout(self._relay_module_object)
self.update_missing_metadata()
except OpError:
raise
except Exception as e:
logger.exception("Failed to convert TFLite model. %s" % repr(e))
msg = "InputConfiguration: TVM cannot convert TFLite model. Please make sure the framework you selected is correct. {}".format(e)
raise RuntimeError(msg)