33inline void Adam(
const Shape &trainable_shape,
float *trainable_data,
const Shape &grad_shape,
34 const float *grad_data,
const Shape &m_shape,
float *m_data,
const Shape &v_shape,
35 float *v_data,
float beta1_power,
float beta2_power,
float learning_rate,
36 float beta1,
float beta2,
float epsilon,
bool use_nesterov)
50 trainable_tensor.
buffer = trainable_data;
53 grad_tensor.
buffer =
const_cast<float *
>(grad_data);
61 std::vector<float> beta1_power_vec{beta1_power};
62 beta1_power_tensor.
buffer = beta1_power_vec.data();
64 std::vector<float> beta2_power_vec{beta2_power};
65 beta2_power_tensor.
buffer = beta2_power_vec.data();
67 std::vector<float> lr_vec{learning_rate};
68 lr_tensor.
buffer = lr_vec.data();
70 std::vector<float> beta1_vec{beta1};
71 beta1_tensor.
buffer = beta1_vec.data();
73 std::vector<float> beta2_vec{beta2};
74 beta2_tensor.
buffer = beta2_vec.data();
76 std::vector<float> epsilon_vec{epsilon};
77 epsilon_tensor.
buffer = epsilon_vec.data();
79 if (trainable_shape != m_shape)
80 throw std::runtime_error(
"cker::Adam: output and m do not have the same shape");
82 if (trainable_shape != v_shape)
83 throw std::runtime_error(
"cker::Adam: output and v do not have the same shape");
85 if (trainable_shape != grad_shape)
86 throw std::runtime_error(
"cker::Adam: output and gradient do not have the same shape");
90 device, trainable_tensor.
flat<
float>(), m_tensor.
flat<
float>(), v_tensor.
flat<
float>(),
91 beta1_power_tensor.
scalar<
float>(), beta2_power_tensor.
scalar<
float>(),
93 epsilon_tensor.
scalar<
float>(),
static_cast<const Tensor &
>(grad_tensor).
flat<
float>(),
void Adam(const Shape &trainable_shape, float *trainable_data, const Shape &grad_shape, const float *grad_data, const Shape &m_shape, float *m_data, const Shape &v_shape, float *v_data, float beta1_power, float beta2_power, float learning_rate, float beta1, float beta2, float epsilon, bool use_nesterov)