135{
137 bool all_equal = true;
138 size_t largest_rank = 0;
140 for (int i = 0; i < N; ++i)
141 {
142 if (x[i] != x[0])
143 {
144 all_equal = false;
145 }
146 if (x[i].
size() > largest_rank)
147 {
148 largest_rank = x[i].size();
149 }
150 }
151 if (all_equal)
152 {
154 }
155 if (all_equal && fewer_dims_optimization)
156 {
157
158 int32_t elements = 1;
159 const int rank = x[0].size();
161 for (int i = 0; i < rank; i++)
162 {
163 const int32_t dim = x[0][i];
164 elements *= dim;
166 }
169 for (int i = 0; i < N; ++i)
170 {
173 }
174
175 return;
176 }
177
178
179
181 for (int i = 0; i < N; ++i)
182 {
183 copy[i] = x[i];
185 }
186
187
188 for (int i = 0; i < N; ++i)
189 {
190 if (copy[i].
size() < largest_rank)
191 {
192 copy[i].resize(largest_rank, 1);
193 }
194 }
195
196
197
198
199
200 bool prev_is_one[N];
201 bool current_is_one[N];
202 for (int i = 0; i < N; ++i)
203 {
204 prev_is_one[i] = false;
205 current_is_one[i] = false;
206 }
208 bool output_dim_set = false;
209 int output_dim = -1;
210 bool none_is_one = true;
211 bool set_one = false;
212 for (size_t j = 0; j < largest_rank; ++j)
213 {
214 output_dim = -1;
215 output_dim_set = false;
216 none_is_one = true;
217
218 for (int i = 0; i < N; ++i)
219 {
220
221 if (copy[i][j] == 1)
222 {
223 current_is_one[i] = true;
224 none_is_one = false;
225 }
226 else
227 {
228 current_is_one[i] = false;
229 if (!output_dim_set || copy[i][j] == output_dim)
230 {
231 output_dim = copy[i][j];
232 output_dim_set = true;
233 }
234 else
235 {
237 return;
238 }
239 }
240 }
241 output_.push_back(output_dim_set ? output_dim : 1);
243
244 if (!output_dim_set)
245 {
246 if (!fewer_dims_optimization)
247 {
248 for (int i = 0; i < N; ++i)
249 {
252 }
254 }
255 for (int i = 0; i < N; ++i)
256 {
258 }
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274 continue;
275 }
276 else if ((fewer_dims_optimization) &&
277 std::equal(current_is_one, current_is_one + N, prev_is_one) && set_one)
278 {
279
280
281
283 for (int i = 0; i < N; ++i)
284 {
286 bcast_[i].back() *= current_is_one[i] ? output_dim : 1;
287 if (current_is_one[i] && !none_is_one)
288 {
290 }
291 }
292 }
293 else
294 {
296 for (int i = 0; i < N; ++i)
297 {
299 bcast_[i].push_back(current_is_one[i] ? output_dim : 1);
300 if (current_is_one[i] && !none_is_one)
301 {
303 }
304 }
305 }
306 set_one = true;
307 for (int i = 0; i < N; ++i)
308 {
309 prev_is_one[i] = current_is_one[i];
310 }
311 }
313 {
315 for (int i = 0; i < N; ++i)
316 {
319 }
320 }
321
322 for (int i = 0; i < N; ++i)
323 {
327 }
330
331
332
334 {
335 for (int i = 0; i < N; ++i)
336 {
338 }
339 }
340}
bool broadcasting_required_
static void Reverse(Vec *shape)
std::vector< int32_t > Vec
int32_t output_batch_size_
std::vector< int32_t > batch_indices_[N]
void ComputeBatchIndices(const int32_t output_batch_size, const std::vector< int32_t > &reshape, const std::vector< int32_t > &bcast, std::vector< int32_t > *out_indices)