35 assert(
axis()->shape().num_elements() == 1);
36 _axis_value = getTensorData<int32_t>(
axis())[0];
39 assert(_axis_value >= 0 && _axis_value <
input()->shape().num_dims());
42 assert(input_size %
_outputs.size() == 0);
43 const int32_t slice_size = input_size /
_outputs.size();
55 tflite::SplitParams params{};
57 params.axis = _axis_value;
59#define TF_LITE_SPLIT(scalar) \
61 VectorOfTensors<scalar, false> all_outputs(_outputs); \
62 luci_interpreter_pal::Split(params, getTensorShape(input()), getTensorData<scalar>(input()), \
63 all_outputs.shapes(), all_outputs.data()); \
66 switch (
input()->element_type())
68 case DataType::FLOAT32:
75 throw std::runtime_error(
"luci-intp Split Unsupported type.");
const std::vector< Tensor * > _outputs
void resize(const Shape &new_shape)
const Shape & shape() const
Tensor * output(int index) const
const Tensor * input() const
const Tensor * axis() const
Split(const Tensor *axis, const Tensor *input, std::vector< Tensor * > outputs)
void configure() override
void execute() const override
#define TF_LITE_SPLIT(scalar)
const luci_interpreter::RuntimeShape output_shape
This file contains utility macro.