ONE - On-device Neural Engine
Loading...
Searching...
No Matches
mir_interpreter::FullyConnectedImpl< T > Struct Template Reference

Static Public Member Functions

static void run (const mir::TensorVariant &inputv, const mir::TensorVariant &weightsv, const mir::ops::FullyConnectedOp &op, mir::TensorVariant &res, const mir::TensorVariant *biasv)
 

Detailed Description

template<typename T>
struct mir_interpreter::FullyConnectedImpl< T >

Definition at line 59 of file FullyConnected.cpp.

Member Function Documentation

◆ run()

template<typename T >
void mir_interpreter::FullyConnectedImpl< T >::run ( const mir::TensorVariant inputv,
const mir::TensorVariant weightsv,
const mir::ops::FullyConnectedOp op,
mir::TensorVariant res,
const mir::TensorVariant biasv 
)
static

Definition at line 67 of file FullyConnected.cpp.

71{
72 if (biasv)
73 {
74 throw std::runtime_error("non-quantized FullyConnected with fused bias is unsupported");
75 }
76
77 mir::Tensor<T> input{inputv};
78 mir::Tensor<T> weights{weightsv};
79
80 erase<T>(res);
81
82 if (input.getShape().rank() == 2 && weights.getShape().rank() == 2 && res.getShape().rank() == 2)
83 {
84 // optimized case for 2d matrix multiplication
85 fullyConnected2D<T>(inputv, weightsv, res);
86 return;
87 }
88
89 mir::Tensor<T> accessor(res);
90
91 const mir::Shape &in_shape = input.getShape();
92 int32_t in_rank = in_shape.rank();
93
94 const mir::Shape &w_shape = weights.getShape();
95 int32_t w_rank = w_shape.rank();
96
97 assert(in_shape.dim(in_rank - 1) == w_shape.dim(w_rank - 2));
98 (void)in_rank;
99
100 mir::ShapeRange out_range(res.getShape());
101
102 int32_t len = w_shape.dim(w_rank - 2);
103
104 for (auto &out_index : out_range)
105 {
106 mir::Index t_index = out_index;
107 T &output_element = accessor.at(out_index);
108 int32_t col = t_index.at(w_rank - 1);
109 int32_t row = t_index.at(w_rank - 2);
110 for (int32_t i = 0; i < len; ++i)
111 {
112 t_index.at(w_rank - 1) = i;
113 T in = input.at(t_index);
114 t_index.at(w_rank - 1) = col;
115 t_index.at(w_rank - 2) = i;
116 T w = weights.at(t_index);
117 t_index.at(w_rank - 2) = row;
118 output_element += in * w;
119 }
120 }
121}
int32_t & at(int32_t axis)
return position on given axis
Definition Index.h:64
int32_t & dim(int32_t axis) noexcept
Definition Shape.h:47
int32_t rank() const
Definition Shape.h:43
const Shape & getShape() const

References mir::Tensor< T >::at(), mir::TensorVariant::at(), mir::Index::at(), mir::Shape::dim(), mir::TensorVariant::getShape(), and mir::Shape::rank().

Referenced by package.infer.session::inference().


The documentation for this struct was generated from the following file: