ONE - On-device Neural Engine
Loading...
Searching...
No Matches
RuySupport.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_RUY_RUY_SUPPORT_H__
19#define __NNFW_RUY_RUY_SUPPORT_H__
20
21#include <util/ConfigSource.h>
22#include <ruy/matrix.h>
23#include <ruy/ruy.h>
24#include <cassert>
25#include "Types.h"
26
27namespace nnfw
28{
29namespace ruy
30{
31namespace ruy_support
32{
33
34inline ::ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy)
35{
36 switch (cache_policy)
37 {
39 return ::ruy::CachePolicy::kNeverCache;
41 return ::ruy::CachePolicy::kCacheIfLargeSpeedup;
43 return ::ruy::CachePolicy::kAlwaysCache;
44 default:
45 assert(false);
46 return ::ruy::CachePolicy::kNeverCache;
47 }
48}
49
50template <typename Scalar, typename DataPointer>
51void MakeRuyMatrix(const MatrixParams<Scalar> &params, DataPointer data_ptr,
52 ::ruy::Matrix<Scalar> *dst, bool use_caching = false)
53{
54 ::ruy::Order ruy_order =
56 ::ruy::MakeSimpleLayout(params.rows, params.cols, ruy_order, dst->mutable_layout());
57 // Note that ruy::Matrix::data is a ConstCheckingPtr, not a plain pointer.
58 // It does care whether we assign to it a Scalar* or a const Scalar*.
59 dst->set_data(data_ptr);
60 dst->set_zero_point(params.zero_point);
61 if (use_caching)
62 {
63 dst->set_cache_policy(ToRuyCachePolicy(params.cache_policy));
64 }
65}
66
67// Floating-point case.
68template <typename AccumScalar, typename DstScalar, QuantizationFlavor quantization_flavor>
70 ::ruy::MulParams<AccumScalar, DstScalar> *ruy_mul_params)
71{
72 static_assert(quantization_flavor == QuantizationFlavor::kFloatingPoint, "");
73 ruy_mul_params->set_bias(params.bias);
74 ruy_mul_params->set_clamp_min(params.clamp_min);
75 ruy_mul_params->set_clamp_max(params.clamp_max);
76}
77
78// Integer-quantized case with destination type narrower than int32
79template <typename DstScalar, QuantizationFlavor quantization_flavor>
81 ::ruy::MulParams<std::int32_t, DstScalar> *ruy_mul_params)
82{
83 static_assert(sizeof(DstScalar) < sizeof(std::int32_t), "");
85 {
86 ruy_mul_params->set_multiplier_fixedpoint(params.multiplier_fixedpoint);
87 ruy_mul_params->set_multiplier_exponent(params.multiplier_exponent);
88 }
89 if (quantization_flavor == QuantizationFlavor::kIntegerWithPerRowMultiplier)
90 {
91 ruy_mul_params->set_multiplier_fixedpoint_perchannel(params.multiplier_fixedpoint_perchannel);
92 ruy_mul_params->set_multiplier_exponent_perchannel(params.multiplier_exponent_perchannel);
93 }
94 ruy_mul_params->set_bias(params.bias);
95 ruy_mul_params->set_clamp_min(params.clamp_min);
96 ruy_mul_params->set_clamp_max(params.clamp_max);
97}
98
99// Raw-integer case with destination type int32.
100template <QuantizationFlavor quantization_flavor>
102 ::ruy::MulParams<std::int32_t, std::int32_t> *ruy_mul_params)
103{
104 ruy_mul_params->set_bias(params.bias);
105}
106
107} // namespace ruy_support
108} // namespace ruy
109} // namespace nnfw
110
111#endif // __NNFW_RUY_RUY_SUPPORT_H__
void MakeRuyMulParams(const GemmParams< AccumScalar, DstScalar, quantization_flavor > &params, ::ruy::MulParams< AccumScalar, DstScalar > *ruy_mul_params)
Definition RuySupport.h:69
void MakeRuyMatrix(const MatrixParams< Scalar > &params, DataPointer data_ptr, ::ruy::Matrix< Scalar > *dst, bool use_caching=false)
Definition RuySupport.h:51
inline ::ruy::CachePolicy ToRuyCachePolicy(CachePolicy cache_policy)
Definition RuySupport.h:34
CachePolicy
Definition Types.h:111
Definition topk_v2.h:30
AccumScalar multiplier_fixedpoint
Definition Types.h:198
DstScalar clamp_max
Definition Types.h:222
DstScalar clamp_min
Definition Types.h:218
const int * multiplier_exponent_perchannel
Definition Types.h:214
const AccumScalar * multiplier_fixedpoint_perchannel
Definition Types.h:206
const AccumScalar * bias
Definition Types.h:216
CachePolicy cache_policy
Definition Types.h:141