45 const int num_dims = std::max(lhs_shape.
rank(), rhs_shape.
rank());
46 Shape result_shape(num_dims);
48 for (
int i = 0; i < num_dims; ++i)
50 const std::int32_t lhs_dim =
51 (i >= num_dims - lhs_shape.
rank()) ? lhs_shape.
dim(i - (num_dims - lhs_shape.
rank())) : 1;
52 const std::int32_t rhs_dim =
53 (i >= num_dims - rhs_shape.
rank()) ? rhs_shape.
dim(i - (num_dims - rhs_shape.
rank())) : 1;
56 result_shape.
dim(i) = rhs_dim;
60 assert(rhs_dim == 1 || rhs_dim == lhs_dim);
61 result_shape.
dim(i) = lhs_dim;