19#include "kernels/Utils.h"
21#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
43 assert(axis >= 0 && axis < input_shape.
num_dims());
47 for (
int in_index = 0; in_index < input_shape.
num_dims(); ++in_index)
60template <
typename T>
void Unpack::executeImpl()
const
62 tflite::UnpackParams
params{};
65 VectorOfTensors<T, false> all_outputs(
_outputs);
67 **all_outputs.shapes(), all_outputs.data());
72 switch (
input()->element_type())
74 case DataType::FLOAT32:
75 return executeImpl<float>();
77 return executeImpl<uint8_t>();
79 throw std::runtime_error(
"luci-intp Unpack Unsupported type.");
const std::vector< Tensor * > _outputs
const UnpackParams _params
const UnpackParams & params() const
void resize(const Shape &new_shape)
const Shape & shape() const
DataType element_type() const
Tensor * output(int index) const
const Tensor * input() const
void execute() const override
void configure() override
Unpack(const Tensor *input, std::vector< Tensor * > outputs, const UnpackParams ¶ms)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)