ONE - On-device Neural Engine
Loading...
Searching...
No Matches
DotBuilder.DotBuilder Class Reference

Public Member Functions

 __init__ (self, str circle_path, str dot_path, str metric, str colors)
 
 save (self, dict qerror_map)
 

Protected Member Functions

 _get_color (self, float qerror)
 
 _gen_color_table (self)
 

Protected Attributes

 _model
 
 _name
 
 _dot_path
 
 _metric
 
 _colors
 

Detailed Description

Definition at line 34 of file DotBuilder.py.

Constructor & Destructor Documentation

◆ __init__()

DotBuilder.DotBuilder.__init__ (   self,
str  circle_path,
str  dot_path,
str  metric,
str  colors 
)
circle_path: Path to the fp32 circle model (required to build graph)
dot_path: Path to the saved dot file
metric: Metric name (ex: MPEIR, MSE)
colors: List of color slots [{'b': begin, 'e': end, 'c':color}, ..]

Definition at line 35 of file DotBuilder.py.

35 def __init__(self, circle_path: str, dot_path: str, metric: str, colors: str):
36 '''
37 circle_path: Path to the fp32 circle model (required to build graph)
38 dot_path: Path to the saved dot file
39 metric: Metric name (ex: MPEIR, MSE)
40 colors: List of color slots [{'b': begin, 'e': end, 'c':color}, ..]
41 '''
42 with open(circle_path, 'rb') as f:
43 self._model = Model.Model.GetRootAsModel(f.read())
44
45 if self._model.SubgraphsLength() != 1:
46 raise RuntimeError("Only one subgraph is supported")
47
48 self._name = Path(circle_path).name
49 self._dot_path = dot_path
50 self._metric = metric
51 self._colors = colors
52

Member Function Documentation

◆ _gen_color_table()

DotBuilder.DotBuilder._gen_color_table (   self)
protected

Definition at line 74 of file DotBuilder.py.

74 def _gen_color_table(self):
75 color_table = "< <table>"
76 for slot in self._colors:
77 begin = slot['b']
78 end = slot['e']
79 color = slot['c']
80 color_table += "<tr> <td bgcolor=\""
81 color_table += color
82 color_table += "\">"
83 color_table += self._metric + ": {:.4f}".format(
84 begin) + " ~ " + "{:.4f}".format(end)
85 color_table += "</td> </tr>"
86 color_table += "</table> >"
87 return pydot.Node("color_table", shape='none', label=color_table)
88

References DotBuilder.DotBuilder._colors, Palette.Palette._colors, Palette.UniformPalette._colors, mpqsolver::core::DatasetEvaluator._metric, and DotBuilder.DotBuilder._metric.

Referenced by DotBuilder.DotBuilder.save().

◆ _get_color()

DotBuilder.DotBuilder._get_color (   self,
float  qerror 
)
protected

Definition at line 54 of file DotBuilder.py.

54 def _get_color(self, qerror: float):
55 # Find a slot where qerror is in the range of [begin, end]
56 for slot in self._colors:
57 begin = slot['b']
58 end = slot['e']
59 if (qerror > begin or math.isclose(
60 qerror, begin)) and (qerror < end or math.isclose(qerror, end)):
61 return slot['c']
62
63 # Use the first color if qerror is smaller than the first begin
64 if qerror < self._colors[0]['b']:
65 return self._colors[0]['c']
66
67 # Use the last color if qerror is larger than the last end
68 if qerror > self._colors[-1]['e']:
69 return self._colors[-1]['c']
70
71 raise RuntimeError("Color ID not found. QError: " + str(qerror))
72

References DotBuilder.DotBuilder._colors, Palette.Palette._colors, and Palette.UniformPalette._colors.

Referenced by DotBuilder.DotBuilder.save().

◆ save()

DotBuilder.DotBuilder.save (   self,
dict  qerror_map 
)
qerror_map: Dictionary of {op_name (str) -> qerror (float)}

Definition at line 90 of file DotBuilder.py.

90 def save(self, qerror_map: dict):
91 '''
92 qerror_map: Dictionary of {op_name (str) -> qerror (float)}
93 '''
94 # Build graph
95 DOT = pydot.Dot(self._name, graph_type="digraph")
96
97 # Add color table
98 DOT.add_node(self._gen_color_table())
99
100 # Dictionary from output tensor name to Op name {str -> str}
101 # This dict is for handling Ops with multiple output tensors.
102 # We use the first output tensor's name as the Op name, following
103 # the implementation of luci IR
104 output_to_op = dict()
105
106 graph = self._model.Subgraphs(0)
107
108 # Add Input nodes
109 for i in range(graph.InputsLength()):
110 name = _tensor_name(graph, graph.Inputs(i))
111 output_to_op[name] = name
112 DOT.add_node(pydot.Node(_quote(name)))
113
114 # Add Output nodes
115 for i in range(graph.OutputsLength()):
116 name = _tensor_name(graph, graph.Outputs(i))
117 output_to_op[name] = name
118 DOT.add_node(pydot.Node(_quote(name)))
119
120 # Add Edges
121 for i in range(graph.OperatorsLength()):
122 op = graph.Operators(i)
123 # Name of the first output tensor
124 op_name = _tensor_name(graph, op.Outputs(0))
125 if op.OutputsLength() == 0:
126 print(op_name)
127 continue
128
129 if op_name in qerror_map:
130 qerror = qerror_map[op_name]
131 node = pydot.Node(_quote(op_name),
132 style="filled",
133 fillcolor=self._get_color(qerror),
134 xlabel=self._metric + ": {:.4f}".format(qerror))
135 else:
136 # qerror_map does not have qerror info for the op. Color gray.
137 # When this happen? visq does not collect qerror info of some Ops
138 # For example, Reshape Op does not change values, so its qerror
139 # info is not collected.
140 node = pydot.Node(_quote(op_name), style="filled", fillcolor='gray')
141
142 DOT.add_node(node)
143
144 for output_idx in range(op.OutputsLength()):
145 output_name = _tensor_name(graph, op.Outputs(output_idx))
146 # Set Op name as the first output tensor name (op_name)
147 output_to_op[output_name] = op_name
148
149 for j in range(op.InputsLength()):
150 op_input = op.Inputs(j)
151
152 # Optional input case (ex: For TConv with no bias, bias is -1)
153 if op_input == -1:
154 continue
155
156 op_input_name = _tensor_name(graph, op_input)
157 if op_input_name not in output_to_op:
158 continue
159
160 # Use the saved name to handle multiple outputs
161 op_input_name = output_to_op[op_input_name]
162 DOT.add_edge(pydot.Edge(_quote(op_input_name), _quote(op_name)))
163
164 DOT.write(self._dot_path)

References DotBuilder.DotBuilder._dot_path, DotBuilder.DotBuilder._gen_color_table(), DotBuilder.DotBuilder._get_color(), mpqsolver::core::DatasetEvaluator._metric, DotBuilder.DotBuilder._metric, SingleOperatorTest.SingleOperatorTest._model, ScalarOperandDecl._model, TensorOperandDecl._model, WeightDecl._model, OperationDecl._model, luci::CircleReader._model, mir_onnx::ONNXImporterImpl._model, mir_tflite::TfliteImporter._model, GenericBackend._model, onnx_legalizer._ModelTransformerHelper._model, DotBuilder.DotBuilder._model, luci_interpreter::CircleReader._model, onert_micro::core::reader::OMCircleReader._model, onert::compiler::Compiler._model, onert::exporter::CircleExporter._model, onert::compiler::train::TrainingCompiler._model, onert::loader::BaseLoader< LoaderDomain >._model, cli::App._name, coco::Arg._name, enco::Pass._name, TFLTensorInfo._name, loco::NamedEntity._name, locoex::COpCall._name, locop::NodeDesc._name, luci_interpreter::Tensor._name, CircleTensorInfo._name, luci::SubGraphContext._name, luci::CircleNode._name, UnrollLSTM._name, luci::CircleNodeOrigin::Source._name, mir::Operation::Output._name, TFReluGraphUpdate._name, TFRelu6GraphUpdate._name, TFRsqrtGraphUpdate._name, TFSqrtGraphUpdate._name, TFStopGradientGraphUpdate._name, TFTanhGraphUpdate._name, moco::TFNode._name, moco::TensorName._name, nnc::ArtifactNamed._name, nnc::ArtifactModule._name, DotBuilder.DotBuilder._name, DotBuilder._quote(), and DotBuilder._tensor_name().

Field Documentation

◆ _colors

DotBuilder.DotBuilder._colors
protected

◆ _dot_path

DotBuilder.DotBuilder._dot_path
protected

Definition at line 49 of file DotBuilder.py.

Referenced by DotBuilder.DotBuilder.save().

◆ _metric

DotBuilder.DotBuilder._metric
protected

◆ _model

◆ _name

DotBuilder.DotBuilder._name
protected

Definition at line 48 of file DotBuilder.py.

Referenced by DotBuilder.DotBuilder.save().


The documentation for this class was generated from the following file: