52 _temp_lhs_shape.
Resize(rank);
54 for (int32_t i = 0; i < rank - 2; i++)
58 _temp_lhs_shape.
SetDim(rank - 2, lhs_shape.
Dims(rank - 1));
59 _temp_lhs_shape.
SetDim(rank - 1, lhs_shape.
Dims(rank - 2));
61 _temp_lhs.resize(_temp_lhs_shape.
FlatSize());
67 _temp_rhs_shape.
Resize(rank);
69 for (int32_t i = 0; i < rank - 2; i++)
73 _temp_rhs_shape.
SetDim(rank - 2, rhs_shape.
Dims(rank - 1));
74 _temp_rhs_shape.
SetDim(rank - 1, rhs_shape.
Dims(rank - 2));
76 _temp_rhs.resize(_temp_rhs_shape.
FlatSize());
81 const float *rhs_data,
bool adj_x,
bool adj_y,
const Shape & ,
89 transposeRowsCols(rhs_shape, rhs_data, _temp_rhs_shape, _temp_rhs.data());
94 transposeRowsCols(lhs_shape, lhs_data, _temp_lhs_shape, _temp_lhs.data());
97 Shape new_lhs_shape = adj_x ? lhs_shape : swapRowColDims(lhs_shape);
98 Shape new_rhs_shape = adj_y ? rhs_shape : swapRowColDims(rhs_shape);
99 const float *new_lhs_data = adj_x ? _temp_lhs.data() : lhs_data;
100 const float *new_rhs_data = adj_y ? rhs_data : _temp_rhs.data();
104 assert(Shape::ExtendedShape(5, new_rhs_shape).
Dims(4) ==
105 Shape::ExtendedShape(5, new_lhs_shape).
Dims(3));
108#if defined(CKER_X86_PLATFORM)
109 optimized::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data);