32 assert(inputs.size() == 1);
34 const auto axes = getAttributeValue<std::vector<std::int64_t>>(onnx_node,
"axes");
35 const auto keepdims = getAttributeValue<int64_t>(onnx_node,
"keepdims", 1);
37 std::vector<int32_t> reduce_dims;
40 reduce_dims.resize(inputs[0]->getShape().rank());
41 std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
45 auto rank = inputs[0]->getShape().rank();
47 std::transform(axes.begin(), axes.end(), std::back_inserter(reduce_dims),
48 [rank](int64_t axis) { return axis < 0 ? axis + rank : axis; });
51 bool keep_dims =
static_cast<bool>(keepdims);
55 createOp<mir::ops::ReduceMeanOp>(graph, inputs[0], reduce_dims, keep_dims)->getOutput(0);
void setNodeOutputs(const onnx::NodeProto &onnx_node, const std::vector< mir::Operation::Output * > &outputs)
std::vector< mir::Operation::Output * > getNodeInputs(const onnx::NodeProto &onnx_node) const
mir::Graph * getGraph() const