ONE - On-device Neural Engine
Loading...
Searching...
No Matches
generate_bcq_metadata.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18import numpy as np
19import tensorflow as tf
20
21import argparse
22import os
23import sys
24
25# TODO Find better way to suppress trackback on error
26sys.tracebacklimit = 0
27
28ONE_START_MAGICNUM = int(-2e9 + 27)
29ONE_END_MAGICNUM = int(2e9 - 27)
30
31
33 """
34 Returns an ArgumentParser for generating BCQ metadata.
35 """
36 parser = argparse.ArgumentParser(
37 description=("Command line tool to generate metadata of BCQ nodes"))
38
39 # Input and output path.
40 parser.add_argument("-i",
41 "--input_path",
42 type=str,
43 help="Full filepath of the input file.",
44 required=True)
45 parser.add_argument("-o",
46 "--output_path",
47 type=str,
48 help="Full filepath of the output file.",
49 required=True)
50 parser.add_argument("-O",
51 "--output_arrays",
52 type=str,
53 help="Original model output arrays",
54 required=True)
55
56 return parser
57
58
59# This function is copied from
60# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/examples/label_image/label_image.py#L26
61def load_graph(model_file):
62 graph = tf.Graph()
63 graph_def = tf.compat.v1.GraphDef()
64
65 with open(model_file, "rb") as f:
66 graph_def.ParseFromString(f.read())
67 with graph.as_default():
68 tf.import_graph_def(graph_def, name="")
69
70 return graph
71
72
73def generate_metadata_header(original_graph, bcq_version, output_arrays):
74 # Generating metadata starts
75 metadata_values = np.array([ONE_START_MAGICNUM])
76
77 # Append BCQ version
78 metadata_values = np.append(metadata_values, bcq_version)
79
80 # Append original output count
81 output_cnt = output_arrays.count(',') + 1
82 metadata_values = np.append(metadata_values, output_cnt)
83
84 return metadata_values
85
86
88 """
89 BCQv1 contains following metadata.
90 - The number of each BCQ information set
91 """
92
93 is_valid = True
94 allowed_info_names = [
95 "bcqinfo_do_w_x", "bcqinfo_alpha", "bcqinfo_packed_binary_code",
96 "bcqinfo_number_of_clusters", "bcqinfo_size_of_clusters",
97 "bcqinfo_qbits_of_clusters", "bcqinfo_dequant_weight"
98 ]
99
100 original_graph = load_graph(flags.input_path)
101 original_graph_def = original_graph.as_graph_def()
102
103 prefix_infonames_dict = {}
104
105 for node in original_graph_def.node:
106 if node.op == "Const" and "/bcqinfo_" in node.name:
107 prefix_index = node.name.index("/bcqinfo_")
108 prefix = node.name[:prefix_index]
109 infoname = node.name[prefix_index + 1:]
110
111 if infoname not in allowed_info_names:
112 is_valid = False
113 break
114
115 if prefix not in prefix_infonames_dict:
116 prefix_infonames_dict[prefix] = set()
117
118 prefix_infonames_dict[prefix].add(infoname)
119
120 # All the number of BCQ information should be same
121 num_of_bcqinfo = -1
122 for key in prefix_infonames_dict:
123 infonames = prefix_infonames_dict[key]
124 if num_of_bcqinfo == -1:
125 num_of_bcqinfo = len(infonames)
126 elif num_of_bcqinfo != len(infonames):
127 is_valid = False
128
129 # The number of BCQv1 information should be 6 or 7
130 if num_of_bcqinfo != 6 and num_of_bcqinfo != 7:
131 is_valid = False
132
133 # If BCQ information is invalid, return original model
134 if is_valid == False:
135 return original_graph_def
136
137 new_graph_def = tf.compat.v1.GraphDef()
138 for node in original_graph_def.node:
139 new_node = new_graph_def.node.add()
140 new_node.CopyFrom(node)
141
142 # Generate metadata header
143 metadata_values = generate_metadata_header(original_graph, 1, flags.output_arrays)
144
145 # Append metadata of BCQv1
146 metadata_values = np.append(metadata_values, num_of_bcqinfo + 1)
147
148 # Finish generating metadata
149 metadata_values = np.append(metadata_values, ONE_END_MAGICNUM)
150
151 # Generate metadata tensor
152 metadata_tensor = tf.make_tensor_proto(metadata_values, tf.int32)
153
154 new_node = new_graph_def.node.add()
155 new_node.op = "Const"
156 new_node.name = "one_compiler/bcqinfo_one_metadata"
157 new_node.attr["dtype"].CopyFrom(
158 tf.compat.v1.AttrValue(type=tf.int32.as_datatype_enum))
159 new_node.attr["value"].tensor.CopyFrom(metadata_tensor)
160 return new_graph_def
161
162
164 """
165 CAUTION : For now, BCQ has only one version and thus always returns 1 when BCQ
166 information nodes are included. If new BCQ version is introduced,
167 this function must be updated accordingly.
168
169 When BCQ information does not exist, -1 is returned.
170 """
171 bcq_version = -1
172
173 original_graph = load_graph(flags.input_path)
174 original_graph_def = original_graph.as_graph_def()
175
176 for node in original_graph_def.node:
177 if node.op == "Const" and "/bcqinfo_" in node.name:
178 bcq_version = 1
179 break
180
181 return bcq_version
182
183
185 """
186 Basic format of metadata is as following.
187 - Magic number indicating start
188 - Version of BCQ Format
189 - The number of original outputs
190 - Metadata based on each BCQ format
191 - Magic number indicating end
192 """
193 program_version = 1
194 model_version = determine_bcq_version(flags)
195
196 if model_version == 1:
197 result_graph_def = generate_bcq_metadata_v1(flags)
198 elif model_version == -1:
199 # When there is no BCQ information, do nothing
200 result_graph_def = load_graph(flags.input_path)
201 else:
202 err_msg = "BCQ version of the model(v{}) ".format(model_version)
203 err_msg += "is higher than "
204 err_msg += "the version supported by this program(v{})".format(program_version)
205 raise SystemExit(err_msg)
206
207 tf.io.write_graph(result_graph_def, '.', flags.output_path, False)
208
209
210def main():
211 # Parse argument.
212 parser = _get_parser()
213 flags = parser.parse_known_args(args=sys.argv[1:])
214
215 # Generate a new pb file, which BCQ metadata is included.
216 generate_bcq_metadata(flags[0])
217
218
219if __name__ == "__main__":
220 try:
221 main()
222 except Exception as e:
223 prog_name = os.path.basename(__file__)
224 print(f"{prog_name}: {type(e).__name__}: " + str(e))
225 sys.exit(255)
generate_metadata_header(original_graph, bcq_version, output_arrays)