28 def _imports_graph_def():
29 tf.compat.v1.import_graph_def(graph_def, name=
"")
31 wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
32 import_graph = wrapped_import.graph
33 return wrapped_import.prune(
34 tf.nest.map_structure(import_graph.as_graph_element, inputs),
35 tf.nest.map_structure(import_graph.as_graph_element, outputs))
40 Returns an ArgumentParser for TensorFlow Lite Converter.
42 parser = argparse.ArgumentParser(
43 description=(
"Command line tool to run TensorFlow Lite Converter."))
46 parser.add_argument(
"-V",
49 help=
"output additional information to stdout or stderr")
52 converter_version = parser.add_mutually_exclusive_group(required=
True)
53 converter_version.add_argument(
"--v1",
55 help=
"Use TensorFlow Lite Converter 1.x")
56 converter_version.add_argument(
"--v2",
58 help=
"Use TensorFlow Lite Converter 2.x")
61 model_format_arg = parser.add_mutually_exclusive_group()
62 model_format_arg.add_argument(
"--graph_def",
66 help=
"Use graph def file(default)")
67 model_format_arg.add_argument(
"--saved_model",
71 help=
"Use saved model")
72 model_format_arg.add_argument(
"--keras_model",
76 help=
"Use keras model")
79 parser.add_argument(
"-i",
82 help=
"Full filepath of the input file.",
84 parser.add_argument(
"-o",
87 help=
"Full filepath of the output file.",
91 parser.add_argument(
"-I",
94 help=
"Names of the input arrays, comma-separated.")
100 "Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")"
102 parser.add_argument(
"-O",
105 help=
"Names of the output arrays, comma-separated.")
108 parser.add_argument(
"--experimental_disable_batchmatmul_unfold",
110 help=
"Experimental disable BatchMatMul unfold")
113 parser.set_defaults(model_format=
"graph_def")
153 if flags.model_format ==
"graph_def":
154 if not flags.input_arrays:
155 raise ValueError(
"--input_arrays must be provided")
156 if not flags.output_arrays:
157 raise ValueError(
"--output_arrays must be provided")
159 if flags.input_shapes:
161 input_shapes_list = [
163 for shape
in flags.input_shapes.split(
":")
165 input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
167 converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
171 if flags.model_format ==
"saved_model":
172 converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(flags.input_path)
174 if flags.model_format ==
"keras_model":
175 converter = tf.compat.v1.lite.TFLiteConverter.from_keras_model_file(
178 converter.allow_custom_ops =
True
180 tflite_model = converter.convert()
181 open(flags.output_path,
"wb").write(tflite_model)
185 if flags.model_format ==
"graph_def":
186 if not flags.input_arrays:
187 raise ValueError(
"--input_arrays must be provided")
188 if not flags.output_arrays:
189 raise ValueError(
"--output_arrays must be provided")
191 if flags.input_shapes:
194 for shape
in flags.input_shapes.split(
":")
196 if len(input_shapes) != len(
_parse_array(flags.input_arrays)):
198 "--input_shapes and --input_arrays must have the same length")
199 file_content = open(flags.input_path,
'rb').read()
201 graph_def = tf.compat.v1.GraphDef()
202 graph_def.ParseFromString(file_content)
203 except (_text_format.ParseError, DecodeError):
205 _text_format.Merge(file_content, graph_def)
206 except (_text_format.ParseError, DecodeError):
207 raise IOError(
"Unable to parse input file '{}'.".format(flags.input_path))
212 _str +
":0" if len(_str.split(
":")) == 1
else _str
216 _str +
":0" if len(_str.split(
":")) == 1
else _str
219 for i
in range(len(input_shapes)):
220 wrap_func.inputs[i].set_shape(input_shapes[i])
221 converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
223 if flags.model_format ==
"saved_model":
224 converter = tf.lite.TFLiteConverter.from_saved_model(flags.input_path)
226 if flags.model_format ==
"keras_model":
227 keras_model = tf.keras.models.load_model(flags.input_path)
228 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
230 if flags.experimental_disable_batchmatmul_unfold:
231 converter._experimental_disable_batchmatmul_unfold =
True
233 converter.allow_custom_ops =
True
234 converter.experimental_new_converter =
True
236 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
238 tflite_model = converter.convert()
239 open(flags.output_path,
"wb").write(tflite_model)