18#ifndef __NNFW_CKER_EIGEN_BIAS_OP_H__
19#define __NNFW_CKER_EIGEN_BIAS_OP_H__
22#define EIGEN_USE_THREADS
24#include "unsupported/Eigen/CXX11/Tensor"
43 template <
typename Func,
typename... Args>
void operator()(Func func, Args &&...args)
const
45 func(std::forward<Args>(args)...);
51template <
typename Device,
typename Func,
typename... Args>
59template <
typename Device,
typename T>
struct Bias
64 T activation_min, T activation_max)
66 const Eigen::Index rest_size = input.size() / bias.dimension(0);
67 Eigen::DSizes<Eigen::Index, 1> bcast(rest_size);
68 MaybeWith32BitIndexing<Device>(
69 [&](
auto input32,
auto bias32,
typename TTypes<T>::Flat output32,
const auto &bcast32,
70 T activation_min, T activation_max) {
72 (input32 + bias32.broadcast(bcast32))
73 .template cwiseMax<Eigen::PropagateNaN>(
static_cast<T
>(activation_min))
74 .template cwiseMin<Eigen::PropagateNaN>(
static_cast<T
>(activation_max));
76 input, bias, output, bcast, activation_min, activation_max);
94using Device = Eigen::ThreadPoolDevice;
98 T *input_data, T activation_min, T activation_max)
102 assert(input_shape.
Dims(channel_dim) == bias_shape.
Dims(0));
106 Tensor bias{bias_shape,
const_cast<T *
>(bias_data)};
107 Tensor input{input_shape, input_data};
111 functor(d,
static_cast<const Tensor &
>(input).flat<T>(),
112 static_cast<const Tensor &
>(bias).flat<T>(), input.flat<T>(), activation_min,
int32_t DimensionsCount() const
int32_t Dims(int i) const
void MaybeWith32BitIndexing(Func func, Args &&...args)
void biasHelper(const Shape &bias_shape, const T *bias_data, const Shape &input_shape, T *input_data, T activation_min, T activation_max)
Eigen::ThreadPoolDevice Device
const Eigen::ThreadPoolDevice * GetThreadPoolDevice()
Eigen::TensorMap< Eigen::Tensor< const T, 1, Eigen::RowMajor, IndexType >, Eigen::Aligned > ConstFlat
Eigen::TensorMap< Eigen::Tensor< T, 1, Eigen::RowMajor, IndexType >, Eigen::Aligned > Flat
Eigen::TensorMap< Eigen::Tensor< const T, 1, Eigen::RowMajor, IndexType >, Eigen::Aligned > ConstVec
void operator()(const Device &d, typename TTypes< T >::ConstFlat input, typename TTypes< T >::ConstVec bias, typename TTypes< T >::Flat output, T activation_min, T activation_max)
void operator()(Func func, Args &&...args) const