205 {
206 if (!_prepared)
207 {
209 }
210
212 std::vector<InputTensor<float>>
inputs(num_inputs);
214 {
215 inputs[i].shape.ReplaceWith(input_shapes[i].DimensionsCount(), input_shapes[i].DimsData());
217 }
218
220 Labels output_labels(_output_labels);
221 std::vector<DimensionType> label_types(_label_types);
223 LabelCounts output_label_counts(_output_label_counts);
225
226 processDimensions(inputs, &input_labels, &output_labels, &label_types, &input_label_counts,
227 &output_label_counts, &label_to_dim_sizes);
228
229
230
231
232
233
234
236 std::vector<Tensor> inputs_reduced(num_inputs);
237 std::vector<bool> swap_free_and_contract(num_inputs);
239 {
240 bool temp_swap_free_and_contract = false;
241 reduceOperand<float>(inputs[i], label_types, input_label_counts[i], &input_labels[i],
242 &free_labels[i], &temp_swap_free_and_contract, &inputs_reduced[i]);
243 swap_free_and_contract[i] = temp_swap_free_and_contract;
244 }
245
246
247
248
249 Tensor contraction_output_reshaped;
250 contractOperands(inputs_reduced, swap_free_and_contract, &contraction_output_reshaped);
251
252
253
254 std::vector<int32_t> result_shape_dims(contraction_output_reshaped.shape.DimensionsCount() - 2);
255
256 for (size_t i = 0; i < result_shape_dims.size(); i++)
257 {
258 result_shape_dims[i] = contraction_output_reshaped.shape.Dims(i);
259 }
260
261 int num_labels = label_types.size();
263
264
265 for (int label = 0; label < num_labels; ++label)
266 {
268 result_labels.push_back(label);
269 }
270 for (int label = 0; label < num_labels; ++label)
271 {
272 if (label_types[label] ==
kBatch)
273 result_labels.push_back(label);
274 }
276 {
277 for (auto &&label : free_labels[i])
278 {
279 result_labels.push_back(label);
280 result_shape_dims.push_back(label_to_dim_sizes[label]);
281 }
282 }
283
284 Shape result_shape(result_shape_dims.size(), result_shape_dims.data());
285
286
287
288 Tensor contraction_output;
289 copyFrom(contraction_output_reshaped, result_shape, &contraction_output);
290
291
292
293
294
296 strideOrInflate<float>(contraction_output, result_labels, output_label_counts,
297 true , &output_inflated);
298
299 if (output_inflated.shape.DimensionsCount() > contraction_output.shape.DimensionsCount())
300 {
301
303 for (auto &&label : result_labels)
304 {
305 inflated_labels.insert(inflated_labels.end(), output_label_counts[label], label);
306 }
307 result_labels.swap(inflated_labels);
308 }
309
310
311
312
313
314
315
316 std::vector<int32_t> output_permutation(output_labels.size());
317 std::vector<int32_t> label_to_position(num_labels, -1);
318 for (size_t i = 0; i < result_labels.size(); ++i)
319 {
320
321 if (label_to_position[result_labels[i]] == -1)
322 {
323 label_to_position[result_labels[i]] = i;
324 }
325 }
326 for (size_t i = 0; i < output_labels.size(); ++i)
327 {
328 output_permutation[i] = label_to_position[output_labels[i]];
329
330 label_to_position[output_labels[i]] += 1;
331 }
332
333 InputTensor<float> temp_inflated;
334 temp_inflated.shape.ReplaceWith(output_inflated.shape.DimensionsCount(),
335 output_inflated.shape.DimsData());
336 temp_inflated.buffer = (reinterpret_cast<const float *>(output_inflated.buffer));
337 ;
338
340 transposeOperand<float>(temp_inflated, output_permutation, &output);
341
343
344 temp_operand.clear();
345 }
void prepare(std::string &equation)
const luci_interpreter::RuntimeShape output_shape
std::vector< Labels > OperandLabels
std::vector< int32_t > LabelCounts
std::vector< LabelCounts > OperandLabelCounts
std::vector< int32_t > Labels
std::vector< int32_t > LabelToDimSizes