18#include "kernels/BatchMatMul.h"
19#include "kernels/Utils.h"
23#include <tensorflow/lite/kernels/internal/reference/transpose.h>
28tflite::RuntimeShape SwapRowColumnDims(
const tflite::RuntimeShape &shape)
30 tflite::RuntimeShape swapped_shape(shape);
31 const int32_t dims = shape.DimensionsCount();
32 swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
33 swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
45 Tensor *y_tmp,
const BatchMatMulParams ¶ms)
46 : KernelWithParams({x, y}, {
output, x_tmp, y_tmp}, params)
50void BatchMatMul::configure()
54 auto adj_x = params().adj_x;
55 auto adj_y = params().adj_y;
58 if (lhs->element_type() != DataType::FLOAT32 || rhs->element_type() != DataType::FLOAT32)
59 assert(
false &&
"Unsupported type.");
63 auto lhs_rank = lhs->shape().num_dims();
64 auto rhs_rank = rhs->shape().num_dims();
68 auto lhs_scratchpad = temp_lhs();
69 auto rhs_scratchpad = temp_rhs();
70 luci_interpreter_pal::SetupScratchpadTensor(lhs_scratchpad, rhs_scratchpad,
getTensorShape(lhs),
73 auto output_rank = std::max(lhs_rank, rhs_rank);
75 auto extended_lhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank,
getTensorShape(lhs));
76 auto extended_rhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank,
getTensorShape(rhs));
79 for (
int i = 0; i < output_rank - 2; ++i)
81 const int lhs_dim = extended_lhs_shape.Dims(i);
82 const int rhs_dim = extended_rhs_shape.Dims(i);
83 if (lhs_dim != rhs_dim)
94 adj_x ? extended_lhs_shape.Dims(output_rank - 2) : extended_lhs_shape.Dims(output_rank - 1);
96 adj_y ? extended_rhs_shape.Dims(output_rank - 1) : extended_rhs_shape.Dims(output_rank - 2);
101 for (
int i = 0; i < output_rank - 2; ++i)
103 const int lhs_dim = extended_lhs_shape.Dims(i);
104 const int rhs_dim = extended_rhs_shape.Dims(i);
105 int broadcast_dim = lhs_dim;
106 if ((lhs_dim != rhs_dim) && (lhs_dim == 1))
108 broadcast_dim = rhs_dim;
113 int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
114 int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
116 output_shape.dim(output_rank - 2) = extended_lhs_shape.Dims(lhs_rows_index);
117 output_shape.dim(output_rank - 1) = extended_rhs_shape.Dims(rhs_cols_index);
126 tflite::TransposeParams params;
127 int rank = shape.DimensionsCount();
128 params.perm_count = rank;
129 for (
int i = 0; i < rank - 2; ++i)
134 params.perm[rank - 2] = rank - 1;
135 params.perm[rank - 1] = rank - 2;
136 transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2));
137 transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1));
138 switch (tensor_in->element_type())
140 case DataType::FLOAT32:
141 tflite::reference_ops::Transpose(params, shape, getTensorData<float>(tensor_in),
142 transposed_shape, getTensorData<float>(tensor_out));
145 assert(
false &&
"Only suppport fp32 BatchMatMul for now.");
149void BatchMatMul::execute()
const
154 bool adj_x = params().adj_x;
155 bool adj_y = params().adj_y;
160 auto rhs_tensor = adj_y ? rhs : temp_rhs();
161 auto lhs_tensor = adj_x ? temp_lhs() : lhs;
170 tflite::RuntimeShape rhs_shape = adj_y ? orig_rhs_shape : SwapRowColumnDims(orig_rhs_shape);
171 tflite::RuntimeShape lhs_shape = adj_x ? orig_lhs_shape : SwapRowColumnDims(orig_lhs_shape);
173 switch (x()->element_type())
175 case DataType::FLOAT32:
178 getTensorData<float>(
output()));
181 assert(
false &&
"Unsupported type.");
BatchMatMul(const Tensor *x, const Tensor *y, Tensor *output, Tensor *x_tmp, Tensor *y_tmp, const BatchMatMulParams ¶ms)
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
void TransposeRowsColumns(const Tensor *tensor_in, Tensor *tensor_out)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void BatchMatMul(const tflite::RuntimeShape &lhs_shape, const float *lhs_data, const tflite::RuntimeShape &rhs_shape, const float *rhs_data, const tflite::RuntimeShape &output_shape, float *output_data)