ONE - On-device Neural Engine
Loading...
Searching...
No Matches
FullyConnected.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_RUY_FULLY_CONNECTED_H__
19#define __NNFW_RUY_FULLY_CONNECTED_H__
20
21#include "ruy/Shape.h"
22#include "ruy/Types.h"
23#include "ruy/Utils.h"
24#include "ruy/RuySupport.h"
25
26#include <ruy/ruy.h>
27#include <ruy/context.h>
28
29namespace nnfw
30{
31namespace ruy
32{
33
34inline void FullyConnected(const FullyConnectedParams &params, const Shape &input_shape,
35 const float *input_data, const Shape &weights_shape,
36 const float *weights_data, const Shape &,
37 const float *optional_bias_data, const Shape &output_shape,
38 float *output_data, ::ruy::Context *ruy_context)
39{
40 const int dims_count = weights_shape.DimensionsCount();
41 const int input_rows = weights_shape.Dims(dims_count - 1);
42 MatrixParams<float> rhs_params;
43 rhs_params.order = Order::kColMajor;
44 rhs_params.rows = input_rows;
45 rhs_params.cols = input_shape.FlatSize() / input_rows;
47 assert(input_shape.FlatSize() == (rhs_params.rows * rhs_params.cols));
48 MatrixParams<float> lhs_params;
49 lhs_params.order = Order::kRowMajor;
50 lhs_params.cols = weights_shape.Dims(dims_count - 1);
51 lhs_params.rows = FlatSizeSkipDim(weights_shape, dims_count - 1);
53 MatrixParams<float> dst_params;
54 dst_params.order = Order::kColMajor;
55 dst_params.rows = output_shape.Dims(output_shape.DimensionsCount() - 1);
56 dst_params.cols = FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
57 GemmParams<float, float> gemm_params;
58 gemm_params.bias = optional_bias_data;
59 gemm_params.clamp_min = params.float_activation_min;
60 gemm_params.clamp_max = params.float_activation_max;
61
62 // Below code was copied from tflite::cpu_backend_gemm::detail::GemmImplUsingRuy
63 ::ruy::Matrix<float> ruy_lhs;
64 ::ruy::Matrix<float> ruy_rhs;
65 ::ruy::Matrix<float> ruy_dst;
66 // Note that cache is always enabled for input and weight tensors
67 ruy_support::MakeRuyMatrix(lhs_params, weights_data, &ruy_lhs, true);
68 ruy_support::MakeRuyMatrix(rhs_params, input_data, &ruy_rhs, true);
69 ruy_support::MakeRuyMatrix(dst_params, output_data, &ruy_dst);
70
71 ::ruy::MulParams<float, float> ruy_mul_params;
72 ruy_support::MakeRuyMulParams(gemm_params, &ruy_mul_params);
73
74 ::ruy::Mul(ruy_lhs, ruy_rhs, ruy_mul_params, ruy_context, &ruy_dst);
75}
76
77} // namespace ruy
78} // namespace nnfw
79
80#endif // __NNFW_RUY_FULLY_CONNECTED_H__
int FlatSize() const
Definition Shape.h:181
int32_t DimensionsCount() const
Definition Shape.h:91
int32_t Dims(int i) const
Definition Shape.h:92
const luci_interpreter::RuntimeShape output_shape
void MakeRuyMulParams(const GemmParams< AccumScalar, DstScalar, quantization_flavor > &params, ::ruy::MulParams< AccumScalar, DstScalar > *ruy_mul_params)
Definition RuySupport.h:69
void MakeRuyMatrix(const MatrixParams< Scalar > &params, DataPointer data_ptr, ::ruy::Matrix< Scalar > *dst, bool use_caching=false)
Definition RuySupport.h:51
void FullyConnected(const FullyConnectedParams &params, const Shape &input_shape, const float *input_data, const Shape &weights_shape, const float *weights_data, const Shape &, const float *optional_bias_data, const Shape &output_shape, float *output_data, ::ruy::Context *ruy_context)
int FlatSizeSkipDim(const Shape &shape, int skip_dim)
Definition Shape.h:254
CachePolicy DefaultCachePolicy(bool is_constant_data)
Definition Types.h:267
Definition topk_v2.h:30
DstScalar clamp_max
Definition Types.h:222
DstScalar clamp_min
Definition Types.h:218
const AccumScalar * bias
Definition Types.h:216
CachePolicy cache_policy
Definition Types.h:141