ONE - On-device Neural Engine
Loading...
Searching...
No Matches
StridedSlice.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "../KernelGenerator.h"
18#include "../Validator.h"
19
20#include <AclKernelGen.h>
21
23{
24
25void Validator::visit(const ir::operation::StridedSlice &) { _supported = true; }
26
27void KernelGenerator::visit(const ir::operation::StridedSlice &node)
28{
29 const auto output_index{node.getOutputs().at(0)};
30 const auto input_index{node.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
31 const auto starts_index{node.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
32 const auto ends_index{node.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
33 const auto strides_index{node.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
34
35 auto outputData_tensor = _tensor_reg->getAclTensor(output_index);
36 auto inputData_tensor = _tensor_reg->getAclTensor(input_index);
37
38 // Set initializers for indices data such as order of inputData
39 int input_rank = _ctx.at(input_index).shape().rank();
40 std::vector<int32_t> starts;
41 std::vector<int32_t> ends;
42 std::vector<int32_t> strides;
43 starts.resize(input_rank, 0);
44 ends.resize(input_rank, 0);
45 strides.resize(input_rank, 0);
46 {
47 assert(_ctx.at(starts_index).data());
48 assert(_ctx.at(ends_index).data());
49 assert(_ctx.at(strides_index).data());
50 auto startData_base = _ctx.at(starts_index).data()->base();
51 auto endData_base = _ctx.at(ends_index).data()->base();
52 auto stridesData_base = _ctx.at(strides_index).data()->base();
53 [[maybe_unused]] const int startData_size = _ctx.at(starts_index).shape().num_elements();
54 [[maybe_unused]] const int endData_size = _ctx.at(ends_index).shape().num_elements();
55 [[maybe_unused]] const int stridesData_size = _ctx.at(strides_index).shape().num_elements();
56
57 using ir::DataType;
58
59 assert(_ctx.at(starts_index).typeInfo().type() == DataType::INT32);
60 assert(_ctx.at(ends_index).typeInfo().type() == DataType::INT32);
61 assert(_ctx.at(strides_index).typeInfo().type() == DataType::INT32);
62 assert(startData_size == input_rank);
63 assert(endData_size == input_rank);
64 assert(stridesData_size == input_rank);
65
66 assert(startData_base != nullptr);
67 for (int n = 0; n < input_rank; ++n)
68 {
69 auto axis = ::onert::backend::acl_common::ToARMComputeAxis(input_rank, n).value();
70
71 int32_t start_value = *(reinterpret_cast<const int32_t *>(startData_base) + n);
72 starts[axis] = start_value;
73
74 int32_t end_value = *(reinterpret_cast<const int32_t *>(endData_base) + n);
75 ends[axis] = end_value;
76
77 int32_t strides_value = *(reinterpret_cast<const int32_t *>(stridesData_base) + n);
78 strides[axis] = strides_value;
79 }
80 }
81
82 // Set mask bits such as order of inputData
83 const auto begin_mask = acl_common::ReorderBits<int32_t>(node.param().begin_mask, input_rank);
84 const auto end_mask = acl_common::ReorderBits<int32_t>(node.param().end_mask, input_rank);
85 const auto shrink_axis_mask =
86 acl_common::ReorderBits<int32_t>(node.param().shrink_axis_mask, input_rank);
87
88 ::arm_compute::Coordinates starts_set;
89 ::arm_compute::Coordinates ends_set;
90 ::arm_compute::BiStrides strides_set;
91
92 for (size_t i = 0; i < starts.size(); ++i)
93 {
94 starts_set.set(i, starts[i]);
95 ends_set.set(i, ends[i]);
96 strides_set.set(i, strides[i]);
97 }
98
99 // Disable applied dim_correction
100 if (inputData_tensor->num_dimensions() != inputData_tensor->info()->num_dimensions())
101 {
102 // This means that high dimension's value is 1 and input tensor is applied dim_correction
103 acl_common::disableDimCorrection(inputData_tensor);
104 }
105
106 auto fn = acl_common::generateLayer<arm_compute::CLStridedSlice>(
107 inputData_tensor->handle(), outputData_tensor->handle(), starts_set, ends_set, strides_set,
108 begin_mask, end_mask, shrink_axis_mask);
109
110 // Revert disabling applied dim_correction
111 if (inputData_tensor->dimension(0) == 1)
112 {
113 acl_common::enableDimCorrection(inputData_tensor);
114 }
115
116 _return_fn = acl_common::asAclFunction(std::move(fn));
117}
118
119} // namespace onert::backend::acl_cl
std::unique_ptr< exec::IFunction > _return_fn
const Object & at(const Index &index) const
Get the object that is associated with the given index.
ARMComputeAxis ToARMComputeAxis(uint32_t rank, uint32_t axis)
Definition Swizzle.h:45
void enableDimCorrection(IACLTensor *tensor)
std::unique_ptr< AclFunction > asAclFunction(std::unique_ptr<::arm_compute::IFunction > &&layer)
Definition Convert.cc:246
void disableDimCorrection(IACLTensor *tensor)
OperandType type
Definition Operand.h:42