22#include "Validation.h"
44template <
typename T> Buffer<T> calc_mat_mul(
const Buffer<T> *lhs_buf,
const Buffer<T> *rhs_buf)
46 const auto lhs_shape = lhs_buf->shape();
47 const auto rhs_shape = rhs_buf->shape();
49 assert(lhs_shape.rank() == 2 &&
"lhs rank must be 2");
50 assert(rhs_shape.rank() == 2 &&
"rhs rank must be 2");
52 assert(lhs_shape.dim(1) == rhs_shape.dim(0) &&
"height/width mismatch");
54 const uint32_t lhs_height = lhs_shape.dim(0);
55 const uint32_t lhs_width = lhs_shape.dim(1);
57 const uint32_t rhs_width = rhs_shape.dim(1);
59 const uint32_t output_height = lhs_height;
60 const uint32_t output_width = rhs_width;
63 auto output_buf = make_buffer<T, LexicalLayout>(
output_shape);
65 for (uint32_t out_y = 0; out_y < output_height; ++out_y)
67 for (uint32_t out_x = 0; out_x < output_width; ++out_x)
69 T total =
static_cast<T
>(0);
71 for (uint32_t axis = 0; axis < lhs_width; ++axis)
73 total += lhs_buf->at(
Index({out_y, axis})) * rhs_buf->at(
Index({axis, out_x}));
76 output_buf.at(
Index({out_y, out_x})) = total;
92 auto lhs_data = annot_data(mat_mul->
lhs());
93 auto rhs_data = annot_data(mat_mul->
rhs());
95 validate(lhs_data,
"Can't find left matrix data of MatMul");
96 validate(lhs_data->shape()->rank() == 2,
"lhs rank must be 2");
98 validate(rhs_data,
"Can't find right matrix data of MatMul");
99 validate(rhs_data->shape()->rank() == 2,
"rhs rank must be 2");
102 "Left matrix of MatMul is not a Matrix");
104 "Right matrix of MatMul is not a Matrix");
106 std::unique_ptr<NodeData> mat_mul_result =
nullptr;
108 if (lhs_data->dtype() == loco::DataType::FLOAT32 && rhs_data->dtype() == loco::DataType::FLOAT32)
110 const auto lhs_buf = lhs_data->as_f32_bufptr();
111 const auto rhs_buf = rhs_data->as_f32_bufptr();
113 auto mat_mul_buf = calc_mat_mul<float>(lhs_buf, rhs_buf);
117 else if (lhs_data->dtype() == loco::DataType::S32 && rhs_data->dtype() == loco::DataType::S32)
119 const auto lhs_buf = lhs_data->as_s32_bufptr();
120 const auto rhs_buf = rhs_data->as_s32_bufptr();
122 auto mat_mul_buf = calc_mat_mul<int32_t>(lhs_buf, rhs_buf);
127 throw std::runtime_error(
"NYI for these DataTypes");
129 assert(mat_mul_result !=
nullptr);
131 annot_data(mat_mul, std::move(mat_mul_result));
140void NodeExecution::execute(
loco::MatMul *mat_mul) { execute_node(mat_mul); }
Matrix Multiplication lhs and rhs.
const luci_interpreter::RuntimeShape output_shape
bool validate(Code *code)
void annot_domain(loco::Node *node, const loco::Domain &domain)
Wrapper to annotate domain to node. Cannot annotate unknown domain.
std::unique_ptr< NodeData > make_data(const NodeData::Buffer< DT > &buffer)
Copy buffer to make NodeData.
Buffer< T > make_buffer(const Shape &shape)