ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert_micro::train::OMBackpropExecute Struct Reference

#include <OMBackpropExecute.h>

Static Public Member Functions

static OMStatus runBackward (const OMConfig &config, OMBackpropExecuteArgs &args, core::memory::OMRuntimeAllocator &allocator)
 

Detailed Description

Definition at line 31 of file OMBackpropExecute.h.

Member Function Documentation

◆ runBackward()

OMStatus OMBackpropExecute::runBackward ( const OMConfig config,
OMBackpropExecuteArgs args,
core::memory::OMRuntimeAllocator allocator 
)
static

Definition at line 26 of file OMBackpropExecute.cpp.

28{
29 OMStatus status = Ok;
30
31 core::OMRuntimeContext &context = args.backward_context;
32 core::OMRuntimeStorage &forward_storage = args.forward_storage;
33 core::OMRuntimeStorage &backward_storage = args.backward_storage;
34
35 const core::reader::CircleOperators *operators = context.getCircleOperators();
36
37 const auto num_operators = operators->size();
38 const auto *op_codes = context.getCircleOpcodes();
39
40 uint32_t num_train_layers =
41 config.training_context.num_of_train_layers == 0
42 ? num_operators
43 : std::min(num_operators, config.training_context.num_of_train_layers);
44 std::unordered_map<uint16_t, uint8_t> trainable_ops_config = context.getTrainableOpsIndexes();
45
46 // If context has config file defined trainable operations
47 // than ignore configs.training_context.num_of_train_layers value
48 // and use max value from trainable_ops_indexes to define last train op
49 uint16_t last_train_op_indx = num_operators - num_train_layers;
50 if (!trainable_ops_config.empty())
51 {
52 last_train_op_indx = std::numeric_limits<uint16_t>::max();
53 // Find op trainable index with min value
54 for (auto &p : trainable_ops_config)
55 {
56 last_train_op_indx = std::min(p.first, last_train_op_indx);
57 }
58 num_train_layers = (num_operators - last_train_op_indx);
59 }
60
61 for (int32_t i = 0; i < num_train_layers; ++i)
62 {
63 uint32_t cur_op_index = num_operators - i - 1;
64 auto *cur_op = operators->operator[](cur_op_index);
65
66 status = allocator.allocate(i, &context, &backward_storage);
67
68 if (status != Ok)
69 return status;
70
72 const circle::Operator *op = operators->operator[](cur_op_index);
73 uint32_t index = op->opcode_index();
74
75 assert(index < op_codes->size());
76
77 const auto opcode = op_codes->operator[](index);
78
79 status = core::getBuilderId(opcode, builder_id);
80
81 assert(status == Ok);
82 if (status != Ok)
83 return status;
84
85 args.kernel_index = cur_op_index;
86
87 if (i == num_train_layers - 1)
88 {
89 args.is_last_layer = true;
90 }
91 else
92 {
93 args.is_last_layer = false;
94 }
95
96 if (trainable_ops_config.empty())
97 {
98 args.is_trainable_layer = true;
99 args.train_rank_type = core::OpTrainableRankType::ALL;
100 }
101 else if (trainable_ops_config.find(cur_op_index) != trainable_ops_config.end())
102 {
103 args.is_trainable_layer = true;
104 args.train_rank_type = core::OpTrainableRankType(trainable_ops_config[cur_op_index]);
105 }
106 else
107 {
108 args.is_trainable_layer = false;
109 }
110
111 // Calculate gradients
112 KernelTrainFunc *train_func = nullptr;
113 if (size_t(builder_id) < size_t(core::OMBuilderID::BuiltinOperatorsSize))
114 {
115 // Builtin operator
116 status = kernel_builtin_train.getKernelTrainFunc(builder_id, &train_func);
117 }
118 else
119 {
120 assert(false && "Unsupported kernel type for training");
121 return UnsupportedOp;
122 }
123
124 assert(train_func != nullptr);
125
126 if (status != Ok)
127 return status;
128
129 status = train_func(args);
130
131 assert(status == Ok);
132
133 if (status != Ok)
134 return status;
135
136 // Deallocate tensors data in backward storage
137#ifdef OM_MEMORY_ESTIMATE
138 status = allocator.deallocate(i, &backward_storage, &context);
139 if (status != Ok)
140 return status;
141
142 // Deallocate tensors data in forward storage
143 status = allocator.deallocate(i, &forward_storage, &context);
144#else
145 status = allocator.deallocate(i, &backward_storage);
146 if (status != Ok)
147 return status;
148
149 // Deallocate tensors data in forward storage
150 status = allocator.deallocate(i, &forward_storage);
151#endif
152 }
153
154 return status;
155}
uoffset_t size() const
std::unordered_map< uint16_t, uint8_t > getTrainableOpsIndexes()
const reader::CircleOperatorCodes * getCircleOpcodes()
const reader::CircleOperators * getCircleOperators()
OMStatus allocate(size_t kernel_index, OMRuntimeContext *context, OMRuntimeStorage *storage)
OMStatus deallocate(size_t kernel_index, OMRuntimeStorage *storage)
OMStatus getKernelTrainFunc(core::OMBuilderID builderID, KernelTrainFunc **train_func) const
args
Definition infer.py:21
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
OMStatus getBuilderId(const circle::OperatorCode *opcode, core::OMBuilderID &builderID)
constexpr KernelBuiltinTrainRegistry kernel_builtin_train
OMStatus(const OMBackpropExecuteArgs &) KernelTrainFunc
@ UnsupportedOp
Definition OMStatus.h:29
int32_t size[5]
Definition Slice.cpp:35

References onert_micro::core::ALL, onert_micro::core::memory::OMRuntimeAllocator::allocate(), onert_micro::core::BuiltinOperatorsSize, onert_micro::core::memory::OMRuntimeAllocator::deallocate(), onert_micro::core::getBuilderId(), onert_micro::core::OMRuntimeContext::getCircleOpcodes(), onert_micro::core::OMRuntimeContext::getCircleOperators(), onert_micro::train::KernelBuiltinTrainRegistry::getKernelTrainFunc(), onert_micro::core::OMRuntimeContext::getTrainableOpsIndexes(), onert_micro::train::kernel_builtin_train, onert_micro::Ok, flatbuffers::Vector< T >::size(), size, onert_micro::core::Size, and onert_micro::UnsupportedOp.

Referenced by onert_micro::core::OMTrainingRuntimeModule::trainSingleStep().


The documentation for this struct was generated from the following files: