ONE - On-device Neural Engine
Loading...
Searching...
No Matches
BatchMatMul.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_BATCH_MATMUL_H__
19#define __NNFW_CKER_BATCH_MATMUL_H__
20
21#include "Transpose.h"
22
23#include "cker/Types.h"
24#include "cker/Shape.h"
25#include "cker/Utils.h"
28
29#include <vector>
30
31namespace nnfw
32{
33namespace cker
34{
35
37{
38public:
40 {
41 // DO NOTHING
42 }
43
47 void prepare(const Shape &lhs_shape, const Shape &rhs_shape, bool adj_x, bool adj_y)
48 {
49 if (adj_x)
50 {
51 int32_t rank = lhs_shape.DimensionsCount();
52 _temp_lhs_shape.Resize(rank);
53
54 for (int32_t i = 0; i < rank - 2; i++)
55 {
56 _temp_lhs_shape.SetDim(i, lhs_shape.Dims(i));
57 }
58 _temp_lhs_shape.SetDim(rank - 2, lhs_shape.Dims(rank - 1));
59 _temp_lhs_shape.SetDim(rank - 1, lhs_shape.Dims(rank - 2));
60
61 _temp_lhs.resize(_temp_lhs_shape.FlatSize());
62 }
63
64 if (!adj_y)
65 {
66 int32_t rank = rhs_shape.DimensionsCount();
67 _temp_rhs_shape.Resize(rank);
68
69 for (int32_t i = 0; i < rank - 2; i++)
70 {
71 _temp_rhs_shape.SetDim(i, rhs_shape.Dims(i));
72 }
73 _temp_rhs_shape.SetDim(rank - 2, rhs_shape.Dims(rank - 1));
74 _temp_rhs_shape.SetDim(rank - 1, rhs_shape.Dims(rank - 2));
75
76 _temp_rhs.resize(_temp_rhs_shape.FlatSize());
77 }
78 }
79
80 void operator()(const Shape &lhs_shape, const float *lhs_data, const Shape &rhs_shape,
81 const float *rhs_data, bool adj_x, bool adj_y, const Shape & /*output_shape*/,
82 float *output_data)
83 {
84 // Assume lhs and rhs is not constant
85 // TODO Handle constant input
86
87 if (!adj_y)
88 {
89 transposeRowsCols(rhs_shape, rhs_data, _temp_rhs_shape, _temp_rhs.data());
90 }
91
92 if (adj_x)
93 {
94 transposeRowsCols(lhs_shape, lhs_data, _temp_lhs_shape, _temp_lhs.data());
95 }
96
97 Shape new_lhs_shape = adj_x ? lhs_shape : swapRowColDims(lhs_shape);
98 Shape new_rhs_shape = adj_y ? rhs_shape : swapRowColDims(rhs_shape);
99 const float *new_lhs_data = adj_x ? _temp_lhs.data() : lhs_data;
100 const float *new_rhs_data = adj_y ? rhs_data : _temp_rhs.data();
101
102 // Note we pass RHS args first, LHS args second
103 // Check accumulative dimensions of lhs and rhs of are equal
104 assert(Shape::ExtendedShape(5, new_rhs_shape).Dims(4) ==
105 Shape::ExtendedShape(5, new_lhs_shape).Dims(3));
106
107 const BatchMatMulParams params{new_rhs_shape, new_lhs_shape};
108#if defined(CKER_X86_PLATFORM)
109 optimized::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data);
110#else
111 reference::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data);
112#endif
113 }
114
115private:
116 Shape swapRowColDims(const Shape &shape)
117 {
118 Shape swapped_shape(shape);
119 const uint32_t dims = shape.DimensionsCount();
120 swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
121 swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
122
123 return swapped_shape;
124 }
125
126 void transposeRowsCols(const Shape &input_shape, const float *input_data,
127 const Shape &output_shape, float *output_data)
128 {
129 TransposeParams params;
130 int rank = input_shape.DimensionsCount();
131 params.perm_count = rank;
132 for (int i = 0; i < 2; i++)
133 {
134 params.perm[i] = i;
135 }
136 params.perm[rank - 2] = rank - 1;
137 params.perm[rank - 1] = rank - 2;
138
139 Transpose<float>(params, input_shape, input_data, output_shape, output_data);
140 }
141
142private:
143 std::vector<float> _temp_lhs;
144 Shape _temp_lhs_shape;
145 std::vector<float> _temp_rhs;
146 Shape _temp_rhs_shape;
147};
148
149} // namespace cker
150} // namespace nnfw
151
152#endif // __NNFW_CKER_BATCH_MATMUL_H__
void prepare(const Shape &lhs_shape, const Shape &rhs_shape, bool adj_x, bool adj_y)
Prepare temporary area for calculation.
Definition BatchMatMul.h:47
void operator()(const Shape &lhs_shape, const float *lhs_data, const Shape &rhs_shape, const float *rhs_data, bool adj_x, bool adj_y, const Shape &, float *output_data)
Definition BatchMatMul.h:80
int32_t DimensionsCount() const
Definition Shape.h:91
int32_t Dims(int i) const
Definition Shape.h:92
int FlatSize() const
Definition Shape.h:181
void Resize(int dimensions_count)
Definition Shape.h:117
void SetDim(int i, int32_t val)
Definition Shape.h:98
const luci_interpreter::RuntimeShape output_shape
void BatchMatMul(const BatchMatMulParams &params, const float *lhs_data, const float *rhs_data, float *output_data)
Definition BatchMatMul.h:32
Definition topk_v2.h:30
Definition Dims.h:26
Definition Shape.h:28