21#include "../KernelGenerator.h"
22#include "../Validator.h"
27void Validator::visit(
const ir::operation::Gather &node)
29 using ir::operation::Gather;
31 const auto input_index{node.getInputs().at(Gather::Input::INPUT)};
36 if (
input_node->typeInfo().type() != ir::DataType::QUANT_GGML_Q4_0)
42void KernelGenerator::visit(
const ir::operation::Gather &node)
44 const auto output_index{node.getOutputs().at(0)};
48 auto output_tensor = _tensor_reg->getPortableTensor(output_index);
49 auto input_tensor = _tensor_reg->getPortableTensor(input_index);
50 auto indices_tensor = _tensor_reg->getPortableTensor(indices_index);
52 const auto rank = _ctx.
at(input_index).shape().rank();
53 const auto axis = ops::getAxis(rank, node.param().axis);
55 auto fn = std::make_unique<ops::GatherLayer>();
57 fn->configure(input_tensor, indices_tensor, output_tensor, axis, _external_context.get());
77void GatherLayer::runByGGMLQuantInputType()
84 throw std::runtime_error(
"Gather: block quantized input tensor must be rank 2");
86 if (_indices->
getShape().rank() >= 4 &&
88 throw std::runtime_error(
"Gather: invalid indices tensor shape");
90 if (_indices->
data_type() != ir::DataType::INT32)
91 throw std::runtime_error(
"Gather: indices tensor must be int32 type");
94 throw std::runtime_error(
"Gather: axis must be 0");
101 output.op = GGML_OP_GET_ROWS;
102 output.src[0] = &input;
103 output.src[1] = &indices;
105 auto *nodes = &output;
108 struct ggml_cgraph graph;
110 memset(&graph, 0,
sizeof(graph));
112 graph.nodes = &nodes;
117 std::vector<uint8_t> buf(cplan.work_size);
118 cplan.work_data = buf.data();
121 ggml_graph_compute(&graph, &cplan);
128 case ir::DataType::QUANT_GGML_Q4_0:
129 runByGGMLQuantInputType();
132 throw std::runtime_error(
"Gather: unsupported input data type");
A tensor class that is portable for other backends.
ir::DataType data_type() const override final
ir::Shape getShape() const override final
Get ir::Shape of tensor.
std::unique_ptr< exec::IFunction > _return_fn
int32_t maxNumThreads() const
void configure(const IPortableTensor *input, const IPortableTensor *indices, IPortableTensor *output, int32_t axis, ExternalContext *ctx)
const Operands & operands() const override
const Object & at(const Index &index) const
Get the object that is associated with the given index.
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.
struct ggml_tensor getGGMLTensor(const IPortableTensor *tensor)