ONE - On-device Neural Engine
Loading...
Searching...
No Matches
BatchMatMul.h
Go to the documentation of this file.
1
/*
2
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3
*
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
7
*
8
* http://www.apache.org/licenses/LICENSE-2.0
9
*
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
15
*/
16
17
#ifndef __NNFW_CKER_OPTIMIZED_BATCH_MATMUL_H__
18
#define __NNFW_CKER_OPTIMIZED_BATCH_MATMUL_H__
19
20
#include "
cker/Shape.h
"
21
#include "
cker/operation/Helper/BatchMatMulParams.h
"
22
#include "
cker/operation/optimized/Gemm.h
"
23
24
namespace
nnfw
25
{
26
namespace
cker
27
{
28
namespace
optimized
29
{
30
#if defined(CKER_X86_PLATFORM)
31
32
inline
void
BatchMatMul
(
const
BatchMatMulParams
¶ms,
const
float
*lhs_data,
33
const
float
*rhs_data,
float
*output_data)
34
{
35
MatrixParams<float>
lhs_params;
36
lhs_params.
order
=
Order::kRowMajor
;
// ignored by GemmImplUsingEigen
37
lhs_params.
rows
= params.
lhs_rows
;
38
lhs_params.
cols
= params.
lhs_cols
;
39
40
MatrixParams<float>
rhs_params;
41
lhs_params.
order
=
Order::kRowMajor
;
// ignored by GemmImplUsingEigen
42
rhs_params.
rows
= params.
rhs_rows
;
43
rhs_params.
cols
= params.
rhs_cols
;
44
45
MatrixParams<float>
dst_params;
46
lhs_params.
order
=
Order::kRowMajor
;
// ignored by GemmImplUsingEigen
47
dst_params.
rows
= params.
lhs_rows
;
48
dst_params.
cols
= params.
rhs_cols
;
49
50
for
(
int
b0 = 0; b0 < params.
batch_dim0
; ++b0)
51
{
52
for
(
int
b1 = 0; b1 < params.
batch_dim1
; ++b1)
53
{
54
for
(
int
b2 = 0; b2 < params.
batch_dim2
; ++b2)
55
{
56
const
float
*lhs_ptr =
57
lhs_data + b0 * params.
lhs_ext0
+ b1 * params.
lhs_ext1
+ b2 * params.
lhs_ext2
;
58
const
float
*rhs_ptr =
59
rhs_data + b0 * params.
rhs_ext0
+ b1 * params.
rhs_ext1
+ b2 * params.
rhs_ext2
;
60
float
*out_ptr = output_data + ((b0 * params.
batch_dim1
* params.
batch_dim2
) +
61
b1 * params.
batch_dim2
+ b2) *
62
params.
lhs_rows
* params.
rhs_cols
;
63
64
optimized::Gemm(lhs_params, lhs_ptr, rhs_params, rhs_ptr, dst_params, out_ptr,
65
GemmParams<float, float>
{});
66
}
67
}
68
}
69
}
70
#endif
71
}
// namespace optimized
72
}
// namespace cker
73
}
// namespace nnfw
74
75
#endif
// __NNFW_CKER_OPTIMIZED_BATCH_MATMUL_H__
BatchMatMulParams.h
nnfw::cker::BatchMatMul
Definition
BatchMatMul.h:37
Shape.h
Gemm.h
nnfw::cker::Order::kRowMajor
@ kRowMajor
nnfw
Definition
topk_v2.h:30
nnfw::cker::BatchMatMulParams
Definition
BatchMatMulParams.h:27
nnfw::cker::BatchMatMulParams::lhs_rows
int lhs_rows
Definition
BatchMatMulParams.h:61
nnfw::cker::BatchMatMulParams::rhs_rows
int rhs_rows
Definition
BatchMatMulParams.h:63
nnfw::cker::BatchMatMulParams::rhs_ext2
int rhs_ext2
Definition
BatchMatMulParams.h:60
nnfw::cker::BatchMatMulParams::rhs_ext1
int rhs_ext1
Definition
BatchMatMulParams.h:59
nnfw::cker::BatchMatMulParams::batch_dim0
int batch_dim0
Definition
BatchMatMulParams.h:52
nnfw::cker::BatchMatMulParams::batch_dim2
int batch_dim2
Definition
BatchMatMulParams.h:54
nnfw::cker::BatchMatMulParams::rhs_cols
int rhs_cols
Definition
BatchMatMulParams.h:64
nnfw::cker::BatchMatMulParams::lhs_ext1
int lhs_ext1
Definition
BatchMatMulParams.h:56
nnfw::cker::BatchMatMulParams::lhs_cols
int lhs_cols
Definition
BatchMatMulParams.h:62
nnfw::cker::BatchMatMulParams::lhs_ext2
int lhs_ext2
Definition
BatchMatMulParams.h:57
nnfw::cker::BatchMatMulParams::rhs_ext0
int rhs_ext0
Definition
BatchMatMulParams.h:58
nnfw::cker::BatchMatMulParams::batch_dim1
int batch_dim1
Definition
BatchMatMulParams.h:53
nnfw::cker::BatchMatMulParams::lhs_ext0
int lhs_ext0
Definition
BatchMatMulParams.h:55
nnfw::cker::GemmParams
Definition
Types.h:509
nnfw::cker::MatrixParams
Definition
Types.h:439
nnfw::cker::MatrixParams::cols
int cols
Definition
Types.h:446
nnfw::cker::MatrixParams::rows
int rows
Definition
Types.h:444
nnfw::cker::MatrixParams::order
Order order
Definition
Types.h:442
compute
cker
include
cker
operation
optimized
BatchMatMul.h
Generated by
1.9.8