43 {
44
45 const int num_output_dims =
output->shape.DimensionsCount();
46 auto output_dims =
output->template flat<OutputT>().dimensions();
47
48 Eigen::Index inner_dim = 1, outer_dim = 1;
49 for (int i = 0; i < num_dims - num_output_dims; ++i)
50 outer_dim *= input_dims[i];
51 for (int i = num_dims - num_output_dims; i < num_dims; ++i)
52 inner_dim *= input_dims[i];
53
54 if (1 == outer_dim)
55 {
56
57 output->template flat<OutputT>() =
58 input.template flat<InputT>().template cast<OutputT>().reshape(output_dims);
59 return;
60 }
61
62
63 const Eigen::Index num_threads = device.numThreads();
64
65
66
67
68 if (inner_dim > num_threads * 32)
69 {
70
71 const Eigen::Index num_blocks = num_threads;
72
73
74 const Eigen::Index inner_block_size = Eigen::divup(inner_dim, num_blocks);
76
77
78 Eigen::Tensor<AccumT, 1, Eigen::RowMajor, Eigen::Index> buffer({inner_dim});
79 buffer.setZero();
80 AccumT *buffer_data = buffer.data();
81
83 Eigen::TensorMap<Eigen::Tensor<AccumT, 1, Eigen::RowMajor, Eigen::Index>, Eigen::Unaligned>;
84
85 using Input = Eigen::TensorMap<Eigen::Tensor<const InputT, 1, Eigen::RowMajor, Eigen::Index>,
86 Eigen::Unaligned>;
87
88 const auto compute = [inner_dim, outer_dim, inner_block_size,
input_data,
89 buffer_data](Eigen::Index start, Eigen::Index limit) -> void {
90 Eigen::Index inner_dim_start = start * inner_block_size;
91 Eigen::Index inner_dim_limit = limit * inner_block_size;
92 inner_dim_limit = std::min(inner_dim, inner_dim_limit);
93 Eigen::Index my_job_len = inner_dim_limit - inner_dim_start;
94
95 const InputT *my_job_start =
input_data + inner_dim_start;
96 Buffer buf(buffer_data + inner_dim_start, my_job_len);
97
98 for (Eigen::Index i = 0; i < outer_dim; ++i)
99 {
100 auto in = Input(my_job_start + i * inner_dim, my_job_len);
101 auto cast = in.template cast<AccumT>();
103 Eigen::TensorCwiseBinaryOp<BinaryFunctor, const decltype(buf), const decltype(cast)>(
104 buf, cast);
105 }
106 };
107
108
109 const Eigen::Index compute_size = outer_dim * inner_block_size;
110 const Eigen::Index compute_input_bytes = compute_size * sizeof(InputT);
111 const Eigen::TensorOpCost cost(compute_input_bytes,
112 0,
113 compute_size *
114 Eigen::internal::functor_traits<BinaryFunctor>::Cost);
115
116 device.parallelFor(num_blocks, cost, compute);
117
118
119 output->template flat<OutputT>() = buffer.template cast<OutputT>().reshape(output_dims);
120 }
121 else
122 {
123
124 const Eigen::Index parallel_cell_size = inner_dim;
125 const Eigen::Index total_workload = outer_dim * inner_dim;
126 const Eigen::Index max_parallelism = total_workload / parallel_cell_size;
127
128 const Eigen::Index min_block_workload = 2000;
129 const Eigen::Index min_block_size = Eigen::divup(min_block_workload, parallel_cell_size);
130 const Eigen::Index max_num_blocks =
131 std::min(max_parallelism, Eigen::divup(total_workload, min_block_size));
132
133
134 const Eigen::Index num_blocks = std::min(max_num_blocks, num_threads);
135
136
137 const Eigen::Index outer_block_size = Eigen::divup(outer_dim, num_blocks);
138
140
141
142 std::vector<AccumT> buffer(num_blocks * inner_dim);
143 AccumT *buffer_data = buffer.data();
144
146 Eigen::TensorMap<Eigen::Tensor<AccumT, 1, Eigen::RowMajor, Eigen::Index>, Eigen::Unaligned>;
147
148 using Input = Eigen::TensorMap<Eigen::Tensor<const InputT, 1, Eigen::RowMajor, Eigen::Index>,
149 Eigen::Unaligned>;
150
151 const auto compute = [inner_dim, outer_block_size, buffer_data,
input_data,
152 outer_dim](Eigen::Index start, Eigen::Index limit) -> void {
153 Eigen::Index outer_dim_start = start * outer_block_size;
154 Eigen::Index outer_dim_limit = limit * outer_block_size;
155 outer_dim_limit = std::min(outer_dim, outer_dim_limit);
156
157 Buffer buf(buffer_data + start * inner_dim, inner_dim);
158 for (Eigen::Index i = outer_dim_start; i < outer_dim_limit; ++i)
159 {
160 auto in = Input(input_data + i * inner_dim, inner_dim);
161 auto cast = in.template cast<AccumT>();
163 Eigen::TensorCwiseBinaryOp<BinaryFunctor, const decltype(buf), const decltype(cast)>(
164 buf, cast);
165 }
166 };
167
168
169 const Eigen::Index compute_size = outer_block_size * inner_dim;
170 const Eigen::Index compute_input_bytes = compute_size * sizeof(InputT);
171 const Eigen::TensorOpCost cost(compute_input_bytes,
172 0,
173 compute_size *
174 Eigen::internal::functor_traits<BinaryFunctor>::Cost);
175
176 device.parallelFor(num_blocks, cost, compute);
177
178
179 auto buf0 =
Buffer(buffer_data, inner_dim);
180
181 for (int i = 1; i < num_blocks; ++i)
182 {
183 auto buf =
Buffer(buffer_data + i * inner_dim, inner_dim);
184 buf0 = Eigen::TensorCwiseBinaryOp<BinaryFunctor, const decltype(buf0), const decltype(buf)>(
185 buf0, buf);
186 }
187
188 output->template flat<OutputT>() = buf0.template cast<OutputT>().reshape(output_dims);
189 }
190 }
T * cast(Object *)
Cast a generic object as a specific one.
std::vector< char > Buffer