60 if (lhs->element_type() != DataType::FLOAT32 || rhs->element_type() != DataType::FLOAT32)
61 throw std::runtime_error(
"luci-intp BatchMatMul(1) Unsupported type.");
65 auto lhs_rank = lhs->shape().num_dims();
66 auto rhs_rank = rhs->shape().num_dims();
70 auto lhs_scratchpad = temp_lhs();
71 auto rhs_scratchpad = temp_rhs();
72 luci_interpreter_pal::SetupScratchpadTensor(lhs_scratchpad, rhs_scratchpad,
getTensorShape(lhs),
75 auto output_rank = std::max(lhs_rank, rhs_rank);
77 auto extended_lhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank,
getTensorShape(lhs));
78 auto extended_rhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank,
getTensorShape(rhs));
81 for (
int i = 0; i < output_rank - 2; ++i)
83 const int lhs_dim = extended_lhs_shape.Dims(i);
84 const int rhs_dim = extended_rhs_shape.Dims(i);
85 if (lhs_dim != rhs_dim)
96 adj_x ? extended_lhs_shape.Dims(output_rank - 2) : extended_lhs_shape.Dims(output_rank - 1);
98 adj_y ? extended_rhs_shape.Dims(output_rank - 1) : extended_rhs_shape.Dims(output_rank - 2);
103 for (
int i = 0; i < output_rank - 2; ++i)
105 const int lhs_dim = extended_lhs_shape.Dims(i);
106 const int rhs_dim = extended_rhs_shape.Dims(i);
107 int broadcast_dim = lhs_dim;
108 if ((lhs_dim != rhs_dim) && (lhs_dim == 1))
110 broadcast_dim = rhs_dim;
115 int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
116 int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
118 output_shape.dim(output_rank - 2) = extended_lhs_shape.Dims(lhs_rows_index);
119 output_shape.dim(output_rank - 1) = extended_rhs_shape.Dims(rhs_cols_index);
128 tflite::TransposeParams params;
129 int rank = shape.DimensionsCount();
130 params.perm_count = rank;
131 for (
int i = 0; i < rank - 2; ++i)
136 params.perm[rank - 2] = rank - 1;
137 params.perm[rank - 1] = rank - 2;
138 transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2));
139 transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1));
142 case DataType::FLOAT32:
143 tflite::reference_ops::Transpose(params, shape, getTensorData<float>(tensor_in),
144 transposed_shape, getTensorData<float>(tensor_out));
147 throw std::runtime_error(
"Only suppport fp32 BatchMatMul for now.");
162 auto rhs_tensor = adj_y ? rhs : temp_rhs();
163 auto lhs_tensor = adj_x ? temp_lhs() : lhs;
172 tflite::RuntimeShape rhs_shape = adj_y ? orig_rhs_shape : SwapRowColumnDims(orig_rhs_shape);
173 tflite::RuntimeShape lhs_shape = adj_x ? orig_lhs_shape : SwapRowColumnDims(orig_lhs_shape);
175 switch (
x()->element_type())
177 case DataType::FLOAT32:
180 getTensorData<float>(
output()));
183 throw std::runtime_error(
"luci-intp BatchMatMul(2) Unsupported type.");
void BatchMatMul(const tflite::RuntimeShape &lhs_shape, const float *lhs_data, const tflite::RuntimeShape &rhs_shape, const float *rhs_data, const tflite::RuntimeShape &output_shape, float *output_data)