41 const auto input_index = cur_op->inputs()->operator[](1);
42 const auto axis_index = cur_op->inputs()->operator[](0);
44 assert(input_index != -1);
45 assert(axis_index != -1);
50 assert(input !=
nullptr);
51 assert(axis !=
nullptr);
54 if (axis_data ==
nullptr)
59 int32_t axis_value = (kernels::getTensorData<int32_t>(axis_data))[0];
61 axis_value += Tensor::num_dims(input);
63 assert(axis_value >= 0);
64 assert(axis_value < Tensor::num_dims(input));
66 switch (Tensor::element_type(input))
69 case DataType::FLOAT32:
71 return splitImpl<float>(cur_op, input, axis_value, runtime_graph);
77 return splitImpl<int8_t>(cur_op, input, axis_value, runtime_graph);
81 return splitImpl<int16_t>(cur_op, input, axis_value, runtime_graph);
86 return splitImpl<int32_t>(cur_op, input, axis_value, runtime_graph);
89 assert(
false &&
"Unsupported type");