ONE - On-device Neural Engine
Loading...
Searching...
No Matches
tf2tfliteV2 Namespace Reference

Functions

 wrap_frozen_graph (graph_def, inputs, outputs)
 
 _get_parser ()
 
 _check_flags (flags)
 
 _parse_array (arrays, type_fn=str)
 
 _v1_convert (flags)
 
 _v2_convert (flags)
 
 _apply_verbosity (verbosity)
 
 _convert (flags)
 
 main ()
 

Variables

 prog_name = os.path.basename(__file__)
 
 file
 

Function Documentation

◆ _apply_verbosity()

tf2tfliteV2._apply_verbosity (   verbosity)
protected

Definition at line 242 of file tf2tfliteV2.py.

242def _apply_verbosity(verbosity):
243 # NOTE
244 # TF_CPP_MIN_LOG_LEVEL
245 # 0 : INFO + WARNING + ERROR + FATAL
246 # 1 : WARNING + ERROR + FATAL
247 # 2 : ERROR + FATAL
248 # 3 : FATAL
249 #
250 # TODO Find better way to suppress trackback on error
251 # tracebacklimit
252 # The default is 1000.
253 # When set to 0 or less, all traceback information is suppressed
254 if verbosity:
255 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
256 sys.tracebacklimit = 1000
257 else:
258 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
259 sys.tracebacklimit = 0
260
261

Referenced by _convert().

◆ _check_flags()

tf2tfliteV2._check_flags (   flags)
protected
Checks the parsed flags to ensure they are valid.

Definition at line 117 of file tf2tfliteV2.py.

117def _check_flags(flags):
118 """
119 Checks the parsed flags to ensure they are valid.
120 """
121 if flags.v1:
122 invalid = ""
123 # To be filled
124
125 if invalid:
126 raise ValueError(invalid + " options must be used with v2")
127
128 if flags.v2:
129 if tf.__version__.find("2.") != 0:
130 raise ValueError(
131 "Imported TensorFlow should have version >= 2.0 but you have " +
132 tf.__version__)
133
134 invalid = ""
135 # To be filled
136
137 if invalid:
138 raise ValueError(invalid + " options must be used with v1")
139
140 if flags.input_shapes:
141 if not flags.input_arrays:
142 raise ValueError("--input_shapes must be used with --input_arrays")
143 if flags.input_shapes.count(":") != flags.input_arrays.count(","):
144 raise ValueError("--input_shapes and --input_arrays must have the same "
145 "number of items")
146
147

Referenced by main().

◆ _convert()

tf2tfliteV2._convert (   flags)
protected

Definition at line 262 of file tf2tfliteV2.py.

262def _convert(flags):
263 _apply_verbosity(flags.verbose)
264
265 if (flags.v1):
266 _v1_convert(flags)
267 else:
268 _v2_convert(flags)
269
270
271"""
272Input frozen graph must be from TensorFlow 1.13.1
273"""
274
275

References _apply_verbosity(), _v1_convert(), and _v2_convert().

Referenced by main().

◆ _get_parser()

tf2tfliteV2._get_parser ( )
protected
Returns an ArgumentParser for TensorFlow Lite Converter.

Definition at line 38 of file tf2tfliteV2.py.

38def _get_parser():
39 """
40 Returns an ArgumentParser for TensorFlow Lite Converter.
41 """
42 parser = argparse.ArgumentParser(
43 description=("Command line tool to run TensorFlow Lite Converter."))
44
45 # Verbose
46 parser.add_argument("-V",
47 "--verbose",
48 action="store_true",
49 help="output additional information to stdout or stderr")
50
51 # Converter version.
52 converter_version = parser.add_mutually_exclusive_group(required=True)
53 converter_version.add_argument("--v1",
54 action="store_true",
55 help="Use TensorFlow Lite Converter 1.x")
56 converter_version.add_argument("--v2",
57 action="store_true",
58 help="Use TensorFlow Lite Converter 2.x")
59
60 # Input model format
61 model_format_arg = parser.add_mutually_exclusive_group()
62 model_format_arg.add_argument("--graph_def",
63 action="store_const",
64 dest="model_format",
65 const="graph_def",
66 help="Use graph def file(default)")
67 model_format_arg.add_argument("--saved_model",
68 action="store_const",
69 dest="model_format",
70 const="saved_model",
71 help="Use saved model")
72 model_format_arg.add_argument("--keras_model",
73 action="store_const",
74 dest="model_format",
75 const="keras_model",
76 help="Use keras model")
77
78 # Input and output path.
79 parser.add_argument("-i",
80 "--input_path",
81 type=str,
82 help="Full filepath of the input file.",
83 required=True)
84 parser.add_argument("-o",
85 "--output_path",
86 type=str,
87 help="Full filepath of the output file.",
88 required=True)
89
90 # Input and output arrays.
91 parser.add_argument("-I",
92 "--input_arrays",
93 type=str,
94 help="Names of the input arrays, comma-separated.")
95 parser.add_argument(
96 "-s",
97 "--input_shapes",
98 type=str,
99 help=
100 "Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")"
101 )
102 parser.add_argument("-O",
103 "--output_arrays",
104 type=str,
105 help="Names of the output arrays, comma-separated.")
106
107 # experimental options
108 parser.add_argument("--experimental_disable_batchmatmul_unfold",
109 action="store_true",
110 help="Experimental disable BatchMatMul unfold")
111
112 # Set default value
113 parser.set_defaults(model_format="graph_def")
114 return parser
115
116

Referenced by main().

◆ _parse_array()

tf2tfliteV2._parse_array (   arrays,
  type_fn = str 
)
protected

Definition at line 148 of file tf2tfliteV2.py.

148def _parse_array(arrays, type_fn=str):
149 return list(map(type_fn, arrays.split(",")))
150
151

Referenced by _v1_convert(), and _v2_convert().

◆ _v1_convert()

tf2tfliteV2._v1_convert (   flags)
protected

Definition at line 152 of file tf2tfliteV2.py.

152def _v1_convert(flags):
153 if flags.model_format == "graph_def":
154 if not flags.input_arrays:
155 raise ValueError("--input_arrays must be provided")
156 if not flags.output_arrays:
157 raise ValueError("--output_arrays must be provided")
158 input_shapes = None
159 if flags.input_shapes:
160 input_arrays = _parse_array(flags.input_arrays)
161 input_shapes_list = [
162 _parse_array(shape, type_fn=int)
163 for shape in flags.input_shapes.split(":")
164 ]
165 input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
166
167 converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
168 flags.input_path, _parse_array(flags.input_arrays),
169 _parse_array(flags.output_arrays), input_shapes)
170
171 if flags.model_format == "saved_model":
172 converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(flags.input_path)
173
174 if flags.model_format == "keras_model":
175 converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
176 flags.input_path)
177
178 converter.allow_custom_ops = True
179
180 tflite_model = converter.convert()
181 open(flags.output_path, "wb").write(tflite_model)
182
183

References _parse_array().

Referenced by _convert().

◆ _v2_convert()

tf2tfliteV2._v2_convert (   flags)
protected

Definition at line 184 of file tf2tfliteV2.py.

184def _v2_convert(flags):
185 if flags.model_format == "graph_def":
186 if not flags.input_arrays:
187 raise ValueError("--input_arrays must be provided")
188 if not flags.output_arrays:
189 raise ValueError("--output_arrays must be provided")
190 input_shapes = []
191 if flags.input_shapes:
192 input_shapes = [
193 _parse_array(shape, type_fn=int)
194 for shape in flags.input_shapes.split(":")
195 ]
196 if len(input_shapes) != len(_parse_array(flags.input_arrays)):
197 raise ValueError(
198 "--input_shapes and --input_arrays must have the same length")
199 file_content = open(flags.input_path, 'rb').read()
200 try:
201 graph_def = tf.compat.v1.GraphDef()
202 graph_def.ParseFromString(file_content)
203 except (_text_format.ParseError, DecodeError):
204 try:
205 _text_format.Merge(file_content, graph_def)
206 except (_text_format.ParseError, DecodeError):
207 raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
208
209 wrap_func = wrap_frozen_graph(
210 graph_def,
211 inputs=[
212 _str + ":0" if len(_str.split(":")) == 1 else _str
213 for _str in _parse_array(flags.input_arrays)
214 ],
215 outputs=[
216 _str + ":0" if len(_str.split(":")) == 1 else _str
217 for _str in _parse_array(flags.output_arrays)
218 ])
219 for i in range(len(input_shapes)):
220 wrap_func.inputs[i].set_shape(input_shapes[i])
221 converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
222
223 if flags.model_format == "saved_model":
224 converter = tf.lite.TFLiteConverter.from_saved_model(flags.input_path)
225
226 if flags.model_format == "keras_model":
227 keras_model = tf.keras.models.load_model(flags.input_path)
228 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
229
230 if flags.experimental_disable_batchmatmul_unfold:
231 converter._experimental_disable_batchmatmul_unfold = True
232
233 converter.allow_custom_ops = True
234 converter.experimental_new_converter = True
235
236 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
237
238 tflite_model = converter.convert()
239 open(flags.output_path, "wb").write(tflite_model)
240
241

References _parse_array(), and wrap_frozen_graph().

Referenced by _convert().

◆ main()

tf2tfliteV2.main ( void  )

Definition at line 276 of file tf2tfliteV2.py.

276def main():
277 # Parse argument.
278 parser = _get_parser()
279
280 # Check if the flags are valid.
281 flags = parser.parse_known_args(args=sys.argv[1:])
282 _check_flags(flags[0])
283
284 # Convert
285 _convert(flags[0])
286
287
int main(void)

References _check_flags(), _convert(), _get_parser(), and main().

Referenced by main().

◆ wrap_frozen_graph()

tf2tfliteV2.wrap_frozen_graph (   graph_def,
  inputs,
  outputs 
)

Definition at line 27 of file tf2tfliteV2.py.

27def wrap_frozen_graph(graph_def, inputs, outputs):
28 def _imports_graph_def():
29 tf.compat.v1.import_graph_def(graph_def, name="")
30
31 wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
32 import_graph = wrapped_import.graph
33 return wrapped_import.prune(
34 tf.nest.map_structure(import_graph.as_graph_element, inputs),
35 tf.nest.map_structure(import_graph.as_graph_element, outputs))
36
37

Referenced by _v2_convert().

Variable Documentation

◆ file

tf2tfliteV2.file

Definition at line 293 of file tf2tfliteV2.py.

◆ prog_name

tf2tfliteV2.prog_name = os.path.basename(__file__)

Definition at line 292 of file tf2tfliteV2.py.