ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
nnfw::cker::BatchMatMul Class Reference

#include <BatchMatMul.h>

Public Member Functions

 BatchMatMul ()
 
void prepare (const Shape &lhs_shape, const Shape &rhs_shape, bool adj_x, bool adj_y, bool rhs_const)
 Prepare temporary area for calculation.
 
void operator() (const Shape &lhs_shape, const float *lhs_data, const Shape &rhs_shape, const float *rhs_data, bool adj_x, bool adj_y, const Shape &, float *output_data)
 

Detailed Description

Definition at line 36 of file BatchMatMul.h.

Constructor & Destructor Documentation

◆ BatchMatMul()

nnfw::cker::BatchMatMul::BatchMatMul ( )
inline

Definition at line 39 of file BatchMatMul.h.

40 {
41 // DO NOTHING
42 }

Member Function Documentation

◆ operator()()

void nnfw::cker::BatchMatMul::operator() ( const Shape lhs_shape,
const float *  lhs_data,
const Shape rhs_shape,
const float *  rhs_data,
bool  adj_x,
bool  adj_y,
const Shape ,
float *  output_data 
)
inline

Definition at line 83 of file BatchMatMul.h.

86 {
87 // Don't need transpose if rhs is constant and already transposed
88 if (!adj_y && !(_rhs_constant && _rhs_transposed))
89 {
90 transposeRowsCols(rhs_shape, rhs_data, _temp_rhs_shape, _temp_rhs.data());
91 _rhs_transposed = true;
92 }
93
94 if (adj_x)
95 {
96 transposeRowsCols(lhs_shape, lhs_data, _temp_lhs_shape, _temp_lhs.data());
97 }
98
99 Shape new_lhs_shape = adj_x ? lhs_shape : swapRowColDims(lhs_shape);
100 Shape new_rhs_shape = adj_y ? rhs_shape : swapRowColDims(rhs_shape);
101 const float *new_lhs_data = adj_x ? _temp_lhs.data() : lhs_data;
102 const float *new_rhs_data = adj_y ? rhs_data : _temp_rhs.data();
103
104 // Note we pass RHS args first, LHS args second
105 // Check accumulative dimensions of lhs and rhs of are equal
106 assert(Shape::ExtendedShape(5, new_rhs_shape).Dims(4) ==
107 Shape::ExtendedShape(5, new_lhs_shape).Dims(3));
108
109 const BatchMatMulParams params{new_rhs_shape, new_lhs_shape};
110#if defined(CKER_X86_PLATFORM)
111 optimized::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data);
112#else
113 reference::BatchMatMul(params, new_rhs_data, new_lhs_data, output_data);
114#endif
115 }
void BatchMatMul(const BatchMatMulParams &params, const float *lhs_data, const float *rhs_data, float *output_data)
Definition BatchMatMul.h:32
Definition Dims.h:26
Definition Shape.h:28

References nnfw::cker::reference::BatchMatMul().

◆ prepare()

void nnfw::cker::BatchMatMul::prepare ( const Shape lhs_shape,
const Shape rhs_shape,
bool  adj_x,
bool  adj_y,
bool  rhs_const 
)
inline

Prepare temporary area for calculation.

Definition at line 47 of file BatchMatMul.h.

49 {
50 if (adj_x)
51 {
52 int32_t rank = lhs_shape.DimensionsCount();
53 _temp_lhs_shape.Resize(rank);
54
55 for (int32_t i = 0; i < rank - 2; i++)
56 {
57 _temp_lhs_shape.SetDim(i, lhs_shape.Dims(i));
58 }
59 _temp_lhs_shape.SetDim(rank - 2, lhs_shape.Dims(rank - 1));
60 _temp_lhs_shape.SetDim(rank - 1, lhs_shape.Dims(rank - 2));
61
62 _temp_lhs.resize(_temp_lhs_shape.FlatSize());
63 }
64
65 if (!adj_y)
66 {
67 int32_t rank = rhs_shape.DimensionsCount();
68 _temp_rhs_shape.Resize(rank);
69
70 for (int32_t i = 0; i < rank - 2; i++)
71 {
72 _temp_rhs_shape.SetDim(i, rhs_shape.Dims(i));
73 }
74 _temp_rhs_shape.SetDim(rank - 2, rhs_shape.Dims(rank - 1));
75 _temp_rhs_shape.SetDim(rank - 1, rhs_shape.Dims(rank - 2));
76
77 _temp_rhs.resize(_temp_rhs_shape.FlatSize());
78 }
79
80 _rhs_constant = rhs_const;
81 }
int FlatSize() const
Definition Shape.h:181
void Resize(int dimensions_count)
Definition Shape.h:117
void SetDim(int i, int32_t val)
Definition Shape.h:98

References nnfw::cker::Shape::DimensionsCount(), nnfw::cker::Shape::Dims(), nnfw::cker::Shape::FlatSize(), nnfw::cker::Shape::Resize(), and nnfw::cker::Shape::SetDim().

Referenced by onert::backend::cpu::ops::BatchMatMulLayer::batchMatMulFloat32().


The documentation for this class was generated from the following file: