18#ifndef __NNFW_CKER_EIGEN_TRAINING_OPS_H__
19#define __NNFW_CKER_EIGEN_TRAINING_OPS_H__
22#define EIGEN_USE_THREADS
24#include "unsupported/Eigen/CXX11/Tensor"
37template <
typename Device,
typename T>
struct ApplyAdam
86 Index length = var.size();
87 Index packet_size = Eigen::internal::packet_traits<T>::size;
88 if (length % packet_size == 0)
90 length = length / packet_size;
97 T *var_ptr = var.data();
100 const T *g_ptr = grad.data();
101 const T alpha = lr() * Eigen::numext::sqrt(T(1) - beta2_power()) / (T(1) - beta1_power());
107 auto shard = [var_ptr, m_ptr, v_ptr, g_ptr, alpha, beta1, beta2, epsilon, use_nesterov,
109 int t_size = (
end -
begin) * packet_size;
118 m += (g -
m) * (T(1) - beta1());
119 v += (g.square() - v) * (T(1) - beta2());
120 var -= ((g * (T(1) - beta1()) + beta1() *
m) * alpha) / (v.sqrt() + epsilon());
124 m += (g -
m) * (T(1) - beta1());
125 v += (g.square() - v) * (T(1) - beta2());
126 var -= (
m * alpha) / (v.sqrt() + epsilon());
132 const int input_bytes = length * packet_size *
sizeof(T) * 4;
133 const int output_bytes = length * packet_size *
sizeof(T) * 3;
134 const int compute_cycles =
136 (Eigen::TensorOpCost::AddCost<int>() * 5 + Eigen::TensorOpCost::MulCost<int>() * 2 +
137 Eigen::TensorOpCost::AddCost<T>() * 10 + Eigen::TensorOpCost::MulCost<T>() * 6 +
138 Eigen::TensorOpCost::DivCost<T>()) *
140 const Eigen::TensorOpCost cost(input_bytes, output_bytes, compute_cycles);
145 d.parallelFor(length, cost, shard);
158 var.device(d) -= grad * lr();
Eigen::ThreadPoolDevice CPUDevice
ShapeIterator end(const Shape &s)
Eigen::TensorMap< Eigen::TensorFixedSize< const T, Eigen::Sizes<>, Eigen::RowMajor, IndexType >, Eigen::Aligned > ConstScalar
Eigen::TensorMap< Eigen::Tensor< const T, NDIMS, Eigen::RowMajor, IndexType > > UnalignedConstTensor
Eigen::TensorMap< Eigen::Tensor< T, NDIMS, Eigen::RowMajor, IndexType > > UnalignedTensor
Eigen::TensorMap< Eigen::Tensor< const T, 1, Eigen::RowMajor, IndexType >, Eigen::Aligned > ConstFlat
Eigen::TensorMap< Eigen::Tensor< T, 1, Eigen::RowMajor, IndexType >, Eigen::Aligned > Flat
void operator()(const Device &d, typename TTypes< T >::Flat var, typename TTypes< T >::Flat m, typename TTypes< T >::Flat v, typename TTypes< T >::ConstScalar beta1_power, typename TTypes< T >::ConstScalar beta2_power, typename TTypes< T >::ConstScalar lr, typename TTypes< T >::ConstScalar beta1, typename TTypes< T >::ConstScalar beta2, typename TTypes< T >::ConstScalar epsilon, typename TTypes< T >::ConstFlat grad, bool use_nesterov)
void operator()(const Device &d, typename TTypes< T >::Flat var, typename TTypes< T >::Flat m, typename TTypes< T >::Flat v, typename TTypes< T >::ConstScalar beta1_power, typename TTypes< T >::ConstScalar beta2_power, typename TTypes< T >::ConstScalar lr, typename TTypes< T >::ConstScalar beta1, typename TTypes< T >::ConstScalar beta2, typename TTypes< T >::ConstScalar epsilon, typename TTypes< T >::ConstFlat grad, bool use_nesterov)
void operator()(const CPUDevice &d, typename TTypes< T >::Flat var, typename TTypes< T >::ConstScalar lr, typename TTypes< T >::ConstFlat grad)
void operator()(const Device &d, typename TTypes< T >::Flat var, typename TTypes< T >::ConstScalar alpha, typename TTypes< T >::ConstFlat delta)