Apply gradient to a trainable tensor.
33{
34 auto [grad_tensor, trainable_tensor, training_step] = factors;
35 assert(trainable_tensor.data_type() == grad_tensor.data_type());
36
37 if (trainable_tensor.getShape() != grad_tensor.getShape())
38 {
39 throw std::runtime_error("SGD: Invalid gradient tensor");
40 }
41
43 switch (grad_tensor.data_type())
44 {
45 case ir::DataType::FLOAT32:
47 ops::getShape(&trainable_tensor), ops::getBuffer<float>(&trainable_tensor),
48 ops::getShape(&grad_tensor), ops::getBuffer<float>(&grad_tensor), lr);
49 break;
50 default:
51 throw std::runtime_error("SGD: Not supported data type");
52 }
53}
double getLearningRate(uint32_t iteration=0) const override
Get the Learning Rate.
void GradientDescent(const Shape &output_shape, float *output_data, const Shape &grad_shape, const float *grad_data, float learning_rate)