Apply gradient to a trainable tensor.
34{
35 auto [grad_tensor, trainable_tensor, training_step] = factors;
36 assert(trainable_tensor.data_type() == grad_tensor.data_type());
37
38 const auto opt_vars = trainable_tensor.optVars();
39 assert(opt_vars.size() == 2);
40
41 auto m_tensor = nnfw::misc::polymorphic_downcast<IPortableTensor *>(opt_vars.at(0));
42
43 auto v_tensor = nnfw::misc::polymorphic_downcast<IPortableTensor *>(opt_vars.at(1));
44
45 const auto beta1_power = std::pow(_props.
beta1, training_step + 1);
46 const auto beta2_power = std::pow(_props.
beta2, training_step + 1);
47
48 const bool use_nesterov = false;
49
50 if (trainable_tensor.getShape() != grad_tensor.getShape())
51 {
52 throw std::runtime_error("Adam: Invalid gradient tensor");
53 }
54
55 switch (grad_tensor.data_type())
56 {
57 case ir::DataType::FLOAT32:
59 ops::getShape(&trainable_tensor), ops::getBuffer<float>(&trainable_tensor),
60 ops::getShape(&grad_tensor), ops::getBuffer<float>(&grad_tensor), ops::getShape(m_tensor),
61 ops::getBuffer<float>(m_tensor), ops::getShape(v_tensor), ops::getBuffer<float>(v_tensor),
62 beta1_power, beta2_power, _learning_rate, _props.
beta1, _props.
beta2, _props.
epsilon,
63 use_nesterov);
64 break;
65 default:
66 throw std::runtime_error("Adam: Not supported data type");
67 }
68}
void Adam(const Shape &trainable_shape, float *trainable_data, const Shape &grad_shape, const float *grad_data, const Shape &m_shape, float *m_data, const Shape &v_shape, float *v_data, float beta1_power, float beta2_power, float learning_rate, float beta1, float beta2, float epsilon, bool use_nesterov)