99{
101
102 assert(input_shape.
Dims(channel_dim) == bias_shape.
Dims(0));
103 assert(input_data);
104 assert(bias_data);
105
106 Tensor bias{bias_shape,
const_cast<T *
>(bias_data)};
107 Tensor input{input_shape, input_data};
108
109 functor::Bias<Device, T> functor;
110 const Eigen::ThreadPoolDevice &d = *eigen_support::GetThreadPoolDevice();
111 functor(d,
static_cast<const Tensor &
>(input).flat<T>(),
112 static_cast<const Tensor &
>(bias).flat<T>(), input.flat<T>(), activation_min,
113 activation_max);
114}
int32_t DimensionsCount() const
int32_t Dims(int i) const