41 void operator()(
const Device &device,
const Eigen::DSizes<Eigen::Index, num_dims> &input_dims,
45 const int num_output_dims = output->shape.DimensionsCount();
46 auto output_dims = output->template flat<OutputT>().dimensions();
48 Eigen::Index inner_dim = 1, outer_dim = 1;
49 for (
int i = 0; i < num_dims - num_output_dims; ++i)
50 outer_dim *= input_dims[i];
51 for (
int i = num_dims - num_output_dims; i < num_dims; ++i)
52 inner_dim *= input_dims[i];
57 output->template flat<OutputT>() =
58 input.template flat<InputT>().template cast<OutputT>().reshape(output_dims);
63 const Eigen::Index num_threads = device.numThreads();
68 if (inner_dim > num_threads * 32)
71 const Eigen::Index num_blocks = num_threads;
74 const Eigen::Index inner_block_size = Eigen::divup(inner_dim, num_blocks);
75 const InputT *input_data = input.template flat<InputT>().data();
78 Eigen::Tensor<AccumT, 1, Eigen::RowMajor, Eigen::Index> buffer({inner_dim});
80 AccumT *buffer_data = buffer.data();
83 Eigen::TensorMap<Eigen::Tensor<AccumT, 1, Eigen::RowMajor, Eigen::Index>, Eigen::Unaligned>;
85 using Input = Eigen::TensorMap<Eigen::Tensor<const InputT, 1, Eigen::RowMajor, Eigen::Index>,
88 const auto compute = [inner_dim, outer_dim, inner_block_size, input_data,
89 buffer_data](Eigen::Index start, Eigen::Index limit) ->
void {
90 Eigen::Index inner_dim_start = start * inner_block_size;
91 Eigen::Index inner_dim_limit = limit * inner_block_size;
92 inner_dim_limit = std::min(inner_dim, inner_dim_limit);
93 Eigen::Index my_job_len = inner_dim_limit - inner_dim_start;
95 const InputT *my_job_start = input_data + inner_dim_start;
96 Buffer buf(buffer_data + inner_dim_start, my_job_len);
98 for (Eigen::Index i = 0; i < outer_dim; ++i)
100 auto in = Input(my_job_start + i * inner_dim, my_job_len);
101 auto cast = in.template cast<AccumT>();
103 Eigen::TensorCwiseBinaryOp<BinaryFunctor, const decltype(buf), const decltype(cast)>(
109 const Eigen::Index compute_size = outer_dim * inner_block_size;
110 const Eigen::Index compute_input_bytes = compute_size *
sizeof(InputT);
111 const Eigen::TensorOpCost cost(compute_input_bytes,
114 Eigen::internal::functor_traits<BinaryFunctor>::Cost);
116 device.parallelFor(num_blocks, cost, compute);
119 output->template flat<OutputT>() = buffer.template cast<OutputT>().reshape(output_dims);
124 const Eigen::Index parallel_cell_size = inner_dim;
125 const Eigen::Index total_workload = outer_dim * inner_dim;
126 const Eigen::Index max_parallelism = total_workload / parallel_cell_size;
128 const Eigen::Index min_block_workload = 2000;
129 const Eigen::Index min_block_size = Eigen::divup(min_block_workload, parallel_cell_size);
130 const Eigen::Index max_num_blocks =
131 std::min(max_parallelism, Eigen::divup(total_workload, min_block_size));
134 const Eigen::Index num_blocks = std::min(max_num_blocks, num_threads);
137 const Eigen::Index outer_block_size = Eigen::divup(outer_dim, num_blocks);
139 const InputT *input_data = input.template flat<InputT>().data();
142 std::vector<AccumT> buffer(num_blocks * inner_dim);
143 AccumT *buffer_data = buffer.data();
146 Eigen::TensorMap<Eigen::Tensor<AccumT, 1, Eigen::RowMajor, Eigen::Index>, Eigen::Unaligned>;
148 using Input = Eigen::TensorMap<Eigen::Tensor<const InputT, 1, Eigen::RowMajor, Eigen::Index>,
151 const auto compute = [inner_dim, outer_block_size, buffer_data, input_data,
152 outer_dim](Eigen::Index start, Eigen::Index limit) ->
void {
153 Eigen::Index outer_dim_start = start * outer_block_size;
154 Eigen::Index outer_dim_limit = limit * outer_block_size;
155 outer_dim_limit = std::min(outer_dim, outer_dim_limit);
157 Buffer buf(buffer_data + start * inner_dim, inner_dim);
158 for (Eigen::Index i = outer_dim_start; i < outer_dim_limit; ++i)
160 auto in = Input(input_data + i * inner_dim, inner_dim);
161 auto cast = in.template cast<AccumT>();
163 Eigen::TensorCwiseBinaryOp<BinaryFunctor, const decltype(buf), const decltype(cast)>(
169 const Eigen::Index compute_size = outer_block_size * inner_dim;
170 const Eigen::Index compute_input_bytes = compute_size *
sizeof(InputT);
171 const Eigen::TensorOpCost cost(compute_input_bytes,
174 Eigen::internal::functor_traits<BinaryFunctor>::Cost);
176 device.parallelFor(num_blocks, cost, compute);
179 auto buf0 = Buffer(buffer_data, inner_dim);
181 for (
int i = 1; i < num_blocks; ++i)
183 auto buf = Buffer(buffer_data + i * inner_dim, inner_dim);
184 buf0 = Eigen::TensorCwiseBinaryOp<BinaryFunctor, const decltype(buf0), const decltype(buf)>(
188 output->template flat<OutputT>() = buf0.template cast<OutputT>().reshape(output_dims);