26#include "PALStridedSlice.h"
35constexpr uint32_t inputTensorIdx = 0;
36constexpr uint32_t beginTensorIdx = 1;
37constexpr uint32_t endTensorIdx = 2;
38constexpr uint32_t stridesTensorIdx = 3;
43 const int32_t *end,
const int32_t *strides,
44 const circle::StridedSliceOptions *options)
51 for (
int i = 0; i < dims; ++i)
55 op_params.
strides[i] = strides[i];
75 const circle::Tensor *
input =
nullptr;
76 const circle::Tensor *
begin =
nullptr;
77 const circle::Tensor *
end =
nullptr;
78 const circle::Tensor *strides =
nullptr;
80 const circle::Tensor *
output =
nullptr;
83 const int32_t *begin_data;
84 const int32_t *end_data;
85 const int32_t *strides_data;
89 const circle::StridedSliceOptions *
options;
93 runtime_kernel.
readKernel(op_index, runtime_context);
97 end = runtime_kernel.
inputs[endTensorIdx];
98 strides = runtime_kernel.
inputs[stridesTensorIdx];
101 assert(input !=
nullptr);
102 assert(
begin !=
nullptr);
103 assert(end !=
nullptr);
104 assert(strides !=
nullptr);
105 assert(output !=
nullptr);
107 status = runtime_kernel.
getDataFromStorage(op_index, runtime_storage, runtime_context);
112 begin_data = utils::castInputData<int32_t>(runtime_kernel.
inputs_data[beginTensorIdx]);
113 end_data = utils::castInputData<int32_t>(runtime_kernel.
inputs_data[endTensorIdx]);
114 strides_data = utils::castInputData<int32_t>(runtime_kernel.
inputs_data[stridesTensorIdx]);
117 assert(input_data !=
nullptr);
118 assert(begin_data !=
nullptr);
119 assert(end_data !=
nullptr);
120 assert(strides_data !=
nullptr);
121 assert(output_data !=
nullptr);
129 strides_data, options);
131 switch (
input->type())
134 case circle::TensorType_FLOAT32:
136 status =
pal::StridedSlice(op_params, input_shape, utils::castInputData<float>(input_data),
137 utils::castOutputData<float>(output_data));
142 case circle::TensorType_INT8:
144 status =
pal::StridedSlice(op_params, input_shape, utils::castInputData<int8_t>(input_data),
145 utils::castOutputData<int8_t>(output_data));
149 case circle::TensorType_INT32:
151 status =
pal::StridedSlice(op_params, input_shape, utils::castInputData<int32_t>(input_data),
152 utils::castOutputData<int32_t>(output_data));
158 assert(
false &&
"Unsupported type.");
uint8_t * outputs_data[maxOutputSize]
const circle::Operator * first_operator
OMStatus getDataFromStorage(uint16_t op_index, core::OMRuntimeStorage &storage, core::OMRuntimeContext &context)
uint8_t * inputs_data[maxInputSize]
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
constexpr uint32_t outputTensorIdx
ShapeIterator end(const Shape &s)
StridedSliceParams buildStridedSliceParams(const T *begin, const T *end, const T *strides, const uint32_t begin_mask, const uint32_t end_mask, const uint32_t shrink_axis_mask, const uint8_t rank)
OMStatus StridedSlice(core::StridedSliceParams &op_params, const core::OMRuntimeShape &unextended_input_shape, const T *input_data, T *output_data)
int8_t start_indices_count
int8_t stop_indices_count
core::OMRuntimeContext & runtime_context
core::OMRuntimeStorage & runtime_storage