ONE - On-device Neural Engine
Loading...
Searching...
No Matches
GatherND.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2021 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "OMStatus.h"
19
20#include "core/OMUtils.h"
23#include "PALGatherND.h"
24
25using namespace onert_micro;
26using namespace onert_micro::core;
27using namespace onert_micro::execute;
28
29namespace
30{
31
32constexpr uint32_t inputTensorIdx = 0;
33constexpr uint32_t positionsTensorIdx = 1;
34constexpr uint32_t outputTensorIdx = 0;
35
36} // namespace
37
38OMStatus onert_micro::import::configure_kernel_CircleGatherND(const OMConfigureArgs &config_args)
39{
40
41 OMRuntimeContext &runtime_context = config_args.runtime_context;
42 uint16_t op_index = config_args.kernel_index;
43
45
46 OMStatus status = runtime_kernel.readKernel(op_index, runtime_context);
47 if (status != Ok)
48 return status;
49
50 const circle::Tensor *input = runtime_kernel.inputs[inputTensorIdx];
51 const circle::Tensor *positions = runtime_kernel.inputs[positionsTensorIdx];
52 const circle::Tensor *output = runtime_kernel.outputs[outputTensorIdx];
53
54 assert(input != nullptr);
55 assert(positions != nullptr);
56 assert(output != nullptr);
57
58 status = utils::checkCondition(input->type() == output->type());
59 if (status != Ok)
60 return status;
61
62 status = utils::checkCondition(positions->type() == circle::TensorType_INT32);
63 if (status != Ok)
64 return status;
65
66 auto input_type = input->type();
67 status = utils::checkCondition(input_type == circle::TensorType_FLOAT32);
68 if (status != Ok)
69 return status;
70
71 core::OMRuntimeShape input_shape(input);
72 core::OMRuntimeShape positions_shape(positions);
73
74 int32_t shape_num_dims = input_shape.dimensionsCount();
75
76 status = utils::checkCondition(shape_num_dims >= 1);
77 if (status != Ok)
78 return status;
79
80 int32_t positions_num_dims = positions_shape.dimensionsCount();
81 int32_t positions_num_dims_nd = positions_shape.dims(positions_num_dims - 1);
82
83 status = utils::checkCondition(positions_num_dims >= 1);
84 if (status != Ok)
85 return status;
86
87 status = utils::checkCondition(positions_num_dims_nd <= shape_num_dims);
88 if (status != Ok)
89 return status;
90
91 status =
92 utils::checkCondition(positions_num_dims_nd <= onert_micro::execute::pal::MAX_INDICES_ND);
93 if (status != Ok)
94 return status;
95
96 return Ok;
97}
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
constexpr uint32_t outputTensorIdx
constexpr int MAX_INDICES_ND
Definition PALGatherND.h:32