40 const auto input_index = cur_op->inputs()->operator[](0);
41 const auto axis_index = cur_op->inputs()->operator[](2);
43 assert(input_index != -1);
44 assert(axis_index != -1);
49 assert(input !=
nullptr);
50 assert(axis !=
nullptr);
53 if (axis_data ==
nullptr)
58 int32_t axis_value = (kernels::getTensorData<int32_t>(axis_data))[0];
60 axis_value += Tensor::num_dims(input);
62 assert(axis_value >= 0);
63 assert(axis_value < Tensor::num_dims(input));
65 switch (Tensor::element_type(input))
68 case DataType::FLOAT32:
70 return splitImpl<float>(cur_op, input, axis_value, runtime_graph);
76 return splitImpl<int8_t>(cur_op, input, axis_value, runtime_graph);
80 return splitImpl<int16_t>(cur_op, input, axis_value, runtime_graph);
85 return splitImpl<int32_t>(cur_op, input, axis_value, runtime_graph);
88 assert(
false &&
"Unsupported type");