Run the pass.
116{
117
119 {
120
122 {
123 using namespace loco;
124
126 if (decode_node == nullptr)
127 {
128 return;
129 }
130 assert(decode_node->input() != nullptr);
131
132 auto encoder = encode_node->encoder();
133 assert(encoder != nullptr);
134
135 auto decoder = decode_node->decoder();
136 assert(decoder != nullptr);
137
138
141
142 if (perm_encoder == nullptr || perm_decoder == nullptr)
143 {
144 return;
145 }
146
147 if (equal(perm_encoder->perm(), perm_decoder->perm()))
148 {
149 forwardCandidates.insert({encode_node, decode_node->input()});
150 }
151 }
152
153
155 {
156 using namespace loco;
157
159 if (encode_node == nullptr)
160 {
161 return;
162 }
163 assert(encode_node->input() != nullptr);
164
165 auto encoder = encode_node->encoder();
166 assert(encoder != nullptr);
167
168 auto decoder = decode_node->decoder();
169 assert(decoder != nullptr);
170
171
174
175 if (perm_encoder == nullptr || perm_decoder == nullptr)
176 {
177 return;
178 }
179
180 if (equal(perm_encoder->perm(), perm_decoder->perm()))
181 {
182 forwardCandidates.insert({decode_node, encode_node->input()});
183 }
184 }
185
186
188 {
189 using namespace loco;
190
192 if (encode_node == nullptr)
193 {
194 return;
195 }
196 assert(encode_node->input() != nullptr);
197
198 auto encoder = encode_node->encoder();
199 assert(encoder != nullptr);
200
201 auto decoder = decode_node->decoder();
202 assert(decoder != nullptr);
203
204
207
208 if (perm_encoder == nullptr || perm_decoder == nullptr)
209 {
210 return;
211 }
212
213 if (equal(perm_encoder->perm(), perm_decoder->perm()))
214 {
215 forwardCandidates.insert({decode_node, encode_node->input()});
216 }
217 else
218 {
219 std::vector<loco::TensorAxis> perm_vec;
220 perm_vec.resize(4);
221
222 auto enc_perm = perm_encoder->perm();
223 auto dec_perm = perm_decoder->perm();
224
225 for (const auto &axis :
226 {FilterAxis::Count, FilterAxis::Height, FilterAxis::Width, FilterAxis::Depth})
227 {
228 auto from = enc_perm->axis(axis);
229 auto to = dec_perm->axis(axis);
230 perm_vec[to] = from;
231 }
232
233 transposeCandidates.insert(
234 std::make_unique<TransposeCtx>(encode_node, decode_node, encode_node->input(), perm_vec));
235 }
236 }
237
238
240 {
241 if (
auto encode_node =
dynamic_cast<loco::BiasEncode *
>(decode_node->input()))
242 {
243 assert(encode_node->input() != nullptr);
244 forwardCandidates.insert({decode_node, encode_node->input()});
245 }
246 }
247
248
250 {
251 using namespace loco;
252
254 if (encode_node == nullptr)
255 {
256 return;
257 }
258 assert(encode_node->input() != nullptr);
259
260 auto encoder = encode_node->encoder();
261 assert(encoder != nullptr);
262
263 auto decoder = decode_node->decoder();
264 assert(decoder != nullptr);
265
266
269
270 if (perm_encoder == nullptr || perm_decoder == nullptr)
271 {
272 return;
273 }
274
275 if (equal(perm_encoder->perm(), perm_decoder->perm()))
276 {
277 forwardCandidates.insert({decode_node, encode_node->input()});
278 }
279 else
280 {
281 std::vector<TensorAxis> perm_vec;
282 perm_vec.resize(4);
283
284 auto enc_perm = perm_encoder->perm();
285 auto dec_perm = perm_decoder->perm();
286
287 for (const auto &axis : {DepthwiseFilterAxis::Depth, DepthwiseFilterAxis::Height,
288 DepthwiseFilterAxis::Width, DepthwiseFilterAxis::Multiplier})
289 {
290 auto from = enc_perm->axis(axis);
291 auto to = dec_perm->axis(axis);
292 perm_vec[to] = from;
293 }
294
295 transposeCandidates.insert(
296 std::make_unique<TransposeCtx>(encode_node, decode_node, encode_node->input(), perm_vec));
297 }
298 }
299
300
302 {
303 using namespace loco;
304
306 if (decode_node == nullptr)
307 {
308 return;
309 }
310 assert(decode_node->input() != nullptr);
311
312 auto encoder = encode_node->encoder();
313 assert(encoder != nullptr);
314
315 auto decoder = decode_node->decoder();
316 assert(decoder != nullptr);
317
318
321
322 if (perm_encoder == nullptr || perm_decoder == nullptr)
323 {
324 return;
325 }
326
327 if (equal(perm_encoder->perm(), perm_decoder->perm()))
328 {
329 forwardCandidates.insert({encode_node, decode_node->input()});
330 }
331 }
332
333
335 {
336 using namespace loco;
337
339 if (encode_node == nullptr)
340 {
341 return;
342 }
343 assert(encode_node->input() != nullptr);
344
345 auto encoder = encode_node->encoder();
346 assert(encoder != nullptr);
347
348 auto decoder = decode_node->decoder();
349 assert(decoder != nullptr);
350
351
354
355 if (perm_encoder == nullptr || perm_decoder == nullptr)
356 {
357 return;
358 }
359
360 if (equal(perm_encoder->perm(), perm_decoder->perm()))
361 {
362 forwardCandidates.insert({decode_node, encode_node->input()});
363 }
364 else
365 {
366 std::vector<loco::TensorAxis> perm_vec;
367 perm_vec.resize(2);
368
369 auto enc_perm = perm_encoder->perm();
370 auto dec_perm = perm_decoder->perm();
371
372 for (const auto &axis : {MatrixAxis::Height, MatrixAxis::Width})
373 {
374 auto from = enc_perm->axis(axis);
375 auto to = dec_perm->axis(axis);
376 perm_vec[to] = from;
377 }
378
379 transposeCandidates.insert(
380 std::make_unique<TransposeCtx>(encode_node, decode_node, encode_node->input(), perm_vec));
381 }
382 }
383
385
386 using SimplifyingInfo = std::pair<
loco::Node * ,
388 std::set<SimplifyingInfo> forwardCandidates;
389
390 struct TransposeCtx
391 {
395 std::vector<loco::TensorAxis> perm_vec;
396
398 std::vector<loco::TensorAxis> perm)
399 : first_node(first), last_node(last),
input_node(
input), perm_vec(perm)
400 {
401 }
402 };
403
404 std::set<std::unique_ptr<TransposeCtx>> transposeCandidates;
405 };
406
407 Collector collector;
408
410 {
412 {
413 auto canonical_node = loco::must_cast<loco::CanonicalNode *>(node);
414 canonical_node->accept(&collector);
415 }
416 }
417
418 for (auto p : collector.forwardCandidates)
419 {
421 forward_node->
input(p.second);
423 set_input_null(p.first);
424 }
425
426 for (auto &ctx : collector.transposeCandidates)
427 {
429 {
430 transpose_node->
perm()->
size(ctx->perm_vec.size());
431
433 transpose_node->perm()->axis(axis) = ctx->perm_vec[axis];
434 }
435
436 transpose_node->input(ctx->input_node);
438 set_input_null(ctx->first_node);
439 }
440
441 return (collector.forwardCandidates.size() > 0 or collector.transposeCandidates.size() > 0);
442}
Create a "Tensor" from a "Bias".
Create a "Bias" from a "Tensor".
static Dialect * get(void)
Create a tensor from a depthwise filter.
Create a depthwise filter from a tensor.
Create a tensor from a feature map.
Create a feature map from a tensor.
Create a tensor from a filter.
Create a filter from a tensor.
Create a new value identical to its input.
Create Tensor from Matrix.
Create Matrix from Tensor.
Logical unit of computation.
void with(Node *into) const
std::set< loco::Node * > active_nodes(const std::vector< loco::Node * > &roots)
Enumerate all the nodes required to compute "roots".
std::vector< Node * > output_nodes(Graph *)
Subst< SubstQualifier::Default > replace(Node *node)
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.