18#ifndef __NNFW_CKER_TRAIN_OPTIMIZER_SGD_H__
19#define __NNFW_CKER_TRAIN_OPTIMIZER_SGD_H__
34 const float *grad_data,
float learning_rate)
41 output_tensor.buffer = output_data;
44 grad_tensor.
buffer =
const_cast<float *
>(grad_data);
46 std::vector<float> lr_vec{learning_rate};
47 lr_tensor.
buffer = lr_vec.data();
50 throw std::runtime_error(
51 "cker::GradientDescent: output and gradient do not have the same shape");
55 device, output_tensor.flat<
float>(), lr_tensor.
scalar<
float>(),
56 static_cast<const Tensor &
>(grad_tensor).flat<float>());
int32_t DimensionsCount() const
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
const luci_interpreter::RuntimeShape output_shape
const Eigen::ThreadPoolDevice * GetThreadPoolDevice()
void GradientDescent(const Shape &output_shape, float *output_data, const Shape &grad_shape, const float *grad_data, float learning_rate)
Eigen::ThreadPoolDevice CPUDevice
TTypes< T >::ConstScalar scalar() const