ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALSplit.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef ONERT_MICRO_EXECUTE_PAL_SPLIT_COMMON_H
18#define ONERT_MICRO_EXECUTE_PAL_SPLIT_COMMON_H
19
20#include "core/OMRuntimeShape.h"
21#include "core/OMKernelData.h"
23#include "OMStatus.h"
24#include "PALSISOOperation.h"
25#include "PALUtils.h"
26#include <cmath>
27
28namespace onert_micro
29{
30namespace execute
31{
32namespace pal
33{
34
35template <typename T>
36OMStatus Split(const core::SplitParams &params, const core::OMRuntimeShape &input_shape,
37 const T *input_data, const core::OMRuntimeShape &output_shape, int32_t axis_value)
38{
39 const auto output_count = params.num_outputs;
40
41 const auto split_dimensions = input_shape.dimensionsCount();
42
43 assert(axis_value < split_dimensions);
44 assert(output_shape.dimensionsCount() == split_dimensions);
45
46 int64_t outer_size = 1;
47 for (uint32_t i = 0; i < axis_value; ++i)
48 {
49 outer_size *= input_shape.dims(i);
50 }
51
52 int64_t base_inner_size = 1;
53 for (uint32_t i = axis_value + 1; i < split_dimensions; ++i)
54 {
55 base_inner_size *= input_shape.dims(i);
56 }
57
58 assert(input_data != nullptr);
59 for (int64_t k = 0; k < outer_size; ++k)
60 {
61 for (uint32_t i = 0; i < output_count; ++i)
62 {
63 T *output_data = core::utils::castOutputData<T>(params.output_data[i]);
64 assert(output_data != nullptr);
65 const auto copy_size = output_shape.dims(axis_value) * base_inner_size;
66 T *output_ptr = output_data + k * copy_size;
67 assert(output_ptr != nullptr);
68 for (int64_t j = 0; j < copy_size; ++j)
69 output_ptr[j] = input_data[j];
70 input_data += copy_size;
71 }
72 }
73 return Ok;
74}
75
76} // namespace pal
77} // namespace execute
78} // namespace onert_micro
79
80#endif // ONERT_MICRO_EXECUTE_PAL_SPLIT_COMMON_H
int32_t dimensionsCount() const
Definition Tensor.h:106
int32_t dims(int i) const
Definition Tensor.h:108
const luci_interpreter::RuntimeShape output_shape
OMStatus Split(const core::SplitParams &params, const core::OMRuntimeShape &input_shape, const T *input_data, const core::OMRuntimeShape &output_shape, int32_t axis_value)
Definition PALSplit.h:36