ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
PALConv2D.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef ONERT_MICRO_EXECUTE_PAL_CONV_2D_H
19#define ONERT_MICRO_EXECUTE_PAL_CONV_2D_H
20
21#include "PALConv2DCommon.h"
22#include "core/OMKernelData.h"
23#include "core/OMRuntimeShape.h"
24#include "PALUtils.h"
25
26#include <arm_nnfunctions.h>
27
28namespace onert_micro
29{
30namespace execute
31{
32namespace pal
33{
34
35// Fixed-point per-channel-quantization convolution reference kernel.
37 const int8_t *input_data, const core::OMRuntimeShape &filter_shape,
38 const int8_t *filter_data, const int32_t *bias_data,
39 const core::OMRuntimeShape &output_shape, int8_t *output_data)
40{
41 cmsis_nn_conv_params conv_params;
42 conv_params.dilation.h = params.dilation_height_factor;
43 conv_params.dilation.w = params.dilation_width_factor;
44
45 assert(conv_params.dilation.h == 1);
46 assert(conv_params.dilation.w == 1);
47
48 conv_params.input_offset = params.input_offset;
49 conv_params.output_offset = params.output_offset;
50 conv_params.stride.h = params.stride_height;
51 conv_params.stride.w = params.stride_width;
52 conv_params.padding.h = params.pad_h;
53 conv_params.padding.w = params.pad_w;
54 conv_params.activation.min = params.quantized_activation_min;
55 conv_params.activation.max = params.quantized_activation_max;
56
57 cmsis_nn_per_channel_quant_params quant_params;
58 quant_params.multiplier = const_cast<int32_t *>(params.per_channel_output_multiplier.data());
59 quant_params.shift = const_cast<int32_t *>(
60 reinterpret_cast<const int32_t *>(params.per_channel_output_shift.data()));
61
62 assert(conv_params.activation.min <= conv_params.activation.max);
63 const int batch_size = input_shape.dims(0);
64 const int input_depth = input_shape.dims(3);
65 const int output_depth = filter_shape.dims(0);
66
67 cmsis_nn_dims input_dims;
68 input_dims.n = batch_size;
69 input_dims.h = input_shape.dims(1);
70 input_dims.w = input_shape.dims(2);
71 input_dims.c = input_depth;
72
73 cmsis_nn_dims filter_dims;
74 filter_dims.n = output_depth;
75 filter_dims.h = filter_shape.dims(1);
76 filter_dims.w = filter_shape.dims(2);
77 filter_dims.c = input_depth;
78
79 cmsis_nn_dims bias_dims;
80 bias_dims.n = 1;
81 bias_dims.h = 1;
82 bias_dims.w = 1;
83 bias_dims.c = output_depth;
84
85 cmsis_nn_dims output_dims;
86 output_dims.n = batch_size;
87 output_dims.h = output_shape.dims(1);
88 output_dims.w = output_shape.dims(2);
89 output_dims.c = output_depth;
90
91 auto buf_size =
92 arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims, &filter_dims, &output_dims);
93
94 auto buffer = std::make_unique<int8_t[]>(buf_size);
95 assert(buffer != nullptr);
96
97 cmsis_nn_context ctx;
98 ctx.buf = buffer.get();
99 ctx.size = buf_size;
100
101 auto res = arm_convolve_wrapper_s8(&ctx, &conv_params, &quant_params, &input_dims, input_data,
102 &filter_dims, filter_data, &bias_dims, bias_data, &output_dims,
103 output_data);
104
105 assert(res == ARM_CMSIS_NN_SUCCESS);
106 if (res != ARM_CMSIS_NN_SUCCESS)
107 return CmsisNNError;
108 return Ok;
109}
110
111} // namespace pal
112} // namespace execute
113} // namespace onert_micro
114
115#endif // ONERT_MICRO_EXECUTE_PAL_CONV_2D_H
int32_t dims(int i) const
Definition Tensor.h:108
const luci_interpreter::RuntimeShape output_shape
OMStatus ConvPerChannel(const core::ConvQuant &params, const core::OMRuntimeShape &input_shape, const int8_t *input_data, const core::OMRuntimeShape &filter_shape, const int8_t *filter_data, const int32_t *bias_data, const core::OMRuntimeShape &output_shape, int8_t *output_data)
Definition PALConv2D.h:36
std::vector< int > per_channel_output_shift
std::vector< int32_t > per_channel_output_multiplier