ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::cpu::ops::MeanLayer Class Reference

#include <MeanLayer.h>

Collaboration diagram for onert::backend::cpu::ops::MeanLayer:

Public Member Functions

 MeanLayer ()
 
void MeanFloat32 ()
 
void MeanQuant8 ()
 
void configure (const IPortableTensor *input, const IPortableTensor *axes, IPortableTensor *output, bool keep_dims)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Protected Attributes

const IPortableTensor_input
 
const IPortableTensor_axes
 
IPortableTensor_output
 
bool _keep_dims
 

Detailed Description

Definition at line 33 of file MeanLayer.h.

Constructor & Destructor Documentation

◆ MeanLayer()

onert::backend::cpu::ops::MeanLayer::MeanLayer ( )

Definition at line 32 of file MeanLayer.cc.

32 : _input(nullptr), _axes(nullptr), _output(nullptr), _keep_dims(false)
33{
34 // DO NOTHING
35}
const IPortableTensor * _input
Definition MeanLayer.h:49
const IPortableTensor * _axes
Definition MeanLayer.h:50

Member Function Documentation

◆ configure()

void onert::backend::cpu::ops::MeanLayer::configure ( const IPortableTensor input,
const IPortableTensor axes,
IPortableTensor output,
bool  keep_dims 
)

Definition at line 64 of file MeanLayer.cc.

66{
67 _input = input;
68 _axes = axes;
70 _keep_dims = keep_dims;
71
72 if (_input->data_type() != OperandType::FLOAT32 &&
73 _input->data_type() != OperandType::QUANT_UINT8_ASYMM)
74 throw std::runtime_error{"Mean: unsupported data type"};
75}
ir::DataType data_type() const override final

References _axes, _input, _keep_dims, _output, and onert::backend::IPortableTensor::data_type().

◆ MeanFloat32()

void onert::backend::cpu::ops::MeanLayer::MeanFloat32 ( )

Definition at line 37 of file MeanLayer.cc.

38{
39 const auto inputShape = getShape(_input);
40 const auto axisVec = getReducerAxes(_axes);
41 bool axis_is_1_and_2 =
42 _keep_dims && inputShape.DimensionsCount() == 4 && axisVec.size() == 2 &&
43 ((axisVec[0] == 1 && axisVec[1] == 2) || (axisVec[0] == 2 && axisVec[1] == 1));
44
45 if (axis_is_1_and_2)
46 {
47 nnfw::cker::MeanAxis1And2(inputShape, getBuffer<float>(_input), getShape(_output),
48 getBuffer<float>(_output));
49 }
50 else
51 {
52 nnfw::cker::Mean(inputShape, getBuffer<float>(_input), getShape(_output),
53 getBuffer<float>(_output), axisVec);
54 }
55}
void MeanAxis1And2(const Shape &input_shape, const In *input_data, const Shape &output_shape, Out *output_data)
Definition ReduceMean.h:233
void Mean(const Shape &input_shape, const In *input_data, const Shape &output_shape, Out *output_data, const std::vector< int > &axes)
Definition ReduceMean.h:211
nnfw::cker::Shape getShape(const IPortableTensor *tensor)
std::vector< int32_t > getReducerAxes(const IPortableTensor *axes)

References _axes, _input, _keep_dims, _output, onert::backend::cpu::ops::getReducerAxes(), onert::backend::cpu::ops::getShape(), nnfw::cker::Mean(), and nnfw::cker::MeanAxis1And2().

Referenced by run().

◆ MeanQuant8()

void onert::backend::cpu::ops::MeanLayer::MeanQuant8 ( )

Definition at line 57 of file MeanLayer.cc.

58{
60 _input->data_zero_point(), getShape(_output), getBuffer<uint8_t>(_output),
62}
float data_scale() const override final
int32_t data_zero_point() const override final
void MeanQ8Asymm(const Shape &input_shape, const In *input_data, float input_scale, int32_t input_offset, const Shape &output_shape, Out *output_data, float output_scale, int32_t output_offset, const std::vector< int > &axes)
Definition ReduceMean.h:221

References _axes, _input, _output, onert::backend::IPortableTensor::data_scale(), onert::backend::IPortableTensor::data_zero_point(), onert::backend::cpu::ops::getReducerAxes(), onert::backend::cpu::ops::getShape(), and nnfw::cker::MeanQ8Asymm().

Referenced by run().

◆ run()

void onert::backend::cpu::ops::MeanLayer::run ( )
overridevirtual

Implements onert::exec::IFunction.

Definition at line 77 of file MeanLayer.cc.

78{
79 if (_input->data_type() == OperandType::FLOAT32)
80 {
82 }
83 else if (_input->data_type() == OperandType::QUANT_UINT8_ASYMM)
84 {
85 MeanQuant8();
86 }
87 else
88 {
89 throw std::runtime_error{"Mean: unsupported data type"};
90 }
91}

References _input, onert::backend::IPortableTensor::data_type(), MeanFloat32(), and MeanQuant8().

Referenced by onert::backend::train::ops::MeanLayer::forward(), and package.infer.session::inference().

Field Documentation

◆ _axes

const IPortableTensor* onert::backend::cpu::ops::MeanLayer::_axes
protected

◆ _input

const IPortableTensor* onert::backend::cpu::ops::MeanLayer::_input
protected

◆ _keep_dims

bool onert::backend::cpu::ops::MeanLayer::_keep_dims
protected

◆ _output

IPortableTensor* onert::backend::cpu::ops::MeanLayer::_output
protected

Definition at line 51 of file MeanLayer.h.

Referenced by configure(), MeanFloat32(), and MeanQuant8().


The documentation for this class was generated from the following files: