ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
FullyConnected.h
Go to the documentation of this file.
1/* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
2 *
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16#ifndef __LUCI_COMPUTE_FULLY_CONNECTED_H__
17#define __LUCI_COMPUTE_FULLY_CONNECTED_H__
18
19#include "Types.h"
20
21#include <loco/IR/TensorShape.h>
22
23namespace luci
24{
25namespace compute
26{
27
28// TODO extract some common for multiple Ops
30{
31public:
32 FullyConnected() = default;
33
34public:
35 FullyConnectedParams &params(void) { return _params; }
36
37 bool keep_num_dims(void) const { return _keep_num_dims; }
38 void keep_num_dims(bool knd) { _keep_num_dims = knd; }
39
40 void input(const loco::TensorShape &shape, const float *data)
41 {
42 _input_shape = shape;
43 _input_data = data;
44 }
45
46 void weights(const loco::TensorShape &shape, const float *data)
47 {
48 _weights_shape = shape;
49 _weights_data = data;
50 }
51
52 void bias(const loco::TensorShape &shape, const float *data)
53 {
54 _bias_shape = shape;
55 _bias_data = data;
56 }
57
58 void fused_act_func(FusedActFunc func) { _fused_act_func = func; };
59
60 void output(float *data) { _output_data = data; }
61
62public:
63 bool prepare(void);
64 const loco::TensorShape &output_shape(void) const { return _output_shape; }
65 void compute(void);
66
67private:
68 // param to pass to compute kernel
69 FullyConnectedParams _params = {};
70 // new param from tflite version 5
71 bool _keep_num_dims = false;
72 // shape and data for inputs
73 loco::TensorShape _input_shape;
74 loco::TensorShape _weights_shape;
75 loco::TensorShape _bias_shape;
76 const float *_input_data = nullptr;
77 const float *_weights_data = nullptr;
78 const float *_bias_data = nullptr;
79 FusedActFunc _fused_act_func = FusedActFunc::UNDEFINED;
80
81 // compute results
82 loco::TensorShape _output_shape;
83 float *_output_data = nullptr;
84};
85
86} // namespace compute
87} // namespace luci
88
89#endif // __LUCI_COMPUTE_FULLY_CONNECTED_H__
void weights(const loco::TensorShape &shape, const float *data)
void input(const loco::TensorShape &shape, const float *data)
const loco::TensorShape & output_shape(void) const
void fused_act_func(FusedActFunc func)
void bias(const loco::TensorShape &shape, const float *data)
FullyConnectedParams & params(void)
const T * data(const std::vector< T, Alloc > &v)