ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::FullyConnected Class Reference

#include <FullyConnected.h>

Collaboration diagram for luci_interpreter::kernels::FullyConnected:

Public Member Functions

 FullyConnected (const Tensor *input, const Tensor *weights, const Tensor *bias, Tensor *output, const FullyConnectedParams &params)
 
 FullyConnected (const Tensor *input, const Tensor *weights, const Tensor *bias, Tensor *output, Tensor *scratch, const FullyConnectedParams &params)
 
const Tensorinput () const
 
const Tensorweights () const
 
const Tensorbias () const
 
Tensoroutput () const
 
Tensorscratch () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< FullyConnectedParams >
const FullyConnectedParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< FullyConnectedParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const FullyConnectedParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< FullyConnectedParams >
const FullyConnectedParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file FullyConnected.h.

Constructor & Destructor Documentation

◆ FullyConnected() [1/2]

luci_interpreter::kernels::FullyConnected::FullyConnected ( const Tensor input,
const Tensor weights,
const Tensor bias,
Tensor output,
const FullyConnectedParams params 
)

Definition at line 31 of file FullyConnected.cpp.

33 : KernelWithParams<FullyConnectedParams>({input, weights, bias}, {output}, params)
34{
35}
const FullyConnectedParams & params() const
Definition Kernel.h:67

References bias(), input(), and weights().

◆ FullyConnected() [2/2]

luci_interpreter::kernels::FullyConnected::FullyConnected ( const Tensor input,
const Tensor weights,
const Tensor bias,
Tensor output,
Tensor scratch,
const FullyConnectedParams params 
)
inline

Definition at line 33 of file FullyConnected.h.

36 {
37 _scratch = scratch;
38 }
FullyConnected(const Tensor *input, const Tensor *weights, const Tensor *bias, Tensor *output, const FullyConnectedParams &params)

References scratch().

Member Function Documentation

◆ bias()

const Tensor * luci_interpreter::kernels::FullyConnected::bias ( ) const
inline

Definition at line 41 of file FullyConnected.h.

41{ return _inputs[2]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and FullyConnected().

◆ configure()

void luci_interpreter::kernels::FullyConnected::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 37 of file FullyConnected.cpp.

38{
39 if (weights()->element_type() == DataType::U8)
40 {
41 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::U8);
42 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::U8);
43 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::S32)
44 }
45 else if (weights()->element_type() == DataType::FLOAT32)
46 {
47 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32);
48 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
49 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32)
50 }
51 else if (weights()->element_type() == DataType::S8)
52 {
53 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::S8);
54 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S8);
55 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::S32)
56 }
57 else if (weights()->element_type() == DataType::S4)
58 {
59 // TODO support other combinations when needed
60 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32);
61 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
62 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32)
63 }
64 else if (weights()->element_type() == DataType::U4)
65 {
66 // TODO support other combinations when needed
67 LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32);
68 LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
69 LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32)
70 }
71 else
72 {
73 throw std::runtime_error("luci-intp FullyConnected(1) Unsupported type.");
74 }
75
76 const Shape &input_shape = input()->shape();
77 const Shape &weights_shape = weights()->shape();
78
79 LUCI_INTERPRETER_CHECK(weights_shape.num_dims() == 2);
80 LUCI_INTERPRETER_CHECK(bias() == nullptr ||
81 bias()->shape().num_elements() == weights_shape.dim(0));
82
83 LUCI_INTERPRETER_CHECK(input_shape.num_elements() % weights_shape.dim(1) == 0);
84 const int32_t batch_size = input_shape.num_elements() / weights_shape.dim(1);
85 const int32_t num_units = weights_shape.dim(0);
86
87 if (params().keep_num_dims == false)
88 {
89 output()->resize({batch_size, num_units});
90 }
91 else
92 {
93 luci_interpreter::Shape output_shape(input_shape.num_dims());
94 for (int i = 0; i < input_shape.num_dims(); ++i)
95 output_shape.dim(i) = input_shape.dim(i);
96 output_shape.dim(input_shape.num_dims() - 1) = num_units;
98 }
99}
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59
Definition Shape.h:28

References bias(), luci_interpreter::Shape::dim(), input(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), luci_interpreter::Shape::num_elements(), output(), output_shape, luci_interpreter::KernelWithParams< FullyConnectedParams >::params(), luci_interpreter::Tensor::resize(), luci_interpreter::Tensor::shape(), and weights().

◆ execute()

void luci_interpreter::kernels::FullyConnected::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 101 of file FullyConnected.cpp.

102{
103 const bool is_hybrid =
104 (input()->element_type() == DataType::FLOAT32 &&
105 (weights()->element_type() == DataType::S4 || weights()->element_type() == DataType::U4) &&
106 output()->element_type() == DataType::FLOAT32 &&
107 (!bias() || bias()->element_type() == DataType::FLOAT32));
108 if (is_hybrid)
109 {
110 switch (weights()->element_type())
111 {
112 case DataType::S4:
113 evalHybridWI4AF32();
114 break;
115 case DataType::U4:
116 evalHybridWU4AF32();
117 break;
118 default:
119 throw std::runtime_error("luci-intp FullyConnected(3) Unsupported type.");
120 }
121 }
122 else
123 {
124 switch (input()->element_type())
125 {
126 case DataType::U8:
127 evalQuantized();
128 break;
129 case DataType::S8:
130 evalQuantizedS8();
131 break;
132 case DataType::FLOAT32:
133 evalFloat();
134 break;
135 default:
136 throw std::runtime_error("luci-intp FullyConnected(2) Unsupported type.");
137 }
138 }
139}
DataType element_type() const
Definition Tensor.h:105

References bias(), luci_interpreter::Tensor::element_type(), input(), output(), and weights().

◆ input()

const Tensor * luci_interpreter::kernels::FullyConnected::input ( ) const
inline

Definition at line 39 of file FullyConnected.h.

39{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and FullyConnected().

◆ output()

Tensor * luci_interpreter::kernels::FullyConnected::output ( ) const
inline

Definition at line 42 of file FullyConnected.h.

42{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure(), and execute().

◆ scratch()

Tensor * luci_interpreter::kernels::FullyConnected::scratch ( ) const
inline

Definition at line 43 of file FullyConnected.h.

43{ return _scratch; }

Referenced by FullyConnected().

◆ weights()

const Tensor * luci_interpreter::kernels::FullyConnected::weights ( ) const
inline

Definition at line 40 of file FullyConnected.h.

40{ return _inputs[1]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and FullyConnected().


The documentation for this class was generated from the following files: