18#include "kernels/Utils.h"
25constexpr int kInputTensor = 0;
28void UnpackImpl(
const circle::Operator *cur_op,
const circle::Tensor *input,
int output_count,
29 int axis, RuntimeGraph *runtime_graph)
31 const auto output0_index = cur_op->outputs()->operator[](0);
32 assert(output0_index != -1);
34 const auto output0 = runtime_graph->getCircleTensorByIndex(output0_index);
35 assert(output0 !=
nullptr);
37 const auto input_dims = Tensor::tensor_shape(input);
38 const auto output_dims = Tensor::tensor_shape(output0);
39 const int dimensions = input_dims.size();
43 axis += input_dims.size();
47 for (
int i = 0; i < axis; ++i)
49 outer_size *= input_dims[i];
52 for (
int i = axis + 1; i < dimensions; ++i)
54 copy_size *= input_dims[i];
57 for (
int i = 0; i < output_dims.size(); ++i)
59 output_size *= output_dims[i];
62 const T *
input_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(input));
64 for (
int i = 0; i < output_count; ++i)
66 const auto output_index = cur_op->outputs()->operator[](i);
67 assert(output_index != -1);
69 const auto t = runtime_graph->getCircleTensorByIndex(output_index);
70 assert(output0 !=
nullptr);
71 T *
output_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(t));
72 for (
int k = 0; k < outer_size; ++k)
75 int loc = k * output_count * copy_size + i * copy_size;
77 for (
int j = 0; j < copy_size; ++j)
78 output_ptr[j] = input_ptr[j];
86 const auto input_index = cur_op->inputs()->operator[](0);
87 const auto output_index = cur_op->outputs()->operator[](0);
89 assert(input_index != -1);
90 assert(output_index != -1);
95 assert(input !=
nullptr);
96 assert(output !=
nullptr);
98 const auto *options = cur_op->builtin_options_as_UnpackOptions();
103 for (
int i = 0; i < Tensor::num_dims(input); ++i)
105 if (i == options->axis())
108 if (i < options->axis())
121 const auto input_index = cur_op->inputs()->operator[](0);
122 assert(input_index != -1);
125 assert(input !=
nullptr);
127 const auto type = Tensor::element_type(input);
129 const auto *options = cur_op->builtin_options_as_UnpackOptions();
134 case DataType::FLOAT32:
136 UnpackImpl<float>(cur_op, input, options->num(), options->axis(), runtime_graph);
141 assert(
false &&
"Unsupported type");
const circle::Tensor * getCircleTensorByIndex(int32_t index)
#define LUCI_INTERPRETER_CHECK(cond)
void execute_kernel_CircleUnpack(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
void configure_kernel_CircleUnpack(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
const loco::Dimension & dim(uint32_t axis) const