ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert_micro::import::OMDynamicShapesHandler Struct Reference

#include <OMDynamicShapesHandler.h>

Static Public Member Functions

static OMStatus importDynamicShapesFromTrainConfig (core::OMRuntimeStorage &storage, core::OMRuntimeContext &context, core::train::OMTrainingStorage &train_storage)
 

Detailed Description

Definition at line 34 of file OMDynamicShapesHandler.h.

Member Function Documentation

◆ importDynamicShapesFromTrainConfig()

OMStatus OMDynamicShapesHandler::importDynamicShapesFromTrainConfig ( core::OMRuntimeStorage storage,
core::OMRuntimeContext context,
core::train::OMTrainingStorage train_storage 
)
static

Definition at line 77 of file OMDynamicShapesHandler.cpp.

80{
81 std::unordered_map<uint16_t, uint8_t> train_op_indexes_to_train_rank =
82 context.getTrainableOpsIndexes();
83 const auto opcodes = context.getCircleOpcodes();
84
85 // Goes over pairs of op index and train rank value
86 for (auto &p : train_op_indexes_to_train_rank)
87 {
88 const uint16_t op_index = p.first;
89 const auto train_rank = static_cast<OpTrainableRankType>(p.second);
90
91 switch (train_rank)
92 {
95 {
96 const auto cur_op = context.getCircleOperatorAt(op_index);
97 const auto opcode = opcodes->operator[](cur_op->opcode_index());
98
99 int32_t res_index = getWeightTensorIndexForOperatorWithOpcode(opcode);
100 // The operation doesn't support such behaviour
101 if (res_index == -1)
102 continue;
103
104 auto tensor_local_index = static_cast<uint16_t>(res_index);
105 auto tensor_index = cur_op->inputs()->operator[](tensor_local_index);
106 auto tensor = context.getTensorByIndex(tensor_index);
107 OMRuntimeShape old_shape(tensor);
108 const float partition_size = 2.f;
109 OMRuntimeShape new_shape =
110 createDynamicRuntimeShapeForOperator(old_shape, opcode, partition_size);
111 storage.setDynamicRuntimeShape(tensor_index, new_shape);
112
113 train_storage.addTrainRank(tensor_index, train_rank);
114
115 break;
116 }
117 default:
118 continue;
119 }
120 }
121
122 return Ok;
123}
std::unordered_map< uint16_t, uint8_t > getTrainableOpsIndexes()
const circle::Operator * getCircleOperatorAt(uint16_t index)
const reader::CircleOperatorCodes * getCircleOpcodes()
const circle::Tensor * getTensorByIndex(int32_t tensor_index)
OMStatus setDynamicRuntimeShape(uint16_t tensor_index, const OMRuntimeShape &shape)
void addTrainRank(uint16_t tensor_index, core::OpTrainableRankType train_rank)

References onert_micro::core::train::OMTrainingStorage::addTrainRank(), onert_micro::core::OMRuntimeContext::getCircleOpcodes(), onert_micro::core::OMRuntimeContext::getCircleOperatorAt(), onert_micro::core::OMRuntimeContext::getTensorByIndex(), onert_micro::core::OMRuntimeContext::getTrainableOpsIndexes(), onert_micro::core::LOWER_1_2_PART, onert_micro::Ok, onert_micro::core::OMRuntimeStorage::setDynamicRuntimeShape(), and onert_micro::core::UP_1_2_PART.

Referenced by onert_micro::core::OMTrainingRuntimeModule::importTrainModel().


The documentation for this struct was generated from the following files: