129{
131 bool all_equal = true;
132 size_t largest_rank = 0;
134 for (int i = 0; i < N; ++i)
135 {
136 if (x[i] != x[0])
137 {
138 all_equal = false;
139 }
140 if (x[i].
size() > largest_rank)
141 {
142 largest_rank = x[i].size();
143 }
144 }
145 if (all_equal)
146 {
148 }
149 if (all_equal && fewer_dims_optimization)
150 {
151
152 int32_t elements = 1;
153 const int rank = x[0].size();
155 for (int i = 0; i < rank; i++)
156 {
157 const int32_t dim = x[0][i];
158 elements *= dim;
160 }
163 for (int i = 0; i < N; ++i)
164 {
167 }
168
169 return;
170 }
171
172
173
175 for (int i = 0; i < N; ++i)
176 {
177 copy[i] = x[i];
179 }
180
181
182 for (int i = 0; i < N; ++i)
183 {
184 if (copy[i].
size() < largest_rank)
185 {
186 copy[i].resize(largest_rank, 1);
187 }
188 }
189
190
191
192
193
194 bool prev_is_one[N];
195 bool current_is_one[N];
196 for (int i = 0; i < N; ++i)
197 {
198 prev_is_one[i] = false;
199 current_is_one[i] = false;
200 }
202 bool output_dim_set = false;
203 int output_dim = -1;
204 bool none_is_one = true;
205 bool set_one = false;
206 for (size_t j = 0; j < largest_rank; ++j)
207 {
208 output_dim = -1;
209 output_dim_set = false;
210 none_is_one = true;
211
212 for (int i = 0; i < N; ++i)
213 {
214
215 if (copy[i][j] == 1)
216 {
217 current_is_one[i] = true;
218 none_is_one = false;
219 }
220 else
221 {
222 current_is_one[i] = false;
223 if (!output_dim_set || copy[i][j] == output_dim)
224 {
225 output_dim = copy[i][j];
226 output_dim_set = true;
227 }
228 else
229 {
231 return;
232 }
233 }
234 }
235 output_.push_back(output_dim_set ? output_dim : 1);
237
238 if (!output_dim_set)
239 {
240 if (!fewer_dims_optimization)
241 {
242 for (int i = 0; i < N; ++i)
243 {
246 }
248 }
249 for (int i = 0; i < N; ++i)
250 {
252 }
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268 continue;
269 }
270 else if ((fewer_dims_optimization) &&
271 std::equal(current_is_one, current_is_one + N, prev_is_one) && set_one)
272 {
273
274
275
277 for (int i = 0; i < N; ++i)
278 {
280 bcast_[i].back() *= current_is_one[i] ? output_dim : 1;
281 if (current_is_one[i] && !none_is_one)
282 {
284 }
285 }
286 }
287 else
288 {
290 for (int i = 0; i < N; ++i)
291 {
293 bcast_[i].push_back(current_is_one[i] ? output_dim : 1);
294 if (current_is_one[i] && !none_is_one)
295 {
297 }
298 }
299 }
300 set_one = true;
301 for (int i = 0; i < N; ++i)
302 {
303 prev_is_one[i] = current_is_one[i];
304 }
305 }
307 {
309 for (int i = 0; i < N; ++i)
310 {
313 }
314 }
315
316 for (int i = 0; i < N; ++i)
317 {
321 }
324
325
326
328 {
329 for (int i = 0; i < N; ++i)
330 {
332 }
333 }
334}
bool broadcasting_required_
static void Reverse(Vec *shape)
std::vector< int32_t > Vec
int32_t output_batch_size_
std::vector< int32_t > batch_indices_[N]
void ComputeBatchIndices(const int32_t output_batch_size, const std::vector< int32_t > &reshape, const std::vector< int32_t > &bcast, std::vector< int32_t > *out_indices)