ONE - On-device Neural Engine
Loading...
Searching...
No Matches
SpaceToBatchND.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_SPACE_TO_BATCH_ND_H__
19#define __NNFW_CKER_SPACE_TO_BATCH_ND_H__
20
21#include "cker/Shape.h"
22#include "cker/Types.h"
23
24namespace nnfw
25{
26namespace cker
27{
28
29template <typename T>
30inline void
31SpaceToBatchND(const SpaceToBatchParams &params, const Shape &unextended_input_shape,
32 const T *input_data, [[maybe_unused]] const Shape &unextended_block_shape_shape,
33 const int32_t *block_shape_data,
34 [[maybe_unused]] const Shape &unextended_padding_shape, const int32_t *paddings_data,
35 const Shape &unextended_output_shape, T *output_data)
36{
37 assert(unextended_input_shape.DimensionsCount() <= 4);
38 assert(unextended_output_shape.DimensionsCount() <= 4);
39 const Shape input_shape = Shape::ExtendedShape(4, unextended_input_shape);
40 const Shape output_shape = Shape::ExtendedShape(4, unextended_output_shape);
41
42 const int depth = input_shape.Dims(3);
43 const int input_width = input_shape.Dims(2);
44 const int input_height = input_shape.Dims(1);
45 const int input_batch_size = input_shape.Dims(0);
46
47 const int output_width = output_shape.Dims(2);
48 const int output_height = output_shape.Dims(1);
49 const int output_batch_size = output_shape.Dims(0);
50
51 const int block_shape_height = block_shape_data[0];
52 const int block_shape_width = block_shape_data[1];
53 const int padding_top = paddings_data[0];
54 const int padding_left = paddings_data[2];
55
56 // For uint8 quantized, the correct padding "zero value" is the output offset.
57 const int32_t pad_value = params.output_offset;
58
59 for (int out_b = 0; out_b < output_batch_size; ++out_b)
60 {
61 int input_batch = out_b % input_batch_size;
62 int shift_w = (out_b / input_batch_size) % block_shape_width;
63 int shift_h = (out_b / input_batch_size) / block_shape_width;
64 for (int out_h = 0; out_h < output_height; ++out_h)
65 {
66 for (int out_w = 0; out_w < output_width; ++out_w)
67 {
68 T *out = output_data + Offset(output_shape, out_b, out_h, out_w, 0);
69 if (out_h * block_shape_height + shift_h < padding_top ||
70 out_h * block_shape_height + shift_h >= padding_top + input_height ||
71 out_w * block_shape_width + shift_w < padding_left ||
72 out_w * block_shape_width + shift_w >= padding_left + input_width)
73 {
74 // This may not execute correctly when pad_value != 0 and T != uint8.
75 memset(out, pad_value, depth * sizeof(T));
76 }
77 else
78 {
79 const T *in =
80 input_data + Offset(input_shape, input_batch,
81 (out_h * block_shape_height + shift_h) - padding_top,
82 (out_w * block_shape_width + shift_w) - padding_left, 0);
83 memcpy(out, in, depth * sizeof(T));
84 }
85 }
86 }
87 }
88}
89
90} // namespace cker
91} // namespace nnfw
92
93#endif // __NNFW_CKER_SPACE_TO_BATCH_ND_H__
int32_t DimensionsCount() const
Definition Shape.h:91
int32_t Dims(int i) const
Definition Shape.h:92
const luci_interpreter::RuntimeShape output_shape
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
Definition Shape.h:237
void SpaceToBatchND(const SpaceToBatchParams &params, const Shape &unextended_input_shape, const T *input_data, const Shape &unextended_block_shape_shape, const int32_t *block_shape_data, const Shape &unextended_padding_shape, const int32_t *paddings_data, const Shape &unextended_output_shape, T *output_data)
Definition topk_v2.h:30