ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Gemm.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_OPTIMIZED_GEMM_H__
19#define __NNFW_CKER_OPTIMIZED_GEMM_H__
20
22#include "cker/Shape.h"
23#include "cker/Types.h"
24
25#include <ruy/context.h>
26
27namespace nnfw
28{
29namespace cker
30{
31namespace optimized
32{
33
34#if defined(CKER_X86_PLATFORM)
35
36/* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_x86.h */
37template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
38 QuantizationFlavor quantization_flavor>
39struct GemmImplX86
40{
41 static void Run(const MatrixParams<LhsScalar> &, const LhsScalar *,
42 const MatrixParams<RhsScalar> &, const RhsScalar *,
43 const MatrixParams<DstScalar> &, DstScalar *,
44 const GemmParams<AccumScalar, DstScalar, quantization_flavor> &)
45 {
46 static_assert(
47 std::is_floating_point<LhsScalar>::value && std::is_floating_point<RhsScalar>::value &&
48 std::is_floating_point<AccumScalar>::value && std::is_floating_point<DstScalar>::value &&
49 quantization_flavor != QuantizationFlavor::kFloatingPoint,
50 "GemmImplX86 does not supported types other than float yet.");
51 }
52};
53
54// For float, defer to eigen for now.
55template <> struct GemmImplX86<float, float, float, float, QuantizationFlavor::kFloatingPoint>
56{
57 static void Run(const MatrixParams<float> &lhs_params, const float *lhs_data,
58 const MatrixParams<float> &rhs_params, const float *rhs_data,
59 const MatrixParams<float> &dst_params, float *dst_data,
60 const GemmParams<float, float, QuantizationFlavor::kFloatingPoint> &params)
61 {
62 detail::GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params,
63 dst_data, params);
64 }
65};
66
67/* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm.h */
68/* GEMM dispatch implementation for x86.
69 */
70template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
71 QuantizationFlavor quantization_flavor>
72struct GemmImpl : GemmImplX86<LhsScalar, RhsScalar, AccumScalar, DstScalar, quantization_flavor>
73{
74};
75
76/* From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm.h */
77template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar,
78 QuantizationFlavor quantization_flavor>
79void Gemm(const MatrixParams<LhsScalar> &lhs_params, const LhsScalar *lhs_data,
80 const MatrixParams<RhsScalar> &rhs_params, const RhsScalar *rhs_data,
81 const MatrixParams<DstScalar> &dst_params, DstScalar *dst_data,
82 const GemmParams<AccumScalar, DstScalar, quantization_flavor> &params)
83{
84 // Generic case: dispatch to any backend as a general GEMM.
85 GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar, quantization_flavor>::Run(
86 lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params);
87}
88
89// From tensorflow/tensorflow/lite/kernels/cpu_backend_gemm_params.h
90inline CachePolicy DefaultCachePolicy(bool is_constant_data)
91{
93}
94#endif // CKER_X86_PLATFORM
95
96} // namespace optimized
97} // namespace cker
98} // namespace nnfw
99
100#endif // __NNFW_CKER_OPTIMIZED_GEMM_H__
void Gemm(const Eigen::MatrixBase< Lhs > &lhs, const Eigen::MatrixBase< Rhs > &rhs, Eigen::MatrixBase< Result > *result)
Definition GEMM.h:24
QuantizationFlavor
Definition Types.h:475
CachePolicy DefaultCachePolicy(bool is_constant_data)
Definition Types.h:267
Definition topk_v2.h:30