ONE - On-device Neural Engine
Loading...
Searching...
No Matches
GatherND.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2021 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "OMStatus.h"
19
20#include "core/OMUtils.h"
23#include "PALGatherND.h"
24
25using namespace onert_micro;
26using namespace onert_micro::core;
27using namespace onert_micro::execute;
28
29namespace
30{
31
32constexpr uint32_t inputTensorIdx = 0;
33constexpr uint32_t positionsTensorIdx = 1;
34constexpr uint32_t outputTensorIdx = 0;
35
36} // namespace
37
38namespace onert_micro
39{
40namespace import
41{
42
44{
45
46 OMRuntimeContext &runtime_context = config_args.runtime_context;
47 uint16_t op_index = config_args.kernel_index;
48
50
51 OMStatus status = runtime_kernel.readKernel(op_index, runtime_context);
52 if (status != Ok)
53 return status;
54
55 const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
56 const circle::Tensor *positions = runtime_kernel.inputs[positionsTensorIdx];
57 const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];
58
59 assert(input != nullptr);
60 assert(positions != nullptr);
61 assert(output != nullptr);
62
63 status = utils::checkCondition(input->type() == output->type());
64 if (status != Ok)
65 return status;
66
67 status = utils::checkCondition(positions->type() == circle::TensorType_INT32);
68 if (status != Ok)
69 return status;
70
71 auto input_type = input->type();
72 status = utils::checkCondition(input_type == circle::TensorType_FLOAT32);
73 if (status != Ok)
74 return status;
75
76 core::OMRuntimeShape input_shape(input);
77 core::OMRuntimeShape positions_shape(positions);
78
79 int32_t shape_num_dims = input_shape.dimensionsCount();
80
81 status = utils::checkCondition(shape_num_dims >= 1);
82 if (status != Ok)
83 return status;
84
85 int32_t positions_num_dims = positions_shape.dimensionsCount();
86 int32_t positions_num_dims_nd = positions_shape.dims(positions_num_dims - 1);
87
88 status = utils::checkCondition(positions_num_dims >= 1);
89 if (status != Ok)
90 return status;
91
92 status = utils::checkCondition(positions_num_dims_nd <= shape_num_dims);
93 if (status != Ok)
94 return status;
95
96 status =
97 utils::checkCondition(positions_num_dims_nd <= onert_micro::execute::pal::MAX_INDICES_ND);
98 if (status != Ok)
99 return status;
100
101 return Ok;
102}
103
104} // namespace import
105} // namespace onert_micro
size_t dimensionsCount() const noexcept
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
constexpr uint32_t outputTensorIdx
constexpr int MAX_INDICES_ND
Definition PALGatherND.h:32
OMStatus configure_kernel_CircleGatherND(const OMConfigureArgs &config_args)
Definition GatherND.cpp:43
core::OMRuntimeContext & runtime_context