ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnfw::cker::MatMulBCast Class Reference

#include <MatmulBCast.h>

Public Member Functions

 MatMulBCast (Shape &shape_x, Shape &shape_y)
 
bool IsValid () const
 
int32_t x_batch_size () const
 
int32_t y_batch_size () const
 
int32_t output_batch_size () const
 
const Shapeoutput_batch_shape () const
 

Detailed Description

Definition at line 38 of file MatmulBCast.h.

Constructor & Destructor Documentation

◆ MatMulBCast()

nnfw::cker::MatMulBCast::MatMulBCast ( Shape shape_x,
Shape shape_y 
)
inline

Definition at line 41 of file MatmulBCast.h.

42 {
43 if (shape_x.DimensionsCount() < 2 || shape_y.DimensionsCount() < 2)
44 return;
45
46 std::vector<int32_t> x;
47 std::vector<int32_t> y;
48
49 x.resize(shape_x.DimensionsCount() - 2);
50 y.resize(shape_y.DimensionsCount() - 2);
51
52 for (size_t i = 0; i < x.size(); i++)
53 {
54 x[i] = shape_x.Dims(i);
55 }
56 for (size_t i = 0; i < y.size(); i++)
57 {
58 y[i] = shape_y.Dims(i);
59 }
60
61 _batch_bcast = std::make_unique<BCast>(std::move(x), std::move(y));
62 if (!_batch_bcast->IsValid())
63 return;
64
65 const auto &x_reshaped = _batch_bcast->x_reshape();
66 const auto &y_reshaped = _batch_bcast->y_reshape();
67 auto output_shape = _batch_bcast->output_shape();
68
69 _x_batch_size = std::accumulate(x_reshaped.cbegin(), x_reshaped.cend(), INT32_C(1),
70 std::multiplies<int32_t>());
71 _y_batch_size = std::accumulate(y_reshaped.cbegin(), y_reshaped.cend(), INT32_C(1),
72 std::multiplies<int32_t>());
73 _output_shape.ReplaceWith(output_shape.size(), output_shape.data());
74 _output_batch_size = _output_shape.FlatSize();
75 }
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
Definition Shape.h:130
int FlatSize() const
Definition Shape.h:181
const luci_interpreter::RuntimeShape output_shape

References nnfw::cker::Shape::DimensionsCount(), nnfw::cker::Shape::Dims(), nnfw::cker::Shape::FlatSize(), output_shape, and nnfw::cker::Shape::ReplaceWith().

Member Function Documentation

◆ IsValid()

bool nnfw::cker::MatMulBCast::IsValid ( ) const
inline

Definition at line 77 of file MatmulBCast.h.

77{ return (_batch_bcast != nullptr) && _batch_bcast->IsValid(); }

◆ output_batch_shape()

const Shape & nnfw::cker::MatMulBCast::output_batch_shape ( ) const
inline

Definition at line 81 of file MatmulBCast.h.

81{ return _output_shape; }

◆ output_batch_size()

int32_t nnfw::cker::MatMulBCast::output_batch_size ( ) const
inline

Definition at line 80 of file MatmulBCast.h.

80{ return _output_batch_size; }

◆ x_batch_size()

int32_t nnfw::cker::MatMulBCast::x_batch_size ( ) const
inline

Definition at line 78 of file MatmulBCast.h.

78{ return _x_batch_size; }

◆ y_batch_size()

int32_t nnfw::cker::MatMulBCast::y_batch_size ( ) const
inline

Definition at line 79 of file MatmulBCast.h.

79{ return _y_batch_size; }

The documentation for this class was generated from the following file: