26 const auto input_index = cur_op->inputs()->operator[](0);
27 const auto axis_index = cur_op->inputs()->operator[](1);
28 const auto output_index = cur_op->outputs()->operator[](0);
30 assert(input_index != -1);
31 assert(axis_index != -1);
32 assert(output_index != -1);
38 assert(input !=
nullptr);
39 assert(axis !=
nullptr);
40 assert(output !=
nullptr);
43 assert(axis_data !=
nullptr);
47 switch (Tensor::element_type(axis))
50 axis_value = *
reinterpret_cast<int32_t *
>(axis_data);
53 axis_value =
static_cast<int32_t
>(*
reinterpret_cast<int64_t *
>(axis_data));
56 assert(
false &&
"Unsupported type.");
61 axis_value += Tensor::num_dims(input) + 1;
70 const auto input_index = cur_op->inputs()->operator[](0);
71 const auto output_index = cur_op->outputs()->operator[](0);
73 assert(input_index != -1);
74 assert(output_index != -1);
91 assert(input_data !=
nullptr);
92 assert(output_data !=
nullptr);
94 const size_t element_size =
getDataTypeSize(Tensor::element_type(input));
95 const int32_t num_elements = Tensor::num_elements(input);
96 std::memcpy(output_data, input_data, num_elements * element_size);