ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALGatherND.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef ONERT_MICRO_EXECUTE_PAL_GATHER_ND_COMMON_H
18#define ONERT_MICRO_EXECUTE_PAL_GATHER_ND_COMMON_H
19
20#include "OMStatus.h"
21#include "core/OMRuntimeShape.h"
22
23#include <cmath>
24
25namespace onert_micro
26{
27namespace execute
28{
29namespace pal
30{
31
32constexpr int MAX_INDICES_ND = 5;
33
34template <typename ParamsT, typename IndicesT>
35inline OMStatus GatherND(core::OMRuntimeShape params_shape, const ParamsT *param_data,
36 core::OMRuntimeShape indices_shape, const IndicesT *index_data,
37 ParamsT *output_data)
38{
39 const int indices_dims = indices_shape.dimensionsCount();
40 const int indices_nd = indices_shape.dims(indices_dims - 1);
41 const int params_dims = params_shape.dimensionsCount();
42
43 int n_slices = 1;
44 for (int i = 0; i < indices_dims - 1; ++i)
45 {
46 n_slices *= indices_shape.dims(i);
47 }
48
49 // If indices[-1] == params.rank, fetch single elements.
50 // If indices[-1] < params.rank, fetch slices.
51 int slice_size = 1;
52 for (int i = indices_nd; i < params_dims; ++i)
53 {
54 slice_size *= params_shape.dims(i);
55 }
56
57 int params_flat_size = params_shape.flatSize();
58 int remain_flat_size = params_flat_size;
59
60 // Number of elements per dimension
61 int dims_to_count[MAX_INDICES_ND];
62 for (int i = 0; i < indices_nd; ++i)
63 {
64 dims_to_count[i] = remain_flat_size / params_shape.dims(i);
65 remain_flat_size = dims_to_count[i];
66 }
67
68 for (int i = 0; i < n_slices; ++i)
69 {
70 int from_pos = 0;
71 for (int j = 0; j < indices_nd; ++j)
72 {
73 int offset = i * indices_nd + j;
74 IndicesT index = index_data[offset];
75 from_pos += index * dims_to_count[j];
76 }
77 if (from_pos < 0 || from_pos + slice_size > params_flat_size)
78 {
79 assert(false && "GatherND error");
80 return UnknownError;
81 }
82 std::memcpy(output_data + i * slice_size, param_data + from_pos, sizeof(ParamsT) * slice_size);
83 }
84
85 return Ok;
86}
87
88} // namespace pal
89
90} // namespace execute
91
92} // namespace onert_micro
93
94#endif // ONERT_MICRO_EXECUTE_PAL_GATHER_ND_COMMON_H
OMStatus GatherND(core::OMRuntimeShape params_shape, const ParamsT *param_data, core::OMRuntimeShape indices_shape, const IndicesT *index_data, ParamsT *output_data)
Definition PALGatherND.h:35
constexpr int MAX_INDICES_ND
Definition PALGatherND.h:32
int offset(const int32_t *dims_data, int i0, int i1, int i2, int i3)
Definition PALUtils.h:220