19#include "kernels/Utils.h"
39 if (
params()->element_type() == DataType::FLOAT32)
45 throw std::runtime_error(
"luci-intp Gather(1) Unsupported type.");
49 indices()->element_type() == DataType::S64);
68 batch_dims += indices_shape.
num_dims();
73 for (
int i = 0; i < batch_dims; ++i)
78 const int num_dimensions = params_shape.
num_dims() + indices_shape.
num_dims() - 1 - batch_dims;
82 for (
int i = 0; i < axis; ++i)
86 for (
int i = batch_dims; i < indices_shape.
num_dims(); ++i)
90 for (
int i = axis + 1; i < params_shape.
num_dims(); ++i)
99 switch (
params()->element_type())
101 case DataType::FLOAT32:
105 throw std::runtime_error(
"luci-intp Gather(2) Unsupported type.");
109void Gather::evalFloat()
const
111 assert(
indices()->element_type() == DataType::S32 ||
indices()->element_type() == DataType::S64);
113 const auto params_data = getTensorData<float>(
params());
114 auto output_data = getTensorData<float>(
output());
116 tflite::GatherParams tparams;
120 if (
indices()->element_type() == DataType::S32)
122 const auto indices_data = getTensorData<int32_t>(
indices());
130 const auto indices_data = getTensorData<int64_t>(
indices());
const GatherParams _params
void resize(const Shape &new_shape)
const Shape & shape() const
const Tensor * indices() const
void configure() override
Gather(const Tensor *params, const Tensor *indices, Tensor *output, const GatherParams &gparams)
const Tensor * params() const
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)