33 std::vector<mir::Operation::Output *> inputs = context->
getNodeInputs(onnx_node);
37 if (shape_attr && shape_attr->ints_size() > 0)
41 for (int32_t index = 0; index < out_shape.
rank(); index++)
43 const auto dim_value = shape_attr->ints(index);
45 out_shape.
dim(index) = in_shape.
dim(index);
47 out_shape.
dim(index) = dim_value;
50 auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape)->getOutput(0);
62 std::vector<mir::Operation::Output *> inputs = context->
getNodeInputs(onnx_node);
65 const auto &in_shape = inputs[0]->getShape();
69 assert(op &&
"We support only constant shape input");
70 auto shape_tensor = op->getValue();
71 mir::Shape shape_tensor_shape = (shape_tensor).getShape();
72 assert(shape_tensor_shape.
rank() == 1);
76 std::vector<int32_t> shape_vector(cnt);
81 for (
auto idx : out_range)
83 if (tensor_accessor.
at(idx) == 0)
84 shape_vector[i] = in_shape.dim(i);
85 else if (tensor_accessor.
at(idx) == -1)
86 shape_vector[i] = mir::Shape::autoDim;
88 shape_vector[i] = tensor_accessor.
at(idx);
92 auto result = createOp<mir::ops::ReshapeOp>(graph, inputs[0], out_shape)->getOutput(0);
void setNodeOutputs(const onnx::NodeProto &onnx_node, const std::vector< mir::Operation::Output * > &outputs)
std::vector< mir::Operation::Output * > getNodeInputs(const onnx::NodeProto &onnx_node) const
mir::Graph * getGraph() const