19#include "kernels/Utils.h"
21#include "PALConcatenation.h"
30void evalGeneric(
const circle::Operator *cur_op,
BaseRuntimeGraph *runtime_graph)
32 const auto output_index = cur_op->outputs()->operator[](0);
34 assert(output_index != -1);
36 auto output = runtime_graph->getCircleTensorByIndex(output_index);
38 const auto *
options = cur_op->builtin_options_as_ConcatenationOptions();
42 axis += Tensor::num_dims(output);
44 const auto input_sizes = cur_op->inputs()->size();
46 std::vector<const T *> all_input_data;
47 std::vector<luci_interpreter::RuntimeShape> all_shape;
48 std::vector<luci_interpreter::RuntimeShape *> all_shape_ptr;
50 for (int32_t i = 0; i < input_sizes; ++i)
53 const auto *
tensor = runtime_graph->getCircleTensorByIndex(input_index);
55 const auto *tensor_data = runtime_graph->getDataByTensor(tensor);
56 if (tensor_data ==
nullptr)
57 tensor_data = runtime_graph->getConstDataByTensor(tensor);
59 auto *
data =
reinterpret_cast<const T *
>(tensor_data);
63 all_input_data.push_back(
data);
64 all_shape.push_back(runtime_shape);
69 all_shape_ptr.push_back(&shape);
72 auto *
output_data =
reinterpret_cast<T *
>(runtime_graph->getDataByTensor(output));
76 params.inputs_count = all_shape.size();
86 const int num_inputs = cur_op->inputs()->size();
89 auto input_index = cur_op->inputs()->operator[](0);
90 auto output_index = cur_op->outputs()->operator[](0);
92 assert(input_index != -1);
93 assert(output_index != -1);
98 const auto *params = cur_op->builtin_options_as_ConcatenationOptions();
103 int axis = params->axis();
105 axis += Tensor::num_dims(t0);
108 for (
int i = 1; i < num_inputs; ++i)
110 input_index = cur_op->inputs()->operator[](i);
119 for (
int i = 1; i < num_inputs; ++i)
121 input_index = cur_op->inputs()->operator[](i);
123 if (Tensor::element_type(tensor) == DataType::S8)
126 Tensor::quantized_dimension(output));
139 int num_inputs = cur_op->inputs()->size();
142 const auto input_index = cur_op->inputs()->operator[](0);
143 assert(input_index != -1);
146 switch (Tensor::element_type(t0))
149 case DataType::FLOAT32:
150 evalGeneric<float>(cur_op, runtime_graph);
155 evalGeneric<int8_t>(cur_op, runtime_graph);
159 evalGeneric<int32_t>(cur_op, runtime_graph);
162 evalGeneric<int64_t>(cur_op, runtime_graph);
165 assert(
false &&
"Unsupported type.");
const circle::Tensor * getCircleTensorByIndex(int32_t index)
#define LUCI_INTERPRETER_CHECK(cond)
const T * data(const std::vector< T, Alloc > &v)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
luci_interpreter::RuntimeShape getTensorRuntimeShape(const circle::Tensor *circle_tensor, BaseRuntimeGraph *runtime_graph)
void Concatenation(const ConcatenationParams ¶ms, const luci_interpreter::RuntimeShape *const *input_shapes, const Scalar *const *input_data, const luci_interpreter::RuntimeShape &output_shape, Scalar *output_data)
void execute_kernel_CircleConcatenation(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
RuntimeGraph BaseRuntimeGraph
void configure_kernel_CircleConcatenation(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)