53{
58
59
60 if (lhs->element_type() != DataType::FLOAT32 || rhs->element_type() != DataType::FLOAT32)
61 throw std::runtime_error("luci-intp BatchMatMul(1) Unsupported type.");
62
64
65 auto lhs_rank = lhs->shape().num_dims();
66 auto rhs_rank = rhs->shape().num_dims();
69
70 auto lhs_scratchpad = temp_lhs();
71 auto rhs_scratchpad = temp_rhs();
72 luci_interpreter_pal::SetupScratchpadTensor(lhs_scratchpad, rhs_scratchpad,
getTensorShape(lhs),
74
75 auto output_rank = std::max(lhs_rank, rhs_rank);
76
77 auto extended_lhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank,
getTensorShape(lhs));
78 auto extended_rhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank,
getTensorShape(rhs));
79
80
81 for (int i = 0; i < output_rank - 2; ++i)
82 {
83 const int lhs_dim = extended_lhs_shape.Dims(i);
84 const int rhs_dim = extended_rhs_shape.Dims(i);
85 if (lhs_dim != rhs_dim)
86 {
87 if (lhs_dim != 1)
88 {
90 }
91 }
92 }
93
94
95 int accum_dim_lhs =
96 adj_x ? extended_lhs_shape.Dims(output_rank - 2) : extended_lhs_shape.Dims(output_rank - 1);
97 int accum_dim_rhs =
98 adj_y ? extended_rhs_shape.Dims(output_rank - 1) : extended_rhs_shape.Dims(output_rank - 2);
100
102
103 for (int i = 0; i < output_rank - 2; ++i)
104 {
105 const int lhs_dim = extended_lhs_shape.Dims(i);
106 const int rhs_dim = extended_rhs_shape.Dims(i);
107 int broadcast_dim = lhs_dim;
108 if ((lhs_dim != rhs_dim) && (lhs_dim == 1))
109 {
110 broadcast_dim = rhs_dim;
111 }
113 }
114
115 int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
116 int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
117
118 output_shape.dim(output_rank - 2) = extended_lhs_shape.Dims(lhs_rows_index);
119 output_shape.dim(output_rank - 1) = extended_rhs_shape.Dims(rhs_cols_index);
120
122}
void resize(const Shape &new_shape)
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)