30 std::vector<mir::Operation::Output *> inputs = context->
getNodeInputs(onnx_node);
33 assert(inputs.size() == 1);
34 auto input = inputs[0];
36 const auto &input_shape = input->getShape();
37 if (input_shape.rank() != 4)
38 throw std::runtime_error(
"MaxPool: only 2-D input is supported.");
40 constexpr int num_spatial_dims = 2;
43 getAttributeValue(onnx_node,
"strides", std::vector<std::int32_t>(num_spatial_dims, 1));
44 if (strides.size() != num_spatial_dims)
45 throw std::runtime_error(
"MaxPool: attribute 'strides' has incorrect size.");
47 const auto kernel_shape = getAttributeValue<std::vector<std::int32_t>>(onnx_node,
"kernel_shape");
48 if (kernel_shape.size() != num_spatial_dims)
49 throw std::runtime_error(
"MaxPool: attribute 'kernel_shape' has incorrect size.");
51 std::vector<std::int32_t> padding_before;
52 std::vector<std::int32_t> padding_after;
55 const auto pads = getAttributeValue<std::vector<std::int32_t>>(*pads_attr);
56 if (pads.size() != num_spatial_dims * 2)
57 throw std::runtime_error(
"MaxPool: attribute 'pads' has incorrect size.");
58 padding_before.assign(pads.cbegin(), std::next(pads.cbegin(), num_spatial_dims));
59 padding_after.assign(std::next(pads.cbegin(), num_spatial_dims), pads.cend());
63 const auto auto_pad = getAttributeValue<std::string>(onnx_node,
"auto_pad",
"NOTSET");
64 const std::vector<std::int32_t> dilations(num_spatial_dims, 1);
65 inferAutoPadding(auto_pad, input_shape, dilations, strides, kernel_shape, padding_before,
70 attributes.
window = kernel_shape;
75 auto result = createOp<mir::ops::MaxPool2DOp>(graph, input, attributes)->getOutput(0);
void setNodeOutputs(const onnx::NodeProto &onnx_node, const std::vector< mir::Operation::Output * > &outputs)
std::vector< mir::Operation::Output * > getNodeInputs(const onnx::NodeProto &onnx_node) const
mir::Graph * getGraph() const
void inferAutoPadding(const std::string &pad_type, const mir::Shape &input_shape, const std::vector< std::int32_t > &dilations, const std::vector< std::int32_t > &strides, const std::vector< std::int32_t > &window_size, std::vector< std::int32_t > &padding_before, std::vector< std::int32_t > &padding_after)