ONE - On-device Neural Engine
Loading...
Searching...
No Matches
BatchMatMul.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __NNFW_CKER_OPTIMIZED_BATCH_MATMUL_H__
18#define __NNFW_CKER_OPTIMIZED_BATCH_MATMUL_H__
19
20#include "cker/Shape.h"
23
24namespace nnfw
25{
26namespace cker
27{
28namespace optimized
29{
30#if defined(CKER_X86_PLATFORM)
31
32inline void BatchMatMul(const BatchMatMulParams &params, const float *lhs_data,
33 const float *rhs_data, float *output_data)
34{
35 MatrixParams<float> lhs_params;
36 lhs_params.order = Order::kRowMajor; // ignored by GemmImplUsingEigen
37 lhs_params.rows = params.lhs_rows;
38 lhs_params.cols = params.lhs_cols;
39
40 MatrixParams<float> rhs_params;
41 lhs_params.order = Order::kRowMajor; // ignored by GemmImplUsingEigen
42 rhs_params.rows = params.rhs_rows;
43 rhs_params.cols = params.rhs_cols;
44
45 MatrixParams<float> dst_params;
46 lhs_params.order = Order::kRowMajor; // ignored by GemmImplUsingEigen
47 dst_params.rows = params.lhs_rows;
48 dst_params.cols = params.rhs_cols;
49
50 for (int b0 = 0; b0 < params.batch_dim0; ++b0)
51 {
52 for (int b1 = 0; b1 < params.batch_dim1; ++b1)
53 {
54 for (int b2 = 0; b2 < params.batch_dim2; ++b2)
55 {
56 const float *lhs_ptr =
57 lhs_data + b0 * params.lhs_ext0 + b1 * params.lhs_ext1 + b2 * params.lhs_ext2;
58 const float *rhs_ptr =
59 rhs_data + b0 * params.rhs_ext0 + b1 * params.rhs_ext1 + b2 * params.rhs_ext2;
60 float *out_ptr = output_data + ((b0 * params.batch_dim1 * params.batch_dim2) +
61 b1 * params.batch_dim2 + b2) *
62 params.lhs_rows * params.rhs_cols;
63
64 optimized::Gemm(lhs_params, lhs_ptr, rhs_params, rhs_ptr, dst_params, out_ptr,
66 }
67 }
68 }
69}
70#endif
71} // namespace optimized
72} // namespace cker
73} // namespace nnfw
74
75#endif // __NNFW_CKER_OPTIMIZED_BATCH_MATMUL_H__
Definition topk_v2.h:30