28 def _imports_graph_def():
29 tf.compat.v1.import_graph_def(graph_def, name=
"")
31 wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
32 import_graph = wrapped_import.graph
33 return wrapped_import.prune(
34 tf.nest.map_structure(import_graph.as_graph_element, inputs),
35 tf.nest.map_structure(import_graph.as_graph_element, outputs))
40 Returns an ArgumentParser for TensorFlow Lite Converter.
42 parser = argparse.ArgumentParser(
43 description=(
"Command line tool to run TensorFlow Lite Converter."))
46 parser.add_argument(
"-V",
49 help=
"output additional information to stdout or stderr")
54 converter_version = parser.add_mutually_exclusive_group(required=
False)
55 converter_version.add_argument(
"--v1",
57 help=
"Use TensorFlow Lite Converter 1.x")
58 converter_version.add_argument(
"--v2",
60 help=
"Use TensorFlow Lite Converter 2.x")
63 model_format_arg = parser.add_mutually_exclusive_group()
64 model_format_arg.add_argument(
"--graph_def",
68 help=
"Use graph def file(default)")
69 model_format_arg.add_argument(
"--saved_model",
73 help=
"Use saved model")
74 model_format_arg.add_argument(
"--keras_model",
78 help=
"Use keras model")
81 parser.add_argument(
"-i",
84 help=
"Full filepath of the input file.",
86 parser.add_argument(
"-o",
89 help=
"Full filepath of the output file.",
93 parser.add_argument(
"-I",
96 help=
"Names of the input arrays, comma-separated.")
102 "Shapes corresponding to --input_arrays, colon-separated.(ex:\"1,4,4,3:1,20,20,3\")"
104 parser.add_argument(
"-O",
107 help=
"Names of the output arrays, comma-separated.")
110 parser.add_argument(
"--experimental_disable_batchmatmul_unfold",
112 help=
"Experimental disable BatchMatMul unfold")
115 parser.set_defaults(model_format=
"graph_def")
142 if flags.model_format ==
"graph_def":
143 if not flags.input_arrays:
144 raise ValueError(
"--input_arrays must be provided")
145 if not flags.output_arrays:
146 raise ValueError(
"--output_arrays must be provided")
148 if flags.input_shapes:
151 for shape
in flags.input_shapes.split(
":")
153 if len(input_shapes) != len(
_parse_array(flags.input_arrays)):
155 "--input_shapes and --input_arrays must have the same length")
156 file_content = open(flags.input_path,
'rb').read()
158 graph_def = tf.compat.v1.GraphDef()
159 graph_def.ParseFromString(file_content)
160 except (_text_format.ParseError, DecodeError):
162 _text_format.Merge(file_content, graph_def)
163 except (_text_format.ParseError, DecodeError):
164 raise IOError(
"Unable to parse input file '{}'.".format(flags.input_path))
169 _str +
":0" if len(_str.split(
":")) == 1
else _str
173 _str +
":0" if len(_str.split(
":")) == 1
else _str
176 for i
in range(len(input_shapes)):
177 wrap_func.inputs[i].set_shape(input_shapes[i])
178 converter = tf.lite.TFLiteConverter.from_concrete_functions([wrap_func])
180 if flags.model_format ==
"saved_model":
181 converter = tf.lite.TFLiteConverter.from_saved_model(flags.input_path)
183 if flags.model_format ==
"keras_model":
184 keras_model = tf.keras.models.load_model(flags.input_path)
185 converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
187 if flags.experimental_disable_batchmatmul_unfold:
188 converter._experimental_disable_batchmatmul_unfold =
True
190 converter.allow_custom_ops =
True
191 converter.experimental_new_converter =
True
193 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
195 tflite_model = converter.convert()
196 open(flags.output_path,
"wb").write(tflite_model)