ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci::compute::FullyConnected Class Reference

#include <FullyConnected.h>

Public Member Functions

 FullyConnected ()=default
 
FullyConnectedParamsparams (void)
 
bool keep_num_dims (void) const
 
void keep_num_dims (bool knd)
 
void input (const loco::TensorShape &shape, const float *data)
 
void weights (const loco::TensorShape &shape, const float *data)
 
void bias (const loco::TensorShape &shape, const float *data)
 
void fused_act_func (FusedActFunc func)
 
void output (float *data)
 
bool prepare (void)
 
const loco::TensorShapeoutput_shape (void) const
 
void compute (void)
 

Detailed Description

Definition at line 29 of file FullyConnected.h.

Constructor & Destructor Documentation

◆ FullyConnected()

luci::compute::FullyConnected::FullyConnected ( )
default

Member Function Documentation

◆ bias()

void luci::compute::FullyConnected::bias ( const loco::TensorShape shape,
const float *  data 
)
inline

Definition at line 52 of file FullyConnected.h.

53 {
54 _bias_shape = shape;
55 _bias_data = data;
56 }
const T * data(const std::vector< T, Alloc > &v)

References flatbuffers::data().

◆ compute()

void luci::compute::FullyConnected::compute ( void  )

Definition at line 76 of file FullyConnected.cpp.

77{
78 assert(_input_data != nullptr);
79 assert(_weights_data != nullptr);
80 // NOTE _bias_shape can be nullptr
81 assert(_output_data != nullptr);
82
83 // NOTE if this fails, structure may have changed
84 static_assert(sizeof(compute::FullyConnectedParams) == sizeof(tflite::FullyConnectedParams));
85
86 tflite::FullyConnectedParams params;
87
88 // clang-format off
101 // clang-format on
102
103 tflite::reference_ops::FullyConnected(
104 params, tflite_shape(_input_shape), _input_data, tflite_shape(_weights_shape), _weights_data,
105 tflite_shape(_bias_shape), _bias_data, tflite_shape(_output_shape), _output_data);
106}
FullyConnectedParams & params(void)
tflite::FullyConnectedWeightsFormat tflite_weights_format(const FullyConnectedWeightsFormat type)
tflite::RuntimeShape tflite_shape(const loco::TensorShape &shape)
FullyConnectedWeightsFormat weights_format
Definition Types.h:105

References luci::compute::FullyConnectedParams::float_activation_max, luci::compute::FullyConnectedParams::float_activation_min, luci::compute::FullyConnectedParams::input_offset, luci::compute::FullyConnectedParams::lhs_cacheable, luci::compute::FullyConnectedParams::output_multiplier, luci::compute::FullyConnectedParams::output_offset, luci::compute::FullyConnectedParams::output_shift, params(), luci::compute::FullyConnectedParams::quantized_activation_max, luci::compute::FullyConnectedParams::quantized_activation_min, luci::compute::FullyConnectedParams::rhs_cacheable, luci::compute::tflite_shape(), luci::compute::tflite_weights_format(), luci::compute::FullyConnectedParams::weights_format, and luci::compute::FullyConnectedParams::weights_offset.

◆ fused_act_func()

void luci::compute::FullyConnected::fused_act_func ( FusedActFunc  func)
inline

Definition at line 58 of file FullyConnected.h.

58{ _fused_act_func = func; };

◆ input()

void luci::compute::FullyConnected::input ( const loco::TensorShape shape,
const float *  data 
)
inline

Definition at line 40 of file FullyConnected.h.

41 {
42 _input_shape = shape;
43 _input_data = data;
44 }

References flatbuffers::data().

◆ keep_num_dims() [1/2]

void luci::compute::FullyConnected::keep_num_dims ( bool  knd)
inline

Definition at line 38 of file FullyConnected.h.

38{ _keep_num_dims = knd; }

◆ keep_num_dims() [2/2]

bool luci::compute::FullyConnected::keep_num_dims ( void  ) const
inline

Definition at line 37 of file FullyConnected.h.

37{ return _keep_num_dims; }

◆ output()

void luci::compute::FullyConnected::output ( float *  data)
inline

Definition at line 60 of file FullyConnected.h.

60{ _output_data = data; }

References flatbuffers::data().

◆ output_shape()

const loco::TensorShape & luci::compute::FullyConnected::output_shape ( void  ) const
inline

Definition at line 64 of file FullyConnected.h.

64{ return _output_shape; }

◆ params()

FullyConnectedParams & luci::compute::FullyConnected::params ( void  )
inline

Definition at line 35 of file FullyConnected.h.

35{ return _params; }

Referenced by compute().

◆ prepare()

bool luci::compute::FullyConnected::prepare ( void  )

Definition at line 37 of file FullyConnected.cpp.

38{
39 if (_input_shape.rank() < 1 || _weights_shape.rank() != 2)
40 return false;
41
42 auto const input_elems = element_count(&_input_shape);
43 auto const weights_height = _weights_shape.dim(0).value();
44 auto const weights_width = _weights_shape.dim(1).value();
45 if (weights_height == 0 || weights_width == 0)
46 return false;
47 if (input_elems % weights_width != 0)
48 return false;
49 auto const batch_size = input_elems / weights_width;
50 auto const num_units = weights_height;
51 if (_bias_data)
52 {
53 if (element_count(&_bias_shape) != num_units)
54 return false;
55 }
56
57 get_act_minmax(_fused_act_func, _params.float_activation_min, _params.float_activation_max);
58
59 if (_keep_num_dims)
60 {
61 _output_shape.rank(_input_shape.rank());
62 for (uint32_t i = 0; i < _input_shape.rank(); i++)
63 _output_shape.dim(i) = _input_shape.dim(i);
64 _output_shape.dim(_input_shape.rank() - 1) = num_units;
65 }
66 else
67 {
68 _output_shape.rank(2);
69 _output_shape.dim(0) = batch_size;
70 _output_shape.dim(1) = num_units;
71 }
72
73 return true;
74}
uint32_t value(void) const
Return the value.
Definition Dimension.h:51
const Dimension & dim(uint32_t axis) const
Definition TensorShape.h:38
uint32_t rank(void) const
Definition TensorShape.h:35
uint32_t element_count(const loco::TensorShape *tensor_shape)
Return the number of elements in a tensor of given shape.
void get_act_minmax(const FusedActFunc act, float &act_min, float &act_max)

References loco::TensorShape::dim(), luci::compute::FullyConnectedParams::float_activation_max, luci::compute::FullyConnectedParams::float_activation_min, luci::compute::get_act_minmax(), loco::TensorShape::rank(), and loco::Dimension::value().

◆ weights()

void luci::compute::FullyConnected::weights ( const loco::TensorShape shape,
const float *  data 
)
inline

Definition at line 46 of file FullyConnected.h.

47 {
48 _weights_shape = shape;
49 _weights_data = data;
50 }

References flatbuffers::data().


The documentation for this class was generated from the following files: