20#include "kernels/Utils.h"
22#include <tensorflow/lite/kernels/internal/reference/strided_slice.h>
40 assert(
begin()->shape().num_dims() == 1);
41 assert(
end()->shape().num_dims() == 1);
42 assert(
strides()->shape().num_dims() == 1);
43 assert(
input()->element_type() ==
output()->element_type());
44 assert(
begin()->element_type() == DataType::S32);
45 assert(
end()->element_type() == DataType::S32);
46 assert(
strides()->element_type() == DataType::S32);
47 assert(
input()->shape().num_dims() <= 5);
48 if (
params().ellipsis_mask != 0)
50 throw std::runtime_error(
"ellipsis_mask is not implemented yet.");
52 if (
params().new_axis_mask != 0)
54 throw std::runtime_error(
"new_axis_mask is not implemented yet.");
56 if (
input()->element_type() == DataType::U8)
59 assert(
input()->zero_point() ==
output()->zero_point());
127 switch (
input()->element_type())
129 case DataType::FLOAT32:
155 throw std::runtime_error(
"luci-intp StridedSlice Unsupported type.");
const StridedSliceParams & params() const
void resize(const Shape &new_shape)
const Shape & shape() const
const Tensor * begin() const
StridedSlice(const Tensor *input, const Tensor *begin, const Tensor *end, const Tensor *strides, Tensor *output, const StridedSliceParams ¶ms)
void configure() override
const Tensor * strides() const
void execute() const override
const Tensor * input() const
const Tensor * end() const
const luci_interpreter::RuntimeShape output_shape
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
T must_cast(loco::Node *node)