18#include "kernels/BatchMatMul.h"
19#include "kernels/Utils.h"
21#include "PALBatchMatMul.h"
23#include <tensorflow/lite/kernels/internal/reference/transpose.h>
30tflite::RuntimeShape SwapRowColumnDims(
const tflite::RuntimeShape &shape)
32 tflite::RuntimeShape swapped_shape(shape);
33 const int32_t
dims = shape.DimensionsCount();
34 swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
35 swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
60 if (lhs->element_type() != DataType::FLOAT32 || rhs->element_type() != DataType::FLOAT32)
61 throw std::runtime_error(
"luci-intp BatchMatMul(1) Unsupported type.");
65 auto lhs_rank = lhs->shape().num_dims();
66 auto rhs_rank = rhs->shape().num_dims();
128 tflite::TransposeParams params;
129 int rank = shape.DimensionsCount();
130 params.perm_count = rank;
131 for (
int i = 0;
i < rank - 2; ++
i)
136 params.perm[rank - 2] = rank - 1;
137 params.perm[rank - 1] = rank - 2;
142 case DataType::FLOAT32:
147 throw std::runtime_error(
"Only suppport fp32 BatchMatMul for now.");
175 switch (
x()->element_type())
177 case DataType::FLOAT32:
183 throw std::runtime_error(
"luci-intp BatchMatMul(2) Unsupported type.");
const BatchMatMulParams & params() const
void resize(const Shape &new_shape)
void configure() override
void execute() const override
BatchMatMul(const Tensor *x, const Tensor *y, Tensor *output, Tensor *x_tmp, Tensor *y_tmp, const BatchMatMulParams ¶ms)
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
std::vector< int > dims(const std::string &src)
void TransposeRowsColumns(const Tensor *tensor_in, Tensor *tensor_out)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void BatchMatMul(const tflite::RuntimeShape &lhs_shape, const float *lhs_data, const tflite::RuntimeShape &rhs_shape, const float *rhs_data, const tflite::RuntimeShape &output_shape, float *output_data)
T must_cast(loco::Node *node)