198 parseEquation(equation);
202 void operator()(std::string_view equation,
const std::vector<Shape> &input_shapes,
211 const int num_inputs = input_shapes.size();
212 std::vector<InputTensor<float>> inputs(num_inputs);
213 for (
int i = 0; i < num_inputs; i++)
215 inputs[i].shape.ReplaceWith(input_shapes[i].DimensionsCount(), input_shapes[i].DimsData());
216 inputs[i].buffer = input_data[i];
220 Labels output_labels(_output_labels);
221 std::vector<DimensionType> label_types(_label_types);
223 LabelCounts output_label_counts(_output_label_counts);
226 processDimensions(inputs, &input_labels, &output_labels, &label_types, &input_label_counts,
227 &output_label_counts, &label_to_dim_sizes);
236 std::vector<Tensor> inputs_reduced(num_inputs);
237 std::vector<bool> swap_free_and_contract(num_inputs);
238 for (
int i = 0; i < num_inputs; ++i)
240 bool temp_swap_free_and_contract =
false;
241 reduceOperand<float>(inputs[i], label_types, input_label_counts[i], &input_labels[i],
242 &free_labels[i], &temp_swap_free_and_contract, &inputs_reduced[i]);
243 swap_free_and_contract[i] = temp_swap_free_and_contract;
249 Tensor contraction_output_reshaped;
250 contractOperands(inputs_reduced, swap_free_and_contract, &contraction_output_reshaped);
254 std::vector<int32_t> result_shape_dims(contraction_output_reshaped.
shape.
DimensionsCount() - 2);
256 for (
size_t i = 0; i < result_shape_dims.size(); i++)
258 result_shape_dims[i] = contraction_output_reshaped.
shape.
Dims(i);
261 int num_labels = label_types.size();
265 for (
int label = 0; label < num_labels; ++label)
268 result_labels.push_back(label);
270 for (
int label = 0; label < num_labels; ++label)
272 if (label_types[label] ==
kBatch)
273 result_labels.push_back(label);
275 for (
int i = 0; i < num_inputs; ++i)
277 for (
auto &&label : free_labels[i])
279 result_labels.push_back(label);
280 result_shape_dims.push_back(label_to_dim_sizes[label]);
284 Shape result_shape(result_shape_dims.size(), result_shape_dims.data());
288 Tensor contraction_output;
289 copyFrom(contraction_output_reshaped, result_shape, &contraction_output);
296 strideOrInflate<float>(contraction_output, result_labels, output_label_counts,
297 true , &output_inflated);
303 for (
auto &&label : result_labels)
305 inflated_labels.insert(inflated_labels.end(), output_label_counts[label], label);
307 result_labels.swap(inflated_labels);
316 std::vector<int32_t> output_permutation(output_labels.size());
317 std::vector<int32_t> label_to_position(num_labels, -1);
318 for (
size_t i = 0; i < result_labels.size(); ++i)
321 if (label_to_position[result_labels[i]] == -1)
323 label_to_position[result_labels[i]] = i;
326 for (
size_t i = 0; i < output_labels.size(); ++i)
328 output_permutation[i] = label_to_position[output_labels[i]];
330 label_to_position[output_labels[i]] += 1;
336 temp_inflated.
buffer = (
reinterpret_cast<const float *
>(output_inflated.
buffer));
340 transposeOperand<float>(temp_inflated, output_permutation, &output);
342 memcpy(output_data, output.buffer,
output_shape.FlatSize() *
sizeof(
float));
344 temp_operand.clear();
202 void operator()(std::string_view equation,
const std::vector<Shape> &input_shapes, {
…}
348 void parseEquation(std::string_view equation)
350 std::vector<std::string> input_str;
351 std::string output_str;
353 parseEinsumEquation(equation, input_str, output_str);
357 std::map<char, int> label_mapping;
358 int num_inputs = input_str.size();
359 _input_labels.resize(num_inputs);
362 for (
int i = 0; i < num_inputs; ++i)
364 mapToLabels(input_str[i], _input_labels.at(i), label_mapping);
366 mapToLabels(output_str, _output_labels, label_mapping);
369 int num_labels = label_mapping.size();
370 _input_label_counts.resize(num_inputs);
371 _input_has_ellipsis.resize(num_inputs);
372 for (
int i = 0; i < num_inputs; ++i)
374 _input_label_counts.at(i).resize(num_labels);
375 for (
const int label : _input_labels.at(i))
377 if (label != kEllipsisLabel)
378 _input_label_counts.at(i)[label] += 1;
380 _input_has_ellipsis.at(i) =
true;
383 _output_label_counts.resize(num_labels);
384 for (
const int label : _output_labels)
386 if (label != kEllipsisLabel)
387 _output_label_counts.at(label) += 1;
389 _output_has_ellipsis =
true;
393 _label_types.resize(num_labels);
394 for (
int label = 0; label < num_labels; ++label)
396 bool removed = (_output_label_counts[label] == 0);
398 num_inputs == 1 || _input_label_counts[0][label] == 0 || _input_label_counts[1][label] == 0;
399 _label_types[label] = getDimensionType(removed, unique);
403 void parseEinsumEquation(std::string_view &equation, std::vector<std::string> &input_subscripts,
404 std::string &output_subscript)
406 std::vector<std::string> inputs_and_output_subscripts = strSplit(equation,
"->");
407 if (inputs_and_output_subscripts.size() != 2)
409 throw std::runtime_error{
"Einsum: Expecting exactly one '->' in einsum equation: " +
410 std::string(equation)};
413 output_subscript = inputs_and_output_subscripts[1];
414 input_subscripts = strSplit(inputs_and_output_subscripts[0],
",");
415 if (input_subscripts.size() != 1 && input_subscripts.size() != 2)
417 throw std::runtime_error{
"Einsum: Expecting 1 or 2 input subscripts in equation '" +
418 std::string(equation) +
419 "' but got: " + std::to_string(input_subscripts.size())};
424 void mapToLabels(std::string_view subscript,
Labels &labels, std::map<char, int> &label_mapping)
426 for (
size_t i = 0; i < subscript.size(); ++i)
428 const char label_char = subscript[i];
429 if (label_char ==
'.')
431 labels.push_back(kEllipsisLabel);
435 if (label_mapping.find(label_char) == label_mapping.end())
437 const int next_label = label_mapping.size();
438 label_mapping[label_char] = next_label;
440 const int mapped_label = label_mapping[label_char];
441 labels.push_back(mapped_label);
445 template <
typename T>
446 void processDimensions(
const std::vector<InputTensor<T>> &inputs,
OperandLabels *input_labels,
447 Labels *output_labels, std::vector<DimensionType> *label_types,
451 if (
inputs.size() != input_labels->size())
453 throw std::runtime_error{
"Expected " + std::to_string(input_labels->size()) +
454 " inputs but got: " + std::to_string(
inputs.size())};
460 int max_bcast_dims = 0;
461 const int num_named_labels = label_types->size();
462 label_to_dim_sizes->resize(num_named_labels);
465 Labels *labels = &(*input_labels)[i];
467 if (!_input_has_ellipsis[i])
469 if (inputs[i].shape.DimensionsCount() != ((int32_t)labels->size()))
471 throw std::runtime_error{
"Expected input " + std::to_string(i) +
" to have rank " +
472 std::to_string(labels->size()) +
" but got: " +
473 std::to_string(inputs[i].shape.DimensionsCount())};
475 for (
size_t label_idx = 0; label_idx < labels->size(); ++label_idx)
477 const int label = (*labels)[label_idx];
478 recordLabelToDimension(label, label_idx, inputs[i].shape, label_to_dim_sizes);
484 if (inputs[i].shape.DimensionsCount() + 1 < (int32_t)labels->size())
486 throw std::runtime_error{
"Expected input " + std::to_string(i) +
" to have rank at least " +
487 std::to_string(labels->size() - 1) +
488 " but got: " + std::to_string(inputs[i].shape.DimensionsCount())};
490 int ellipsis_axis = -1;
491 const int num_bcast_dims =
inputs[i].shape.DimensionsCount() - labels->size() + 1;
492 for (
size_t label_idx = 0; label_idx < labels->size(); ++label_idx)
494 const int label = (*labels)[label_idx];
495 if (label == kEllipsisLabel)
497 ellipsis_axis = label_idx;
501 const int axis = label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1);
502 recordLabelToDimension(label, axis, inputs[i].shape, label_to_dim_sizes);
506 if (ellipsis_axis != -1)
508 insertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis, labels,
509 &input_label_counts->at(i));
510 max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims);
514 std::vector<bool>::iterator it_input =
515 std::find(_input_has_ellipsis.begin(), _input_has_ellipsis.end(),
true);
516 if (it_input == _input_has_ellipsis.end() && !_output_has_ellipsis)
521 auto it = std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel);
522 if (it != output_labels->end())
524 const int ellipsis_axis = it - output_labels->begin();
525 insertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis, output_labels,
526 output_label_counts);
528 else if (max_bcast_dims > 0)
530 std::runtime_error{
"Output contains " + std::to_string(max_bcast_dims) +
531 " broadcasting dimension(s) but no ellipsis " +
532 "(...) was found in the output subscripts."};
535 label_types->resize(num_named_labels + max_bcast_dims,
kBroadcasting);
538 void recordLabelToDimension(
const int32_t label,
const int axis,
const Shape &input_shape,
541 const int32_t input_dim = input_shape.Dims(axis);
543 if (label_to_dim_sizes->at(label) != 0 && label_to_dim_sizes->at(label) != input_dim)
545 std::runtime_error{
"Expected dimension " + std::to_string(label_to_dim_sizes->at(label)) +
546 " at axis " + std::to_string(axis) +
547 " of the input shaped but got dimension " + std::to_string(input_dim)};
549 (*label_to_dim_sizes)[label] = input_dim;
552 void insertBroadcastLabels(
int num_bcast_dims,
int num_named_labels,
int ellipsis_axis,
555 labels->erase(labels->begin() + ellipsis_axis);
556 labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0);
557 std::iota(labels->begin() + ellipsis_axis, labels->begin() + ellipsis_axis + num_bcast_dims,
561 label_counts->resize(num_named_labels + num_bcast_dims, 1);
564 template <
typename T>
565 void reduceOperand(
const InputTensor<T> &input,
const std::vector<DimensionType> &label_types,
567 bool *swap_free_and_contract,
Tensor *output)
572 std::vector<int32_t> permutation(
input.shape.DimensionsCount());
573 std::iota(permutation.begin(), permutation.end(), 0);
579 if (shouldSwapFreeAndContract(*labels, label_types))
581 *swap_free_and_contract =
true;
585 std::sort(permutation.begin(), permutation.end(), [&](
int i,
int j) {
586 int label_i = (*labels)[i];
587 int label_j = (*labels)[j];
588 return std::tie(label_types[label_i], label_i) < std::tie(label_types[label_j], label_j);
592 transposeOperand<T>(input, permutation, &input_transposed);
594 permuteLabels(permutation, labels);
598 labels->erase(std::unique(labels->begin(), labels->end()), labels->end());
599 strideOrInflate<T>(input_transposed, *labels, label_counts,
false ,
604 std::vector<int32_t>
reshape(5, 1);
611 std::vector<int32_t> output_shape_dims;
612 for (
size_t label_idx = 0; label_idx < labels->size(); ++label_idx)
614 const int label = labels->at(label_idx);
615 int32_t dim = input_deduped.shape.Dims(label_idx);
618 output_shape_dims.push_back(dim);
620 else if (label_types[label] ==
kFree)
622 free_labels->push_back(label);
624 reshape[label_types[label]] *= dim;
627 if (*swap_free_and_contract)
630 output_shape_dims.push_back(reshape[
kFree]);
631 output_shape_dims.push_back(reshape[
kContract]);
633 output_shape.ReplaceWith(output_shape_dims.size(), output_shape_dims.data());
642 using Reducer = Eigen::internal::SumReducer<T>;
643 using Index =
typename TTypes<T>::Tensor::Index;
648 const int32_t output_size =
651 device,
output->shaped<T, 1>({output_size}),
652 input_deduped.shaped<T, 2>({output_size, reshape[kReduce]}), Eigen::array<Index, 1>({1}),
656 bool shouldSwapFreeAndContract(
const Labels &labels,
657 const std::vector<DimensionType> &label_types)
661 std::vector<int> remap = {0, 1, 3, 2, 4};
662 for (
size_t i = 0; i + 1 < labels.size(); ++i)
664 const int dimtype_a = remap[label_types[labels[i]]];
665 const int dimtype_b = remap[label_types[labels[i + 1]]];
666 if (dimtype_a > dimtype_b || (dimtype_a == dimtype_b && labels[i] > labels[i + 1]))
674 template <
typename T>
675 void transposeOperand(
const InputTensor<T> &input,
const std::vector<int32_t> &permutation,
678 if (!shouldTranspose(
input.shape, permutation))
680 copyFrom(input,
input.shape, output);
683 Shape transposed_shape(
input.shape.DimensionsCount());
684 for (
int i = 0; i <
input.shape.DimensionsCount(); ++i)
686 transposed_shape.SetDim(i,
input.shape.Dims(permutation[i]));
690 if (
input.shape.FlatSize() == 0)
692 copyFrom(input, transposed_shape, output);
696 temp_operand.emplace_back(std::make_unique<T[]>(transposed_shape.FlatSize()));
697 T *new_buffer = temp_operand.back().get();
699 TransposeParams transpose_params;
700 transpose_params.perm_count = permutation.size();
701 for (
size_t i = 0; i < permutation.size(); i++)
703 transpose_params.perm[i] = permutation[i];
706 Transpose<T>(transpose_params,
input.shape,
input.buffer, transposed_shape, new_buffer);
708 output->shape.ReplaceWith(transposed_shape.DimensionsCount(), transposed_shape.DimsData());
709 output->buffer = new_buffer;
712 bool shouldTranspose(
const Shape &input_shape,
const std::vector<int32_t> &permutation)
714 if (input_shape.DimensionsCount() < 2)
716 for (
size_t i = 0; i < permutation.size(); ++i)
718 if (permutation[i] != (int32_t)i)
724 template <
typename T>
725 void copyFrom(
const InputTensor<T> &input,
const Shape &shape,
Tensor *output)
728 temp_tensor.shape.ReplaceWith(
input.shape.DimensionsCount(),
input.shape.DimsData());
729 temp_operand.emplace_back(std::make_unique<
float[]>(
input.shape.FlatSize()));
730 temp_tensor.buffer = temp_operand.back().get();
731 memcpy(temp_tensor.buffer,
input.buffer,
input.shape.FlatSize() *
sizeof(
float));
733 copyFrom(temp_tensor, shape, output);
738 if (
output->copyFrom(input, shape))
741 throw std::runtime_error{
"Einsum: Encountered error while reshaping a Tensor"};
745 void permuteLabels(
const std::vector<int32_t> &permutation,
Labels *labels)
747 Labels permuted_labels(labels->size());
748 for (
size_t i = 0; i < labels->size(); ++i)
750 permuted_labels[i] = (*labels)[permutation[i]];
752 labels->swap(permuted_labels);
757 template <
typename T>
759 const bool should_inflate,
Tensor *output)
762 if (std::all_of(label_counts.begin(), label_counts.end(), [](
int c) { return c <= 1; }))
764 return copyFrom(input,
input.shape, output);
771 std::vector<int32_t> strides;
776 Shape inflated_shape;
777 std::vector<int32_t> strided_shape_dims;
778 std::vector<int32_t> inflated_shape_dims;
779 for (
auto &&label : labels)
781 const int32_t count = label_counts[label];
782 const int current_axis =
783 should_inflate ? strided_shape_dims.size() : inflated_shape_dims.size();
784 const int32_t dim =
input.shape.Dims(current_axis);
785 strided_shape_dims.push_back(dim);
786 inflated_shape_dims.insert(inflated_shape_dims.end(), count, dim);
787 const int32_t reshape_dim = std::pow(dim, count);
788 reshape.push_back(reshape_dim);
792 const int32_t stride = (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1;
793 strides.push_back(stride);
796 strided_shape.ReplaceWith(strided_shape_dims.size(), strided_shape_dims.data());
797 inflated_shape.ReplaceWith(inflated_shape_dims.size(), inflated_shape_dims.data());
802 temp_operand.emplace_back(std::make_unique<
float[]>(
output_shape.FlatSize()));
803 output->buffer = temp_operand.back().get();
809#define NDIMS_CASE(N) \
812 if (should_inflate) \
814 auto output_map = output->shaped<T, N>(reshape); \
815 auto input_map = input.shaped<T, N>(strided_shape_dims); \
816 functor::InflateFunctor<Eigen::ThreadPoolDevice, T, N>()(device, input_map, strides, \
821 auto input_map = input.shaped<T, N>(reshape); \
822 auto output_map = output->shaped<T, N>(strided_shape_dims); \
823 functor::StrideFunctor<Eigen::ThreadPoolDevice, T, N>()(device, input_map, strides, \
835 throw std::runtime_error{
"Unsupported rank: " + std::to_string(
reshape.size()) +
836 " while handling repeated indices. Up to rank 6 is supported."};
841 void allocateTemp(
const Shape &shape,
Tensor *output)
843 output->shape.ReplaceWith(shape.DimensionsCount(), shape.DimsData());
844 temp_operand.emplace_back(std::make_unique<
float[]>(shape.FlatSize()));
845 output->buffer = temp_operand.back().get();
857 void contractOperands(std::vector<Tensor> &inputs, std::vector<bool> &swap_free_and_contract,
861 return copyFrom(inputs[0], inputs[0].shape, output);
863 MatMulBCast bcast(inputs[0].shape, inputs[1].shape);
864 if (!bcast.IsValid())
866 throw std::runtime_error{
"Einsum: Invalid broadcasting dimensions"};
870 reshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs);
872 reshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs);
873 Shape old_output_shape = bcast.output_batch_shape();
875 for (
int i = 0; i < old_output_shape.DimensionsCount(); i++)
880 for (
size_t i = 0; i <
inputs.size(); ++i)
882 const int32_t free_axis =
883 inputs[i].shape.DimensionsCount() - (swap_free_and_contract[i] ? 1 : 2);
884 output_shape.SetDim(i + old_output_shape.DimensionsCount(),
inputs[i].shape.Dims(free_axis));
886 bool adj_x = swap_free_and_contract[0];
887 bool adj_y = !swap_free_and_contract[1];
893 if (lhs.shape.FlatSize() == 0 || rhs.shape.FlatSize() == 0)
895 functor::SetZeroFunctor<Eigen::ThreadPoolDevice, float> set_zero;
902 reshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped);
907 batchMatMul.prepare(lhs.shape, rhs.shape, adj_x, adj_y,
false);
908 batchMatMul(lhs.shape, lhs.base<
float>(), rhs.shape, rhs.base<
float>(), adj_x, adj_y,
909 output_reshaped.shape, output_reshaped.base<
float>());
912 void reshapeToRank3(
const Tensor &input,
int batch_size,
Tensor *output)
914 const int rank =
input.shape.DimensionsCount();
924 std::vector<DimensionType> _label_types;
927 std::vector<bool> _input_has_ellipsis;
928 bool _output_has_ellipsis =
false;
930 std::vector<std::unique_ptr<float[]>> temp_operand;