ONE - On-device Neural Engine
Loading...
Searching...
No Matches
BatchMatMul.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/BatchMatMul.h"
19#include "kernels/Utils.h"
20
21#include "PALBatchMatMul.h"
22
23#include <tensorflow/lite/kernels/internal/reference/transpose.h>
24
25#include <stdexcept>
26
27namespace
28{
29
30tflite::RuntimeShape SwapRowColumnDims(const tflite::RuntimeShape &shape)
31{
32 tflite::RuntimeShape swapped_shape(shape);
33 const int32_t dims = shape.DimensionsCount();
34 swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
35 swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
36 return swapped_shape;
37}
38
39} // namespace
40
41namespace luci_interpreter
42{
43namespace kernels
44{
45
46BatchMatMul::BatchMatMul(const Tensor *x, const Tensor *y, Tensor *output, Tensor *x_tmp,
47 Tensor *y_tmp, const BatchMatMulParams &params)
48 : KernelWithParams({x, y}, {output, x_tmp, y_tmp}, params)
49{
50}
51
53{
54 auto lhs = x();
55 auto rhs = y();
56 auto adj_x = params().adj_x;
57 auto adj_y = params().adj_y;
58
59 // TODO Support non-float types
60 if (lhs->element_type() != DataType::FLOAT32 || rhs->element_type() != DataType::FLOAT32)
61 throw std::runtime_error("luci-intp BatchMatMul(1) Unsupported type.");
62
63 LUCI_INTERPRETER_CHECK(lhs->element_type() == rhs->element_type());
64
65 auto lhs_rank = lhs->shape().num_dims();
66 auto rhs_rank = rhs->shape().num_dims();
67 LUCI_INTERPRETER_CHECK(lhs_rank >= 2 && lhs_rank <= 4);
68 LUCI_INTERPRETER_CHECK(rhs_rank >= 2 && rhs_rank <= 4);
69
70 auto lhs_scratchpad = temp_lhs();
71 auto rhs_scratchpad = temp_rhs();
72 luci_interpreter_pal::SetupScratchpadTensor(lhs_scratchpad, rhs_scratchpad, getTensorShape(lhs),
73 getTensorShape(rhs));
74
75 auto output_rank = std::max(lhs_rank, rhs_rank);
76
77 auto extended_lhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank, getTensorShape(lhs));
78 auto extended_rhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank, getTensorShape(rhs));
79
80 // Ensure any batch dimensions obey broacasting rules.
81 for (int i = 0; i < output_rank - 2; ++i)
82 {
83 const int lhs_dim = extended_lhs_shape.Dims(i);
84 const int rhs_dim = extended_rhs_shape.Dims(i);
85 if (lhs_dim != rhs_dim)
86 {
87 if (lhs_dim != 1)
88 {
89 LUCI_INTERPRETER_CHECK(rhs_dim == 1);
90 }
91 }
92 }
93
94 // Ensure other dimensions work for matrix multiplication.
95 int accum_dim_lhs =
96 adj_x ? extended_lhs_shape.Dims(output_rank - 2) : extended_lhs_shape.Dims(output_rank - 1);
97 int accum_dim_rhs =
98 adj_y ? extended_rhs_shape.Dims(output_rank - 1) : extended_rhs_shape.Dims(output_rank - 2);
99 LUCI_INTERPRETER_CHECK(accum_dim_lhs == accum_dim_rhs);
100
101 Shape output_shape(output_rank);
102 // Fill in any broadcast dimensions.
103 for (int i = 0; i < output_rank - 2; ++i)
104 {
105 const int lhs_dim = extended_lhs_shape.Dims(i);
106 const int rhs_dim = extended_rhs_shape.Dims(i);
107 int broadcast_dim = lhs_dim;
108 if ((lhs_dim != rhs_dim) && (lhs_dim == 1))
109 {
110 broadcast_dim = rhs_dim;
111 }
112 output_shape.dim(i) = broadcast_dim;
113 }
114 // Fill in the matmul dimensions.
115 int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
116 int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
117
118 output_shape.dim(output_rank - 2) = extended_lhs_shape.Dims(lhs_rows_index);
119 output_shape.dim(output_rank - 1) = extended_rhs_shape.Dims(rhs_cols_index);
120
122}
123
124void TransposeRowsColumns(const Tensor *tensor_in, Tensor *tensor_out)
125{
126 tflite::RuntimeShape transposed_shape(getTensorShape(tensor_in));
127 tflite::RuntimeShape shape(getTensorShape(tensor_in));
128 tflite::TransposeParams params;
129 int rank = shape.DimensionsCount();
130 params.perm_count = rank;
131 for (int i = 0; i < rank - 2; ++i)
132 {
133 params.perm[i] = i;
134 }
135 // Transpose the last two dimensions.
136 params.perm[rank - 2] = rank - 1;
137 params.perm[rank - 1] = rank - 2;
138 transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2));
139 transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1));
140 switch (tensor_in->element_type())
141 {
142 case DataType::FLOAT32:
143 tflite::reference_ops::Transpose(params, shape, getTensorData<float>(tensor_in),
144 transposed_shape, getTensorData<float>(tensor_out));
145 break;
146 default:
147 throw std::runtime_error("Only suppport fp32 BatchMatMul for now.");
148 }
149}
150
152{
153 auto lhs = x();
154 auto rhs = y();
155
156 bool adj_x = params().adj_x;
157 bool adj_y = params().adj_y;
158
159 auto orig_lhs_shape = getTensorShape(lhs);
160 auto orig_rhs_shape = getTensorShape(rhs);
161
162 auto rhs_tensor = adj_y ? rhs : temp_rhs();
163 auto lhs_tensor = adj_x ? temp_lhs() : lhs;
164 if (not adj_y)
165 {
166 TransposeRowsColumns(rhs, temp_rhs());
167 }
168 if (adj_x)
169 {
170 TransposeRowsColumns(lhs, temp_lhs());
171 }
172 tflite::RuntimeShape rhs_shape = adj_y ? orig_rhs_shape : SwapRowColumnDims(orig_rhs_shape);
173 tflite::RuntimeShape lhs_shape = adj_x ? orig_lhs_shape : SwapRowColumnDims(orig_lhs_shape);
174
175 switch (x()->element_type())
176 {
177 case DataType::FLOAT32:
178 luci_interpreter_pal::BatchMatMul(rhs_shape, getTensorData<float>(rhs_tensor), lhs_shape,
179 getTensorData<float>(lhs_tensor), getTensorShape(output()),
180 getTensorData<float>(output()));
181 break;
182 default:
183 throw std::runtime_error("luci-intp BatchMatMul(2) Unsupported type.");
184 }
185}
186
187} // namespace kernels
188} // namespace luci_interpreter
const BatchMatMulParams & params() const
Definition Kernel.h:67
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
DataType element_type() const
Definition Tensor.h:105
BatchMatMul(const Tensor *x, const Tensor *y, Tensor *output, Tensor *x_tmp, Tensor *y_tmp, const BatchMatMulParams &params)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
void TransposeRowsColumns(const Tensor *tensor_in, Tensor *tensor_out)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void BatchMatMul(const tflite::RuntimeShape &lhs_shape, const float *lhs_data, const tflite::RuntimeShape &rhs_shape, const float *rhs_data, const tflite::RuntimeShape &output_shape, float *output_data)