ONE - On-device Neural Engine
Loading...
Searching...
No Matches
generate_bcq_output_arrays.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2
3# Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18import tensorflow as tf
19
20import argparse
21import sys
22
23
24# This function is copied from
25# https://github.com/tensorflow/tensorflow/blob/r2.3/tensorflow/examples/label_image/label_image.py#L26
26def load_graph(model_file):
27 graph = tf.Graph()
28 graph_def = tf.compat.v1.GraphDef()
29
30 with open(model_file, "rb") as f:
31 graph_def.ParseFromString(f.read())
32 with graph.as_default():
33 tf.import_graph_def(graph_def, name="")
34
35 return graph
36
37
38def get_bcq_version(input_path):
39 """
40 If BCQ metadata exists, BCQ version is in the second element.
41 Return -1 when the metadata is not found.
42 """
43 graph = load_graph(input_path)
44 graph_def = graph.as_graph_def()
45 for node in graph_def.node:
46 if node.op == "Const" and "one_compiler/bcqinfo_one_metadata" in node.name:
47 metadata_tensor = tf.make_ndarray(node.attr["value"].tensor)
48 return metadata_tensor[1]
49 return -1
50
51
52def get_bcqinfo_output_arrays_v1(input_path, output_arrays):
53 """
54 This function generates a file which includes output arrays of BCQ v1
55 information bundles. Each bundle is consisted with one of candidate
56 operations (BCQ may be applied) and BCQ constant nodes related with
57 the operation.
58 """
59 graph = load_graph(input_path)
60 ops = graph.get_operations()
61
62 # If there is a constant node named PREFIX_1/bcqinfo_alpha,
63 # it is used for applying BCQ to constant node named PREFIX_1.
64 # Collected prefixes will be used for connecting
65 # bcqinfo nodes and user operations of prefix nodes.
66 prefix_set = set()
67 has_dequant_weight = False
68 for op in ops:
69 if op.type == "Const" and "/bcqinfo_" in op.outputs[0].name:
70 # Metadata do not have prefix
71 if "one_compiler/bcqinfo_one_metadata" in op.outputs[0].name:
72 continue
73
74 prefix_index = op.outputs[0].name.index("/bcqinfo_")
75 prefix = op.outputs[0].name[:prefix_index]
76 prefix_set.add(prefix)
77
78 # Usually, output name of op is like "outputname:0"
79 # -2 is for removing ":0"
80 infoname = op.outputs[0].name[prefix_index + 1:-2]
81 if infoname == "bcqinfo_dequant_weight":
82 has_dequant_weight = True
83
84 # the name of metadata node
85 ret_output_arrays = ['one_compiler/bcqinfo_one_metadata']
86
87 # given node from user
88 ret_output_arrays += output_arrays.split(',')
89
90 # all pairs of a constant node and related BCQ information nodes.
91 for prefix in prefix_set:
92 ret_output_arrays.append(prefix + '/bcqinfo_do_w_x')
93 ret_output_arrays.append(prefix + '/bcqinfo_alpha')
94 ret_output_arrays.append(prefix + '/bcqinfo_packed_binary_code')
95 ret_output_arrays.append(prefix + '/bcqinfo_number_of_clusters')
96 ret_output_arrays.append(prefix + '/bcqinfo_size_of_clusters')
97 ret_output_arrays.append(prefix + '/bcqinfo_qbits_of_clusters')
98 ret_output_arrays.append(prefix)
99 if has_dequant_weight:
100 ret_output_arrays.append(prefix + '/bcqinfo_dequant_weight')
101
102 return ret_output_arrays
103
104
105def get_bcq_output_arrays(input_path, output_arrays):
106 """Returns BCQ output arrays that the model from input_path has"""
107 program_version = 1
108 model_version = get_bcq_version(input_path)
109
110 if model_version == 1:
111 return get_bcqinfo_output_arrays_v1(input_path, output_arrays)
112 elif model_version == -1:
113 return output_arrays.split(',')
114 else:
115 err_msg = "BCQ version of the model(v{}) ".format(model_version)
116 err_msg += "is higher than "
117 err_msg += "the version supported by this program(v{})".format(program_version)
118 raise SystemExit(err_msg)
get_bcqinfo_output_arrays_v1(input_path, output_arrays)
get_bcq_output_arrays(input_path, output_arrays)