17#include "../KernelGenerator.h"
18#include "../Validator.h"
25void Validator::visit(
const ir::operation::StridedSlice &) {
_supported =
true; }
27void KernelGenerator::visit(
const ir::operation::StridedSlice &node)
29 const auto output_index{node.getOutputs().at(0)};
35 auto outputData_tensor = _tensor_reg->getAclTensor(output_index);
36 auto inputData_tensor = _tensor_reg->getAclTensor(input_index);
39 int input_rank = _ctx.
at(input_index).shape().rank();
40 std::vector<int32_t> starts;
41 std::vector<int32_t> ends;
42 std::vector<int32_t> strides;
43 starts.resize(input_rank, 0);
44 ends.resize(input_rank, 0);
45 strides.resize(input_rank, 0);
47 auto startData_base = _ctx.
at(starts_index).data()->base();
48 auto endData_base = _ctx.
at(ends_index).data()->base();
49 auto stridesData_base = _ctx.
at(strides_index).data()->base();
50 [[maybe_unused]]
const int startData_size = _ctx.
at(starts_index).shape().num_elements();
51 [[maybe_unused]]
const int endData_size = _ctx.
at(ends_index).shape().num_elements();
52 [[maybe_unused]]
const int stridesData_size = _ctx.
at(strides_index).shape().num_elements();
56 assert(_ctx.
at(starts_index).typeInfo().
type() == DataType::INT32);
57 assert(_ctx.
at(ends_index).typeInfo().
type() == DataType::INT32);
58 assert(_ctx.
at(strides_index).typeInfo().
type() == DataType::INT32);
59 assert(startData_size == input_rank);
60 assert(endData_size == input_rank);
61 assert(stridesData_size == input_rank);
63 assert(startData_base !=
nullptr);
64 for (
int n = 0; n < input_rank; ++n)
68 int32_t start_value = *(
reinterpret_cast<const int32_t *
>(startData_base) + n);
69 starts[axis] = start_value;
71 int32_t end_value = *(
reinterpret_cast<const int32_t *
>(endData_base) + n);
72 ends[axis] = end_value;
74 int32_t strides_value = *(
reinterpret_cast<const int32_t *
>(stridesData_base) + n);
75 strides[axis] = strides_value;
80 const auto begin_mask = acl_common::ReorderBits<int32_t>(node.param().begin_mask, input_rank);
81 const auto end_mask = acl_common::ReorderBits<int32_t>(node.param().end_mask, input_rank);
82 const auto shrink_axis_mask =
83 acl_common::ReorderBits<int32_t>(node.param().shrink_axis_mask, input_rank);
85 ::arm_compute::Coordinates starts_set;
86 ::arm_compute::Coordinates ends_set;
87 ::arm_compute::BiStrides strides_set;
89 for (
size_t i = 0; i < starts.size(); ++i)
91 starts_set.set(i, starts[i]);
92 ends_set.set(i, ends[i]);
93 strides_set.set(i, strides[i]);
97 if (
static_cast<size_t>(inputData_tensor->getShape().rank()) !=
98 inputData_tensor->info()->num_dimensions())
104 auto fn = acl_common::generateLayer<arm_compute::NEStridedSlice>(
105 inputData_tensor->handle(), outputData_tensor->handle(), starts_set, ends_set, strides_set,
106 begin_mask, end_mask, shrink_axis_mask);
109 if (inputData_tensor->getShape().dim(0) == 1)
uint32_t value(void) const
std::unique_ptr< exec::IFunction > _return_fn
const Object & at(const Index &index) const
Get the object that is associated with the given index.
ARMComputeAxis ToARMComputeAxis(uint32_t rank, uint32_t axis)
void enableDimCorrection(IACLTensor *tensor)
std::unique_ptr< AclFunction > asAclFunction(std::unique_ptr<::arm_compute::IFunction > &&layer)
void disableDimCorrection(IACLTensor *tensor)