ONE - On-device Neural Engine
Loading...
Searching...
No Matches
MatMul.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "NodeExecution.h"
19
20#include "NodeDataImpl.h"
21#include "NodeDomain.h"
22#include "Validation.h"
23
29
30#include <cassert>
31#include <stdexcept>
32
33namespace
34{
40
44template <typename T> Buffer<T> calc_mat_mul(const Buffer<T> *lhs_buf, const Buffer<T> *rhs_buf)
45{
46 const auto lhs_shape = lhs_buf->shape();
47 const auto rhs_shape = rhs_buf->shape();
48
49 assert(lhs_shape.rank() == 2 && "lhs rank must be 2");
50 assert(rhs_shape.rank() == 2 && "rhs rank must be 2");
51 // lhs width should be the same as rhs height
52 assert(lhs_shape.dim(1) == rhs_shape.dim(0) && "height/width mismatch");
53
54 const uint32_t lhs_height = lhs_shape.dim(0);
55 const uint32_t lhs_width = lhs_shape.dim(1);
56
57 const uint32_t rhs_width = rhs_shape.dim(1);
58
59 const uint32_t output_height = lhs_height;
60 const uint32_t output_width = rhs_width;
61
62 Shape output_shape{output_height, output_width};
63 auto output_buf = make_buffer<T, LexicalLayout>(output_shape);
64
65 for (uint32_t out_y = 0; out_y < output_height; ++out_y)
66 {
67 for (uint32_t out_x = 0; out_x < output_width; ++out_x)
68 {
69 T total = static_cast<T>(0); // accumulator
70 // Accumulate through axis
71 for (uint32_t axis = 0; axis < lhs_width; ++axis)
72 {
73 total += lhs_buf->at(Index({out_y, axis})) * rhs_buf->at(Index({axis, out_x}));
74 }
75 // Set output value
76 output_buf.at(Index({out_y, out_x})) = total;
77 }
78 }
79
80 return output_buf;
81}
82
83} // namespace
84
85namespace
86{
87
88using namespace locomotiv;
89
90void execute_node(loco::MatMul *mat_mul)
91{
92 auto lhs_data = annot_data(mat_mul->lhs());
93 auto rhs_data = annot_data(mat_mul->rhs());
94
95 validate(lhs_data, "Can't find left matrix data of MatMul");
96 validate(lhs_data->shape()->rank() == 2, "lhs rank must be 2");
97
98 validate(rhs_data, "Can't find right matrix data of MatMul");
99 validate(rhs_data->shape()->rank() == 2, "rhs rank must be 2");
100
102 "Left matrix of MatMul is not a Matrix");
104 "Right matrix of MatMul is not a Matrix");
105
106 std::unique_ptr<NodeData> mat_mul_result = nullptr;
107
108 if (lhs_data->dtype() == loco::DataType::FLOAT32 && rhs_data->dtype() == loco::DataType::FLOAT32)
109 {
110 const auto lhs_buf = lhs_data->as_f32_bufptr();
111 const auto rhs_buf = rhs_data->as_f32_bufptr();
112
113 auto mat_mul_buf = calc_mat_mul<float>(lhs_buf, rhs_buf);
114
115 mat_mul_result = make_data(mat_mul_buf);
116 }
117 else if (lhs_data->dtype() == loco::DataType::S32 && rhs_data->dtype() == loco::DataType::S32)
118 {
119 const auto lhs_buf = lhs_data->as_s32_bufptr();
120 const auto rhs_buf = rhs_data->as_s32_bufptr();
121
122 auto mat_mul_buf = calc_mat_mul<int32_t>(lhs_buf, rhs_buf);
123
124 mat_mul_result = make_data(mat_mul_buf);
125 }
126 else
127 throw std::runtime_error("NYI for these DataTypes");
128
129 assert(mat_mul_result != nullptr);
130
131 annot_data(mat_mul, std::move(mat_mul_result));
133}
134
135} // namespace
136
137namespace locomotiv
138{
139
140void NodeExecution::execute(loco::MatMul *mat_mul) { execute_node(mat_mul); }
141
142} // namespace locomotiv
Matrix Multiplication lhs and rhs.
Definition Nodes.h:1065
Node * rhs(void) const
Definition Nodes.h:1073
Node * lhs(void) const
Definition Nodes.h:1070
const luci_interpreter::RuntimeShape output_shape
bool validate(Code *code)
void annot_domain(loco::Node *node, const loco::Domain &domain)
Wrapper to annotate domain to node. Cannot annotate unknown domain.
std::unique_ptr< NodeData > make_data(const NodeData::Buffer< DT > &buffer)
Copy buffer to make NodeData.
Buffer< T > make_buffer(const Shape &shape)
Definition Buffer.h:47
Definition Shape.h:28