19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/transpose.h>
42 assert(
input()->shape().num_dims() <= 5);
43 assert(
input()->element_type() ==
output()->element_type());
45 assert(
perm()->shape().num_dims() == 1);
46 assert(
perm()->shape().dim(0) == dims);
49 for (
int i = 0;
i < dims;
i++)
60 tflite::TransposeParams params{};
63 params.perm_count =
size;
66 switch (
input()->element_type())
68 case DataType::FLOAT32:
84 throw std::runtime_error(
"luci-intp Transpose Unsupported type.");
void resize(const Shape &new_shape)
const Shape & shape() const
const Tensor * input() const
Transpose(const Tensor *input, const Tensor *perm, Tensor *output)
void configure() override
void execute() const override
const Tensor * perm() const
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
T must_cast(loco::Node *node)