20#include "../KernelGenerator.h"
21#include "../Validator.h"
28void Validator::visit(
const ir::operation::Gather &node)
30 using ir::operation::Gather;
32 const auto input_index{node.getInputs().at(Gather::Input::INPUT)};
37 if (
input_node->typeInfo().type() == ir::DataType::QUANT_GGML_Q4_0)
43void KernelGenerator::visit(
const ir::operation::Gather &node)
45 const auto output_index{node.getOutputs().at(0)};
49 auto output_tensor = _tensor_reg->getPortableTensor(output_index);
50 auto input_tensor = _tensor_reg->getPortableTensor(input_index);
51 auto indices_tensor = _tensor_reg->getPortableTensor(indices_index);
53 const auto rank = _ctx.
at(input_index).shape().rank();
54 const auto axis = ops::getAxis(rank, node.param().axis);
56 auto fn = std::make_unique<ops::GatherLayer>();
58 fn->configure(input_tensor, indices_tensor, output_tensor, axis);
77template <
typename InputType>
void GatherLayer::runByInputType()
79 using OutputType = InputType;
81 op_params.
axis = _axis;
85 case OperandType::INT32:
87 using IndicesType = int32_t;
89 nnfw::cker::Gather<InputType, IndicesType>(
90 op_params,
getShape(_input), getBuffer<InputType>(_input),
getShape(_indices),
91 getBuffer<IndicesType>(_indices),
getShape(_output), getBuffer<OutputType>(_output));
94 case OperandType::INT64:
96 using IndicesType = int64_t;
98 nnfw::cker::Gather<InputType, IndicesType>(
99 op_params,
getShape(_input), getBuffer<InputType>(_input),
getShape(_indices),
100 getBuffer<IndicesType>(_indices),
getShape(_output), getBuffer<OutputType>(_output));
104 throw std::runtime_error(
"Gather: unsupported indices data type");
112 case OperandType::FLOAT32:
113 runByInputType<float>();
115 case OperandType::QUANT_UINT8_ASYMM:
116 runByInputType<uint8_t>();
118 case OperandType::INT32:
119 runByInputType<int32_t>();
121 case OperandType::BOOL8:
122 runByInputType<bool>();
125 throw std::runtime_error(
"Gather: unsupported input data type");
A tensor class that is portable for other backends.
ir::DataType data_type() const override final
std::unique_ptr< exec::IFunction > _return_fn
void configure(const IPortableTensor *input, const IPortableTensor *indices, IPortableTensor *output, int32_t axis)
const Operands & operands() const override
const Object & at(const Index &index) const
Get the object that is associated with the given index.
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.
nnfw::cker::Shape getShape(const IPortableTensor *tensor)