332{
333 assert(!_tensor_index_to_gradient.empty());
334 for (auto &tensor_to_data : _tensor_index_to_gradient)
335 {
336 auto exponent_squares_it = _tensor_to_exponent_avg_squares.find(tensor_to_data.first);
337 if (exponent_squares_it == _tensor_to_exponent_avg_squares.end())
339
340 auto exponent_it = _tensor_to_exponent_avg.find(tensor_to_data.first);
341 if (exponent_it == _tensor_to_exponent_avg.end())
343
346
347 auto original_d = shape.dims(0);
348
350
351#ifndef DIS_DYN_SHAPES
353 if (dynamic_tensor_size != 0)
355#endif
356
357 auto *exponent_data = reinterpret_cast<float *>(exponent_it->second);
358 auto *exponent_square_data = reinterpret_cast<float *>(exponent_squares_it->second);
359 auto *calculated_data = reinterpret_cast<float *>(tensor_to_data.second);
360 float beta = training_config.
beta;
362 auto batches =
static_cast<float>(training_config.
batch_size);
364 {
365 const auto cur_val = calculated_data[i];
366 exponent_data[i] = beta * exponent_data[i] + (1 - beta) * cur_val;
367 exponent_square_data[i] =
368 beta_squares * exponent_square_data[i] + (1 - beta_squares) * cur_val * cur_val;
369 }
370
371 uint8_t *weight_data = nullptr;
374
375 assert(weight_data != nullptr);
376 if (weight_data == nullptr)
378
379 auto *f_weight_data = reinterpret_cast<float *>(weight_data);
381 auto num_step =
static_cast<float>(training_config.
num_step);
382 float beta_in_pow_batch = std::pow(beta, num_step);
383 float beta_square_in_pow_batch = std::pow(beta_squares, num_step);
384 float epsilon = training_config.
epsilon;
385
386 assert((1.f - beta_in_pow_batch) != 0);
387 assert((1.f - beta_square_in_pow_batch) != 0);
388 auto train_it = tensor_index_to_rank_type_map.find(tensor_to_data.first);
393
395 {
396 float exponent_corrected = exponent_data[i] / (1.f - beta_in_pow_batch);
397 float exponent_square_corrected = exponent_square_data[i] / (1.f - beta_square_in_pow_batch);
398 f_weight_data[i + depth_bounds.first] -=
399 lambda * (exponent_corrected / (std::sqrt(exponent_square_corrected + epsilon)));
400 }
401 }
402
404}
OMStatus getConstDataByTensorIndex(uint8_t **data, uint16_t tensor_index)
std::pair< uint32_t, uint32_t > getUpLowerWeightTensorDepth(core::OpTrainableRankType rank, const uint32_t output_depth)