19#include "kernels/Utils.h"
39 if (
params()->element_type() == DataType::FLOAT32)
43 else if (
params()->element_type() == DataType::S32)
49 throw std::runtime_error(
"luci-intp Gather(1) Unsupported type.");
53 indices()->element_type() == DataType::S64);
63 indices_shape =
Shape({1});
79 batch_dims += indices_shape.
num_dims();
84 for (
int i = 0; i < batch_dims; ++i)
89 const int num_dimensions = params_shape.
num_dims() + indices_shape.
num_dims() - 1 - batch_dims;
93 for (
int i = 0; i < axis; ++i)
97 for (
int i = batch_dims; i < indices_shape.
num_dims(); ++i)
101 for (
int i = axis + 1; i < params_shape.
num_dims(); ++i)
110 switch (
params()->element_type())
112 case DataType::FLOAT32:
119 throw std::runtime_error(
"luci-intp Gather(2) Unsupported type.");
123template <
typename T>
void Gather::eval()
const
125 assert(
indices()->element_type() == DataType::S32 ||
indices()->element_type() == DataType::S64);
127 const auto params_data = getTensorData<T>(
params());
128 auto output_data = getTensorData<T>(
output());
130 tflite::GatherParams tparams;
134 if (
indices()->element_type() == DataType::S32)
136 const auto indices_data = getTensorData<int32_t>(
indices());
144 const auto indices_data = getTensorData<int64_t>(
indices());
const GatherParams _params
void resize(const Shape &new_shape)
const Shape & shape() const
const Tensor * indices() const
void configure() override
Gather(const Tensor *params, const Tensor *indices, Tensor *output, const GatherParams &gparams)
const Tensor * params() const
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)