ONE - On-device Neural Engine
Loading...
Searching...
No Matches
tf2tfliteV2 Namespace Reference

Functions

 wrap_frozen_graph (graph_def, inputs, outputs)
 
 _get_parser ()
 
 _check_flags (flags)
 
 _parse_array (arrays, type_fn=str)
 
 _v2_convert (flags)
 
 _apply_verbosity (verbosity)
 
 _convert (flags)
 
 main ()
 

Variables

 prog_name = os.path.basename(__file__)
 
 file
 

Function Documentation

◆ _apply_verbosity()

tf2tfliteV2._apply_verbosity (   verbosity)
protected

Definition at line 199 of file tf2tfliteV2.py.

199def _apply_verbosity(verbosity):
200 # NOTE
201 # TF_CPP_MIN_LOG_LEVEL
202 # 0 : INFO + WARNING + ERROR + FATAL
203 # 1 : WARNING + ERROR + FATAL
204 # 2 : ERROR + FATAL
205 # 3 : FATAL
206 #
207 # TODO Find better way to suppress trackback on error
208 # tracebacklimit
209 # The default is 1000.
210 # When set to 0 or less, all traceback information is suppressed
211 if verbosity:
212 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
213 sys.tracebacklimit = 1000
214 else:
215 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
216 sys.tracebacklimit = 0
217
218

Referenced by _convert().

◆ _check_flags()

tf2tfliteV2._check_flags (   flags)
protected
Checks the parsed flags to ensure they are valid.

Definition at line 119 of file tf2tfliteV2.py.

119def _check_flags(flags):
120 """
121 Checks the parsed flags to ensure they are valid.
122 """
123 if flags.v1:
124 print("Warning: option --v1 is deprecated", file=sys.stderr)
125
126 if flags.v2:
127 print("Warning: option --v2 is deprecated", file=sys.stderr)
128
129 if flags.input_shapes:
130 if not flags.input_arrays:
131 raise ValueError("--input_shapes must be used with --input_arrays")
132 if flags.input_shapes.count(":") != flags.input_arrays.count(","):
133 raise ValueError("--input_shapes and --input_arrays must have the same "
134 "number of items")
135
136

Referenced by main().

◆ _convert()

tf2tfliteV2._convert (   flags)
protected

Definition at line 219 of file tf2tfliteV2.py.

219def _convert(flags):
220 _apply_verbosity(flags.verbose)
221
222 _v2_convert(flags)
223
224
225"""
226Input frozen graph must be from TensorFlow 1.13.1
227"""
228
229

References _apply_verbosity(), and _v2_convert().

Referenced by main().

◆ _get_parser()

tf2tfliteV2._get_parser ( )
protected
Returns an ArgumentParser for TensorFlow Lite Converter.

Definition at line 38 of file tf2tfliteV2.py.

38def _get_parser():
39 """
40 Returns an ArgumentParser for TensorFlow Lite Converter.
41 """
42 parser = argparse.ArgumentParser(
43 description=("Command line tool to run TensorFlow Lite Converter."))
44
45 # Verbose
46 parser.add_argument("-V",
47 "--verbose",
48 action="store_true",
49 help="output additional information to stdout or stderr")
50
51 # TODO deprecate options v1 and v2
52 # TODO add "deprecated" for v1, v2 when python >= 3.13
53 # Converter version.
54 converter_version = parser.add_mutually_exclusive_group(required=False)
55 converter_version.add_argument("--v1",
56 action="store_true",
57 help="Use TensorFlow Lite Converter 1.x")
58 converter_version.add_argument("--v2",
59 action="store_true",
60 help="Use TensorFlow Lite Converter 2.x")
61
62 # Input model format
63 model_format_arg = parser.add_mutually_exclusive_group()
64 model_format_arg.add_argument("--graph_def",
65 action="store_const",
66 dest="model_format",
67 const="graph_def",
68 help="Use graph def file(default)")
69 model_format_arg.add_argument("--saved_model",
70 action="store_const",
71 dest="model_format",
72 const="saved_model",
73 help="Use saved model")
74 model_format_arg.add_argument("--keras_model",
75 action="store_const",
76 dest="model_format",
77 const="keras_model",
78 help="Use keras model")
79
80 # Input and output path.
81 parser.add_argument("-i",
82 "--input_path",
83 type=str,
84 help="Full filepath of the input file.",
85 required=True)
86 parser.add_argument("-o",
87 "--output_path",
88 type=str,
89 help="Full filepath of the output file.",
90 required=True)
91
92 # Input and output arrays.
93 parser.add_argument("-I",
94 "--input_arrays",
95 type=str,
96 help="Names of the input arrays, comma-separated.")
97 parser.add_argument(
98 "-s",
99 "--input_shapes",
100 type=str,
101 help=
102 "Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")"
103 )
104 parser.add_argument("-O",
105 "--output_arrays",
106 type=str,
107 help="Names of the output arrays, comma-separated.")
108
109 # experimental options
110 parser.add_argument("--experimental_disable_batchmatmul_unfold",
111 action="store_true",
112 help="Experimental disable BatchMatMul unfold")
113
114 # Set default value
115 parser.set_defaults(model_format="graph_def")
116 return parser
117
118

Referenced by main().

◆ _parse_array()

tf2tfliteV2._parse_array (   arrays,
  type_fn = str 
)
protected

Definition at line 137 of file tf2tfliteV2.py.

137def _parse_array(arrays, type_fn=str):
138 return list(map(type_fn, arrays.split(",")))
139
140

Referenced by _v2_convert().

◆ _v2_convert()

tf2tfliteV2._v2_convert (   flags)
protected

Definition at line 141 of file tf2tfliteV2.py.

141def _v2_convert(flags):
142 if flags.model_format == "graph_def":
143 if not flags.input_arrays:
144 raise ValueError("--input_arrays must be provided")
145 if not flags.output_arrays:
146 raise ValueError("--output_arrays must be provided")
147 input_shapes = []
148 if flags.input_shapes:
149 input_shapes = [
150 _parse_array(shape, type_fn=int)
151 for shape in flags.input_shapes.split(":")
152 ]
153 if len(input_shapes) != len(_parse_array(flags.input_arrays)):
154 raise ValueError(
155 "--input_shapes and --input_arrays must have the same length")
156 file_content = open(flags.input_path, 'rb').read()
157 try:
158 graph_def = tf.compat.v1.GraphDef()
159 graph_def.ParseFromString(file_content)
160 except (_text_format.ParseError, DecodeError):
161 try:
162 _text_format.Merge(file_content, graph_def)
163 except (_text_format.ParseError, DecodeError):
164 raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
165
166 wrap_func = wrap_frozen_graph(
167 graph_def,
168 inputs=[
169 _str + ":0" if len(_str.split(":")) == 1 else _str
170 for _str in _parse_array(flags.input_arrays)
171 ],
172 outputs=[
173 _str + ":0" if len(_str.split(":")) == 1 else _str
174 for _str in _parse_array(flags.output_arrays)
175 ])
176 for i in range(len(input_shapes)):
177 wrap_func.inputs[i].set_shape(input_shapes[i])
178 converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
179
180 if flags.model_format == "saved_model":
181 converter = tf.lite.TFLiteConverter.from_saved_model(flags.input_path)
182
183 if flags.model_format == "keras_model":
184 keras_model = tf.keras.models.load_model(flags.input_path)
185 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
186
187 if flags.experimental_disable_batchmatmul_unfold:
188 converter._experimental_disable_batchmatmul_unfold = True
189
190 converter.allow_custom_ops = True
191 converter.experimental_new_converter = True
192
193 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
194
195 tflite_model = converter.convert()
196 open(flags.output_path, "wb").write(tflite_model)
197
198

References _parse_array(), and wrap_frozen_graph().

Referenced by _convert().

◆ main()

tf2tfliteV2.main ( void  )

Definition at line 230 of file tf2tfliteV2.py.

230def main():
231 # Parse argument.
232 parser = _get_parser()
233
234 # Check if the flags are valid.
235 flags = parser.parse_known_args(args=sys.argv[1:])
236 _check_flags(flags[0])
237
238 # Convert
239 _convert(flags[0])
240
241
int main(void)

References _check_flags(), _convert(), _get_parser(), and main().

Referenced by main().

◆ wrap_frozen_graph()

tf2tfliteV2.wrap_frozen_graph (   graph_def,
  inputs,
  outputs 
)

Definition at line 27 of file tf2tfliteV2.py.

27def wrap_frozen_graph(graph_def, inputs, outputs):
28 def _imports_graph_def():
29 tf.compat.v1.import_graph_def(graph_def, name="")
30
31 wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
32 import_graph = wrapped_import.graph
33 return wrapped_import.prune(
34 tf.nest.map_structure(import_graph.as_graph_element, inputs),
35 tf.nest.map_structure(import_graph.as_graph_element, outputs))
36
37

Referenced by _v2_convert().

Variable Documentation

◆ file

tf2tfliteV2.file

Definition at line 247 of file tf2tfliteV2.py.

◆ prog_name

tf2tfliteV2.prog_name = os.path.basename(__file__)

Definition at line 246 of file tf2tfliteV2.py.