ONE - On-device Neural Engine
Loading...
Searching...
No Matches
tf2tfliteV2.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4# Copyright (C) 2018 The TensorFlow Authors
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18import os
19import tensorflow as tf
20import argparse
21import sys
22
23from google.protobuf.message import DecodeError
24from google.protobuf import text_format as _text_format
25
26
27def wrap_frozen_graph(graph_def, inputs, outputs):
28 def _imports_graph_def():
29 tf.compat.v1.import_graph_def(graph_def, name="")
30
31 wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
32 import_graph = wrapped_import.graph
33 return wrapped_import.prune(
34 tf.nest.map_structure(import_graph.as_graph_element, inputs),
35 tf.nest.map_structure(import_graph.as_graph_element, outputs))
36
37
39 """
40 Returns an ArgumentParser for TensorFlow Lite Converter.
41 """
42 parser = argparse.ArgumentParser(
43 description=("Command line tool to run TensorFlow Lite Converter."))
44
45 # Verbose
46 parser.add_argument("-V",
47 "--verbose",
48 action="store_true",
49 help="output additional information to stdout or stderr")
50
51 # TODO deprecate options v1 and v2
52 # TODO add "deprecated" for v1, v2 when python >= 3.13
53 # Converter version.
54 converter_version = parser.add_mutually_exclusive_group(required=False)
55 converter_version.add_argument("--v1",
56 action="store_true",
57 help="Use TensorFlow Lite Converter 1.x")
58 converter_version.add_argument("--v2",
59 action="store_true",
60 help="Use TensorFlow Lite Converter 2.x")
61
62 # Input model format
63 model_format_arg = parser.add_mutually_exclusive_group()
64 model_format_arg.add_argument("--graph_def",
65 action="store_const",
66 dest="model_format",
67 const="graph_def",
68 help="Use graph def file(default)")
69 model_format_arg.add_argument("--saved_model",
70 action="store_const",
71 dest="model_format",
72 const="saved_model",
73 help="Use saved model")
74 model_format_arg.add_argument("--keras_model",
75 action="store_const",
76 dest="model_format",
77 const="keras_model",
78 help="Use keras model")
79
80 # Input and output path.
81 parser.add_argument("-i",
82 "--input_path",
83 type=str,
84 help="Full filepath of the input file.",
85 required=True)
86 parser.add_argument("-o",
87 "--output_path",
88 type=str,
89 help="Full filepath of the output file.",
90 required=True)
91
92 # Input and output arrays.
93 parser.add_argument("-I",
94 "--input_arrays",
95 type=str,
96 help="Names of the input arrays, comma-separated.")
97 parser.add_argument(
98 "-s",
99 "--input_shapes",
100 type=str,
101 help=
102 "Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")"
103 )
104 parser.add_argument("-O",
105 "--output_arrays",
106 type=str,
107 help="Names of the output arrays, comma-separated.")
108
109 # experimental options
110 parser.add_argument("--experimental_disable_batchmatmul_unfold",
111 action="store_true",
112 help="Experimental disable BatchMatMul unfold")
113
114 # Set default value
115 parser.set_defaults(model_format="graph_def")
116 return parser
117
118
119def _check_flags(flags):
120 """
121 Checks the parsed flags to ensure they are valid.
122 """
123 if flags.v1:
124 print("Warning: option --v1 is deprecated", file=sys.stderr)
125
126 if flags.v2:
127 print("Warning: option --v2 is deprecated", file=sys.stderr)
128
129 if flags.input_shapes:
130 if not flags.input_arrays:
131 raise ValueError("--input_shapes must be used with --input_arrays")
132 if flags.input_shapes.count(":") != flags.input_arrays.count(","):
133 raise ValueError("--input_shapes and --input_arrays must have the same "
134 "number of items")
135
136
137def _parse_array(arrays, type_fn=str):
138 return list(map(type_fn, arrays.split(",")))
139
140
141def _v2_convert(flags):
142 if flags.model_format == "graph_def":
143 if not flags.input_arrays:
144 raise ValueError("--input_arrays must be provided")
145 if not flags.output_arrays:
146 raise ValueError("--output_arrays must be provided")
147 input_shapes = []
148 if flags.input_shapes:
149 input_shapes = [
150 _parse_array(shape, type_fn=int)
151 for shape in flags.input_shapes.split(":")
152 ]
153 if len(input_shapes) != len(_parse_array(flags.input_arrays)):
154 raise ValueError(
155 "--input_shapes and --input_arrays must have the same length")
156 file_content = open(flags.input_path, 'rb').read()
157 try:
158 graph_def = tf.compat.v1.GraphDef()
159 graph_def.ParseFromString(file_content)
160 except (_text_format.ParseError, DecodeError):
161 try:
162 _text_format.Merge(file_content, graph_def)
163 except (_text_format.ParseError, DecodeError):
164 raise IOError("Unable to parse input file '{}'.".format(flags.input_path))
165
166 wrap_func = wrap_frozen_graph(
167 graph_def,
168 inputs=[
169 _str + ":0" if len(_str.split(":")) == 1 else _str
170 for _str in _parse_array(flags.input_arrays)
171 ],
172 outputs=[
173 _str + ":0" if len(_str.split(":")) == 1 else _str
174 for _str in _parse_array(flags.output_arrays)
175 ])
176 for i in range(len(input_shapes)):
177 wrap_func.inputs[i].set_shape(input_shapes[i])
178 converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
179
180 if flags.model_format == "saved_model":
181 converter = tf.lite.TFLiteConverter.from_saved_model(flags.input_path)
182
183 if flags.model_format == "keras_model":
184 keras_model = tf.keras.models.load_model(flags.input_path)
185 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
186
187 if flags.experimental_disable_batchmatmul_unfold:
188 converter._experimental_disable_batchmatmul_unfold = True
189
190 converter.allow_custom_ops = True
191 converter.experimental_new_converter = True
192
193 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
194
195 tflite_model = converter.convert()
196 open(flags.output_path, "wb").write(tflite_model)
197
198
199def _apply_verbosity(verbosity):
200 # NOTE
201 # TF_CPP_MIN_LOG_LEVEL
202 # 0 : INFO + WARNING + ERROR + FATAL
203 # 1 : WARNING + ERROR + FATAL
204 # 2 : ERROR + FATAL
205 # 3 : FATAL
206 #
207 # TODO Find better way to suppress trackback on error
208 # tracebacklimit
209 # The default is 1000.
210 # When set to 0 or less, all traceback information is suppressed
211 if verbosity:
212 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
213 sys.tracebacklimit = 1000
214 else:
215 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
216 sys.tracebacklimit = 0
217
218
219def _convert(flags):
220 _apply_verbosity(flags.verbose)
221
222 _v2_convert(flags)
223
224
225"""
226Input frozen graph must be from TensorFlow 1.13.1
227"""
228
229
230def main():
231 # Parse argument.
232 parser = _get_parser()
233
234 # Check if the flags are valid.
235 flags = parser.parse_known_args(args=sys.argv[1:])
236 _check_flags(flags[0])
237
238 # Convert
239 _convert(flags[0])
240
241
242if __name__ == "__main__":
243 try:
244 main()
245 except Exception as e:
246 prog_name = os.path.basename(__file__)
247 print(f"{prog_name}: {type(e).__name__}: " + str(e), file=sys.stderr)
248 sys.exit(255)
wrap_frozen_graph(graph_def, inputs, outputs)
_apply_verbosity(verbosity)
_parse_array(arrays, type_fn=str)
_check_flags(flags)
_convert(flags)
_v2_convert(flags)