ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnfw::cker::BatchMatMul Class Reference

#include <BatchMatMul.h>

Public Member Functions

 BatchMatMul ()
 
void prepare (const Shape &lhs_shape, const Shape &rhs_shape, bool adj_x, bool adj_y)
 Prepare temporary area for calculation.
 
void operator() (const Shape &lhs_shape, const float *lhs_data, const Shape &rhs_shape, const float *rhs_data, bool adj_x, bool adj_y, const Shape &, float *output_data)
 

Detailed Description

Definition at line 36 of file BatchMatMul.h.

Constructor & Destructor Documentation

◆ BatchMatMul()

nnfw::cker::BatchMatMul::BatchMatMul ( )
inline

Definition at line 39 of file BatchMatMul.h.

40 {
41 // DO NOTHING
42 }

Member Function Documentation

◆ operator()()

void nnfw::cker::BatchMatMul::operator() ( const Shape lhs_shape,
const float *  lhs_data,
const Shape rhs_shape,
const float *  rhs_data,
bool  adj_x,
bool  adj_y,
const Shape ,
float *  output_data 
)
inline

Definition at line 80 of file BatchMatMul.h.

83 {
84 // Assume lhs and rhs is not constant
85 // TODO Handle constant input
86
87 if (!adj_y)
88 {
89 transposeRowsCols(rhs_shape, rhs_data, _temp_rhs_shape, _temp_rhs.data());
90 }
91
92 if (adj_x)
93 {
94 transposeRowsCols(lhs_shape, lhs_data, _temp_lhs_shape, _temp_lhs.data());
95 }
96
97 Shape new_lhs_shape = adj_x ? lhs_shape : swapRowColDims(lhs_shape);
98 Shape new_rhs_shape = adj_y ? rhs_shape : swapRowColDims(rhs_shape);
99 const float *new_lhs_data = adj_x ? _temp_lhs.data() : lhs_data;
100 const float *new_rhs_data = adj_y ? rhs_data : _temp_rhs.data();
101
102 // Note we pass RHS args first, LHS args second
103 // Check accumulative dimensions of lhs and rhs of are equal
104 assert(Shape::ExtendedShape(5, new_rhs_shape).Dims(4) ==
105 Shape::ExtendedShape(5, new_lhs_shape).Dims(3));
106
107 const BatchMatMulParams params{new_rhs_shape, new_lhs_shape};
108#if defined(CKER_X86_PLATFORM)
109 optimized::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data);
110#else
111 reference::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data);
112#endif
113 }
void BatchMatMul(const BatchMatMulParams &params, const float *lhs_data, const float *rhs_data, float *output_data)
Definition BatchMatMul.h:32
Definition Dims.h:26
Definition Shape.h:28

References nnfw::cker::reference::BatchMatMul().

◆ prepare()

void nnfw::cker::BatchMatMul::prepare ( const Shape lhs_shape,
const Shape rhs_shape,
bool  adj_x,
bool  adj_y 
)
inline

Prepare temporary area for calculation.

Definition at line 47 of file BatchMatMul.h.

48 {
49 if (adj_x)
50 {
51 int32_t rank = lhs_shape.DimensionsCount();
52 _temp_lhs_shape.Resize(rank);
53
54 for (int32_t i = 0; i < rank - 2; i++)
55 {
56 _temp_lhs_shape.SetDim(i, lhs_shape.Dims(i));
57 }
58 _temp_lhs_shape.SetDim(rank - 2, lhs_shape.Dims(rank - 1));
59 _temp_lhs_shape.SetDim(rank - 1, lhs_shape.Dims(rank - 2));
60
61 _temp_lhs.resize(_temp_lhs_shape.FlatSize());
62 }
63
64 if (!adj_y)
65 {
66 int32_t rank = rhs_shape.DimensionsCount();
67 _temp_rhs_shape.Resize(rank);
68
69 for (int32_t i = 0; i < rank - 2; i++)
70 {
71 _temp_rhs_shape.SetDim(i, rhs_shape.Dims(i));
72 }
73 _temp_rhs_shape.SetDim(rank - 2, rhs_shape.Dims(rank - 1));
74 _temp_rhs_shape.SetDim(rank - 1, rhs_shape.Dims(rank - 2));
75
76 _temp_rhs.resize(_temp_rhs_shape.FlatSize());
77 }
78 }
int FlatSize() const
Definition Shape.h:181
void Resize(int dimensions_count)
Definition Shape.h:117
void SetDim(int i, int32_t val)
Definition Shape.h:98

References nnfw::cker::Shape::DimensionsCount(), nnfw::cker::Shape::Dims(), nnfw::cker::Shape::FlatSize(), nnfw::cker::Shape::Resize(), and nnfw::cker::Shape::SetDim().

Referenced by onert::backend::cpu::ops::BatchMatMulLayer::batchMatMulFloat32().


The documentation for this class was generated from the following file: