40 assert(
begin()->shape().num_dims() == 1);
41 assert(
end()->shape().num_dims() == 1);
42 assert(
strides()->shape().num_dims() == 1);
43 assert(
input()->element_type() ==
output()->element_type());
44 assert(
begin()->element_type() == DataType::S32);
45 assert(
end()->element_type() == DataType::S32);
46 assert(
strides()->element_type() == DataType::S32);
47 assert(
input()->shape().num_dims() <= 4);
48 if (
params().ellipsis_mask != 0)
50 throw std::runtime_error(
"ellipsis_mask is not implemented yet.");
52 if (
params().new_axis_mask != 0)
54 throw std::runtime_error(
"new_axis_mask is not implemented yet.");
56 if (
input()->element_type() == DataType::U8)
59 assert(
input()->zero_point() ==
output()->zero_point());
61 tflite::StridedSliceParams op_params{};
68 op_params.start_indices[i] = getTensorData<int32_t>(
begin())[i];
69 op_params.stop_indices[i] = getTensorData<int32_t>(
end())[i];
70 op_params.strides[i] = getTensorData<int32_t>(
strides())[i];
73 op_params.ellipsis_mask = 0;
75 op_params.new_axis_mask = 0;
77 std::vector<int32_t> output_shape_vector;
81 int32_t stride = getTensorData<int32_t>(
strides())[idx];
93 int32_t dim_shape = std::ceil((
end -
begin) /
static_cast<float>(stride));
94 dim_shape = dim_shape < 0 ? 0 : dim_shape;
97 output_shape_vector.push_back(dim_shape);
101 for (
size_t i = 0; i < output_shape_vector.size(); i++)
103 output_shape.dim(i) = output_shape_vector[output_shape_vector.size() - i - 1];
110 tflite::StridedSliceParams op_params{};
117 op_params.start_indices[i] = getTensorData<int32_t>(
begin())[i];
118 op_params.stop_indices[i] = getTensorData<int32_t>(
end())[i];
119 op_params.strides[i] = getTensorData<int32_t>(
strides())[i];
122 op_params.ellipsis_mask = 0;
124 op_params.new_axis_mask = 0;
127 switch (
input()->element_type())
129 case DataType::FLOAT32:
132 getTensorData<float>(
output()));
137 getTensorData<uint8_t>(
output()));
142 getTensorData<int32_t>(
output()));
147 getTensorData<int64_t>(
output()));
152 getTensorData<bool>(
output()));
155 throw std::runtime_error(
"luci-intp StridedSlice Unsupported type.");