ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
BatchMatMul.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_BATCH_MATMUL_H__
19#define __NNFW_CKER_BATCH_MATMUL_H__
20
21#include "Transpose.h"
22
23#include "cker/Types.h"
24#include "cker/Shape.h"
25#include "cker/Utils.h"
28
29#include <vector>
30
31namespace nnfw
32{
33namespace cker
34{
35
37{
38public:
40 {
41 // DO NOTHING
42 }
43
47 void prepare(const Shape &lhs_shape, const Shape &rhs_shape, bool adj_x, bool adj_y,
48 bool rhs_const)
49 {
50 if (adj_x)
51 {
52 int32_t rank = lhs_shape.DimensionsCount();
53 _temp_lhs_shape.Resize(rank);
54
55 for (int32_t i = 0; i < rank - 2; i++)
56 {
57 _temp_lhs_shape.SetDim(i, lhs_shape.Dims(i));
58 }
59 _temp_lhs_shape.SetDim(rank - 2, lhs_shape.Dims(rank - 1));
60 _temp_lhs_shape.SetDim(rank - 1, lhs_shape.Dims(rank - 2));
61
62 _temp_lhs.resize(_temp_lhs_shape.FlatSize());
63 }
64
65 if (!adj_y)
66 {
67 int32_t rank = rhs_shape.DimensionsCount();
68 _temp_rhs_shape.Resize(rank);
69
70 for (int32_t i = 0; i < rank - 2; i++)
71 {
72 _temp_rhs_shape.SetDim(i, rhs_shape.Dims(i));
73 }
74 _temp_rhs_shape.SetDim(rank - 2, rhs_shape.Dims(rank - 1));
75 _temp_rhs_shape.SetDim(rank - 1, rhs_shape.Dims(rank - 2));
76
77 _temp_rhs.resize(_temp_rhs_shape.FlatSize());
78 }
79
80 _rhs_constant = rhs_const;
81 }
82
83 void operator()(const Shape &lhs_shape, const float *lhs_data, const Shape &rhs_shape,
84 const float *rhs_data, bool adj_x, bool adj_y, const Shape & /*output_shape*/,
85 float *output_data)
86 {
87 // Don't need transpose if rhs is constant and already transposed
88 if (!adj_y && !(_rhs_constant && _rhs_transposed))
89 {
90 transposeRowsCols(rhs_shape, rhs_data, _temp_rhs_shape, _temp_rhs.data());
91 _rhs_transposed = true;
92 }
93
94 if (adj_x)
95 {
96 transposeRowsCols(lhs_shape, lhs_data, _temp_lhs_shape, _temp_lhs.data());
97 }
98
99 Shape new_lhs_shape = adj_x ? lhs_shape : swapRowColDims(lhs_shape);
100 Shape new_rhs_shape = adj_y ? rhs_shape : swapRowColDims(rhs_shape);
101 const float *new_lhs_data = adj_x ? _temp_lhs.data() : lhs_data;
102 const float *new_rhs_data = adj_y ? rhs_data : _temp_rhs.data();
103
104 // Note we pass RHS args first, LHS args second
105 // Check accumulative dimensions of lhs and rhs of are equal
106 assert(Shape::ExtendedShape(5, new_rhs_shape).Dims(4) ==
107 Shape::ExtendedShape(5, new_lhs_shape).Dims(3));
108
109 const BatchMatMulParams params{new_rhs_shape, new_lhs_shape};
110#if defined(CKER_X86_PLATFORM)
111 optimized::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data);
112#else
113 reference::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data);
114#endif
115 }
116
117private:
118 Shape swapRowColDims(const Shape &shape)
119 {
120 Shape swapped_shape(shape);
121 const uint32_t dims = shape.DimensionsCount();
122 swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
123 swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
124
125 return swapped_shape;
126 }
127
128 void transposeRowsCols(const Shape &input_shape, const float *input_data,
129 const Shape &output_shape, float *output_data)
130 {
131 TransposeParams params;
132 int rank = input_shape.DimensionsCount();
133 params.perm_count = rank;
134 for (int i = 0; i < 2; i++)
135 {
136 params.perm[i] = i;
137 }
138 params.perm[rank - 2] = rank - 1;
139 params.perm[rank - 1] = rank - 2;
140
141 Transpose<float>(params, input_shape, input_data, output_shape, output_data);
142 }
143
144private:
145 std::vector<float> _temp_lhs;
146 Shape _temp_lhs_shape;
147 std::vector<float> _temp_rhs;
148 Shape _temp_rhs_shape;
149 bool _rhs_constant = false;
150 bool _rhs_transposed = false;
151};
152
153} // namespace cker
154} // namespace nnfw
155
156#endif // __NNFW_CKER_BATCH_MATMUL_H__
void prepare(const Shape &lhs_shape, const Shape &rhs_shape, bool adj_x, bool adj_y, bool rhs_const)
Prepare temporary area for calculation.
Definition BatchMatMul.h:47
void operator()(const Shape &lhs_shape, const float *lhs_data, const Shape &rhs_shape, const float *rhs_data, bool adj_x, bool adj_y, const Shape &, float *output_data)
Definition BatchMatMul.h:83
int32_t DimensionsCount() const
Definition Shape.h:107
int32_t Dims(int i) const
Definition Shape.h:110
int FlatSize() const
Definition Shape.h:249
void Resize(int dimensions_count)
Definition Shape.h:170
void SetDim(int i, int32_t val)
Definition Shape.h:124
const luci_interpreter::RuntimeShape output_shape
void BatchMatMul(const BatchMatMulParams &params, const float *lhs_data, const float *rhs_data, float *output_data)
Definition BatchMatMul.h:32
Definition topk_v2.h:30
Definition Dims.h:26
Definition Shape.h:28