ONE - On-device Neural Engine
Loading...
Searching...
No Matches
BatchMatMul.cpp
Go to the documentation of this file.
1/*
2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#include "kernels/BatchMatMul.h"
19#include "kernels/Utils.h"
20
21#include "PALBatchMatMul.h"
22
23#include <tensorflow/lite/kernels/internal/reference/transpose.h>
24
25namespace
26{
27
28tflite::RuntimeShape SwapRowColumnDims(const tflite::RuntimeShape &shape)
29{
30 tflite::RuntimeShape swapped_shape(shape);
31 const int32_t dims = shape.DimensionsCount();
32 swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
33 swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
34 return swapped_shape;
35}
36
37} // namespace
38
39namespace luci_interpreter
40{
41namespace kernels
42{
43
44BatchMatMul::BatchMatMul(const Tensor *x, const Tensor *y, Tensor *output, Tensor *x_tmp,
45 Tensor *y_tmp, const BatchMatMulParams &params)
46 : KernelWithParams({x, y}, {output, x_tmp, y_tmp}, params)
47{
48}
49
50void BatchMatMul::configure()
51{
52 auto lhs = x();
53 auto rhs = y();
54 auto adj_x = params().adj_x;
55 auto adj_y = params().adj_y;
56
57 // TODO Support non-float types
58 if (lhs->element_type() != DataType::FLOAT32 || rhs->element_type() != DataType::FLOAT32)
59 assert(false && "Unsupported type.");
60
61 LUCI_INTERPRETER_CHECK(lhs->element_type() == rhs->element_type());
62
63 auto lhs_rank = lhs->shape().num_dims();
64 auto rhs_rank = rhs->shape().num_dims();
65 LUCI_INTERPRETER_CHECK(lhs_rank >= 2 && lhs_rank <= 4);
66 LUCI_INTERPRETER_CHECK(rhs_rank >= 2 && rhs_rank <= 4);
67
68 auto lhs_scratchpad = temp_lhs();
69 auto rhs_scratchpad = temp_rhs();
70 luci_interpreter_pal::SetupScratchpadTensor(lhs_scratchpad, rhs_scratchpad, getTensorShape(lhs),
71 getTensorShape(rhs));
72
73 auto output_rank = std::max(lhs_rank, rhs_rank);
74
75 auto extended_lhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank, getTensorShape(lhs));
76 auto extended_rhs_shape = tflite::RuntimeShape::ExtendedShape(output_rank, getTensorShape(rhs));
77
78 // Ensure any batch dimensions obey broacasting rules.
79 for (int i = 0; i < output_rank - 2; ++i)
80 {
81 const int lhs_dim = extended_lhs_shape.Dims(i);
82 const int rhs_dim = extended_rhs_shape.Dims(i);
83 if (lhs_dim != rhs_dim)
84 {
85 if (lhs_dim != 1)
86 {
87 LUCI_INTERPRETER_CHECK(rhs_dim == 1);
88 }
89 }
90 }
91
92 // Ensure other dimensions work for matrix multiplication.
93 int accum_dim_lhs =
94 adj_x ? extended_lhs_shape.Dims(output_rank - 2) : extended_lhs_shape.Dims(output_rank - 1);
95 int accum_dim_rhs =
96 adj_y ? extended_rhs_shape.Dims(output_rank - 1) : extended_rhs_shape.Dims(output_rank - 2);
97 LUCI_INTERPRETER_CHECK(accum_dim_lhs == accum_dim_rhs);
98
99 Shape output_shape(output_rank);
100 // Fill in any broadcast dimensions.
101 for (int i = 0; i < output_rank - 2; ++i)
102 {
103 const int lhs_dim = extended_lhs_shape.Dims(i);
104 const int rhs_dim = extended_rhs_shape.Dims(i);
105 int broadcast_dim = lhs_dim;
106 if ((lhs_dim != rhs_dim) && (lhs_dim == 1))
107 {
108 broadcast_dim = rhs_dim;
109 }
110 output_shape.dim(i) = broadcast_dim;
111 }
112 // Fill in the matmul dimensions.
113 int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
114 int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
115
116 output_shape.dim(output_rank - 2) = extended_lhs_shape.Dims(lhs_rows_index);
117 output_shape.dim(output_rank - 1) = extended_rhs_shape.Dims(rhs_cols_index);
118
119 output()->resize(output_shape);
120}
121
122void TransposeRowsColumns(const Tensor *tensor_in, Tensor *tensor_out)
123{
124 tflite::RuntimeShape transposed_shape(getTensorShape(tensor_in));
125 tflite::RuntimeShape shape(getTensorShape(tensor_in));
126 tflite::TransposeParams params;
127 int rank = shape.DimensionsCount();
128 params.perm_count = rank;
129 for (int i = 0; i < rank - 2; ++i)
130 {
131 params.perm[i] = i;
132 }
133 // Transpose the last two dimensions.
134 params.perm[rank - 2] = rank - 1;
135 params.perm[rank - 1] = rank - 2;
136 transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2));
137 transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1));
138 switch (tensor_in->element_type())
139 {
140 case DataType::FLOAT32:
141 tflite::reference_ops::Transpose(params, shape, getTensorData<float>(tensor_in),
142 transposed_shape, getTensorData<float>(tensor_out));
143 break;
144 default:
145 assert(false && "Only suppport fp32 BatchMatMul for now.");
146 }
147}
148
149void BatchMatMul::execute() const
150{
151 auto lhs = x();
152 auto rhs = y();
153
154 bool adj_x = params().adj_x;
155 bool adj_y = params().adj_y;
156
157 auto orig_lhs_shape = getTensorShape(lhs);
158 auto orig_rhs_shape = getTensorShape(rhs);
159
160 auto rhs_tensor = adj_y ? rhs : temp_rhs();
161 auto lhs_tensor = adj_x ? temp_lhs() : lhs;
162 if (not adj_y)
163 {
164 TransposeRowsColumns(rhs, temp_rhs());
165 }
166 if (adj_x)
167 {
168 TransposeRowsColumns(lhs, temp_lhs());
169 }
170 tflite::RuntimeShape rhs_shape = adj_y ? orig_rhs_shape : SwapRowColumnDims(orig_rhs_shape);
171 tflite::RuntimeShape lhs_shape = adj_x ? orig_lhs_shape : SwapRowColumnDims(orig_lhs_shape);
172
173 switch (x()->element_type())
174 {
175 case DataType::FLOAT32:
176 luci_interpreter_pal::BatchMatMul(rhs_shape, getTensorData<float>(rhs_tensor), lhs_shape,
177 getTensorData<float>(lhs_tensor), getTensorShape(output()),
178 getTensorData<float>(output()));
179 break;
180 default:
181 assert(false && "Unsupported type.");
182 }
183}
184
185} // namespace kernels
186} // namespace luci_interpreter
BatchMatMul(const Tensor *x, const Tensor *y, Tensor *output, Tensor *x_tmp, Tensor *y_tmp, const BatchMatMulParams &params)
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
void TransposeRowsColumns(const Tensor *tensor_in, Tensor *tensor_out)
tflite::RuntimeShape getTensorShape(const Tensor *tensor)
Definition Utils.h:194
void BatchMatMul(const tflite::RuntimeShape &lhs_shape, const float *lhs_data, const tflite::RuntimeShape &rhs_shape, const float *rhs_data, const tflite::RuntimeShape &output_shape, float *output_data)
Definition Shape.h:28