151{
152 assert(!_tensor_index_to_gradient.empty());
153 if (_tensor_index_to_gradient.empty())
155
156 for (auto &tensor_to_data : _tensor_index_to_gradient)
157 {
160 auto num_elements = shape.flatSize();
161
162 auto original_d = shape.dims(0);
163
164#ifndef DIS_DYN_SHAPES
166 if (dynamic_tensor_size != 0)
167 num_elements = dynamic_tensor_size;
168#endif
169
170 auto *grad_data = reinterpret_cast<float *>(tensor_to_data.second);
171 uint8_t *weight_data = nullptr;
174
175 assert(weight_data != nullptr);
176 if (weight_data == nullptr)
178
179 auto *f_weight_data = reinterpret_cast<float *>(weight_data);
181 const uint32_t batch_size = training_config.
batch_size;
182 auto train_it = tensor_index_to_rank_type_map.find(tensor_to_data.first);
187
188 assert(batch_size != 0);
189
190 for (uint32_t i = 0; i < num_elements; ++i)
191 {
192 f_weight_data[i] -= (lambda * grad_data[i]) / (static_cast<float>(batch_size));
193 }
194 }
196}
#define OM_LOG_AND_RETURN(err, msg)
OMStatus getConstDataByTensorIndex(uint8_t **data, uint16_t tensor_index)
std::pair< uint32_t, uint32_t > getUpLowerWeightTensorDepth(core::OpTrainableRankType rank, const uint32_t output_depth)