ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FullyConnectedLayer.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "FullyConnectedLayer.h"
18
19#include "GGMLHelper.h"
20#include "../KernelGenerator.h"
21#include "../Validator.h"
22
24{
25
26void Validator::visit(const ir::operation::FullyConnected &node)
27{
28 using ir::operation::FullyConnected;
29
30 const auto weight_index{node.getInputs().at(FullyConnected::Input::WEIGHT)};
31 const auto weight_node = &_graph.operands().at(weight_index);
32
33 _supported = false;
34
35 if (weight_node->typeInfo().type() != ir::DataType::QUANT_GGML_Q4_0 &&
36 weight_node->typeInfo().type() != ir::DataType::QUANT_GGML_Q8_0)
37 return;
38
39 if (node.param().activation != ir::Activation::NONE)
40 return;
41
42 _supported = true;
43}
44
45void KernelGenerator::visit(const ir::operation::FullyConnected &node)
46{
47 using ir::operation::FullyConnected;
48
49 const auto output_index{node.getOutputs().at(0)};
50 const auto input_index{node.getInputs().at(FullyConnected::Input::INPUT)};
51 const auto weight_index{node.getInputs().at(FullyConnected::Input::WEIGHT)};
52 const auto bias_index{node.getInputs().at(FullyConnected::Input::BIAS)};
53 const auto activation = node.param().activation;
54 const auto weights_format = node.param().weights_format;
55 if (weights_format != ir::FullyConnectedWeightsFormat::Default)
56 throw std::runtime_error("Unsupported FullyConnected Weights Format");
57
58 auto output_tensor = _tensor_reg->getPortableTensor(output_index);
59 auto input_tensor = _tensor_reg->getPortableTensor(input_index);
60 auto weight_tensor = _tensor_reg->getPortableTensor(weight_index);
61 auto bias_tensor = bias_index.undefined() ? nullptr : _tensor_reg->getPortableTensor(bias_index);
62
63 auto fn = std::make_unique<ops::FullyConnectedLayer>();
64
65 fn->configure(input_tensor, weight_tensor, bias_tensor, activation, output_tensor,
66 _external_context);
67
68 _return_fn = std::move(fn);
69}
70
71} // namespace onert::backend::ggml
72
74{
75
77 : _input(nullptr), _weights(nullptr), _bias(nullptr), _output(nullptr),
78 _activation(ir::Activation::NONE), _external_context(nullptr)
79{
80 // DO NOTHING
81}
82
84
86{
87 if (_bias)
88 throw std::runtime_error{"FullyConnected: GGML weights format does not support bias yet."};
89
90 // convert tensor
91 auto input = getGGMLTensor(_input);
92 auto weights = getGGMLTensor(_weights);
93 auto output = getGGMLTensor(_output);
94 {
95 output.op = GGML_OP_MUL_MAT;
96 output.src[0] = &weights;
97 output.src[1] = &input;
98 }
99 auto *nodes = &output;
100
101 // create graph
102 struct ggml_cgraph graph;
103 {
104 memset(&graph, 0, sizeof(graph));
105 graph.n_nodes = 1;
106 graph.nodes = &nodes;
107 }
108
109 // get cplan
110 auto cplan = ggml_graph_plan(&graph, _external_context->maxNumThreads());
111 std::vector<uint8_t> buf(cplan.work_size);
112 cplan.work_data = buf.data();
113
114 // compute
115 ggml_graph_compute(&graph, &cplan);
116}
117
119 const IPortableTensor *bias, ir::Activation activation,
120 IPortableTensor *output,
121 const std::shared_ptr<ExternalContext> &external_context)
122{
123 _input = input;
124 _weights = weights;
125 _bias = bias;
126 _activation = activation;
127 _output = output;
128 _external_context = external_context;
129}
130
132{
133 if (_weights->data_type() == ir::DataType::QUANT_GGML_Q4_0 ||
134 _weights->data_type() == ir::DataType::QUANT_GGML_Q8_0)
135 {
137 }
138 else
139 {
140 throw std::runtime_error{"FullyConnected: unsupported data type"};
141 }
142}
143
145{
146 // DO NOTHING
147}
148
149} // namespace onert::backend::ggml::ops
A tensor class that is portable for other backends.
ir::DataType data_type() const override final
std::unique_ptr< exec::IFunction > _return_fn
std::shared_ptr< ExternalContext > _external_context
void configure(const IPortableTensor *input, const IPortableTensor *weights, const IPortableTensor *bias, ir::Activation activation, IPortableTensor *output, const std::shared_ptr< ExternalContext > &external_context)
const Operands & operands() const override
Definition Graph.h:103
const Object & at(const Index &index) const
Get the object that is associated with the given index.
struct ggml_tensor getGGMLTensor(const IPortableTensor *tensor)
Definition GGMLHelper.cc:41
CLTensor bias_tensor