ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Einsum.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_EINSUM_H__
19#define __NNFW_CKER_EINSUM_H__
20
21#include "cker/Types.h"
22#include "cker/Shape.h"
23#include "cker/Utils.h"
24
27
28#include "Transpose.h"
29#include "BatchMatMul.h"
30
31#include <string>
32#include <vector>
33#include <map>
34#include <numeric>
35#include <algorithm>
36
37namespace nnfw
38{
39namespace cker
40{
41
42namespace functor
43{
44
45template <typename Device, typename T, int N> struct StrideFunctor
46{
47 void operator()(const Device &d, typename TTypes<T, N>::ConstTensor input,
48 const std::vector<int32_t> &strides, typename TTypes<T, N>::Tensor output)
49 {
50
51 Eigen::DSizes<Eigen::DenseIndex, N> dsizes;
52 for (size_t d = 0; d < strides.size(); d++)
53 {
54 dsizes[d] = static_cast<Eigen::DenseIndex>(strides[d]);
55 }
56 for (size_t d = strides.size(); d < N; d++)
57 {
58 dsizes[d] = 1;
59 }
60
61 output.device(d) = input.stride(dsizes);
62 }
63};
64
65template <typename Device, typename T, int N> struct InflateFunctor
66{
67 void operator()(const Device &d, typename TTypes<T, N>::ConstTensor input,
68 const std::vector<int32_t> &strides, typename TTypes<T, N>::Tensor output)
69 {
70
71 Eigen::DSizes<Eigen::DenseIndex, N> dsizes;
72 for (size_t d = 0; d < strides.size(); d++)
73 {
74 dsizes[d] = static_cast<Eigen::DenseIndex>(strides[d]);
75 }
76 for (size_t d = strides.size(); d < N; d++)
77 {
78 dsizes[d] = 1;
79 }
80
81 output.device(d) = input.inflate(dsizes);
82 }
83};
84
85template <typename Device, typename Reducer> struct ReduceFunctor
86{
87 template <typename OUT_T, typename IN_T, typename ReductionAxes>
88 static void Reduce(const Device &d, OUT_T out, IN_T in, const ReductionAxes &reduction_axes,
89 const Reducer &reducer)
90 {
91 out.device(d) = in.reduce(reduction_axes, reducer);
92 }
93};
94
95template <typename Device, typename T> struct SetZeroFunctor
96{
97 // Computes on device "d": out = out.setZero(),
98 void operator()(const Device &d, typename TTypes<T>::Flat out)
99 {
100 out.device(d) = out.constant(T(0));
101 }
102};
103
104} // namespace functor
105
106using ShapeVec = std::vector<int32_t>;
107using Labels = std::vector<int32_t>;
108using OperandLabels = std::vector<Labels>;
109using LabelCounts = std::vector<int32_t>;
110using OperandLabelCounts = std::vector<LabelCounts>;
111using LabelToDimSizes = std::vector<int32_t>;
112
113// Each dimension is categorized into exactly one of five types based on
114// whether its corresponding label is present in the input and/or the output
115// subscripts.
117{
118 // Batch dimensions are those present in two inputs as well as the output.
119 // They are part of the batch dimensions during Tensor contraction.
120 // Such dimensions may be broadcasting dimensions (those mapping to
121 // ellipsis)
122 // or explicit batch dimensions corresponding to named axis labels.
125 // Free dimensions are present in exactly one of the inputs, and also the
126 // output. These are non-contracted axes in the Tensor contraction.
127 kFree = 2,
128 // Contract dimensions are present in two inputs, but not the output. These
129 // dimensions are contracted in Tensor contraction.
131 // Reduce dimensions are present in exactly one input; and not in the output
132 // and are summed over prior to Tensor contraction.
134};
135
136namespace
137{
138
139constexpr int kEllipsisLabel = -1;
140
141std::vector<std::string> strSplit(std::string_view text, std::string_view delimiter)
142{
143 std::vector<std::string> result;
144
145 size_t start = 0;
146 size_t pos = 0;
147
148 do
149 {
150 pos = text.find(delimiter, start);
151 if (pos == std::string::npos)
152 {
153 result.push_back(std::string(text.substr(start, text.size() - start)));
154 break;
155 }
156
157 result.push_back(std::string(text.substr(start, pos - start)));
158 start = pos + delimiter.size();
159 } while (pos != std::string::npos);
160
161 return result;
162}
163
164inline DimensionType getDimensionType(bool is_removed, bool is_unique)
165{
166 if (!is_removed && !is_unique)
167 return kBatch;
168 else if (!is_removed && is_unique)
169 return kFree;
170 else if (is_removed && !is_unique)
171 return kContract;
172 else // is_removed && is_unique
173 return kReduce;
174}
175
176inline Shape copyShape(const Shape &shape)
177{
178 return Shape::ExtendedShape(shape.DimensionsCount(), shape);
179}
180} // namespace
181
183{
184public:
185 Einsum() : _prepared(false)
186 {
187 // DO NOTHING
188 }
189
190 void prepare(std::string_view equation)
191 {
192 if (_prepared)
193 {
194 return;
195 }
196
197 // Parse equation
198 parseEquation(equation);
199 _prepared = true;
200 }
201
202 void operator()(std::string_view equation, const std::vector<Shape> &input_shapes,
203 const std::vector<const float *> &input_data, const Shape &output_shape,
204 float *output_data)
205 {
206 if (!_prepared)
207 {
208 prepare(equation);
209 }
210
211 const int num_inputs = input_shapes.size();
212 std::vector<InputTensor<float>> inputs(num_inputs);
213 for (int i = 0; i < num_inputs; i++)
214 {
215 inputs[i].shape.ReplaceWith(input_shapes[i].DimensionsCount(), input_shapes[i].DimsData());
216 inputs[i].buffer = input_data[i];
217 }
218
219 OperandLabels input_labels(_input_labels);
220 Labels output_labels(_output_labels);
221 std::vector<DimensionType> label_types(_label_types);
222 OperandLabelCounts input_label_counts(_input_label_counts);
223 LabelCounts output_label_counts(_output_label_counts);
224 LabelToDimSizes label_to_dim_sizes;
225
226 processDimensions(inputs, &input_labels, &output_labels, &label_types, &input_label_counts,
227 &output_label_counts, &label_to_dim_sizes);
228
229 // The reduction phase (a) sums across reduction dimensions, (b) takes
230 // generalized diagonals, and (c) reshapes it into shape
231 // [(broadcasting) batch shape] + [F,C]
232 // where F and C denote the total (compacted) size of free and contract
233 // dimensions, respectively.
234
235 OperandLabels free_labels(num_inputs);
236 std::vector<Tensor> inputs_reduced(num_inputs);
237 std::vector<bool> swap_free_and_contract(num_inputs);
238 for (int i = 0; i < num_inputs; ++i)
239 {
240 bool temp_swap_free_and_contract = false;
241 reduceOperand<float>(inputs[i], label_types, input_label_counts[i], &input_labels[i],
242 &free_labels[i], &temp_swap_free_and_contract, &inputs_reduced[i]);
243 swap_free_and_contract[i] = temp_swap_free_and_contract;
244 }
245
246 // After reduction, the inputs should be reshaped to Tensors suitable for
247 // contraction. If num_inputs is 1, the reduced input is simply forwarded to
248 // the output.
249 Tensor contraction_output_reshaped;
250 contractOperands(inputs_reduced, swap_free_and_contract, &contraction_output_reshaped);
251
252 // Copy the batch labels from the contraction output. Recover the batch
253 // shape, which may have been broadcasted.
254 std::vector<int32_t> result_shape_dims(contraction_output_reshaped.shape.DimensionsCount() - 2);
255
256 for (size_t i = 0; i < result_shape_dims.size(); i++)
257 {
258 result_shape_dims[i] = contraction_output_reshaped.shape.Dims(i);
259 }
260
261 int num_labels = label_types.size();
262 Labels result_labels;
263 // All batch dimensions should be present in the contracted result. First
264 // the broadcasting dimensions, then the named batch dimensions.
265 for (int label = 0; label < num_labels; ++label)
266 {
267 if (label_types[label] == kBroadcasting)
268 result_labels.push_back(label);
269 }
270 for (int label = 0; label < num_labels; ++label)
271 {
272 if (label_types[label] == kBatch)
273 result_labels.push_back(label);
274 }
275 for (int i = 0; i < num_inputs; ++i)
276 {
277 for (auto &&label : free_labels[i])
278 {
279 result_labels.push_back(label);
280 result_shape_dims.push_back(label_to_dim_sizes[label]);
281 }
282 }
283
284 Shape result_shape(result_shape_dims.size(), result_shape_dims.data());
285
286 // Reshape the contraction (or reduction) result to its expanded shape:
287 // [(broadcasted) batch shape] + [free shape 0] + [free shape 1].
288 Tensor contraction_output;
289 copyFrom(contraction_output_reshaped, result_shape, &contraction_output);
290
291 // Inflate the output if necessary. (E.g. for the equation 'i->iii' which
292 // may arise while computing gradient of a regular Einsum).
293 // TODO(anudhyan): It's possible that Eigen's contract and inflate can be
294 // chained here to avoid materializing an intermediate.
295 Tensor output_inflated;
296 strideOrInflate<float>(contraction_output, result_labels, output_label_counts,
297 true /* should_inflate */, &output_inflated);
298
299 if (output_inflated.shape.DimensionsCount() > contraction_output.shape.DimensionsCount())
300 {
301 // We inflated the output. Modify result labels accordingly.
302 Labels inflated_labels;
303 for (auto &&label : result_labels)
304 {
305 inflated_labels.insert(inflated_labels.end(), output_label_counts[label], label);
306 }
307 result_labels.swap(inflated_labels);
308 }
309
310 // Find the permutation to map the result labels to the output labels. Note
311 // that both the result and the final output may have the repeated labels,
312 // in which case the permutation preserves the left-to-right ordering.
313 // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the
314 // permutation should be [0, 2, 1]. We also use the fact that repeated
315 // labels in the result are adjacent to each other.
316 std::vector<int32_t> output_permutation(output_labels.size());
317 std::vector<int32_t> label_to_position(num_labels, -1);
318 for (size_t i = 0; i < result_labels.size(); ++i)
319 {
320 // Remember the position of only the leftmost result label.
321 if (label_to_position[result_labels[i]] == -1)
322 {
323 label_to_position[result_labels[i]] = i;
324 }
325 }
326 for (size_t i = 0; i < output_labels.size(); ++i)
327 {
328 output_permutation[i] = label_to_position[output_labels[i]];
329 // We have found the leftmost occurrence. The next one would be adjacent.
330 label_to_position[output_labels[i]] += 1;
331 }
332
333 InputTensor<float> temp_inflated;
334 temp_inflated.shape.ReplaceWith(output_inflated.shape.DimensionsCount(),
335 output_inflated.shape.DimsData());
336 temp_inflated.buffer = (reinterpret_cast<const float *>(output_inflated.buffer));
337 ;
338
339 Tensor output;
340 transposeOperand<float>(temp_inflated, output_permutation, &output);
341
342 memcpy(output_data, output.buffer, output_shape.FlatSize() * sizeof(float));
343
344 temp_operand.clear();
345 }
346
347private:
348 void parseEquation(std::string_view equation)
349 {
350 std::vector<std::string> input_str;
351 std::string output_str;
352
353 parseEinsumEquation(equation, input_str, output_str);
354
355 // Temporary map from single character labels to (consecutive) integer
356 // labels.
357 std::map<char, int> label_mapping;
358 int num_inputs = input_str.size();
359 _input_labels.resize(num_inputs);
360
361 // Map from single characters to integer labels.
362 for (int i = 0; i < num_inputs; ++i)
363 {
364 mapToLabels(input_str[i], _input_labels.at(i), label_mapping);
365 }
366 mapToLabels(output_str, _output_labels, label_mapping);
367
368 // Compute counts for input and output labels.
369 int num_labels = label_mapping.size();
370 _input_label_counts.resize(num_inputs);
371 _input_has_ellipsis.resize(num_inputs);
372 for (int i = 0; i < num_inputs; ++i)
373 {
374 _input_label_counts.at(i).resize(num_labels);
375 for (const int label : _input_labels.at(i))
376 {
377 if (label != kEllipsisLabel)
378 _input_label_counts.at(i)[label] += 1;
379 else
380 _input_has_ellipsis.at(i) = true;
381 }
382 }
383 _output_label_counts.resize(num_labels);
384 for (const int label : _output_labels)
385 {
386 if (label != kEllipsisLabel)
387 _output_label_counts.at(label) += 1;
388 else
389 _output_has_ellipsis = true;
390 }
391
392 // Map each label to a unique DimensionType.
393 _label_types.resize(num_labels);
394 for (int label = 0; label < num_labels; ++label)
395 {
396 bool removed = (_output_label_counts[label] == 0);
397 bool unique =
398 num_inputs == 1 || _input_label_counts[0][label] == 0 || _input_label_counts[1][label] == 0;
399 _label_types[label] = getDimensionType(removed, unique);
400 }
401 }
402
403 void parseEinsumEquation(std::string_view &equation, std::vector<std::string> &input_subscripts,
404 std::string &output_subscript)
405 {
406 std::vector<std::string> inputs_and_output_subscripts = strSplit(equation, "->");
407 if (inputs_and_output_subscripts.size() != 2)
408 {
409 throw std::runtime_error{"Einsum: Expecting exactly one '->' in einsum equation: " +
410 std::string(equation)};
411 }
412
413 output_subscript = inputs_and_output_subscripts[1];
414 input_subscripts = strSplit(inputs_and_output_subscripts[0], ",");
415 if (input_subscripts.size() != 1 && input_subscripts.size() != 2)
416 {
417 throw std::runtime_error{"Einsum: Expecting 1 or 2 input subscripts in equation '" +
418 std::string(equation) +
419 "' but got: " + std::to_string(input_subscripts.size())};
420 }
421 }
422
423 // Maps the character labels to consecutive integers.
424 void mapToLabels(std::string_view subscript, Labels &labels, std::map<char, int> &label_mapping)
425 {
426 for (size_t i = 0; i < subscript.size(); ++i)
427 {
428 const char label_char = subscript[i];
429 if (label_char == '.')
430 {
431 labels.push_back(kEllipsisLabel);
432 i += 2; // Skip next 2 characters as well.
433 continue;
434 }
435 if (label_mapping.find(label_char) == label_mapping.end())
436 {
437 const int next_label = label_mapping.size();
438 label_mapping[label_char] = next_label;
439 }
440 const int mapped_label = label_mapping[label_char];
441 labels.push_back(mapped_label);
442 }
443 }
444
445 template <typename T>
446 void processDimensions(const std::vector<InputTensor<T>> &inputs, OperandLabels *input_labels,
447 Labels *output_labels, std::vector<DimensionType> *label_types,
448 OperandLabelCounts *input_label_counts, LabelCounts *output_label_counts,
449 LabelToDimSizes *label_to_dim_sizes)
450 {
451 if (inputs.size() != input_labels->size())
452 {
453 throw std::runtime_error{"Expected " + std::to_string(input_labels->size()) +
454 " inputs but got: " + std::to_string(inputs.size())};
455 }
456 const int num_inputs = inputs.size();
457
458 // We infer the number of broadcasting dimensions by taking the maximum rank
459 // among the broadcasting subshapes of the input.
460 int max_bcast_dims = 0;
461 const int num_named_labels = label_types->size();
462 label_to_dim_sizes->resize(num_named_labels);
463 for (int i = 0; i < num_inputs; ++i)
464 {
465 Labels *labels = &(*input_labels)[i];
466
467 if (!_input_has_ellipsis[i])
468 {
469 if (inputs[i].shape.DimensionsCount() != ((int32_t)labels->size()))
470 {
471 throw std::runtime_error{"Expected input " + std::to_string(i) + " to have rank " +
472 std::to_string(labels->size()) + " but got: " +
473 std::to_string(inputs[i].shape.DimensionsCount())};
474 }
475 for (size_t label_idx = 0; label_idx < labels->size(); ++label_idx)
476 {
477 const int label = (*labels)[label_idx];
478 recordLabelToDimension(label, label_idx, inputs[i].shape, label_to_dim_sizes);
479 }
480 continue;
481 }
482
483 // Input has an ellipsis.
484 if (inputs[i].shape.DimensionsCount() + 1 < (int32_t)labels->size())
485 {
486 throw std::runtime_error{"Expected input " + std::to_string(i) + " to have rank at least " +
487 std::to_string(labels->size() - 1) +
488 " but got: " + std::to_string(inputs[i].shape.DimensionsCount())};
489 }
490 int ellipsis_axis = -1;
491 const int num_bcast_dims = inputs[i].shape.DimensionsCount() - labels->size() + 1;
492 for (size_t label_idx = 0; label_idx < labels->size(); ++label_idx)
493 {
494 const int label = (*labels)[label_idx];
495 if (label == kEllipsisLabel)
496 {
497 ellipsis_axis = label_idx;
498 continue;
499 }
500 // Current label is not an ellipsis.
501 const int axis = label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1);
502 recordLabelToDimension(label, axis, inputs[i].shape, label_to_dim_sizes);
503 }
504 // Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting
505 // dimensions.
506 if (ellipsis_axis != -1)
507 {
508 insertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis, labels,
509 &input_label_counts->at(i));
510 max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims);
511 }
512 }
513
514 std::vector<bool>::iterator it_input =
515 std::find(_input_has_ellipsis.begin(), _input_has_ellipsis.end(), true);
516 if (it_input == _input_has_ellipsis.end() && !_output_has_ellipsis)
517 {
518 return;
519 }
520 // Insert broadcasting dimensions in the output labels.
521 auto it = std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel);
522 if (it != output_labels->end())
523 {
524 const int ellipsis_axis = it - output_labels->begin();
525 insertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis, output_labels,
526 output_label_counts);
527 }
528 else if (max_bcast_dims > 0)
529 {
530 std::runtime_error{"Output contains " + std::to_string(max_bcast_dims) +
531 " broadcasting dimension(s) but no ellipsis " +
532 "(...) was found in the output subscripts."};
533 }
534 // Populate DimensionType for the new broadcasting labels.
535 label_types->resize(num_named_labels + max_bcast_dims, kBroadcasting);
536 }
537
538 void recordLabelToDimension(const int32_t label, const int axis, const Shape &input_shape,
539 LabelToDimSizes *label_to_dim_sizes)
540 {
541 const int32_t input_dim = input_shape.Dims(axis);
542 // We know that label_to_dim_sizes has the size to accommodate named labels.
543 if (label_to_dim_sizes->at(label) != 0 && label_to_dim_sizes->at(label) != input_dim)
544 {
545 std::runtime_error{"Expected dimension " + std::to_string(label_to_dim_sizes->at(label)) +
546 " at axis " + std::to_string(axis) +
547 " of the input shaped but got dimension " + std::to_string(input_dim)};
548 }
549 (*label_to_dim_sizes)[label] = input_dim;
550 }
551
552 void insertBroadcastLabels(int num_bcast_dims, int num_named_labels, int ellipsis_axis,
553 Labels *labels, LabelCounts *label_counts)
554 {
555 labels->erase(labels->begin() + ellipsis_axis);
556 labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0);
557 std::iota(labels->begin() + ellipsis_axis, labels->begin() + ellipsis_axis + num_bcast_dims,
558 num_named_labels);
559 // Increment label counts. Since these are new labels, the count is set
560 // to 1.
561 label_counts->resize(num_named_labels + num_bcast_dims, 1);
562 }
563
564 template <typename T>
565 void reduceOperand(const InputTensor<T> &input, const std::vector<DimensionType> &label_types,
566 const LabelCounts &label_counts, Labels *labels, Labels *free_labels,
567 bool *swap_free_and_contract, Tensor *output)
568 {
569 // Find the permutation to transpose the input dimensions in the order of
570 // DimensionType; i.e. batch, free, contract and reduce dimensions. This
571 // makes it more convenient to invoke Reduce/Contract operations.
572 std::vector<int32_t> permutation(input.shape.DimensionsCount());
573 std::iota(permutation.begin(), permutation.end(), 0);
574 Tensor input_transposed;
575
576 // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y)
577 // flag during BatchMatMul. This is an extra optimization not necessary for
578 // correctness.
579 if (shouldSwapFreeAndContract(*labels, label_types))
580 {
581 *swap_free_and_contract = true;
582 }
583 else
584 {
585 std::sort(permutation.begin(), permutation.end(), [&](int i, int j) {
586 int label_i = (*labels)[i];
587 int label_j = (*labels)[j];
588 return std::tie(label_types[label_i], label_i) < std::tie(label_types[label_j], label_j);
589 });
590 }
591 // Transpose the input so that DimensionTypes are in order.
592 transposeOperand<T>(input, permutation, &input_transposed);
593
594 permuteLabels(permutation, labels);
595
596 // Take the generalized diagonal for dimensions with repeated axis labels.
597 Tensor input_deduped;
598 labels->erase(std::unique(labels->begin(), labels->end()), labels->end());
599 strideOrInflate<T>(input_transposed, *labels, label_counts, false /* should_inflate */,
600 &input_deduped);
601
602 // Reshape denotes the rank-5 shape [broadcast, batch, free, contract,
603 // reduce] where we've compacted the dimensions of each DimensionType.
604 std::vector<int32_t> reshape(5, 1);
605
606 // The output shape is [batch shape] + [free size, contract size]
607 // That is, the batch shape is preserved (for broadcasting while
608 // contracting) while the free dims and contract dims are compressed to one
609 // dimension each.
611 std::vector<int32_t> output_shape_dims;
612 for (size_t label_idx = 0; label_idx < labels->size(); ++label_idx)
613 {
614 const int label = labels->at(label_idx);
615 int32_t dim = input_deduped.shape.Dims(label_idx);
616 if (label_types[label] == kBroadcasting || label_types[label] == kBatch)
617 {
618 output_shape_dims.push_back(dim);
619 }
620 else if (label_types[label] == kFree)
621 {
622 free_labels->push_back(label);
623 }
624 reshape[label_types[label]] *= dim;
625 }
626
627 if (*swap_free_and_contract)
628 std::swap(reshape[kFree], reshape[kContract]);
629
630 output_shape_dims.push_back(reshape[kFree]);
631 output_shape_dims.push_back(reshape[kContract]);
632
633 output_shape.ReplaceWith(output_shape_dims.size(), output_shape_dims.data());
634
635 if (reshape[kReduce] == 1)
636 { // No need to actually reduce.
637 return copyFrom(input_deduped, output_shape, output);
638 }
639
640 allocateTemp(output_shape, output);
641
642 using Reducer = Eigen::internal::SumReducer<T>;
643 using Index = typename TTypes<T>::Tensor::Index;
644
645 const Eigen::ThreadPoolDevice &device = *eigen_support::GetThreadPoolDevice();
646
647 // Reduce along the last axis (i.e axis 1) of the rank-2 Tensor.
648 const int32_t output_size =
651 device, output->shaped<T, 1>({output_size}),
652 input_deduped.shaped<T, 2>({output_size, reshape[kReduce]}), Eigen::array<Index, 1>({1}),
653 Reducer());
654 }
655
656 bool shouldSwapFreeAndContract(const Labels &labels,
657 const std::vector<DimensionType> &label_types)
658 {
659 // Check that ordering is according to dimension type, with the role of
660 // free and contract dimensions swapped.
661 std::vector<int> remap = {0, 1, 3, 2, 4};
662 for (size_t i = 0; i + 1 < labels.size(); ++i)
663 {
664 const int dimtype_a = remap[label_types[labels[i]]];
665 const int dimtype_b = remap[label_types[labels[i + 1]]];
666 if (dimtype_a > dimtype_b || (dimtype_a == dimtype_b && labels[i] > labels[i + 1]))
667 {
668 return false;
669 }
670 }
671 return true;
672 }
673
674 template <typename T>
675 void transposeOperand(const InputTensor<T> &input, const std::vector<int32_t> &permutation,
676 Tensor *output)
677 {
678 if (!shouldTranspose(input.shape, permutation))
679 {
680 copyFrom(input, input.shape, output);
681 return;
682 }
683 Shape transposed_shape(input.shape.DimensionsCount());
684 for (int i = 0; i < input.shape.DimensionsCount(); ++i)
685 {
686 transposed_shape.SetDim(i, input.shape.Dims(permutation[i]));
687 }
688 // For empty Tensors, just change the shape. E.g. we may need to transpose
689 // from shape [1, 0, 5] to [5, 1, 0].
690 if (input.shape.FlatSize() == 0)
691 {
692 copyFrom(input, transposed_shape, output);
693 return;
694 }
695
696 temp_operand.emplace_back(std::make_unique<T[]>(transposed_shape.FlatSize()));
697 T *new_buffer = temp_operand.back().get();
698
699 TransposeParams transpose_params;
700 transpose_params.perm_count = permutation.size();
701 for (size_t i = 0; i < permutation.size(); i++)
702 {
703 transpose_params.perm[i] = permutation[i];
704 }
705
706 Transpose<T>(transpose_params, input.shape, input.buffer, transposed_shape, new_buffer);
707
708 output->shape.ReplaceWith(transposed_shape.DimensionsCount(), transposed_shape.DimsData());
709 output->buffer = new_buffer;
710 }
711
712 bool shouldTranspose(const Shape &input_shape, const std::vector<int32_t> &permutation)
713 {
714 if (input_shape.DimensionsCount() < 2)
715 return false;
716 for (size_t i = 0; i < permutation.size(); ++i)
717 {
718 if (permutation[i] != (int32_t)i)
719 return true;
720 }
721 return false;
722 }
723
724 template <typename T>
725 void copyFrom(const InputTensor<T> &input, const Shape &shape, Tensor *output)
726 {
727 Tensor temp_tensor;
728 temp_tensor.shape.ReplaceWith(input.shape.DimensionsCount(), input.shape.DimsData());
729 temp_operand.emplace_back(std::make_unique<float[]>(input.shape.FlatSize()));
730 temp_tensor.buffer = temp_operand.back().get();
731 memcpy(temp_tensor.buffer, input.buffer, input.shape.FlatSize() * sizeof(float));
732
733 copyFrom(temp_tensor, shape, output);
734 }
735
736 void copyFrom(const Tensor &input, const Shape &shape, Tensor *output)
737 {
738 if (output->copyFrom(input, shape))
739 return;
740
741 throw std::runtime_error{"Einsum: Encountered error while reshaping a Tensor"};
742 }
743
744 // Permutes the labels according to the given permutation.
745 void permuteLabels(const std::vector<int32_t> &permutation, Labels *labels)
746 {
747 Labels permuted_labels(labels->size());
748 for (size_t i = 0; i < labels->size(); ++i)
749 {
750 permuted_labels[i] = (*labels)[permutation[i]];
751 }
752 labels->swap(permuted_labels);
753 }
754
755 // If there are repeated labels in either the input or output, then this
756 // strides the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively.
757 template <typename T>
758 void strideOrInflate(const Tensor &input, const Labels &labels, const LabelCounts &label_counts,
759 const bool should_inflate, Tensor *output)
760 {
761 // Return early if there are no repeated indices.
762 if (std::all_of(label_counts.begin(), label_counts.end(), [](int c) { return c <= 1; }))
763 {
764 return copyFrom(input, input.shape, output);
765 }
766 // We reshape so that each repeated label is compressed to one dimension.
767 // E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27,
768 // 5]. Striding appropriately (in this case with strides 14 (=1+3+9) and 1)
769 // recovers the generalized diagonal of shape [3, 5].
770 std::vector<int32_t> reshape;
771 std::vector<int32_t> strides;
772 // Strided and inflated shapes correspond to input and output shapes,
773 // respectively, should_inflate is true (vice-versa if should_inflate is
774 // false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example.
775 Shape strided_shape;
776 Shape inflated_shape;
777 std::vector<int32_t> strided_shape_dims;
778 std::vector<int32_t> inflated_shape_dims;
779 for (auto &&label : labels)
780 {
781 const int32_t count = label_counts[label];
782 const int current_axis =
783 should_inflate ? strided_shape_dims.size() : inflated_shape_dims.size();
784 const int32_t dim = input.shape.Dims(current_axis);
785 strided_shape_dims.push_back(dim);
786 inflated_shape_dims.insert(inflated_shape_dims.end(), count, dim);
787 const int32_t reshape_dim = std::pow(dim, count);
788 reshape.push_back(reshape_dim);
789 // While taking the d-diagonal in a rank k Tensor, we take d
790 // equally-spaced elements including the first and last element. Then, (k
791 // - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1).
792 const int32_t stride = (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1;
793 strides.push_back(stride);
794 }
795
796 strided_shape.ReplaceWith(strided_shape_dims.size(), strided_shape_dims.data());
797 inflated_shape.ReplaceWith(inflated_shape_dims.size(), inflated_shape_dims.data());
798
799 Shape output_shape = Shape(should_inflate ? inflated_shape : strided_shape);
800
801 output->shape.ReplaceWith(output_shape.DimensionsCount(), output_shape.DimsData());
802 temp_operand.emplace_back(std::make_unique<float[]>(output_shape.FlatSize()));
803 output->buffer = temp_operand.back().get();
804
805 const Eigen::ThreadPoolDevice &device = *eigen_support::GetThreadPoolDevice();
806
807 switch (reshape.size())
808 {
809#define NDIMS_CASE(N) \
810 case N: \
811 { \
812 if (should_inflate) \
813 { \
814 auto output_map = output->shaped<T, N>(reshape); \
815 auto input_map = input.shaped<T, N>(strided_shape_dims); \
816 functor::InflateFunctor<Eigen::ThreadPoolDevice, T, N>()(device, input_map, strides, \
817 output_map); \
818 } \
819 else \
820 { \
821 auto input_map = input.shaped<T, N>(reshape); \
822 auto output_map = output->shaped<T, N>(strided_shape_dims); \
823 functor::StrideFunctor<Eigen::ThreadPoolDevice, T, N>()(device, input_map, strides, \
824 output_map); \
825 } \
826 } \
827 break;
828 NDIMS_CASE(1);
829 NDIMS_CASE(2);
830 NDIMS_CASE(3);
831 NDIMS_CASE(4);
832 NDIMS_CASE(5);
833 NDIMS_CASE(6);
834 default:
835 throw std::runtime_error{"Unsupported rank: " + std::to_string(reshape.size()) +
836 " while handling repeated indices. Up to rank 6 is supported."};
837#undef NDIMS_CASE
838 }
839 }
840
841 void allocateTemp(const Shape &shape, Tensor *output)
842 {
843 output->shape.ReplaceWith(shape.DimensionsCount(), shape.DimsData());
844 temp_operand.emplace_back(std::make_unique<float[]>(shape.FlatSize()));
845 output->buffer = temp_operand.back().get();
846 }
847
848 // Contracts the inputs along the last axis. (or the second last if the
849 // corresponding value of swap_free_and_contract is true). The batch
850 // dimensions are broadcast to the output shape.
851 // TODO(anudhyan): Factor this function into a BatchMatMul functor and support
852 // transpose_x and transpose_y attributes (in addition to adj_x and adj_y).
853 // Also, the BatchMatMul might devolve into a component-wise multiplication
854 // when the matrix shape is [1,1]; in this case BatchMatMul functor would be
855 // very inefficient. The functor should detect if this is the case and perform
856 // componentwise multiplication functor instead.
857 void contractOperands(std::vector<Tensor> &inputs, std::vector<bool> &swap_free_and_contract,
858 Tensor *output)
859 {
860 if (inputs.size() == 1)
861 return copyFrom(inputs[0], inputs[0].shape, output);
862
863 MatMulBCast bcast(inputs[0].shape, inputs[1].shape);
864 if (!bcast.IsValid())
865 {
866 throw std::runtime_error{"Einsum: Invalid broadcasting dimensions"};
867 }
868
869 Tensor lhs;
870 reshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs);
871 Tensor rhs;
872 reshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs);
873 Shape old_output_shape = bcast.output_batch_shape();
874 Shape output_shape(static_cast<int>(old_output_shape.DimensionsCount() + inputs.size()));
875 for (int i = 0; i < old_output_shape.DimensionsCount(); i++)
876 {
877 output_shape.SetDim(i, old_output_shape.Dims(i));
878 }
879
880 for (size_t i = 0; i < inputs.size(); ++i)
881 {
882 const int32_t free_axis =
883 inputs[i].shape.DimensionsCount() - (swap_free_and_contract[i] ? 1 : 2);
884 output_shape.SetDim(i + old_output_shape.DimensionsCount(), inputs[i].shape.Dims(free_axis));
885 }
886 bool adj_x = swap_free_and_contract[0];
887 bool adj_y = !swap_free_and_contract[1];
888
889 allocateTemp(output_shape, output);
890
891 const Eigen::ThreadPoolDevice &device = *eigen_support::GetThreadPoolDevice();
892
893 if (lhs.shape.FlatSize() == 0 || rhs.shape.FlatSize() == 0)
894 {
895 functor::SetZeroFunctor<Eigen::ThreadPoolDevice, float> set_zero;
896 set_zero(device,
897 typename TTypes<float, 1>::Tensor(output->base<float>(), output->shape.FlatSize()));
898 return;
899 }
900
901 Tensor output_reshaped;
902 reshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped);
903
904 // LaunchBatchMatMul::Launch(lhs, rhs, adj_x, adj_y, bcast, &output_reshaped);
905 BatchMatMul batchMatMul;
906 // Set rhs is not constant: don't use optimization
907 batchMatMul.prepare(lhs.shape, rhs.shape, adj_x, adj_y, false);
908 batchMatMul(lhs.shape, lhs.base<float>(), rhs.shape, rhs.base<float>(), adj_x, adj_y,
909 output_reshaped.shape, output_reshaped.base<float>());
910 }
911
912 void reshapeToRank3(const Tensor &input, int batch_size, Tensor *output)
913 {
914 const int rank = input.shape.DimensionsCount();
915 Shape output_shape({batch_size, input.shape.Dims(rank - 2), input.shape.Dims(rank - 1)});
916 copyFrom(input, output_shape, output);
917 }
918
919private:
920 bool _prepared;
921
922 OperandLabels _input_labels;
923 Labels _output_labels;
924 std::vector<DimensionType> _label_types;
925 OperandLabelCounts _input_label_counts;
926 LabelCounts _output_label_counts;
927 std::vector<bool> _input_has_ellipsis;
928 bool _output_has_ellipsis = false;
929
930 std::vector<std::unique_ptr<float[]>> temp_operand;
931};
932
933} // namespace cker
934} // namespace nnfw
935
936#endif // __NNFW_CKER_EINSUM_H__
void operator()(std::string_view equation, const std::vector< Shape > &input_shapes, const std::vector< const float * > &input_data, const Shape &output_shape, float *output_data)
Definition Einsum.h:202
void prepare(std::string_view equation)
Definition Einsum.h:190
int32_t DimensionsCount() const
Definition Shape.h:107
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
Definition Shape.h:203
int32_t Dims(int i) const
Definition Shape.h:110
int32_t * DimsData()
Definition Shape.h:138
#define NDIMS_CASE(N)
const luci_interpreter::RuntimeShape output_shape
void BatchMatMul(const tflite::RuntimeShape &lhs_shape, const float *lhs_data, const tflite::RuntimeShape &rhs_shape, const float *rhs_data, const tflite::RuntimeShape &output_shape, float *output_data)
const Eigen::ThreadPoolDevice * GetThreadPoolDevice()
std::vector< Labels > OperandLabels
Definition Einsum.h:108
std::vector< int32_t > LabelCounts
Definition Einsum.h:109
std::vector< int32_t > ShapeVec
Definition Einsum.h:106
std::vector< LabelCounts > OperandLabelCounts
Definition Einsum.h:110
@ kBroadcasting
Definition Einsum.h:123
@ kContract
Definition Einsum.h:130
std::vector< int32_t > Labels
Definition Einsum.h:107
std::vector< int32_t > LabelToDimSizes
Definition Einsum.h:111
Definition topk_v2.h:30
Definition Shape.h:28
Eigen::TensorMap< Eigen::Tensor< const T, NDIMS, Eigen::RowMajor, IndexType >, Eigen::Aligned > ConstTensor
Definition Tensor.h:35
Eigen::TensorMap< Eigen::Tensor< T, 1, Eigen::RowMajor, IndexType >, Eigen::Aligned > Flat
Definition Tensor.h:61
Eigen::TensorMap< Eigen::Tensor< T, NDIMS, Eigen::RowMajor, IndexType >, Eigen::Aligned > Tensor
Definition Tensor.h:32
void operator()(const Device &d, typename TTypes< T, N >::ConstTensor input, const std::vector< int32_t > &strides, typename TTypes< T, N >::Tensor output)
Definition Einsum.h:67
static void Reduce(const Device &d, OUT_T out, IN_T in, const ReductionAxes &reduction_axes, const Reducer &reducer)
Definition Einsum.h:88
void operator()(const Device &d, typename TTypes< T >::Flat out)
Definition Einsum.h:98
void operator()(const Device &d, typename TTypes< T, N >::ConstTensor input, const std::vector< int32_t > &strides, typename TTypes< T, N >::Tensor output)
Definition Einsum.h:47