86 const auto x = loco::must_cast<CircleNode *>(node->
x());
87 const auto y = loco::must_cast<CircleNode *>(node->
y());
92 uint32_t x_rank = x_shape.rank();
93 uint32_t y_rank = y_shape.rank();
96 throw_unless(x_rank >= 2,
"x_rank shoud be >= 2");
97 throw_unless(y_rank >= 2,
"y_rank shoud be >= 2");
98 throw_unless((not contain_zero(x_shape)),
"x_shape should NOT have 0");
99 throw_unless((not contain_zero(y_shape)),
"y_shape should NOT have 0");
112 uint32_t max_rank = x_rank > y_rank ? x_rank : y_rank;
117 if (x_rank > 2 || y_rank > 2)
119 const auto x_batch_dims = remove_last_two(x_shape);
120 const auto y_batch_dims = remove_last_two(y_shape);
124 const auto o_batch_rank = o_batch_dims.rank();
125 for (uint i = 0u; i < o_batch_rank; ++i)
132 const auto adj_x = node->
adj_x();
133 const auto adj_y = node->
adj_y();
135 loco::Dimension x_lhs = adj_x ? x_shape.dim(x_rank - 1) : x_shape.dim(x_rank - 2);
136 loco::Dimension x_rhs = adj_x ? x_shape.dim(x_rank - 2) : x_shape.dim(x_rank - 1);
137 loco::Dimension y_lhs = adj_y ? y_shape.dim(y_rank - 1) : y_shape.dim(y_rank - 2);
138 loco::Dimension y_rhs = adj_y ? y_shape.dim(y_rank - 2) : y_shape.dim(y_rank - 1);
140 if (x_rhs.
known() && y_lhs.
known() && not(x_rhs == y_lhs))