ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALConv2d.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LUCI_INTERPRETER_PAL_CONV2D_H
18#define LUCI_INTERPRETER_PAL_CONV2D_H
19
20#include <tensorflow/lite/kernels/internal/reference/conv.h>
21#include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
22#include <arm_nn_types.h>
23#include <arm_nnfunctions.h>
24
26{
27static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
28 const float *input_data, const tflite::RuntimeShape &filter_shape,
29 const float *filter_data, const tflite::RuntimeShape &bias_shape,
30 const float *bias_data, const tflite::RuntimeShape &output_shape,
31 float *output_data, const tflite::RuntimeShape &scratchpad_shape,
32 float *scratchpad_data)
33{
34 (void)scratchpad_shape;
35 (void)scratchpad_data;
36 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
37 bias_shape, bias_data, output_shape, output_data,
38 tflite::RuntimeShape(), nullptr);
39}
40
41static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
42 const uint8 *input_data, const tflite::RuntimeShape &filter_shape,
43 const uint8 *filter_data, const tflite::RuntimeShape &bias_shape,
44 const int32 *bias_data, const tflite::RuntimeShape &output_shape,
45 uint8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
46 uint8 *scratchpad_data)
47{
48 (void)scratchpad_shape;
49 (void)scratchpad_data;
50 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
51 bias_shape, bias_data, output_shape, output_data, scratchpad_shape,
52 scratchpad_data, nullptr);
53}
54
55static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_t *mult,
56 const int32_t *shifts, const tflite::RuntimeShape &input_shape,
57 const int8 *input_data, const tflite::RuntimeShape &filter_shape,
58 const int8 *filter_data, const tflite::RuntimeShape &bias_shape,
59 const int32 *bias_data, const tflite::RuntimeShape &output_shape,
60 int8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
61 int8 *scratchpad_data)
62{
63 if (scratchpad_data)
64 {
65 cmsis_nn_conv_params conv_params;
66 conv_params.dilation.h = params.dilation_height_factor;
67 conv_params.dilation.w = params.dilation_width_factor;
68
69 assert(conv_params.dilation.h == 1);
70 assert(conv_params.dilation.w == 1);
71
72 conv_params.input_offset = params.input_offset;
73 conv_params.output_offset = params.output_offset;
74 conv_params.stride.h = params.stride_height;
75 conv_params.stride.w = params.stride_width;
76 conv_params.padding.h = params.padding_values.height;
77 conv_params.padding.w = params.padding_values.width;
78 conv_params.activation.min = params.quantized_activation_min;
79 conv_params.activation.max = params.quantized_activation_max;
80
81 cmsis_nn_per_channel_quant_params quant_params;
82 quant_params.multiplier = const_cast<int32_t *>(mult);
83 quant_params.shift = const_cast<int32_t *>(shifts);
84
85 assert(conv_params.activation.min <= conv_params.activation.max);
86 assert(input_shape.DimensionsCount() == 4);
87 assert(filter_shape.DimensionsCount() == 4);
88 assert(output_shape.DimensionsCount() == 4);
89 const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
90 const int input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
91 const int output_depth = tflite::MatchingDim(filter_shape, 0, output_shape, 3);
92 if (bias_data)
93 {
94 assert(bias_shape.FlatSize() == output_depth);
95 }
96
97 cmsis_nn_dims input_dims;
98 input_dims.n = batch_size;
99 input_dims.h = input_shape.Dims(1);
100 input_dims.w = input_shape.Dims(2);
101 input_dims.c = input_depth;
102
103 cmsis_nn_dims filter_dims;
104 filter_dims.n = output_depth;
105 filter_dims.h = filter_shape.Dims(1);
106 filter_dims.w = filter_shape.Dims(2);
107 filter_dims.c = input_depth;
108
109 cmsis_nn_dims bias_dims;
110 bias_dims.n = 1;
111 bias_dims.h = 1;
112 bias_dims.w = 1;
113 bias_dims.c = output_depth;
114
115 cmsis_nn_dims output_dims;
116 output_dims.n = batch_size;
117 output_dims.h = output_shape.Dims(1);
118 output_dims.w = output_shape.Dims(2);
119 output_dims.c = output_depth;
120
121 cmsis_nn_context ctx;
122 ctx.buf = scratchpad_data;
123 ctx.size = scratchpad_shape.Dims(0);
124
125 auto res = arm_convolve_wrapper_s8(&ctx, &conv_params, &quant_params, &input_dims, input_data,
126 &filter_dims, filter_data, &bias_dims, bias_data,
127 &output_dims, output_data);
128 assert(res == ARM_MATH_SUCCESS);
129 }
130 else
131 {
132 tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
133 filter_shape, filter_data, bias_shape, bias_data,
134 output_shape, output_data);
135 }
136}
137
138static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
139 const luci_interpreter::DataType &input_data_type,
140 const tflite::ConvParams &params,
141 const tflite::RuntimeShape &input_shape,
142 const tflite::RuntimeShape &filter_shape,
143 const tflite::RuntimeShape &output_shape)
144{
145 cmsis_nn_conv_params conv_params;
146 conv_params.dilation.h = params.dilation_height_factor;
147 conv_params.dilation.w = params.dilation_width_factor;
148
149 if (input_data_type == loco::DataType::S8 && conv_params.dilation.h == 1 &&
150 conv_params.dilation.w == 1)
151 {
152 const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
153 const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
154 const int32_t output_depth = tflite::MatchingDim(filter_shape, 0, output_shape, 3);
155 const int32_t filter_height = filter_shape.Dims(1);
156 const int32_t filter_width = filter_shape.Dims(2);
157 const int32_t output_height = output_shape.Dims(1);
158 const int32_t output_width = output_shape.Dims(2);
159
160 conv_params.input_offset = params.input_offset;
161 conv_params.output_offset = params.output_offset;
162 conv_params.stride.h = params.stride_height;
163 conv_params.stride.w = params.stride_width;
164 conv_params.padding.h = params.padding_values.height;
165 conv_params.padding.w = params.padding_values.width;
166
167 cmsis_nn_dims input_dims;
168 input_dims.n = batches;
169 input_dims.h = input_shape.Dims(1);
170 input_dims.w = input_shape.Dims(2);
171 input_dims.c = input_depth;
172
173 cmsis_nn_dims filter_dims;
174 filter_dims.n = output_depth;
175 filter_dims.h = filter_height;
176 filter_dims.w = filter_width;
177 filter_dims.c = input_depth;
178
179 cmsis_nn_dims output_dims;
180 output_dims.n = batches;
181 output_dims.h = output_height;
182 output_dims.w = output_width;
183 output_dims.c = output_depth;
184
185 const int32_t buf_size = arm_convolve_wrapper_s8_get_buffer_size(&conv_params, &input_dims,
186 &filter_dims, &output_dims);
187
188 luci_interpreter::Shape scratchpad_shape{buf_size};
189 scratchpad->resize(scratchpad_shape);
190 }
191 else
192 {
193 scratchpad->set_allocatable(false);
194 }
195}
196
197} // namespace luci_interpreter_pal
198
199#endif // LUCI_INTERPRETER_PAL_CONV2D_H
std::uint8_t uint8
Definition Macro.h:52
void set_allocatable(bool value)
Definition Tensor.h:168
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const luci_interpreter::RuntimeShape output_shape
DataType
"scalar" value type
Definition DataType.h:32
void Conv(const ConvParams &params, const Shape &input_shape, const uint8_t *input_data, const Shape &filter_shape, const uint8_t *filter_data, const Shape &bias_shape, const int32_t *bias_data, const Shape &output_shape, uint8_t *output_data, const Shape &im2col_shape, uint8_t *im2col_data)
Definition Conv.h:83
int32_t int32
Definition topk_v2.h:27