ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Conv.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_RUY_CONV_H__
19#define __NNFW_RUY_CONV_H__
20
21#include "ruy/Types.h"
22#include "ruy/Shape.h"
23#include "ruy/Utils.h"
24#include "ruy/RuySupport.h"
25
26#include <ruy/ruy.h>
27#include <ruy/context.h>
28#include <iostream>
29#include <vector>
30
31namespace nnfw
32{
33namespace ruy
34{
35
36class Conv
37{
38public:
39 Conv() : _im2col_shape(4), _need_im2col(false), _prepared(false) {}
40
41 void prepare(const Shape &input_shape, const Shape &kernel_shape, const Shape &output_shape,
42 uint32_t stride_width, uint32_t stride_height, uint32_t dilation_width_factor,
43 uint32_t dilation_height_factor)
44 {
45 if (!_prepared)
46 {
47 IsRequiredIm2col(input_shape, kernel_shape, output_shape, stride_width, stride_height,
48 dilation_width_factor, dilation_height_factor);
49 _prepared = true;
50 }
51 }
52
53 void operator()(const ConvParams &params, const Shape &input_shape, const float *input_data,
54 const Shape &filter_shape, const float *filter_data, const Shape &bias_shape,
55 const float *bias_data, const Shape &output_shape, float *output_data,
56 ::ruy::Context *ruy_context)
57 {
58 if (!_prepared)
59 {
60 // This means that input or output are dynamic or filter is not constant
61 IsRequiredIm2col(input_shape, filter_shape, output_shape, params.stride_width,
64 _prepared = true;
65 }
66
67 int im2col_size = _need_im2col ? _im2col_shape.FlatSize() : 0;
68 std::vector<float> im2col_data(im2col_size);
69 if (im2col_size > 0)
70 {
71 ConvFloat(params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data,
72 output_shape, output_data, _im2col_shape, im2col_data.data(), ruy_context);
73 }
74 else
75 {
76 ConvFloat(params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data,
77 output_shape, output_data, _im2col_shape, nullptr, ruy_context);
78 }
79 }
80
81private:
82 void ConvFloat(const ConvParams &params, const Shape &input_shape, const float *input_data,
83 const Shape &filter_shape, const float *filter_data,
84 [[maybe_unused]] const Shape &bias_shape, const float *bias_data,
85 const Shape &output_shape, float *output_data, const Shape &im2col_shape,
86 float *im2col_data, ::ruy::Context *ruy_context)
87 {
88 const int stride_width = params.stride_width;
89 const int stride_height = params.stride_height;
90 const int dilation_width_factor = params.dilation_width_factor;
91 const int dilation_height_factor = params.dilation_height_factor;
92 const float output_activation_min = params.float_activation_min;
93 const float output_activation_max = params.float_activation_max;
94 assert(input_shape.DimensionsCount() == 4);
95 assert(filter_shape.DimensionsCount() == 4);
96 assert(output_shape.DimensionsCount() == 4);
97
98 // NB: the float 0.0f value is represented by all zero bytes.
99 const uint8_t float_zero_byte = 0x00;
100 const float *gemm_input_data = nullptr;
101 const Shape *gemm_input_shape = nullptr;
102 const int filter_width = filter_shape.Dims(2);
103 const int filter_height = filter_shape.Dims(1);
104 const bool need_dilated_im2col = dilation_width_factor != 1 || dilation_height_factor != 1;
105 const bool need_im2col =
106 stride_width != 1 || stride_height != 1 || filter_width != 1 || filter_height != 1;
107 if (need_dilated_im2col)
108 {
109 DilatedIm2col(params, float_zero_byte, input_shape, input_data, filter_shape, output_shape,
110 im2col_data);
111 gemm_input_data = im2col_data;
112 gemm_input_shape = &im2col_shape;
113 }
114 else if (need_im2col)
115 {
116 assert(im2col_data);
117 Im2col(params, filter_height, filter_width, float_zero_byte, input_shape, input_data,
118 im2col_shape, im2col_data);
119 gemm_input_data = im2col_data;
120 gemm_input_shape = &im2col_shape;
121 }
122 else
123 {
124 // TODO(aselle): We need to make sure to not send im2col if it is not
125 // needed.
126 assert(!im2col_data);
127 gemm_input_data = input_data;
128 gemm_input_shape = &input_shape;
129 }
130
131 const int gemm_input_dims = gemm_input_shape->DimensionsCount();
132 int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
133 int n = output_shape.Dims(3);
134 int k = gemm_input_shape->Dims(gemm_input_dims - 1);
135
136 // When an optimized CBLAS implementation is not available, fall back
137 // to using cpu_backend_gemm.
138 MatrixParams<float> lhs_params;
139 lhs_params.order = Order::kRowMajor;
140 lhs_params.rows = n;
141 lhs_params.cols = k;
142 MatrixParams<float> rhs_params;
143 rhs_params.order = Order::kColMajor;
144 rhs_params.rows = k;
145 rhs_params.cols = m;
146 MatrixParams<float> dst_params;
147 dst_params.order = Order::kColMajor;
148 dst_params.rows = n;
149 dst_params.cols = m;
150 GemmParams<float, float> gemm_params;
151 gemm_params.bias = bias_data;
152 gemm_params.clamp_min = output_activation_min;
153 gemm_params.clamp_max = output_activation_max;
154
155 // Below code is from tflite::cpu_backend_gemm::detail::GemmImplUsingRuy
156 ::ruy::Matrix<float> ruy_lhs;
157 ::ruy::Matrix<float> ruy_rhs;
158 ::ruy::Matrix<float> ruy_dst;
159 // Note that cache is always enabled for input and weight tensors
160 ruy_support::MakeRuyMatrix(lhs_params, filter_data, &ruy_lhs, true);
161 ruy_support::MakeRuyMatrix(rhs_params, gemm_input_data, &ruy_rhs, true);
162 ruy_support::MakeRuyMatrix(dst_params, output_data, &ruy_dst);
163
164 ::ruy::MulParams<float, float> ruy_mul_params;
165 ruy_support::MakeRuyMulParams(gemm_params, &ruy_mul_params);
166
167 ::ruy::Mul(ruy_lhs, ruy_rhs, ruy_mul_params, ruy_context, &ruy_dst);
168 }
169
170 void IsRequiredIm2col(const Shape &input_shape, const Shape &kernel_shape,
171 const Shape &output_shape, uint32_t stride_width, uint32_t stride_height,
172 uint32_t dilation_width_factor, uint32_t dilation_height_factor)
173 {
174 const bool need_dilated_im2col = dilation_width_factor != 1 || dilation_height_factor != 1;
175 const bool need_non_dilated_im2col = stride_width != 1 || stride_height != 1 ||
176 kernel_shape.Dims(1) != 1 || kernel_shape.Dims(2) != 1;
177
178 _need_im2col = need_dilated_im2col || need_non_dilated_im2col;
179
180 if (_need_im2col)
181 {
182 _im2col_shape.SetDim(0, output_shape.Dims(0));
183 _im2col_shape.SetDim(1, output_shape.Dims(1));
184 _im2col_shape.SetDim(2, output_shape.Dims(2));
185 _im2col_shape.SetDim(3, input_shape.Dims(3) * kernel_shape.Dims(1) * kernel_shape.Dims(2));
186 }
187 }
188
189private:
190 Shape _im2col_shape;
191 bool _need_im2col;
192 bool _prepared;
193};
194} // namespace ruy
195} // namespace nnfw
196
197#endif // __NNFW_RUY_CONV_H_
void operator()(const ConvParams &params, const Shape &input_shape, const float *input_data, const Shape &filter_shape, const float *filter_data, const Shape &bias_shape, const float *bias_data, const Shape &output_shape, float *output_data, ::ruy::Context *ruy_context)
Definition Conv.h:53
void prepare(const Shape &input_shape, const Shape &kernel_shape, const Shape &output_shape, uint32_t stride_width, uint32_t stride_height, uint32_t dilation_width_factor, uint32_t dilation_height_factor)
Definition Conv.h:41
int FlatSize() const
Definition Shape.h:181
int32_t DimensionsCount() const
Definition Shape.h:91
int32_t Dims(int i) const
Definition Shape.h:92
void SetDim(int i, int32_t val)
Definition Shape.h:98
const luci_interpreter::RuntimeShape output_shape
void MakeRuyMulParams(const GemmParams< AccumScalar, DstScalar, quantization_flavor > &params, ::ruy::MulParams< AccumScalar, DstScalar > *ruy_mul_params)
Definition RuySupport.h:69
void MakeRuyMatrix(const MatrixParams< Scalar > &params, DataPointer data_ptr, ::ruy::Matrix< Scalar > *dst, bool use_caching=false)
Definition RuySupport.h:51
void Im2col(const ConvParams &params, int kheight, int kwidth, uint8_t zero_byte, const Shape &input_shape, const T *input_data, const Shape &output_shape, T *output_data)
Definition Utils.h:219
int FlatSizeSkipDim(const Shape &shape, int skip_dim)
Definition Shape.h:254
void DilatedIm2col(const ConvParams &params, const Shape &input_shape, const T *input_data, const Shape &filter_shape, const Shape &output_shape, T *im2col_data, const int32_t *zero_bytes, const int zero_bytes_len)
Definition Utils.h:118
Definition topk_v2.h:30
Definition Shape.h:28
float float_activation_max
Definition Types.h:77
float float_activation_min
Definition Types.h:76
int16_t dilation_height_factor
Definition Types.h:64
int16_t stride_width
Definition Types.h:61
int16_t dilation_width_factor
Definition Types.h:63
int16_t stride_height
Definition Types.h:62