ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALGatherND.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_PAL_GATHER_ND_COMMON_H
19#define LUCI_INTERPRETER_PAL_GATHER_ND_COMMON_H
20
21#include "PALUtils.h"
22#include <cmath>
23
25{
26
27constexpr int MAX_INDICES_ND = 5;
28
29template <typename ParamsT, typename IndicesT>
30inline void GatherND(luci_interpreter::RuntimeShape params_shape, const ParamsT *param_data,
31 luci_interpreter::RuntimeShape indices_shape, const IndicesT *index_data,
32 ParamsT *output_data)
33{
34 const int indices_dims = indices_shape.dimensionsCount();
35 const int indices_nd = indices_shape.dims(indices_dims - 1);
36 const int params_dims = params_shape.dimensionsCount();
37
38 int n_slices = 1;
39 for (int i = 0; i < indices_dims - 1; ++i)
40 {
41 n_slices *= indices_shape.dims(i);
42 }
43
44 // If indices[-1] == params.rank, fetch single elements.
45 // If indices[-1] < params.rank, fetch slices.
46 int slice_size = 1;
47 for (int i = indices_nd; i < params_dims; ++i)
48 {
49 slice_size *= params_shape.dims(i);
50 }
51
52 int params_flat_size = params_shape.flatSize();
53 int remain_flat_size = params_flat_size;
54
55 // Number of elements per dimension
56 int dims_to_count[MAX_INDICES_ND];
57 for (int i = 0; i < indices_nd; ++i)
58 {
59 dims_to_count[i] = remain_flat_size / params_shape.dims(i);
60 remain_flat_size = dims_to_count[i];
61 }
62
63 for (int i = 0; i < n_slices; ++i)
64 {
65 int from_pos = 0;
66 for (int j = 0; j < indices_nd; ++j)
67 {
68 int offset = i * indices_nd + j;
69 IndicesT index = index_data[offset];
70 from_pos += index * dims_to_count[j];
71 }
72 if (from_pos < 0 || from_pos + slice_size > params_flat_size)
73 {
74 assert(false && "GatherND error");
75 return;
76 }
77 std::memcpy(output_data + i * slice_size, param_data + from_pos, sizeof(ParamsT) * slice_size);
78 }
79}
80
81} // namespace luci_interpreter_pal
82
83#endif // LUCI_INTERPRETER_PAL_GATHER_ND_COMMON_H
int32_t dimensionsCount() const
Definition Tensor.h:106
int32_t dims(int i) const
Definition Tensor.h:108
int offset(const int32_t *dims_data, int i0, int i1, int i2, int i3)
Definition PALUtils.h:193
constexpr int MAX_INDICES_ND
Definition PALGatherND.h:27
void GatherND(luci_interpreter::RuntimeShape params_shape, const ParamsT *param_data, luci_interpreter::RuntimeShape indices_shape, const IndicesT *index_data, ParamsT *output_data)
Definition PALGatherND.h:30