ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Cast.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "OMStatus.h"
18
19#include "core/OMUtils.h"
20#include "core/OMKernelData.h"
21#include "core/OMDataType.h"
22
24#include "execute/OMUtils.h"
26
27#include "PALCast.h"
28
29using namespace onert_micro;
30using namespace onert_micro::execute;
31
32namespace
33{
34
35constexpr uint32_t inputTensorIdx = 0;
36constexpr uint32_t outputTensorIdx = 0;
37
38} // namespace
39
40// NOTE: doesnt currently support dynamic shapes
41OMStatus onert_micro::execute::execute_kernel_CircleCast(const OMExecuteArgs &execute_args)
42{
43 const circle::Tensor *input = nullptr;
44 const circle::Tensor *output = nullptr;
45
46 uint8_t *input_data = nullptr;
47 uint8_t *output_data = nullptr;
48
49 SISOHeader(execute_args, &input, &output, &input_data, &output_data);
50
51 OMStatus status;
52
53 switch (input->type())
54 {
55#ifndef DIS_FLOAT
56 case circle::TensorType_FLOAT32:
57 {
58 switch (output->type())
59 {
60 case circle::TensorType_INT32:
61 {
62 status = pal::Cast(
63 core::OMRuntimeShape(input), core::utils::castInputData<float>(input_data),
64 core::OMRuntimeShape(output), core::utils::castOutputData<int32_t>(output_data));
65 break;
66 }
67 case circle::TensorType_INT8:
68 {
69 status = pal::Cast(
70 core::OMRuntimeShape(input), core::utils::castInputData<float>(input_data),
71 core::OMRuntimeShape(output), core::utils::castOutputData<int8_t>(output_data));
72 break;
73 }
74 case circle::TensorType_INT16:
75 {
76 status = pal::Cast(
77 core::OMRuntimeShape(input), core::utils::castInputData<float>(input_data),
78 core::OMRuntimeShape(output), core::utils::castOutputData<int16_t>(output_data));
79 break;
80 }
81 default:
82 {
83 status = UnsupportedType;
84 assert(false && "Unsupported type.");
85 break;
86 }
87 }
88 }
89 break;
90#endif // DIS_FLOAT
91 default:
92 {
93 status = UnsupportedType;
94 assert(false && "Unsupported type.");
95 break;
96 }
97 }
98
99 return status;
100}
constexpr uint32_t outputTensorIdx
list input_data
Definition infer.py:29
OMStatus Cast(const core::OMRuntimeShape &input_shape, const FromT *input_data, const core::OMRuntimeShape &output_shape, ToT *output_data)
Definition PALCast.h:34
OMStatus SISOHeader(const OMExecuteArgs &execute_args, const circle::Tensor **input, const circle::Tensor **output, uint8_t **input_data, uint8_t **output_data)
Definition OMUtils.cpp:159
@ UnsupportedType
Definition OMStatus.h:26