ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Select.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_SELECT_H__
19#define __NNFW_CKER_SELECT_H__
20
21#include "cker/Shape.h"
22#include "cker/Utils.h"
23
24#include <cmath>
25
26namespace nnfw
27{
28namespace cker
29{
30
31template <typename D, typename T>
32void Select(const Shape &input_condition_shape, const D *input_condition_data,
33 const Shape &input_x_shape, const T *input_x_data, const Shape &input_y_shape,
34 const T *input_y_data, const Shape &output_shape, T *output_data)
35{
36 const int64_t flatsize =
37 MatchingFlatSize(input_condition_shape, input_x_shape, input_y_shape, output_shape);
38 for (int64_t i = 0; i < flatsize; ++i)
39 {
40 output_data[i] = (input_condition_data[i] != 0) ? input_x_data[i] : input_y_data[i];
41 }
42}
43
44template <typename D, typename T>
45void RankOneSelect(const Shape &input_condition_shape, const D *input_condition_data,
46 const Shape &input_x_shape, const T *input_x_data, const Shape &input_y_shape,
47 const T *input_y_data, const Shape &output_shape, T *output_data)
48{
49 const int64_t outer_size = input_condition_shape.FlatSize();
50 assert(MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0) == outer_size);
51 const int64_t inner_size = MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
52
53 int64_t offset = 0;
54 for (int64_t i = 0; i < outer_size; i++)
55 {
56 const T *input_data = (input_condition_data[i] != 0) ? input_x_data : input_y_data;
57 memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
58 offset += inner_size;
59 }
60}
61
62template <typename D, typename T>
63void BroadcastSelect4DSlow(const Shape &input_condition_shape, const D *input_condition_data,
64 const Shape &input_x_shape, const T *input_x_data,
65 const Shape &input_y_shape, const T *input_y_data,
66 const Shape &output_shape, T *output_data)
67{
68 assert(input_condition_shape.DimensionsCount() <= 4);
69 assert(input_x_shape.DimensionsCount() <= 4);
70 assert(input_y_shape.DimensionsCount() <= 4);
71 assert(output_shape.DimensionsCount() <= 4);
72
73 const Shape extended_output_shape = Shape::ExtendedShape(4, output_shape);
74
75 NdArrayDesc<4> desc_condition;
76 NdArrayDesc<4> desc_x;
77 NdArrayDesc<4> desc_y;
78 NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape, input_y_shape,
79 &desc_condition, &desc_x, &desc_y);
80
81 // In Tensorflow, the dimensions are canonically named (batch_number, row,
82 // col, channel), with extents (batches, height, width, depth), with the
83 // trailing dimension changing most rapidly (channels has the smallest
84 // stride, typically 1 element).
85 //
86 // In generated C code, we store arrays with the dimensions reversed. The
87 // first dimension has smallest stride.
88 //
89 // We name our variables by their Tensorflow convention, but generate C code
90 // nesting loops such that the innermost loop has the smallest stride for
91 // the best cache behavior.
92 for (int b = 0; b < extended_output_shape.Dims(0); ++b)
93 {
94 for (int y = 0; y < extended_output_shape.Dims(1); ++y)
95 {
96 for (int x = 0; x < extended_output_shape.Dims(2); ++x)
97 {
98 for (int c = 0; c < extended_output_shape.Dims(3); ++c)
99 {
100 const int condition_index = SubscriptToIndex(desc_condition, b, y, x, c);
101 const int x_index = SubscriptToIndex(desc_x, b, y, x, c);
102 const int y_index = SubscriptToIndex(desc_y, b, y, x, c);
103 output_data[Offset(extended_output_shape, b, y, x, c)] =
104 input_condition_data[condition_index] ? input_x_data[x_index] : input_y_data[y_index];
105 }
106 }
107 }
108 }
109}
110
111} // namespace cker
112} // namespace nnfw
113
114#endif // __NNFW_CKER_SELECT_H__
int32_t DimensionsCount() const
Definition Shape.h:91
int32_t Dims(int i) const
Definition Shape.h:92
int FlatSize() const
Definition Shape.h:181
__global uchar * offset(const Image *img, int x, int y)
Definition helpers.h:540
const luci_interpreter::RuntimeShape output_shape
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
Definition Shape.h:220
void Select(const Shape &input_condition_shape, const D *input_condition_data, const Shape &input_x_shape, const T *input_x_data, const Shape &input_y_shape, const T *input_y_data, const Shape &output_shape, T *output_data)
Definition Select.h:32
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
Definition Shape.h:237
void RankOneSelect(const Shape &input_condition_shape, const D *input_condition_data, const Shape &input_x_shape, const T *input_x_data, const Shape &input_y_shape, const T *input_y_data, const Shape &output_shape, T *output_data)
Definition Select.h:45
void NdArrayDescsForElementwiseBroadcast(const Shape &input0_shape, const Shape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
Definition Utils.h:290
int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
Definition Shape.h:304
void BroadcastSelect4DSlow(const Shape &input_condition_shape, const D *input_condition_data, const Shape &input_x_shape, const T *input_x_data, const Shape &input_y_shape, const T *input_y_data, const Shape &output_shape, T *output_data)
Definition Select.h:63
int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
Definition Shape.h:297
int SubscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
Definition Utils.h:255
Definition topk_v2.h:30