19#include "kernels/Utils.h"
39 if (
params()->element_type() == DataType::FLOAT32)
43 else if (
params()->element_type() == DataType::S32)
49 throw std::runtime_error(
"luci-intp Gather(1) Unsupported type.");
53 indices()->element_type() == DataType::S64);
84 for (
int i = 0;
i < batch_dims; ++
i)
93 for (
int i = 0;
i < axis; ++
i)
110 switch (
params()->element_type())
112 case DataType::FLOAT32:
119 throw std::runtime_error(
"luci-intp Gather(2) Unsupported type.");
123template <
typename T>
void Gather::eval()
const
125 assert(
indices()->element_type() == DataType::S32 ||
indices()->element_type() == DataType::S64);
134 if (
indices()->element_type() == DataType::S32)
const GatherParams _params
void resize(const Shape &new_shape)
const Shape & shape() const
const Tensor * indices() const
void configure() override
Gather(const Tensor *params, const Tensor *indices, Tensor *output, const GatherParams &gparams)
const Tensor * params() const
void execute() const override
#define LUCI_INTERPRETER_CHECK(cond)
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
T must_cast(loco::Node *node)