19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/concatenation.h>
38 const int num_inputs =
_inputs.size();
47 axis +=
t0->shape().num_dims();
51 for (
int i = 1;
i < num_inputs; ++
i)
56 for (
int d = 0; d <
t0->shape().num_dims(); ++d)
60 sum_axis += tensor->shape().dim(axis);
79 output()->quantized_dimension());
92 switch (
_inputs[0]->element_type())
94 case DataType::FLOAT32:
110 throw std::runtime_error(
"luci-intp Concatenation Unsupported type.");
114template <
typename T>
void Concatenation::evalGeneric()
const
121 tflite::ConcatenationParams
params{};
124 tflite::reference_ops::Concatenation(
params, inputs.shapes(), inputs.data(),
128void Concatenation::evalQuantized()
const
135 tflite::ConcatenationParams
params{};
143 tflite::reference_ops::ConcatenationWithScaling(
params,
inputs.shapes(),
inputs.data(),
const std::vector< const Tensor * > _inputs
const ConcatenationParams _params
const ConcatenationParams & params() const
void resize(const Shape &new_shape)
const Shape & shape() const
int32_t zero_point() const
Concatenation(std::vector< const Tensor * > inputs, Tensor *output, const ConcatenationParams ¶ms)
void execute() const override
void configure() override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
T must_cast(loco::Node *node)