53 _temp_lhs_shape.
Resize(rank);
55 for (int32_t i = 0; i < rank - 2; i++)
59 _temp_lhs_shape.
SetDim(rank - 2, lhs_shape.
Dims(rank - 1));
60 _temp_lhs_shape.
SetDim(rank - 1, lhs_shape.
Dims(rank - 2));
62 _temp_lhs.resize(_temp_lhs_shape.
FlatSize());
68 _temp_rhs_shape.
Resize(rank);
70 for (int32_t i = 0; i < rank - 2; i++)
74 _temp_rhs_shape.
SetDim(rank - 2, rhs_shape.
Dims(rank - 1));
75 _temp_rhs_shape.
SetDim(rank - 1, rhs_shape.
Dims(rank - 2));
77 _temp_rhs.resize(_temp_rhs_shape.
FlatSize());
80 _rhs_constant = rhs_const;
84 const float *rhs_data,
bool adj_x,
bool adj_y,
const Shape & ,
88 if (!adj_y && !(_rhs_constant && _rhs_transposed))
90 transposeRowsCols(rhs_shape, rhs_data, _temp_rhs_shape, _temp_rhs.data());
91 _rhs_transposed =
true;
96 transposeRowsCols(lhs_shape, lhs_data, _temp_lhs_shape, _temp_lhs.data());
99 Shape new_lhs_shape = adj_x ? lhs_shape : swapRowColDims(lhs_shape);
100 Shape new_rhs_shape = adj_y ? rhs_shape : swapRowColDims(rhs_shape);
101 const float *new_lhs_data = adj_x ? _temp_lhs.data() : lhs_data;
102 const float *new_rhs_data = adj_y ? rhs_data : _temp_rhs.data();
106 assert(Shape::ExtendedShape(5, new_rhs_shape).
Dims(4) ==
107 Shape::ExtendedShape(5, new_lhs_shape).
Dims(3));
110#if defined(CKER_X86_PLATFORM)
111 optimized::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data);