Apply gradient to a trainable tensor.
40{
41 auto [grad_tensor, trainable_tensor, training_step] = factors;
42 assert(trainable_tensor.data_type() == grad_tensor.data_type());
43
44 const auto opt_vars = trainable_tensor.optVars();
45 assert(opt_vars.size() == 2);
46
47 auto m_tensor = nnfw::misc::polymorphic_downcast<IPortableTensor *>(opt_vars.at(0));
48
49 auto v_tensor = nnfw::misc::polymorphic_downcast<IPortableTensor *>(opt_vars.at(1));
50
51 const auto beta1_power = std::pow(_props.
beta1, training_step + 1);
52 const auto beta2_power = std::pow(_props.
beta2, training_step + 1);
53
54 const bool use_nesterov = false;
55
56 if (trainable_tensor.getShape() != grad_tensor.getShape())
57 {
58 throw std::runtime_error("Adam: Invalid gradient tensor");
59 }
60
61 switch (grad_tensor.data_type())
62 {
63 case ir::DataType::FLOAT32:
65 ops::getShape(&trainable_tensor), ops::getBuffer<float>(&trainable_tensor),
66 ops::getShape(&grad_tensor), ops::getBuffer<float>(&grad_tensor), ops::getShape(m_tensor),
67 ops::getBuffer<float>(m_tensor), ops::getShape(v_tensor), ops::getBuffer<float>(v_tensor),
68 beta1_power, beta2_power, _learning_rate, _props.
beta1, _props.
beta2, _props.
epsilon,
69 use_nesterov);
70 break;
71 default:
72 throw std::runtime_error("Adam: Not supported data type");
73 }
74}
void Adam(const Shape &trainable_shape, float *trainable_data, const Shape &grad_shape, const float *grad_data, const Shape &m_shape, float *m_data, const Shape &v_shape, float *v_data, float beta1_power, float beta2_power, float learning_rate, float beta1, float beta2, float epsilon, bool use_nesterov)