29void packImpl(
const circle::Tensor *input0,
const circle::Tensor *output,
31 uint8_t *output_data_raw)
33 const auto *
options = cur_op->builtin_options_as_PackOptions();
35 const int values_count =
options->values_count();
37 const int dimensions = Tensor::num_dims(output);
39 const auto input_dims =
wrap(input0->shape());
40 const auto output_dims =
wrap(
output->shape());
48 for (
int i = 0; i < axis; ++i)
49 outer_size *= output_dims[i];
52 for (
int i = axis + 1; i < dimensions; ++i)
53 copy_size *= output_dims[i];
56 for (
int i = 0; i < input_dims.size(); ++i)
57 input_size *= input_dims[i];
59 assert(input_size == copy_size * outer_size);
61 T *
output_data = kernels::getTensorData<T>(output_data_raw);
62 assert(output_data !=
nullptr);
64 for (
int i = 0; i < values_count; ++i)
66 const auto input_index = cur_op->inputs()->operator[](i);
67 assert(input_index != -1);
68 const auto input = runtime_graph->getCircleTensorByIndex(input_index);
70 auto input_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(input));
71 assert(input_data !=
nullptr);
72 for (
int k = 0; k < outer_size; ++k)
74 const T *input_ptr =
input_data + copy_size * k;
75 int loc = k * values_count * copy_size + i * copy_size;
77 for (
int j = 0; j < copy_size; ++j)
78 output_ptr[j] = input_ptr[j];
92 const auto input_index = cur_op->inputs()->operator[](0);
93 const auto output_index = cur_op->outputs()->operator[](0);
94 assert(output_index != -1);
95 assert(input_index != -1);
100 assert(output_data !=
nullptr);
102 switch (Tensor::element_type(output))
105 case DataType::FLOAT32:
106 packImpl<float>(input, output, cur_op, runtime_graph, output_data);
111 packImpl<int8_t>(input, output, cur_op, runtime_graph, output_data);
114 packImpl<uint8_t>(input, output, cur_op, runtime_graph, output_data);
118 packImpl<int32_t>(input, output, cur_op, runtime_graph, output_data);
121 packImpl<int64_t>(input, output, cur_op, runtime_graph, output_data);
124 assert(
false &&
"Unsupported types");
const circle::Tensor * getCircleTensorByIndex(int32_t index)
uint8_t * getDataByTensor(const circle::Tensor *raw_tensor)
void execute_kernel_CirclePack(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
RuntimeGraph BaseRuntimeGraph
void configure_kernel_CirclePack(const circle::Operator *, BaseRuntimeGraph *)
VectorWrapper< T > wrap(const flatbuffers::Vector< T > *vec)
This file contains utility macro.