18#ifndef __NNFW_CKER_OPTIMIZED_GEMM_H__
19#define __NNFW_CKER_OPTIMIZED_GEMM_H__
25#include <ruy/context.h>
34#if defined(CKER_X86_PLATFORM)
37template <
typename LhsScalar,
typename RhsScalar,
typename AccumScalar,
typename DstScalar,
41 static void Run(
const MatrixParams<LhsScalar> &,
const LhsScalar *,
42 const MatrixParams<RhsScalar> &,
const RhsScalar *,
43 const MatrixParams<DstScalar> &, DstScalar *,
44 const GemmParams<AccumScalar, DstScalar, quantization_flavor> &)
47 std::is_floating_point<LhsScalar>::value && std::is_floating_point<RhsScalar>::value &&
48 std::is_floating_point<AccumScalar>::value && std::is_floating_point<DstScalar>::value &&
49 quantization_flavor != QuantizationFlavor::kFloatingPoint,
50 "GemmImplX86 does not supported types other than float yet.");
57 static void Run(
const MatrixParams<float> &lhs_params,
const float *lhs_data,
58 const MatrixParams<float> &rhs_params,
const float *rhs_data,
59 const MatrixParams<float> &dst_params,
float *dst_data,
60 const GemmParams<float, float, QuantizationFlavor::kFloatingPoint> ¶ms)
62 detail::GemmImplUsingEigen::Run(lhs_params, lhs_data, rhs_params, rhs_data, dst_params,
70template <
typename LhsScalar,
typename RhsScalar,
typename AccumScalar,
typename DstScalar,
72struct GemmImpl : GemmImplX86<LhsScalar, RhsScalar, AccumScalar, DstScalar, quantization_flavor>
77template <
typename LhsScalar,
typename RhsScalar,
typename AccumScalar,
typename DstScalar,
79void Gemm(
const MatrixParams<LhsScalar> &lhs_params,
const LhsScalar *lhs_data,
80 const MatrixParams<RhsScalar> &rhs_params,
const RhsScalar *rhs_data,
81 const MatrixParams<DstScalar> &dst_params, DstScalar *dst_data,
82 const GemmParams<AccumScalar, DstScalar, quantization_flavor> ¶ms)
85 GemmImpl<LhsScalar, RhsScalar, AccumScalar, DstScalar, quantization_flavor>::Run(
86 lhs_params, lhs_data, rhs_params, rhs_data, dst_params, dst_data, params);
void Gemm(const Eigen::MatrixBase< Lhs > &lhs, const Eigen::MatrixBase< Rhs > &rhs, Eigen::MatrixBase< Result > *result)
CachePolicy DefaultCachePolicy(bool is_constant_data)