85 const auto x = loco::must_cast<CircleNode *>(node->
x());
86 const auto y = loco::must_cast<CircleNode *>(node->
y());
91 uint32_t x_rank = x_shape.rank();
92 uint32_t y_rank = y_shape.rank();
95 throw_unless(x_rank >= 2,
"x_rank shoud be >= 2");
96 throw_unless(y_rank >= 2,
"y_rank shoud be >= 2");
97 throw_unless((not contain_zero(x_shape)),
"x_shape should NOT have 0");
98 throw_unless((not contain_zero(y_shape)),
"y_shape should NOT have 0");
111 uint32_t max_rank = x_rank > y_rank ? x_rank : y_rank;
116 if (x_rank > 2 || y_rank > 2)
118 const auto x_batch_dims = remove_last_two(x_shape);
119 const auto y_batch_dims = remove_last_two(y_shape);
123 const auto o_batch_rank = o_batch_dims.rank();
124 for (uint i = 0u; i < o_batch_rank; ++i)
131 const auto adj_x = node->
adj_x();
132 const auto adj_y = node->
adj_y();
134 loco::Dimension x_lhs = adj_x ? x_shape.dim(x_rank - 1) : x_shape.dim(x_rank - 2);
135 loco::Dimension x_rhs = adj_x ? x_shape.dim(x_rank - 2) : x_shape.dim(x_rank - 1);
136 loco::Dimension y_lhs = adj_y ? y_shape.dim(y_rank - 1) : y_shape.dim(y_rank - 2);
137 loco::Dimension y_rhs = adj_y ? y_shape.dim(y_rank - 2) : y_shape.dim(y_rank - 1);
139 if (x_rhs.
known() && y_lhs.
known() && not(x_rhs == y_lhs))