ONE - On-device Neural Engine
Loading...
Searching...
No Matches
dalgona::PostOperatorHook Class Referencefinal

#include <PostOperatorHook.h>

Collaboration diagram for dalgona::PostOperatorHook:

Public Member Functions

 PostOperatorHook (py::object analysis, luci_interpreter::Interpreter *interpreter)
 
void visit (const luci::CircleNode *node)
 Default fallback.
 
void visit (const luci::CircleConv2D *node)
 
void visit (const luci::CircleDepthwiseConv2D *node)
 
void visit (const luci::CircleAdd *node)
 
void visit (const luci::CircleFullyConnected *node)
 
void visit (const luci::CircleTransposeConv *node)
 
void visit (const luci::CircleInstanceNorm *node)
 
void visit (const luci::CircleSplit *node)
 
- Public Member Functions inherited from luci::CircleNodeVisitor< void >
virtual ~CircleNodeVisitor ()=default
 
- Public Member Functions inherited from luci::CircleNodeVisitorBase< T >
virtual ~CircleNodeVisitorBase ()=default
 

Detailed Description

Definition at line 37 of file PostOperatorHook.h.

Constructor & Destructor Documentation

◆ PostOperatorHook()

dalgona::PostOperatorHook::PostOperatorHook ( py::object  analysis,
luci_interpreter::Interpreter interpreter 
)
inlineexplicit

Definition at line 72 of file PostOperatorHook.h.

73 : _analysis(analysis), _interpreter(interpreter)
74 {
75 // Do nothing
76 }

Member Function Documentation

◆ visit() [1/8]

void dalgona::PostOperatorHook::visit ( const luci::CircleAdd node)
inline

Definition at line 170 of file PostOperatorHook.h.

171 {
173
174 auto fused_act = node->fusedActivationFunction();
175
176 pySafeCall(hook,
177 node->name(), // name
178 inputs[0], // x
179 inputs[1], // y
180 output, // output
181 toString(fused_act) // fused activation
182 );
183 }
void Add(const float *input1_data, const Dims< 4 > &input1_dims, const float *input2_data, const Dims< 4 > &input2_dims, float *output_data, const Dims< 4 > &output_dims)
Definition Add.float.cpp:28
#define POST_OPERATOR_HOOK_PROLOGUE(OP_NAME)
const std::string toString(luci::CircleOpcode opcode)
void pySafeCall(py::object func, Args... args)
Definition Utils.h:29
NodeName name(void) const

References Add(), luci::CircleNodeMixin< CircleNodeTrait::FusedActFunc >::fusedActivationFunction(), luci::CircleNode::name(), POST_OPERATOR_HOOK_PROLOGUE, dalgona::pySafeCall(), and dalgona::toString().

◆ visit() [2/8]

void dalgona::PostOperatorHook::visit ( const luci::CircleConv2D node)
inline

Definition at line 116 of file PostOperatorHook.h.

117 {
119
120 auto padding = node->padding();
121 auto stride = node->stride();
122 auto dilation = node->dilation();
123
124 auto py_stride = py::dict("w"_a = stride->w(), "h"_a = stride->h());
125 auto py_dilation = py::dict("w"_a = dilation->w(), "h"_a = dilation->h());
126
127 auto fused_act = node->fusedActivationFunction();
128
129 pySafeCall(hook,
130 node->name(), // name
131 inputs[0], // input
132 inputs[1], // filter
133 inputs[2], // bias
134 padding == luci::Padding::SAME ? "SAME" : "VALID", // padding
135 py_stride, // stride
136 py_dilation, // dilation
137 output, // output
138 toString(fused_act) // fused activation
139 );
140 }
const Stride * stride(void) const
Padding padding() const
const Dilation * dilation(void) const

References luci::CircleConv2D::dilation(), luci::CircleNodeMixin< CircleNodeTrait::FusedActFunc >::fusedActivationFunction(), luci::CircleNode::name(), luci::CircleConv2D::padding(), POST_OPERATOR_HOOK_PROLOGUE, dalgona::pySafeCall(), luci::SAME, luci::CircleConv2D::stride(), and dalgona::toString().

◆ visit() [3/8]

void dalgona::PostOperatorHook::visit ( const luci::CircleDepthwiseConv2D node)
inline

Definition at line 142 of file PostOperatorHook.h.

143 {
145
146 auto padding = node->padding();
147 auto stride = node->stride();
148 auto dilation = node->dilation();
149 auto depthMultiplier = node->depthMultiplier();
150
151 auto py_stride = py::dict("w"_a = stride->w(), "h"_a = stride->h());
152 auto py_dilation = py::dict("w"_a = dilation->w(), "h"_a = dilation->h());
153
154 auto fused_act = node->fusedActivationFunction();
155
156 pySafeCall(hook,
157 node->name(), // name
158 inputs[0], // input
159 inputs[1], // filter
160 inputs[2], // bias
161 padding == luci::Padding::SAME ? "SAME" : "VALID", // padding
162 py_stride, // stride
163 depthMultiplier, // depthMultiplier
164 py_dilation, // dilation
165 output, // output
166 toString(fused_act) // fused activation
167 );
168 }
const Stride * stride(void) const
const Dilation * dilation(void) const

References luci::CircleDepthwiseConv2D::depthMultiplier(), DepthwiseConv2D, luci::CircleDepthwiseConv2D::dilation(), luci::CircleNodeMixin< CircleNodeTrait::FusedActFunc >::fusedActivationFunction(), luci::CircleNode::name(), luci::CircleDepthwiseConv2D::padding(), POST_OPERATOR_HOOK_PROLOGUE, dalgona::pySafeCall(), luci::SAME, luci::CircleDepthwiseConv2D::stride(), and dalgona::toString().

◆ visit() [4/8]

void dalgona::PostOperatorHook::visit ( const luci::CircleFullyConnected node)
inline

Definition at line 185 of file PostOperatorHook.h.

186 {
188
189 auto fused_act = node->fusedActivationFunction();
190 py::dict bias;
191 // bias is optional
192 if (inputs.size() == 3)
193 {
194 bias = inputs[2];
195 }
196 pySafeCall(hook,
197 node->name(), // name
198 inputs[0], // input
199 inputs[1], // weights
200 bias, // bias
201 output, // output
202 toString(fused_act) // fused activation
203 );
204 }
void FullyConnected(const float *input_data, const Dims< 4 > &input_dims, const float *weights_data, const Dims< 4 > &weights_dims, const float *bias_data, const Dims< 4 > &bias_dims, float *output_data, const Dims< 4 > &output_dims)

References FullyConnected(), luci::CircleNodeMixin< CircleNodeTrait::FusedActFunc >::fusedActivationFunction(), luci::CircleNode::name(), POST_OPERATOR_HOOK_PROLOGUE, dalgona::pySafeCall(), and dalgona::toString().

◆ visit() [5/8]

void dalgona::PostOperatorHook::visit ( const luci::CircleInstanceNorm node)
inline

Definition at line 227 of file PostOperatorHook.h.

228 {
229 POST_OPERATOR_HOOK_PROLOGUE(InstanceNorm)
230
231 auto epsilon = node->epsilon();
232
233 auto fused_act = node->fusedActivationFunction();
234
235 pySafeCall(hook,
236 node->name(), // name
237 inputs[0], // input
238 inputs[1], // gamma
239 inputs[2], // beta
240 epsilon, // epsilon
241 output, // output
242 toString(fused_act) // fused activation
243 );
244 }

References luci::CircleInstanceNorm::epsilon(), luci::CircleNodeMixin< CircleNodeTrait::FusedActFunc >::fusedActivationFunction(), POST_OPERATOR_HOOK_PROLOGUE, dalgona::pySafeCall(), and dalgona::toString().

◆ visit() [6/8]

void dalgona::PostOperatorHook::visit ( const luci::CircleNode )
inlinevirtual

Default fallback.

Reimplemented from luci::CircleNodeVisitor< void >.

Definition at line 79 of file PostOperatorHook.h.

80 {
81 if (not py::hasattr(_analysis, "DefaultOpPost"))
82 return;
83
84 py::object hook = _analysis.attr("DefaultOpPost");
85 auto inputs = inputsPyArray(node, _interpreter);
86
87 py::list input_list;
88 for (uint32_t i = 0; i < inputs.size(); i++)
89 {
90 input_list.append(inputs[i]);
91 }
92
93 py::list output_list;
94 if (multi_out_node(node))
95 {
96 auto outputs = outputsPyArray(node, _interpreter);
97 for (uint32_t i = 0; i < outputs.size(); i++)
98 {
99 output_list.append(outputs[i]);
100 }
101 }
102 else
103 {
104 auto output = outputPyArray(node, _interpreter);
105 output_list.append(output);
106 }
107
108 pySafeCall(hook,
109 node->name(), // name
110 toString(node->opcode()), // opcode
111 input_list, // list of inputs
112 output_list // list of outputs
113 );
114 }
std::vector< py::dict > inputsPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
Definition Utils.cpp:109
std::vector< py::dict > outputsPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
Definition Utils.cpp:134
py::dict outputPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
Definition Utils.cpp:160
bool multi_out_node(const luci::CircleNode *node)
Definition Utils.cpp:175

References dalgona::inputsPyArray(), dalgona::multi_out_node(), luci::CircleNode::name(), luci::CircleNode::opcode(), dalgona::outputPyArray(), dalgona::outputsPyArray(), dalgona::pySafeCall(), and dalgona::toString().

◆ visit() [7/8]

void dalgona::PostOperatorHook::visit ( const luci::CircleSplit node)
inline

Definition at line 246 of file PostOperatorHook.h.

247 {
249
250 py::list output_list;
251 for (uint32_t i = 0; i < outputs.size(); i++)
252 {
253 output_list.append(outputs[i]);
254 }
255
256 auto num_split = node->num_split();
257
258 pySafeCall(hook,
259 node->name(), // name
260 inputs[0], // split_dim
261 inputs[1], // input
262 num_split, // num_split
263 output_list // list of outputs
264 );
265 }
#define POST_OPERATOR_HOOK_PROLOGUE_MULTI_OUTS(OP_NAME)
int32_t num_split(void) const
Definition CircleSplit.h:42

References luci::CircleNode::name(), luci::CircleSplit::num_split(), POST_OPERATOR_HOOK_PROLOGUE_MULTI_OUTS, and dalgona::pySafeCall().

◆ visit() [8/8]

void dalgona::PostOperatorHook::visit ( const luci::CircleTransposeConv node)
inline

Definition at line 206 of file PostOperatorHook.h.

207 {
208 POST_OPERATOR_HOOK_PROLOGUE(TransposeConv)
209
210 auto padding = node->padding();
211 auto stride = node->stride();
212
213 auto py_stride = py::dict("w"_a = stride->w(), "h"_a = stride->h());
214
215 pySafeCall(hook,
216 node->name(), // name
217 inputs[2], // input
218 inputs[1], // filter
219 inputs[0], // output shape
220 inputs.size() == 4 ? inputs[3] : none(), // bias
221 padding == luci::Padding::SAME ? "SAME" : "VALID", // padding
222 py_stride, // stride
223 output // output
224 );
225 }
const Stride * stride(void) const
const Padding & padding(void) const
py::object none()
Definition Utils.cpp:107

References luci::CircleNode::name(), dalgona::none(), luci::CircleTransposeConv::padding(), POST_OPERATOR_HOOK_PROLOGUE, dalgona::pySafeCall(), luci::SAME, and luci::CircleTransposeConv::stride().


The documentation for this class was generated from the following file: