19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
43 axis += dimension_size;
51 throw std::runtime_error(
"luci-intp Pack(1) Unsupported type.");
54 for (uint32_t i = 1; i <
_inputs.size(); ++i)
59 for (
int d = 0; d < t0->
shape().num_dims(); ++d)
67 for (
int index = 0; index < dimension_size; ++index)
98 switch (
_inputs[0]->element_type())
100 case DataType::FLOAT32:
101 evalGeneric<float>();
104 evalGeneric<uint8_t>();
107 evalGeneric<int8_t>();
110 evalGeneric<int16_t>();
113 evalGeneric<int32_t>();
116 evalGeneric<int64_t>();
119 throw std::runtime_error(
"luci-intp Pack(2) Unsupported type.");
123template <
typename T>
void Pack::evalGeneric()
const
130 axis += dimension_size;
133 VectorOfTensors<T, true> inputs(
_inputs);
134 tflite::PackParams
params{};
138 getTensorData<T>(
output()));
const std::vector< const Tensor * > _inputs
const PackParams & params() const
void resize(const Shape &new_shape)
const Shape & shape() const
DataType element_type() const
int32_t zero_point() const
void configure() override
Pack(std::vector< const Tensor * > inputs, Tensor *output, const PackParams ¶ms)
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)