ONE - On-device Neural Engine
Loading...
Searching...
No Matches
tf2tfliteV2.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4# Copyright (C) 2018 The TensorFlow Authors
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18import os
19import tensorflow as tf
20import argparse
21import sys
22
23from google.protobuf.message import DecodeError
24from google.protobuf import text_format as _text_format
25
26
27def wrap_frozen_graph(graph_def, inputs, outputs):
28 def _imports_graph_def():
29 tf.compat.v1.import_graph_def(graph_def, name="")
30
31 wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
32 import_graph = wrapped_import.graph
33 return wrapped_import.prune(
34 tf.nest.map_structure(import_graph.as_graph_element, inputs),
35 tf.nest.map_structure(import_graph.as_graph_element, outputs))
36
37
39 """
40 Returns an ArgumentParser for TensorFlow Lite Converter.
41 """
42 parser = argparse.ArgumentParser(
43 description=("Command line tool to run TensorFlow Lite Converter."))
44
45 # Verbose
46 parser.add_argument("-V",
47 "--verbose",
48 action="store_true",
49 help="output additional information to stdout or stderr")
50
51 # Converter version.
52 converter_version = parser.add_mutually_exclusive_group(required=True)
53 converter_version.add_argument("--v1",
54 action="store_true",
55 help="Use TensorFlow Lite Converter 1.x")
56 converter_version.add_argument("--v2",
57 action="store_true",
58 help="Use TensorFlow Lite Converter 2.x")
59
60 # Input model format
61 model_format_arg = parser.add_mutually_exclusive_group()
62 model_format_arg.add_argument("--graph_def",
63 action="store_const",
64 dest="model_format",
65 const="graph_def",
66 help="Use graph def file(default)")
67 model_format_arg.add_argument("--saved_model",
68 action="store_const",
69 dest="model_format",
70 const="saved_model",
71 help="Use saved model")
72 model_format_arg.add_argument("--keras_model",
73 action="store_const",
74 dest="model_format",
75 const="keras_model",
76 help="Use keras model")
77
78 # Input and output path.
79 parser.add_argument("-i",
80 "--input_path",
81 type=str,
82 help="Full filepath of the input file.",
83 required=True)
84 parser.add_argument("-o",
85 "--output_path",
86 type=str,
87 help="Full filepath of the output file.",
88 required=True)
89
90 # Input and output arrays.
91 parser.add_argument("-I",
92 "--input_arrays",
93 type=str,
94 help="Names of the input arrays, comma-separated.")
95 parser.add_argument(
96 "-s",
97 "--input_shapes",
98 type=str,
99 help=
100 "Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")"
101 )
102 parser.add_argument("-O",
103 "--output_arrays",
104 type=str,
105 help="Names of the output arrays, comma-separated.")
106
107 # experimental options
108 parser.add_argument("--experimental_disable_batchmatmul_unfold",
109 action="store_true",
110 help="Experimental disable BatchMatMul unfold")
111
112 # Set default value
113 parser.set_defaults(model_format="graph_def")
114 return parser
115
116
117def _check_flags(flags):
118 """
119 Checks the parsed flags to ensure they are valid.
120 """
121 if flags.v1:
122 invalid = ""
123 # To be filled
124
125 if invalid:
126 raise ValueError(invalid + " options must be used with v2")
127
128 if flags.v2:
129 if tf.__version__.find("2.") != 0:
130 raise ValueError(
131 "Imported TensorFlow should have version >= 2.0 but you have " +
132 tf.__version__)
133
134 invalid = ""
135 # To be filled
136
137 if invalid:
138 raise ValueError(invalid + " options must be used with v1")
139
140 if flags.input_shapes:
141 if not flags.input_arrays:
142 raise ValueError("--input_shapes must be used with --input_arrays")
143 if flags.input_shapes.count(":") != flags.input_arrays.count(","):
144 raise ValueError("--input_shapes and --input_arrays must have the same "
145 "number of items")
146
147
148def _parse_array(arrays, type_fn=str):
149 return list(map(type_fn, arrays.split(",")))
150
151
152def _v1_convert(flags):
153 if flags.model_format == "graph_def":
154 if not flags.input_arrays:
155 raise ValueError("--input_arrays must be provided")
156 if not flags.output_arrays:
157 raise ValueError("--output_arrays must be provided")
158 input_shapes = None
159 if flags.input_shapes:
160 input_arrays = _parse_array(flags.input_arrays)
161 input_shapes_list = [
162 _parse_array(shape, type_fn=int)
163 for shape in flags.input_shapes.split(":")
164 ]
165 input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
166
167 converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
168 flags.input_path, _parse_array(flags.input_arrays),
169 _parse_array(flags.output_arrays), input_shapes)
170
171 if flags.model_format == "saved_model":
172 converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(flags.input_path)
173
174 if flags.model_format == "keras_model":
175 converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
176 flags.input_path)
177
178 converter.allow_custom_ops = True
179
180 tflite_model = converter.convert()
181 open(flags.output_path, "wb").write(tflite_model)
182
183
184def _v2_convert(flags):
185 if flags.model_format == "graph_def":
186 if not flags.input_arrays:
187 raise ValueError("--input_arrays must be provided")
188 if not flags.output_arrays:
189 raise ValueError("--output_arrays must be provided")
190 input_shapes = []
191 if flags.input_shapes:
192 input_shapes = [
193 _parse_array(shape, type_fn=int)
194 for shape in flags.input_shapes.split(":")
195 ]
196 if len(input_shapes) != len(_parse_array(flags.input_arrays)):
197 raise ValueError(
198 "--input_shapes and --input_arrays must have the same length")
199 file_content = open(flags.input_path, 'rb').read()
200 try:
201 graph_def = tf.compat.v1.GraphDef()
202 graph_def.ParseFromString(file_content)
203 except (_text_format.ParseError, DecodeError):
204 try:
205 _text_format.Merge(file_content, graph_def)
206 except (_text_format.ParseError, DecodeError):
207 raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
208
209 wrap_func = wrap_frozen_graph(
210 graph_def,
211 inputs=[
212 _str + ":0" if len(_str.split(":")) == 1 else _str
213 for _str in _parse_array(flags.input_arrays)
214 ],
215 outputs=[
216 _str + ":0" if len(_str.split(":")) == 1 else _str
217 for _str in _parse_array(flags.output_arrays)
218 ])
219 for i in range(len(input_shapes)):
220 wrap_func.inputs[i].set_shape(input_shapes[i])
221 converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
222
223 if flags.model_format == "saved_model":
224 converter = tf.lite.TFLiteConverter.from_saved_model(flags.input_path)
225
226 if flags.model_format == "keras_model":
227 keras_model = tf.keras.models.load_model(flags.input_path)
228 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
229
230 if flags.experimental_disable_batchmatmul_unfold:
231 converter._experimental_disable_batchmatmul_unfold = True
232
233 converter.allow_custom_ops = True
234 converter.experimental_new_converter = True
235
236 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
237
238 tflite_model = converter.convert()
239 open(flags.output_path, "wb").write(tflite_model)
240
241
242def _apply_verbosity(verbosity):
243 # NOTE
244 # TF_CPP_MIN_LOG_LEVEL
245 # 0 : INFO + WARNING + ERROR + FATAL
246 # 1 : WARNING + ERROR + FATAL
247 # 2 : ERROR + FATAL
248 # 3 : FATAL
249 #
250 # TODO Find better way to suppress trackback on error
251 # tracebacklimit
252 # The default is 1000.
253 # When set to 0 or less, all traceback information is suppressed
254 if verbosity:
255 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
256 sys.tracebacklimit = 1000
257 else:
258 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
259 sys.tracebacklimit = 0
260
261
262def _convert(flags):
263 _apply_verbosity(flags.verbose)
264
265 if (flags.v1):
266 _v1_convert(flags)
267 else:
268 _v2_convert(flags)
269
270
271"""
272Input frozen graph must be from TensorFlow 1.13.1
273"""
274
275
276def main():
277 # Parse argument.
278 parser = _get_parser()
279
280 # Check if the flags are valid.
281 flags = parser.parse_known_args(args=sys.argv[1:])
282 _check_flags(flags[0])
283
284 # Convert
285 _convert(flags[0])
286
287
288if __name__ == "__main__":
289 try:
290 main()
291 except Exception as e:
292 prog_name = os.path.basename(__file__)
293 print(f"{prog_name}: {type(e).__name__}: " + str(e), file=sys.stderr)
294 sys.exit(255)
wrap_frozen_graph(graph_def, inputs, outputs)
_apply_verbosity(verbosity)
_parse_array(arrays, type_fn=str)
_check_flags(flags)
_convert(flags)
_v2_convert(flags)
_v1_convert(flags)