ONE - On-device Neural Engine
Loading...
Searching...
No Matches
eigen_gemm_eigen.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_EGIEN_EIGEN_GEMM_EIGEN_H__
19#define __NNFW_CKER_EGIEN_EIGEN_GEMM_EIGEN_H__
20
21// See b/131835803: in TFLite code, because eigen_spatial_convolutions.h does
22// #define Eigen EigenForTFLite, it is difficult to have any #include of Eigen
23// headers in a header file, as that results in name classes (compilation
24// errors) depending on the order in which these headers are #included.
25// So we have moved the #include of Eigen here, in a .cc file, where we have
26// control over the header #include sequence.
27// #include "third_party/eigen3/Eigen/Core"
28// #include "tensorflow/lite/kernels/cpu_backend_context.h"
29// #include "tensorflow/lite/kernels/cpu_backend_gemm_params.h"
30// #include "tensorflow/lite/kernels/internal/common.h"
31// #include "cker/eigen/eigen_convolution_helpers.h"
33#include "cker/Types.h"
34
35#include <Eigen/Core>
36
37namespace nnfw
38{
39namespace cker
40{
41namespace detail
42{
43
44// tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_eigen.h and cpu_backend_gemm_eigen.cc
46{
47 static void Run(const MatrixParams<float> &lhs_params, const float *lhs_data,
48 const MatrixParams<float> &rhs_params, const float *rhs_data,
49 const MatrixParams<float> &dst_params, float *dst_data,
50 const GemmParams<float, float> &params)
51 {
52 // This code assumes specific storage orders, encoded in these Eigen types.
53 // These assumptions have been checked by TF_LITE_ASSERT's in the public
54 // Gemm entry point already, before the implementation gets to this point.
55 using EigenMatrixMapRowMajorConst =
56 Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
57 using EigenMatrixMapColMajorConst =
58 Eigen::Map<const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>;
59 using EigenMatrixMapColMajorMutable =
60 Eigen::Map<Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>;
61
62 EigenMatrixMapRowMajorConst eigen_lhs(lhs_data, lhs_params.rows, lhs_params.cols);
63 EigenMatrixMapColMajorConst eigen_rhs(rhs_data, rhs_params.rows, rhs_params.cols);
64 EigenMatrixMapColMajorMutable eigen_dst(dst_data, dst_params.rows, dst_params.cols);
65
66 if (rhs_params.cols == 1)
67 {
68 eigen_dst.col(0).noalias() = eigen_lhs * eigen_rhs.col(0);
69 }
70 else if (lhs_params.rows == 1)
71 {
72 eigen_dst.row(0).noalias() = eigen_lhs.row(0) * eigen_rhs;
73 }
74 else
75 {
76 eigen_dst.noalias() = eigen_lhs * eigen_rhs;
77 }
78
79 if (params.bias)
80 {
81 BiasAndClamp(params.clamp_min, params.clamp_max, dst_params.rows, params.bias,
82 dst_params.rows * dst_params.cols, dst_data);
83 }
84 else
85 {
86 eigen_dst = eigen_dst.cwiseMin(params.clamp_max).cwiseMax(params.clamp_min);
87 }
88 }
89};
90
91} // namespace detail
92} // namespace cker
93} // namespace nnfw
94
95#endif // __NNFW_CKER_EGIEN_EIGEN_GEMM_EIGEN_H__
void BiasAndClamp(float clamp_min, float clamp_max, int bias_size, const float *bias_data, int array_size, float *array_data)
Definition Common.h:29
Definition topk_v2.h:30
DstScalar clamp_max
Definition Types.h:537
const AccumScalar * bias
Definition Types.h:531
DstScalar clamp_min
Definition Types.h:533
static void Run(const MatrixParams< float > &lhs_params, const float *lhs_data, const MatrixParams< float > &rhs_params, const float *rhs_data, const MatrixParams< float > &dst_params, float *dst_data, const GemmParams< float, float > &params)