ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALConv2d.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LUCI_INTERPRETER_PAL_CONV2D_H
18#define LUCI_INTERPRETER_PAL_CONV2D_H
19
20#include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
21#include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
22
24{
25static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
26 const float *input_data, const tflite::RuntimeShape &filter_shape,
27 const float *filter_data, const tflite::RuntimeShape &bias_shape,
28 const float *bias_data, const tflite::RuntimeShape &output_shape,
29 float *output_data, const tflite::RuntimeShape &scratchpad_shape,
30 float *scratchpad_data)
31{
32 (void)scratchpad_shape;
33
34 const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
35 const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
36 const int32_t output_height = output_shape.Dims(1);
37 const int32_t output_width = output_shape.Dims(2);
38 const int32_t filter_height = filter_shape.Dims(1);
39 const int32_t filter_width = filter_shape.Dims(2);
40
41 int64_t im2col_flat_size = 1;
42 im2col_flat_size *= batches;
43 im2col_flat_size *= output_height;
44 im2col_flat_size *= output_width;
45 im2col_flat_size *= input_depth;
46 im2col_flat_size *= filter_height;
47 im2col_flat_size *= filter_width;
48
49 // This condition checks if integer overflow will occur inside the optimized kernel.
50 // https://github.com/tensorflow/tensorflow/blob/v2.12.1/tensorflow/lite/kernels/internal/optimized/im2col_utils.h#L81
51 // If overflow is expected, we fall back to the reference kernel.
52 // NOTE This is just a rough check.
53 bool opt_kernel_overflow = im2col_flat_size > std::numeric_limits<int32_t>::max();
54
55 if (scratchpad_data and not opt_kernel_overflow)
56 {
57 tflite::RuntimeShape im2col_shape{batches, output_height, output_width,
58 input_depth * filter_height * filter_width};
59
60 tflite::optimized_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
61 bias_shape, bias_data, output_shape, output_data, im2col_shape,
62 scratchpad_data);
63 }
64 else
65 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
66 bias_shape, bias_data, output_shape, output_data,
67 tflite::RuntimeShape(), nullptr);
68}
69
70static inline void Conv(const tflite::ConvParams &params, const tflite::RuntimeShape &input_shape,
71 const uint8 *input_data, const tflite::RuntimeShape &filter_shape,
72 const uint8 *filter_data, const tflite::RuntimeShape &bias_shape,
73 const int32 *bias_data, const tflite::RuntimeShape &output_shape,
74 uint8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
75 uint8 *scratchpad_data)
76{
77 // TODO This should only be done once (although it takes only a few microseconds).
78 // Also, the user should be able to adjust the number of threads.
79 auto gemmlowp_context = std::make_unique<gemmlowp::GemmContext>();
80 gemmlowp_context->set_max_num_threads(static_cast<int>(std::thread::hardware_concurrency()));
81
82 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
83 bias_shape, bias_data, output_shape, output_data, scratchpad_shape,
84 scratchpad_data, gemmlowp_context.get());
85}
86
87static inline void ConvPerChannel(const tflite::ConvParams &params, const int32_t *mult,
88 const int32_t *shifts, const tflite::RuntimeShape &input_shape,
89 const int8 *input_data, const tflite::RuntimeShape &filter_shape,
90 const int8 *filter_data, const tflite::RuntimeShape &bias_shape,
91 const int32 *bias_data, const tflite::RuntimeShape &output_shape,
92 int8 *output_data, const tflite::RuntimeShape &scratchpad_shape,
93 int8 *scratchpad_data)
94{
95 (void)scratchpad_shape;
96 (void)scratchpad_data;
97 // TODO enable optimized version
98 tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
99 filter_shape, filter_data, bias_shape, bias_data,
100 output_shape, output_data);
101}
102
103static inline void SetupScratchpadTensor(luci_interpreter::Tensor *scratchpad,
104 const luci_interpreter::DataType &input_data_type,
105 const tflite::ConvParams &params,
106 const tflite::RuntimeShape &input_shape,
107 const tflite::RuntimeShape &filter_shape,
108 const tflite::RuntimeShape &output_shape)
109{
110 const int32_t filter_height = filter_shape.Dims(1);
111 const int32_t filter_width = filter_shape.Dims(2);
112
113 // Allocate tensor for scratchpad, if needed.
114 // The checks here should be aligned with the actual implementation.
115 const bool need_dilated_scratchpad =
116 params.dilation_height_factor != 1 || params.dilation_width_factor != 1;
117 const bool need_non_dilated_scratchpad = params.stride_height != 1 || params.stride_width != 1 ||
118 filter_height != 1 || filter_width != 1;
119 auto _need_scratchpad = input_data_type != luci_interpreter::DataType::S16 &&
120 (need_dilated_scratchpad || need_non_dilated_scratchpad);
121
122 if (_need_scratchpad)
123 {
124 const int32_t batches = tflite::MatchingDim(input_shape, 0, output_shape, 0);
125 const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
126 const int32_t output_height = output_shape.Dims(1);
127 const int32_t output_width = output_shape.Dims(2);
128
129 auto data_type_size = static_cast<int32_t>(luci_interpreter::getDataTypeSize(input_data_type));
130 // im2col_shape
131 // data_type_size is added because we use U8 for scratchpad buffer dtype
132 luci_interpreter::Shape scratchpad_shape{batches, output_height, output_width,
133 input_depth * filter_height * filter_width,
134 data_type_size};
135 scratchpad->resize(scratchpad_shape);
136 }
137 else
138 {
139 scratchpad->set_allocatable(false);
140 }
141}
142
143} // namespace luci_interpreter_pal
144
145#endif // LUCI_INTERPRETER_PAL_CONV2D_H
std::uint8_t uint8
Definition Macro.h:52
void set_allocatable(bool value)
Definition Tensor.h:168
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const luci_interpreter::RuntimeShape output_shape
size_t getDataTypeSize(DataType data_type)
Definition DataType.h:33
DataType
"scalar" value type
Definition DataType.h:32
void Conv(const ConvParams &params, const Shape &input_shape, const uint8_t *input_data, const Shape &filter_shape, const uint8_t *filter_data, const Shape &bias_shape, const int32_t *bias_data, const Shape &output_shape, uint8_t *output_data, const Shape &im2col_shape, uint8_t *im2col_data)
Definition Conv.h:83
int32_t int32
Definition topk_v2.h:27