20#include "kernels/Utils.h"
22#include "PALTranspose.h"
33 const int32_t dims = Tensor::num_dims(kernel.
input1());
49 const circle::Tensor *input = kernel.
input1();
50 const circle::Tensor *perm = kernel.
input2();
51 const circle::Tensor *output = kernel.
output();
62 switch (Tensor::element_type(input))
65 case DataType::FLOAT32:
67 kernels::getTensorData<float>(
tiso_data.input1_data),
69 kernels::getTensorData<float>(
tiso_data.output_data));
75 kernels::getTensorData<uint8_t>(
tiso_data.input1_data),
77 kernels::getTensorData<uint8_t>(
tiso_data.output_data));
81 assert(
false &&
"Unsupported type");
uint8_t * getConstDataByTensor(const circle::Tensor *raw_tensor)
const circle::Tensor * output() const
const circle::Tensor * input2() const
const circle::Tensor * input1() const
#define LUCI_INTERPRETER_CHECK(cond)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
void Transpose(const TransposeParams ¶ms, const luci_interpreter::RuntimeShape &unextended_input_shape, const T *input_data, const luci_interpreter::RuntimeShape &unextended_output_shape, T *output_data)
void execute_kernel_CircleTranspose(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
void configure_kernel_CircleTranspose(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
T must_cast(loco::Node *node)
const loco::Dimension & dim(uint32_t axis) const