ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::compiler::train::StaticBackwardShapeInferer Class Reference

Class to infer shape before running kernels. It does the following: More...

#include <StaticBackwardShapeInferer.h>

Collaboration diagram for onert::compiler::train::StaticBackwardShapeInferer:

Public Member Functions

 StaticBackwardShapeInferer (compiler::train::LoweredTrainableGraph *lowered_subg)
 
void infer (void)
 Infer shape of operands belonging to ops and set the output shape. If output shape cannot be known without running op, mark it so that it can be allocated when running kernel.
 
void dump ()
 
- Public Member Functions inherited from onert::ir::train::TrainableOperationVisitor
virtual ~TrainableOperationVisitor ()=default
 

Detailed Description

Class to infer shape before running kernels. It does the following:

  • re-calculate and set output shape at compile time (before running kernels)
  • if calculation cannot be done at compile time, mark the outputs to be dynamic, meaning shapes of outputs will be calculated during running kernels

Definition at line 41 of file StaticBackwardShapeInferer.h.

Constructor & Destructor Documentation

◆ StaticBackwardShapeInferer()

onert::compiler::train::StaticBackwardShapeInferer::StaticBackwardShapeInferer ( compiler::train::LoweredTrainableGraph lowered_subg)
inline

Definition at line 44 of file StaticBackwardShapeInferer.h.

45 : _lowered_subg{lowered_subg}
46 {
47 }

Member Function Documentation

◆ dump()

void onert::compiler::train::StaticBackwardShapeInferer::dump ( )

Definition at line 55 of file StaticBackwardShapeInferer.cc.

56{
57 // TODO dump
58}

◆ infer()

void onert::compiler::train::StaticBackwardShapeInferer::infer ( void  )

Infer shape of operands belonging to ops and set the output shape. If output shape cannot be known without running op, mark it so that it can be allocated when running kernel.

Definition at line 33 of file StaticBackwardShapeInferer.cc.

34{
35 // It is not determined to iterate in reverse order.
36 auto sorted_ops = _lowered_subg->graph().topolSortOperations();
37 for (auto it = sorted_ops.rbegin(); it != sorted_ops.rend(); ++it)
38 {
39 const auto op_idx = *it;
40 const auto &op = _lowered_subg->trainable_graph().operation(op_idx);
41 if (checkDynamicInput(op))
42 {
43 std::stringstream msg;
44 msg << "StaticBackwardShapeInferer does not support dynamic shape yet, ";
45 msg << op.name() << "(op index: " << op_idx << ") has dynamic shape.";
46 throw std::runtime_error(msg.str());
47 }
48
49 checkOutput(op);
50
51 op.accept(*this);
52 }
53}
std::vector< ir::OperationIndex > topolSortOperations() const
Definition Graph.cc:184
const ITrainableOperation & operation(OperationIndex index) const
virtual std::string name() const
Definition IOperation.h:38

References onert::compiler::train::LoweredTrainableGraph::graph(), onert::ir::IOperation::name(), onert::ir::train::TrainableGraph::operation(), onert::ir::Graph::topolSortOperations(), and onert::compiler::train::LoweredTrainableGraph::trainable_graph().


The documentation for this class was generated from the following files: