ONE - On-device Neural Engine
Loading...
Searching...
No Matches
SelectV2.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "OMStatus.h"
18
19#include "core/OMUtils.h"
20#include "core/OMRuntimeShape.h"
21
22#include "execute/OMUtils.h"
25#include "PALSelectV2.h"
26
27using namespace onert_micro;
28using namespace onert_micro::execute;
29
30namespace
31{
32
33constexpr int inputCond = 0;
34constexpr int inputX = 1;
35constexpr int inputY = 2;
36constexpr int outputIndex = 0;
37
38template <typename T>
39void CallSelect(const core::OMRuntimeShape &input_cond_shape, const bool *input_cond_data,
40 const core::OMRuntimeShape &input_x_shape, const T *input_x_data,
41 const core::OMRuntimeShape &input_y_shape, const T *input_y_data,
42 const core::OMRuntimeShape &output_shape, T *output_data)
43{
44 using Func = decltype(onert_micro::execute::pal::Select<bool, T>) *;
45 Func select_func;
46 select_func = onert_micro::execute::pal::Select<bool, T>;
47
48 select_func(input_cond_shape, input_cond_data, input_x_shape, input_x_data, input_y_shape,
49 input_y_data, output_shape, output_data);
50}
51
52} // namespace
53
54// NOTE: doesnt currently support dynamic shapes
55namespace onert_micro
56{
57namespace execute
58{
59
61{
62 core::OMRuntimeContext &runtime_context = execute_args.runtime_context;
63 core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage;
64 uint16_t op_index = execute_args.kernel_index;
65
66 const circle::Tensor *input_cond;
67 const circle::Tensor *input_x;
68 const circle::Tensor *input_y;
69 const circle::Tensor *output;
70
71 uint8_t *input_cond_data;
72 uint8_t *input_x_data;
73 uint8_t *input_y_data;
74 uint8_t *output_data;
75
76 OMStatus status = Ok;
77
78 // Read kernel
79 {
80 execute::OMRuntimeKernel runtime_kernel;
81 runtime_kernel.readKernel(op_index, runtime_context);
82
83 input_cond = runtime_kernel.inputs[inputCond];
84 input_x = runtime_kernel.inputs[inputX];
85 input_y = runtime_kernel.inputs[inputY];
86 output = runtime_kernel.outputs[outputIndex];
87
88 assert(input_cond != nullptr);
89 assert(input_x != nullptr);
90 assert(input_y != nullptr);
91 assert(output != nullptr);
92
93 status = runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context);
94 if (status != Ok)
95 return status;
96
97 input_cond_data = runtime_kernel.inputs_data[inputCond];
98 input_x_data = runtime_kernel.inputs_data[inputX];
99 input_y_data = runtime_kernel.inputs_data[inputY];
100 output_data = runtime_kernel.outputs_data[outputIndex];
101
102 assert(input_cond_data != nullptr);
103 assert(input_x_data != nullptr);
104 assert(input_y_data != nullptr);
105 assert(output_data != nullptr);
106 }
107
108 const core::OMRuntimeShape input_cond_shape(input_cond);
109 assert(input_cond_shape.flatSize() > 0);
110 const core::OMRuntimeShape input_x_shape(input_x);
111 const core::OMRuntimeShape input_y_shape(input_y);
113
114 switch (input_x->type())
115 {
116#ifndef DIS_FLOAT
117 case circle::TensorType_FLOAT32:
118 {
119 CallSelect<float>(input_cond_shape, core::utils::castInputData<bool>(input_cond_data),
120 input_x_shape, core::utils::castInputData<float>(input_x_data),
121 input_y_shape, core::utils::castInputData<float>(input_y_data),
122 output_shape, core::utils::castOutputData<float>(output_data));
123 }
124 break;
125#endif
126 default:
127 {
128 status = UnsupportedType;
129 assert(false && "Unsupported type.");
130 }
131 }
132
133 return status;
134}
135
136} // namespace execute
137} // namespace onert_micro
uint8_t * outputs_data[maxOutputSize]
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
const luci_interpreter::RuntimeShape output_shape
OMStatus execute_kernel_CircleSelectV2(const OMExecuteArgs &execute_args)
Definition SelectV2.cpp:60
@ UnsupportedType
Definition OMStatus.h:26
core::OMRuntimeContext & runtime_context
core::OMRuntimeStorage & runtime_storage