21#include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
29 std::vector<Tensor *> outputs)
36 assert(
axis()->shape().num_elements() == 1);
40 assert(_axis_value >= 0 && _axis_value <
input()->shape().num_dims());
42 auto num_split =
static_cast<int32_t
>(
_outputs.size());
63 assert(
size_splits()->shape().num_elements() == num_split);
66 for (int32_t
i = 0;
i < num_split; ++
i)
82 tflite::SplitParams params{};
84 params.axis = _axis_value;
86#define TF_LITE_SPLIT(scalar) \
88 VectorOfTensors<scalar, false> all_outputs(_outputs); \
89 tflite::optimized_ops::Split(params, getTensorShape(input()), getTensorData<scalar>(input()), \
90 all_outputs.shapes(), all_outputs.data()); \
93 switch (
input()->element_type())
95 case DataType::FLOAT32:
105 throw std::runtime_error(
"luci-intp SplitV Unsupported type.");
const std::vector< Tensor * > _outputs
void resize(int dimensions_count)
const Shape & shape() const
void configure() override
const Tensor * axis() const
SplitV(const Tensor *input, const Tensor *size_splits, const Tensor *axis, std::vector< Tensor * > outputs)
void execute() const override
const Tensor * size_splits() const
const Tensor * input() const
#define TF_LITE_SPLIT(scalar)
const luci_interpreter::RuntimeShape output_shape
T must_cast(loco::Node *node)