34 assert(
axes()->shape().num_dims() == 1);
35 assert(
input()->shape().num_dims() >=
axes()->shape().num_elements());
36 if (
input()->element_type() != DataType::S32 &&
input()->element_type() != DataType::FLOAT32 &&
37 input()->element_type() != DataType::U8 &&
input()->element_type() != DataType::S16 &&
38 input()->element_type() != DataType::S64)
40 throw std::runtime_error(
"Unsupported input type.");
42 if (
axes()->element_type() != DataType::S32)
44 throw std::runtime_error(
"Unsupported axes type.");
46 if (
axes()->shape().num_elements() > 1)
48 throw std::runtime_error(
"Current implementation does not support more than 1 axis.");
50 int axis_value = getTensorData<int32_t>(
axes())[0];
51 if (axis_value < 0 || axis_value >=
input()->shape().num_dims())
53 throw std::runtime_error(
"Invalid axes value");
55 assert(
input()->element_type() ==
output()->element_type());