19#include "kernels/Utils.h"
21#include "PALConcatenation.h"
38 const auto *
options =
cur_op->builtin_options_as_ConcatenationOptions();
42 axis += Tensor::num_dims(output);
44 const auto input_sizes =
cur_op->inputs()->size();
47 std::vector<luci_interpreter::RuntimeShape>
all_shape;
50 for (int32_t
i = 0;
i < input_sizes; ++
i)
53 const auto *
tensor = runtime_graph->getCircleTensorByIndex(input_index);
55 const auto *
tensor_data = runtime_graph->getDataByTensor(tensor);
57 tensor_data = runtime_graph->getConstDataByTensor(tensor);
72 auto *
output_data =
reinterpret_cast<T *
>(runtime_graph->getDataByTensor(output));
86 const int num_inputs =
cur_op->inputs()->size();
89 auto input_index =
cur_op->inputs()->operator[](0);
92 assert(input_index != -1);
98 const auto *params =
cur_op->builtin_options_as_ConcatenationOptions();
103 int axis = params->axis();
105 axis += Tensor::num_dims(
t0);
108 for (
int i = 1;
i < num_inputs; ++
i)
110 input_index =
cur_op->inputs()->operator[](
i);
119 for (
int i = 1;
i < num_inputs; ++
i)
121 input_index =
cur_op->inputs()->operator[](
i);
123 if (Tensor::element_type(tensor) == DataType::S8)
126 Tensor::quantized_dimension(output));
139 int num_inputs =
cur_op->inputs()->size();
142 const auto input_index =
cur_op->inputs()->operator[](0);
143 assert(input_index != -1);
146 switch (Tensor::element_type(
t0))
149 case DataType::FLOAT32:
165 assert(
false &&
"Unsupported type.");
const circle::Tensor * getCircleTensorByIndex(int32_t index)
#define LUCI_INTERPRETER_CHECK(cond)
const T * data(const std::vector< T, Alloc > &v)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
luci_interpreter::RuntimeShape getTensorRuntimeShape(const circle::Tensor *circle_tensor, BaseRuntimeGraph *runtime_graph)
void Concatenation(const ConcatenationParams ¶ms, const luci_interpreter::RuntimeShape *const *input_shapes, const Scalar *const *input_data, const luci_interpreter::RuntimeShape &output_shape, Scalar *output_data)
void execute_kernel_CircleConcatenation(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
RuntimeGraph BaseRuntimeGraph
void configure_kernel_CircleConcatenation(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
T must_cast(loco::Node *node)
FusedActFunc luci_actfunc(const circle::ActivationFunctionType type)