ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FullyConnected.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "OMStatus.h"
18
20
21#include "core/OMUtils.h"
22#include "core/OMKernelData.h"
23
25
26using namespace onert_micro;
27using namespace onert_micro::core;
28
29namespace
30{
31
32constexpr uint32_t inputTensorIdx = 0;
33constexpr uint32_t weightTensorIdx = 1;
34constexpr uint32_t biasTensorIdx = 2;
35
36constexpr uint32_t outputTensorIdx = 0;
37
38} // namespace
39
40namespace onert_micro
41{
42namespace import
43{
44
46{
47
48 OMRuntimeContext &runtime_context = config_args.runtime_context;
49 uint16_t op_index = config_args.kernel_index;
50 OMRuntimeStorage &runtime_storage = config_args.runtime_storage;
51
52 execute::OMRuntimeKernel runtime_kernel;
53 runtime_kernel.readKernel(op_index, runtime_context);
54
55 const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
56 const circle::Tensor *weight = runtime_kernel.inputs[weightTensorIdx];
57 const circle::Tensor *bias = runtime_kernel.inputs[biasTensorIdx];
58 const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];
59
60 assert(input != nullptr);
61 assert(weight != nullptr);
62 // Bias can be nullptr
63 assert(output != nullptr);
64
65 OMStatus status = Ok;
66
67#ifndef DIS_FLOAT
68 if (weight->type() == circle::TensorType_FLOAT32)
69 {
70
71 status = utils::checkCondition(input->type() == circle::TensorType_FLOAT32 and
72 output->type() == circle::TensorType_FLOAT32 and
73 (!bias or bias->type() == circle::TensorType_FLOAT32));
74 if (status != Ok)
75 return status;
76 }
77#endif // DIS_FLOAT
78#ifndef DIS_QUANT
79 if (weight->type() == circle::TensorType_UINT8)
80 {
81
82 status = utils::checkCondition(input->type() == circle::TensorType_UINT8 and
83 output->type() == circle::TensorType_UINT8 and
84 (!bias or bias->type() == circle::TensorType_INT32));
85 if (status != Ok)
86 return status;
87 }
88 else if (weight->type() == circle::TensorType_INT8)
89 {
90 status = utils::checkCondition(input->type() == circle::TensorType_INT8 or
91 input->type() == circle::TensorType_FLOAT32);
92 if (status != Ok)
93 return status;
94
95 status = utils::checkCondition(output->type() == circle::TensorType_INT8 or
96 output->type() == circle::TensorType_FLOAT32);
97 if (status != Ok)
98 return status;
99
100 status = utils::checkCondition(!bias or bias->type() == circle::TensorType_INT32 or
101 bias->type() == circle::TensorType_INT64 or
102 bias->type() == circle::TensorType_FLOAT32);
103 if (status != Ok)
104 return status;
105
106 if (input->type() == circle::TensorType_FLOAT32)
107 {
108 // hybrid mode
109 // Check it is channel wise quantization
110 status = utils::checkCondition(weight->quantization() != nullptr and
111 weight->quantization()->scale() != nullptr);
112 if (status != Ok)
113 return status;
114 }
115 }
116 else if (weight->type() == circle::TensorType_INT16)
117 {
118
119 status = utils::checkCondition(input->type() == circle::TensorType_INT16 and
120 output->type() == circle::TensorType_INT16 and
121 (!bias or bias->type() == circle::TensorType_INT32));
122 if (status != Ok)
123 return status;
124 }
125#endif // DIS_QUANT
126
127 core::OMRuntimeShape weight_shape(weight);
128 core::OMRuntimeShape bias_shape(bias);
129 core::OMRuntimeShape input_shape(input);
131
132 status = utils::checkCondition(weight_shape.dimensionsCount() == 2);
133 if (status != Ok)
134 return status;
135
136 if (input_shape.flatSize() == 1 and output_shape.flatSize() != 1)
137 {
138#ifndef DIS_DYN_SHAPES
139 input_shape =
140 runtime_storage.getDynamicRuntimeShape(runtime_kernel.inputs_index[inputTensorIdx]);
141 if (input_shape.flatSize() == 0)
143#else
145#endif // DIS_DYN_SHAPES
146 }
147
148 status = utils::checkCondition(bias == nullptr or weight_shape.dims(0) == bias_shape.flatSize());
149
150 if (input->type() == circle::TensorType_FLOAT32)
151 return status;
152
153#ifndef DIS_QUANT
154
155 // Check quantized version
156 if (input->quantization() == nullptr or output->quantization() == nullptr or
157 weight->quantization() == nullptr)
158 return NoQuantization;
159
160 if (output->quantization()->scale() == nullptr or output->quantization()->scale()->size() != 1)
162
163 if (output->quantization()->zero_point() == nullptr or
164 output->quantization()->zero_point()->size() != 1)
166
167 if (weight->quantization()->scale() == nullptr or weight->quantization()->scale()->size() != 1)
169
170 if (weight->quantization()->zero_point() == nullptr or
171 weight->quantization()->zero_point()->size() != 1)
173
174#endif // DIS_QUANT
175
176 return status;
177}
178
179} // namespace import
180} // namespace onert_micro
size_t dimensionsCount() const noexcept
OMRuntimeShape getDynamicRuntimeShape(uint16_t tensor_index)
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
const luci_interpreter::RuntimeShape output_shape
constexpr uint32_t outputTensorIdx
OMStatus configure_kernel_CircleFullyConnected(const OMConfigureArgs &config_args)
@ UnsupportedQuantizationType
Definition OMStatus.h:27
@ NoQuantization
Definition OMStatus.h:33
@ UnsupportedDynamicShapeCase
Definition OMStatus.h:34
core::OMRuntimeContext & runtime_context
core::OMRuntimeStorage & runtime_storage