ONE - On-device Neural Engine
Loading...
Searching...
No Matches
DepthwiseConv.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __NNFW_CKER_TRAIN_OPERATION_DEPTHWISECONV_H__
18#define __NNFW_CKER_TRAIN_OPERATION_DEPTHWISECONV_H__
19
21#include "cker/Shape.h"
22#include "cker/Types.h"
23
24namespace nnfw
25{
26namespace cker
27{
28namespace train
29{
30
31template <typename T>
32void backpropInput(const DepthwiseConvParams &params, const Shape &incoming_shape,
33 const T *incoming_data, const Shape &filter_shape, const T *filter_data,
34 T *padded_filter_data, const Shape &grad_shape, T *grad_data, bool pad_filter,
35 T *filter_buffers_data, T *filter_dim_buffers_data)
36{
37 if (params.stride_height != params.stride_width)
38 throw std::runtime_error("Not support different length strides");
39
40 if (params.dilation_height_factor != 1 || params.dilation_width_factor != 1)
41 throw std::runtime_error{"Not support dilation other than 1."};
42
43 const int batch = MatchingDim(incoming_shape, 0, grad_shape, 0);
44 const int input_depth = grad_shape.Dims(3);
45 const int output_depth = incoming_shape.Dims(3);
46 const int incoming_height = incoming_shape.Dims(1);
47 const int incoming_width = incoming_shape.Dims(2);
48 const int grad_height = grad_shape.Dims(1);
49 const int grad_width = grad_shape.Dims(2);
50 const int stride = params.stride_height;
51 const int depth_multiplier = params.depth_multiplier;
52 const int filter_height = filter_shape.Dims(1);
53 const int filter_width = filter_shape.Dims(2);
54 const int pad_height = params.padding_values.height;
55 const int pad_width = params.padding_values.width;
56
58 batch, grad_height, grad_width, input_depth, filter_height, filter_width, depth_multiplier,
59 stride, pad_height, pad_width, incoming_height, incoming_width, output_depth, incoming_data,
60 filter_data, padded_filter_data, grad_data, pad_filter, filter_buffers_data,
61 filter_dim_buffers_data);
62}
63
64template <typename T>
65void backpropFilter(const DepthwiseConvParams &params, const Shape &incoming_shape,
66 const T *incoming_data, const Shape &input_shape, const T *input_data,
67 const Shape &filter_grad_shape, T *filter_grad_data, T *padded_filter_data,
68 T *filter_buffers_data)
69{
70 if (params.stride_height != params.stride_width)
71 throw std::runtime_error("Not support different length strides");
72
73 if (params.dilation_height_factor != 1 || params.dilation_width_factor != 1)
74 throw std::runtime_error{"Not support dilation other than 1."};
75
76 const int batch = MatchingDim(incoming_shape, 0, input_shape, 0);
77 const int input_depth = input_shape.Dims(3);
78 const int output_depth = incoming_shape.Dims(3);
79 const int incoming_height = incoming_shape.Dims(1);
80 const int incoming_width = incoming_shape.Dims(2);
81 const int input_height = input_shape.Dims(1);
82 const int input_width = input_shape.Dims(2);
83 const int stride = params.stride_height;
84 const int depth_multiplier = params.depth_multiplier;
85 const int filter_height = filter_grad_shape.Dims(1);
86 const int filter_width = filter_grad_shape.Dims(2);
87 const int pad_height = params.padding_values.height;
88 const int pad_width = params.padding_values.width;
89
91 batch, input_height, input_width, input_depth, filter_height, filter_width, depth_multiplier,
92 stride, pad_height, pad_width, incoming_height, incoming_width, output_depth, incoming_data,
93 input_data, filter_grad_data, padded_filter_data, filter_buffers_data);
94}
95
96} // namespace train
97} // namespace cker
98} // namespace nnfw
99
100#endif // __NNFW_CKER_TRAIN_OPERATION_DEPTHWISECONV_H__
int32_t Dims(int i) const
Definition Shape.h:92
void backpropFilter(const DepthwiseConvParams &params, const Shape &incoming_shape, const T *incoming_data, const Shape &input_shape, const T *input_data, const Shape &filter_grad_shape, T *filter_grad_data, T *padded_filter_data, T *filter_buffers_data)
void backpropInput(const DepthwiseConvParams &params, const Shape &incoming_shape, const T *incoming_data, const Shape &filter_shape, const T *filter_data, T *padded_filter_data, const Shape &grad_shape, T *grad_data, bool pad_filter, T *filter_buffers_data, T *filter_dim_buffers_data)
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
Definition Shape.h:220
Definition topk_v2.h:30
PaddingValues padding_values
Definition Types.h:234