28{
30
34
36
37 const auto num_operators = operators->
size();
39
40 uint32_t num_train_layers =
41 config.training_context.num_of_train_layers == 0
42 ? num_operators
43 : std::min(num_operators,
config.training_context.num_of_train_layers);
45
46
47
48
49 uint16_t last_train_op_indx = num_operators - num_train_layers;
50 if (!trainable_ops_config.empty())
51 {
52 last_train_op_indx = std::numeric_limits<uint16_t>::max();
53
54 for (auto &p : trainable_ops_config)
55 {
56 last_train_op_indx = std::min(p.first, last_train_op_indx);
57 }
58 num_train_layers = (num_operators - last_train_op_indx);
59 }
60
61 for (int32_t i = 0; i < num_train_layers; ++i)
62 {
63 uint32_t cur_op_index = num_operators - i - 1;
64 auto *cur_op = operators->operator[](cur_op_index);
65
66 status = allocator.
allocate(i, &context, &backward_storage);
67
69 return status;
70
72 const circle::Operator *op = operators->operator[](cur_op_index);
73 uint32_t
index = op->opcode_index();
74
75 assert(index < op_codes->
size());
76
77 const auto opcode = op_codes->operator[](
index);
78
80
83 return status;
84
85 args.kernel_index = cur_op_index;
86
87 if (i == num_train_layers - 1)
88 {
89 args.is_last_layer =
true;
90 }
91 else
92 {
93 args.is_last_layer =
false;
94 }
95
96 if (trainable_ops_config.empty())
97 {
98 args.is_trainable_layer =
true;
100 }
101 else if (trainable_ops_config.find(cur_op_index) != trainable_ops_config.end())
102 {
103 args.is_trainable_layer =
true;
105 }
106 else
107 {
108 args.is_trainable_layer =
false;
109 }
110
111
114 {
115
117 }
118 else
119 {
120 assert(false && "Unsupported kernel type for training");
122 }
123
124 assert(train_func != nullptr);
125
127 return status;
128
129 status = train_func(args);
130
131 assert(status ==
Ok);
132
134 return status;
135
136
137#ifdef OM_MEMORY_ESTIMATE
138 status = allocator.
deallocate(i, &backward_storage, &context);
140 return status;
141
142
143 status = allocator.
deallocate(i, &forward_storage, &context);
144#else
145 status = allocator.
deallocate(i, &backward_storage);
147 return status;
148
149
150 status = allocator.
deallocate(i, &forward_storage);
151#endif
152 }
153
154 return status;
155}
std::unordered_map< uint16_t, uint8_t > getTrainableOpsIndexes()
const reader::CircleOperatorCodes * getCircleOpcodes()
const reader::CircleOperators * getCircleOperators()
OMStatus allocate(size_t kernel_index, OMRuntimeContext *context, OMRuntimeStorage *storage)
OMStatus deallocate(size_t kernel_index, OMRuntimeStorage *storage)
OMStatus getKernelTrainFunc(core::OMBuilderID builderID, KernelTrainFunc **train_func) const
loco::GraphInputIndex index(const TFPlaceholder *node)
OMStatus getBuilderId(const circle::OperatorCode *opcode, core::OMBuilderID &builderID)
constexpr KernelBuiltinTrainRegistry kernel_builtin_train
OMStatus(const OMBackpropExecuteArgs &) KernelTrainFunc