ONE - On-device Neural Engine
Loading...
Searching...
No Matches
GatherLayer.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "GatherLayer.h"
18
19#include "OperationUtils.h"
20#include "../KernelGenerator.h"
21#include "../Validator.h"
22
24
25namespace onert::backend::cpu
26{
27
28void Validator::visit(const ir::operation::Gather &node)
29{
30 using ir::operation::Gather;
31
32 const auto input_index{node.getInputs().at(Gather::Input::INPUT)};
33 const auto input_node = &_graph.operands().at(input_index);
34
35 _supported = false;
36
37 if (input_node->typeInfo().type() == ir::DataType::QUANT_GGML_Q4_0)
38 return;
39
40 _supported = true;
41}
42
43void KernelGenerator::visit(const ir::operation::Gather &node)
44{
45 const auto output_index{node.getOutputs().at(0)};
46 const auto input_index{node.getInputs().at(ir::operation::Gather::Input::INPUT)};
47 const auto indices_index{node.getInputs().at(ir::operation::Gather::Input::INDICES)};
48
49 auto output_tensor = _tensor_reg->getPortableTensor(output_index);
50 auto input_tensor = _tensor_reg->getPortableTensor(input_index);
51 auto indices_tensor = _tensor_reg->getPortableTensor(indices_index);
52
53 const auto rank = _ctx.at(input_index).shape().rank();
54 const auto axis = ops::getAxis(rank, node.param().axis);
55
56 auto fn = std::make_unique<ops::GatherLayer>();
57
58 fn->configure(input_tensor, indices_tensor, output_tensor, axis);
59
60 _return_fn = std::move(fn);
61}
62
63} // namespace onert::backend::cpu
64
66{
67
68void GatherLayer::configure(const IPortableTensor *input, const IPortableTensor *indices,
69 IPortableTensor *output, int32_t axis)
70{
71 _input = input;
72 _indices = indices;
73 _axis = axis;
74 _output = output;
75}
76
77template <typename InputType> void GatherLayer::runByInputType()
78{
79 using OutputType = InputType;
81 op_params.axis = _axis;
82
83 switch (_indices->data_type())
84 {
85 case OperandType::INT32:
86 {
87 using IndicesType = int32_t;
88
89 nnfw::cker::Gather<InputType, IndicesType>(
90 op_params, getShape(_input), getBuffer<InputType>(_input), getShape(_indices),
91 getBuffer<IndicesType>(_indices), getShape(_output), getBuffer<OutputType>(_output));
92 break;
93 }
94 case OperandType::INT64:
95 {
96 using IndicesType = int64_t;
97
98 nnfw::cker::Gather<InputType, IndicesType>(
99 op_params, getShape(_input), getBuffer<InputType>(_input), getShape(_indices),
100 getBuffer<IndicesType>(_indices), getShape(_output), getBuffer<OutputType>(_output));
101 break;
102 }
103 default:
104 throw std::runtime_error("Gather: unsupported indices data type");
105 }
106}
107
109{
110 switch (_input->data_type())
111 {
112 case OperandType::FLOAT32:
113 runByInputType<float>();
114 break;
115 case OperandType::QUANT_UINT8_ASYMM:
116 runByInputType<uint8_t>();
117 break;
118 case OperandType::INT32:
119 runByInputType<int32_t>();
120 break;
121 case OperandType::BOOL8:
122 runByInputType<bool>();
123 break;
124 default:
125 throw std::runtime_error("Gather: unsupported input data type");
126 }
127}
128
129} // namespace onert::backend::cpu::ops
A tensor class that is portable for other backends.
ir::DataType data_type() const override final
std::unique_ptr< exec::IFunction > _return_fn
void configure(const IPortableTensor *input, const IPortableTensor *indices, IPortableTensor *output, int32_t axis)
const Operands & operands() const override
Definition Graph.h:103
const Object & at(const Index &index) const
Get the object that is associated with the given index.
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.
nnfw::cker::Shape getShape(const IPortableTensor *tensor)