ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
onert::compiler::StaticShapeInferer Class Reference

Class to infer shape before running kernels. It does the following: More...

#include <StaticShapeInferer.h>

Collaboration diagram for onert::compiler::StaticShapeInferer:

Public Member Functions

 StaticShapeInferer (compiler::ILoweredGraph *lowered_subg)
 
virtual ~StaticShapeInferer ()=default
 
void appendSubgInputObserver (const ir::SubgraphIndex &subg_idx, std::unique_ptr< OperandObserver > &&subg_input_observer) noexcept
 
void setControlflowOutputObserver (std::unique_ptr< OperandObserver > &&output_observer) noexcept
 
void appendChildInferer (const ir::SubgraphIndex &subg_idx, compiler::StaticShapeInferer *inferer)
 
void infer (void)
 Infer shape of operands belonging to ops and set the output shape. If output shape cannot be known without running op, mark it so that it can be allocated when running kernel.
 
void dump ()
 
- Public Member Functions inherited from onert::ir::OperationVisitor
virtual ~OperationVisitor ()=default
 

Static Public Member Functions

static std::unordered_map< ir::SubgraphIndex, std::unique_ptr< StaticShapeInferer > > createStaticShapeInferers (const std::unordered_map< ir::SubgraphIndex, ILoweredGraph * > &lowered_subgs)
 Create a shape inferer map for a lowered model.
 

Detailed Description

Class to infer shape before running kernels. It does the following:

  • re-calculate and set output shape at compile time (before running kernels)
  • if calculation cannot be done at compile time, mark the outputs to be dynamic, meaning shapes of outputs will be calculated during running kernels

Definition at line 66 of file StaticShapeInferer.h.

Constructor & Destructor Documentation

◆ StaticShapeInferer()

onert::compiler::StaticShapeInferer::StaticShapeInferer ( compiler::ILoweredGraph lowered_subg)
inline

Definition at line 69 of file StaticShapeInferer.h.

70 : _lowered_subg{lowered_subg}, _subg_input_observers{}, _controlflow_output_observer{nullptr},
71 _child_inferers{}
72 {
73 }

◆ ~StaticShapeInferer()

virtual onert::compiler::StaticShapeInferer::~StaticShapeInferer ( )
virtualdefault

Member Function Documentation

◆ appendChildInferer()

void onert::compiler::StaticShapeInferer::appendChildInferer ( const ir::SubgraphIndex subg_idx,
compiler::StaticShapeInferer inferer 
)
inline

Definition at line 88 of file StaticShapeInferer.h.

89 {
90 _child_inferers[subg_idx] = inferer;
91 }

Referenced by createStaticShapeInferers().

◆ appendSubgInputObserver()

void onert::compiler::StaticShapeInferer::appendSubgInputObserver ( const ir::SubgraphIndex subg_idx,
std::unique_ptr< OperandObserver > &&  subg_input_observer 
)
inlinenoexcept

Definition at line 77 of file StaticShapeInferer.h.

79 {
80 _subg_input_observers[subg_idx] = std::move(subg_input_observer);
81 }

◆ createStaticShapeInferers()

std::unordered_map< ir::SubgraphIndex, std::unique_ptr< StaticShapeInferer > > onert::compiler::StaticShapeInferer::createStaticShapeInferers ( const std::unordered_map< ir::SubgraphIndex, ILoweredGraph * > &  lowered_subgs)
static

Create a shape inferer map for a lowered model.

Parameters
[in]lowered_subgslowered model map
Returns
Shape inferer map

Definition at line 192 of file StaticShapeInferer.cc.

194{
195 // Allocate StaticShapeInferer per each subgraph
196 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers;
197 for (auto &&[subg_index, lowered_subg] : lowered_subgs)
198 {
199 inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg);
200 }
201
202 // Append observers in all StaticShapeInferers
203 for (auto &&pair : lowered_subgs)
204 {
205 const auto &subg_index = pair.first;
206 auto &lowered_subg = pair.second;
207
208 // TODO: Change this iteration for all to controlflow iteration
209 lowered_subg->graph().operations().iterate(
210 [&](const ir::OperationIndex &, const ir::IOperation &op) {
211 // A Function to append child inferers. These make it possible for a StaticShapeInferer to
212 // call StaticShapeInferes of child subgraphs recursively
213 auto appendChildInferer = [&](const ir::SubgraphIndex &child_subg_idx) {
214 auto *child_inferer = inferers.at(child_subg_idx).get();
215 inferers.at(subg_index)->appendChildInferer(child_subg_idx, child_inferer);
216 };
217
218 // A Function to appaend subg input observers. This makes it possible for a
219 // StaticShapeInferer to update inputs of child subgraphs
220 auto appendSubgraphInputObserver = [&](const ir::SubgraphIndex &child_subg_idx) {
221 std::vector<ir::Operand *> child_subg_inputs;
222 auto &child_subg = lowered_subgs.at(child_subg_idx)->graph();
223 for (const auto &input_idx : child_subg.getInputs())
224 {
225 auto operand_ptr = child_subg.operands().getRawPtr(input_idx);
226 child_subg_inputs.emplace_back(operand_ptr);
227 }
228 inferers.at(subg_index)
229 ->appendSubgInputObserver(child_subg_idx,
230 std::make_unique<OperandObserver>(child_subg_inputs));
231 };
232
233 // A Function to set controlflow output observers. This makes it possible for a
234 // StaticShapeInferer to update outputs of parent controlflow opeerations
235 auto setControlFlowOutputObserver = [&](const ir::SubgraphIndex &child_subg_idx) {
236 std::vector<ir::Operand *> cf_outputs;
237 auto &subg = lowered_subg->graph();
238 for (const auto &output_idx : op.getOutputs())
239 {
240 auto operand_ptr = subg.operands().getRawPtr(output_idx);
241 cf_outputs.emplace_back(operand_ptr);
242 }
243 inferers.at(child_subg_idx)
244 ->setControlflowOutputObserver(std::make_unique<OperandObserver>(cf_outputs));
245 };
246
247 // Append Observers in a StaticShapeInferer
248 if (op.opcode() == ir::OpCode::If)
249 {
250 // TODO Remove dynamic_cast
251 // An virtual base class cannot be downcasted by static_cast
252 try
253 {
254 const auto &if_op = dynamic_cast<const ir::operation::If &>(op);
255
256 appendChildInferer(if_op.param().then_subg_index);
257 appendChildInferer(if_op.param().else_subg_index);
258
259 appendSubgraphInputObserver(if_op.param().then_subg_index);
260 appendSubgraphInputObserver(if_op.param().else_subg_index);
261
262 setControlFlowOutputObserver(if_op.param().then_subg_index);
263 }
264 catch (const std::bad_cast &)
265 {
266 throw std::runtime_error("StaticShapeInferer: Invalid If operation");
267 }
268 }
269 else if (op.opcode() == ir::OpCode::While)
270 {
271 // TODO Remove dynamic_cast
272 try
273 {
274 const auto &while_op = dynamic_cast<const ir::operation::While &>(op);
275
276 appendChildInferer(while_op.param().cond_subg_index);
277 appendChildInferer(while_op.param().body_subg_index);
278
279 appendSubgraphInputObserver(while_op.param().cond_subg_index);
280 appendSubgraphInputObserver(while_op.param().body_subg_index);
281
282 setControlFlowOutputObserver(while_op.param().body_subg_index);
283 }
284 catch (const std::bad_cast &)
285 {
286 throw std::runtime_error("StaticShapeInferer: Invalid While operation");
287 }
288 }
289 });
290 }
291
292 return inferers;
293}
void appendChildInferer(const ir::SubgraphIndex &subg_idx, compiler::StaticShapeInferer *inferer)
::onert::util::Index< uint32_t, OperationIndexTag > OperationIndex
Definition Index.h:30
::onert::util::Index< uint16_t, SubgraphIndexTag > SubgraphIndex
Definition Index.h:39

References appendChildInferer(), onert::ir::IOperation::getOutputs(), and onert::ir::IOperation::opcode().

◆ dump()

void onert::compiler::StaticShapeInferer::dump ( )

Definition at line 167 of file StaticShapeInferer.cc.

168{
169 auto get_shape_str = [](const ir::Shape &shape) {
170 std::stringstream sstream;
171 sstream << "shape : {";
172 for (int i = 0; i < shape.rank(); i++)
173 {
174 if (i == 0)
175 sstream << shape.dim(i);
176 else
177 sstream << " " << shape.dim(i);
178 }
179 sstream << "}";
180 return sstream.str();
181 };
182
183 _lowered_subg->graph().operands().iterate(
184 [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
185 VERBOSE(StaticShapeInferer) << " " << ind << ", "
186 << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", "
187 << get_shape_str(operand.info().shape()) << std::endl;
188 });
189}
const Operands & operands() const override
Definition Graph.h:110
void iterate(const std::function< void(const Index &, const Object &)> &fn) const
Iterate over the container with given function.
#define VERBOSE(name, lv)
Definition Log.h:71
::onert::util::Index< uint32_t, OperandIndexTag > OperandIndex
Definition Index.h:33
virtual ir::Graph & graph()=0

References onert::compiler::ILoweredGraph::graph(), onert::ir::Operand::info(), onert::ir::OperandInfo::isDynamic(), onert::util::ObjectManager< Index, Object >::iterate(), onert::ir::Graph::operands(), onert::ir::OperandInfo::shape(), and VERBOSE.

◆ infer()

void onert::compiler::StaticShapeInferer::infer ( void  )

Infer shape of operands belonging to ops and set the output shape. If output shape cannot be known without running op, mark it so that it can be allocated when running kernel.

Definition at line 56 of file StaticShapeInferer.cc.

57{
58 for (const auto &op_idx : _lowered_subg->graph().topolSortOperations())
59 {
60 const auto &op = _lowered_subg->graph().operations().at(op_idx);
61 bool has_dynamic_tensor = false;
62 const auto opcode = op.opcode();
63 // IF: requires shape inference for then, else
64 // While: requires shape inference for condition, body
65 if (opcode == ir::OpCode::If || opcode == ir::OpCode::While)
66 {
67 op.accept(*this);
68 }
69 else
70 {
71 has_dynamic_tensor = checkDynamicInput(op);
72 if (has_dynamic_tensor)
73 {
74 setDynamicOutput(op);
75 }
76 else
77 {
78 op.accept(*this);
79 }
80 }
81 has_dynamic_tensor = has_dynamic_tensor || checkDynamicOutput(op);
82 _lowered_subg->setHasDynamicTensor(op_idx, has_dynamic_tensor);
83 }
84
85 if (_controlflow_output_observer != nullptr)
86 {
87 // re-sizing output shapes of the controflow operation branching to this subgraph
88 std::vector<ir::OperandInfo> outputs_info;
89 const auto &graph = _lowered_subg->graph();
90 const auto &outputs = graph.getOutputs();
91 for (size_t i = 0; i < outputs.size(); ++i)
92 {
93 const auto &operand_info = graph.operands().at(outputs.at(i)).info();
94 outputs_info.emplace_back(operand_info);
95 }
96 _controlflow_output_observer->updateShapes(outputs_info);
97 }
98}
const Operations & operations() const override
Definition Graph.h:112
const Object & at(const Index &index) const
Get the object that is associated with the given index.
virtual void setHasDynamicTensor(ir::OperationIndex ind, bool val)=0

References onert::util::ObjectManager< Index, Object >::at(), onert::compiler::ILoweredGraph::graph(), onert::ir::Graph::operations(), onert::compiler::ILoweredGraph::setHasDynamicTensor(), and onert::ir::Graph::topolSortOperations().

◆ setControlflowOutputObserver()

void onert::compiler::StaticShapeInferer::setControlflowOutputObserver ( std::unique_ptr< OperandObserver > &&  output_observer)
inlinenoexcept

Definition at line 83 of file StaticShapeInferer.h.

84 {
85 _controlflow_output_observer = std::move(output_observer);
86 }

The documentation for this class was generated from the following files: