ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALDepthwiseConv2d.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
18#define LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
19
20#include <tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h>
21#include <tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h>
22#include <tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h>
23#include <arm_nnfunctions.h>
24
26{
27template <typename T>
28static inline void
29DepthwiseConvPerChannel(const tflite::DepthwiseParams &params, const int32_t *output_multiplier,
30 const int32_t *output_shift, const tflite::RuntimeShape &input_shape,
31 const T *input_data, const tflite::RuntimeShape &filter_shape,
32 const T *filter_data, const tflite::RuntimeShape &bias_shape,
33 const int32_t *bias_data, const tflite::RuntimeShape &output_shape,
34 T *output_data, const tflite::RuntimeShape &scratchpad_shape,
35 T *scratchpad_data)
36{
37 {
38 // MARK: At this moment this operation is not supported
39 assert(false && "DepthwiseConvPerChannel NYI");
40 (void)params;
41 (void)output_multiplier;
42 (void)output_shift;
43 (void)input_shape;
44 (void)output_data;
45 (void)input_data;
46 (void)filter_shape;
47 (void)filter_data;
48 (void)bias_shape;
49 (void)bias_data;
50 (void)output_shape;
51 (void)output_data;
52 (void)scratchpad_shape;
53 (void)scratchpad_data;
54 }
55}
56
57template <>
59 const tflite::DepthwiseParams &params, const int32_t *output_multiplier,
60 const int32_t *output_shift, const tflite::RuntimeShape &input_shape, const int8_t *input_data,
61 const tflite::RuntimeShape &filter_shape, const int8_t *filter_data,
62 const tflite::RuntimeShape &bias_shape, const int32_t *bias_data,
63 const tflite::RuntimeShape &output_shape, int8_t *output_data,
64 const tflite::RuntimeShape &scratchpad_shape, int8_t *scratchpad_data)
65{
66 if (scratchpad_data)
67 {
68 cmsis_nn_dw_conv_params dw_conv_params;
69 dw_conv_params.dilation.h = params.dilation_height_factor;
70 dw_conv_params.dilation.w = params.dilation_width_factor;
71 assert(dw_conv_params.dilation.h == 1);
72 assert(dw_conv_params.dilation.w == 1);
73
74 dw_conv_params.input_offset = params.input_offset;
75 dw_conv_params.output_offset = params.output_offset;
76 dw_conv_params.stride.h = params.stride_height;
77 dw_conv_params.stride.w = params.stride_width;
78 dw_conv_params.padding.h = params.padding_values.height;
79 dw_conv_params.padding.w = params.padding_values.width;
80
81 dw_conv_params.activation.min = params.quantized_activation_min;
82 dw_conv_params.activation.max = params.quantized_activation_max;
83 dw_conv_params.ch_mult = params.depth_multiplier;
84
85 cmsis_nn_per_channel_quant_params quant_params;
86 int32_t output_multiplier = params.output_multiplier;
87 int32_t output_shift = params.output_shift;
88
89 quant_params.multiplier = &output_multiplier;
90 quant_params.shift = &output_shift;
91
92 assert(dw_conv_params.activation.min <= dw_conv_params.activation.max);
93 const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
94 const int output_depth = tflite::MatchingDim(filter_shape, 3, output_shape, 3);
95 if (bias_data)
96 {
97 assert(bias_shape.FlatSize() == output_depth);
98 }
99
100 cmsis_nn_dims input_dims;
101 input_dims.n = batch_size;
102 input_dims.h = input_shape.Dims(1);
103 input_dims.w = input_shape.Dims(2);
104 input_dims.c = input_shape.Dims(3);
105
106 cmsis_nn_dims filter_dims;
107 filter_dims.n = filter_shape.Dims(0);
108 filter_dims.h = filter_shape.Dims(1);
109 filter_dims.w = filter_shape.Dims(2);
110 filter_dims.c = output_depth;
111
112 cmsis_nn_dims bias_dims;
113 bias_dims.n = 1;
114 bias_dims.h = 1;
115 bias_dims.w = 1;
116 bias_dims.c = output_depth;
117
118 cmsis_nn_dims output_dims;
119 output_dims.n = batch_size;
120 output_dims.h = output_shape.Dims(1);
121 output_dims.w = output_shape.Dims(2);
122 output_dims.c = output_depth;
123
124 cmsis_nn_context ctx;
125 ctx.buf = scratchpad_data;
126 ctx.size = scratchpad_shape.Dims(0);
127
128 auto res = arm_depthwise_conv_wrapper_s8(&ctx, &dw_conv_params, &quant_params, &input_dims,
129 input_data, &filter_dims, filter_data, &bias_dims,
130 bias_data, &output_dims, output_data);
131 assert(res == ARM_MATH_SUCCESS);
132 }
133 else
134 {
135 tflite::reference_integer_ops::DepthwiseConvPerChannel(
136 params, output_multiplier, output_shift, input_shape, input_data, filter_shape, filter_data,
137 bias_shape, bias_data, output_shape, output_data);
138 }
139}
140
141static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
142 const tflite::DepthwiseParams &params,
143 const luci_interpreter::DataType &input_data_type,
144 const tflite::RuntimeShape &input_shape,
145 const tflite::RuntimeShape &filter_shape,
146 const tflite::RuntimeShape &output_shape)
147{
148 cmsis_nn_dw_conv_params dw_conv_params;
149 dw_conv_params.dilation.h = params.dilation_height_factor;
150 dw_conv_params.dilation.w = params.dilation_width_factor;
151
152 if (input_data_type == loco::DataType::S8 && dw_conv_params.dilation.h == 1 &&
153 dw_conv_params.dilation.w == 1)
154 {
155 const int batch_size = tflite::MatchingDim(input_shape, 0, output_shape, 0);
156 const int output_depth = tflite::MatchingDim(filter_shape, 3, output_shape, 3);
157
158 cmsis_nn_dims input_dims;
159 input_dims.n = batch_size;
160 input_dims.h = input_shape.Dims(1);
161 input_dims.w = input_shape.Dims(2);
162 input_dims.c = input_shape.Dims(3);
163
164 cmsis_nn_dims filter_dims;
165 filter_dims.n = filter_shape.Dims(0);
166 filter_dims.h = filter_shape.Dims(1);
167 filter_dims.w = filter_shape.Dims(2);
168 filter_dims.c = output_depth;
169
170 cmsis_nn_dims output_dims;
171 output_dims.n = batch_size;
172 output_dims.h = output_shape.Dims(1);
173 output_dims.w = output_shape.Dims(2);
174 output_dims.c = output_depth;
175
176 const int32_t buf_size = arm_depthwise_conv_wrapper_s8_get_buffer_size(
177 &dw_conv_params, &input_dims, &filter_dims, &output_dims);
178
179 auto data_type_size = static_cast<int32_t>(luci_interpreter::getDataTypeSize(input_data_type));
180
181 luci_interpreter::Shape scratchpad_shape{buf_size * data_type_size};
182 scratchpad->resize(scratchpad_shape);
183 }
184 else
185 {
186 scratchpad->set_allocatable(false);
187 }
188}
189
190} // namespace luci_interpreter_pal
191
192#endif // LUCI_INTERPRETER_PAL_DEPTHWISECONV2D_H
void set_allocatable(bool value)
Definition Tensor.h:168
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const luci_interpreter::RuntimeShape output_shape
void DepthwiseConvPerChannel< int8_t >(const tflite::DepthwiseParams &params, const int32_t *output_multiplier, const int32_t *output_shift, const tflite::RuntimeShape &input_shape, const int8_t *input_data, const tflite::RuntimeShape &filter_shape, const int8_t *filter_data, const tflite::RuntimeShape &bias_shape, const int32_t *bias_data, const tflite::RuntimeShape &output_shape, int8_t *output_data, const tflite::RuntimeShape &scratchpad_shape, int8_t *scratchpad_data)
size_t getDataTypeSize(DataType data_type)
Definition DataType.h:33
DataType
"scalar" value type
Definition DataType.h:32
void DepthwiseConvPerChannel(const DepthwiseConvParams &params, const int32_t *output_multiplier, const int32_t *output_shift, const Shape &input_shape, const int8_t *input_data, const Shape &filter_shape, const int8_t *filter_data, const Shape &bias_shape, const int32_t *bias_data, const Shape &output_shape, int8_t *output_data, ruy::Context *ruy_context)