19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/concatenation.h>
38 const int num_inputs =
_inputs.size();
50 int32_t sum_axis = t0->
shape().
dim(axis);
51 for (
int i = 1; i < num_inputs; ++i)
56 for (
int d = 0; d < t0->
shape().num_dims(); ++d)
60 sum_axis += tensor->shape().dim(axis);
74 for (
auto current_tensor :
_inputs)
76 if (current_tensor->element_type() == DataType::S8)
79 output()->quantized_dimension());
82 current_tensor->scales().size());
92 switch (
_inputs[0]->element_type())
94 case DataType::FLOAT32:
101 evalGeneric<int8_t>();
104 evalGeneric<int32_t>();
107 evalGeneric<int64_t>();
110 throw std::runtime_error(
"luci-intp Concatenation Unsupported type.");
114template <
typename T>
void Concatenation::evalGeneric()
const
121 tflite::ConcatenationParams
params{};
124 tflite::reference_ops::Concatenation(
params, inputs.shapes(), inputs.data(),
128void Concatenation::evalQuantized()
const
134 VectorOfQuantizedTensors<true> inputs(
_inputs);
135 tflite::ConcatenationParams
params{};
143 tflite::reference_ops::ConcatenationWithScaling(
params,
inputs.shapes(),
inputs.data(),
145 getTensorData<uint8_t>(
output()));
const std::vector< const Tensor * > _inputs
const ConcatenationParams _params
const ConcatenationParams & params() const
void resize(const Shape &new_shape)
const Shape & shape() const
DataType element_type() const
int32_t zero_point() const
Concatenation(std::vector< const Tensor * > inputs, Tensor *output, const ConcatenationParams ¶ms)
void execute() const override
void configure() override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)