ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::Mean Class Reference

#include <Mean.h>

Collaboration diagram for luci_interpreter::kernels::Mean:

Public Member Functions

 Mean (const Tensor *input, const Tensor *axes, Tensor *output, Tensor *temp_index, Tensor *resolved_axes, Tensor *temp_sum, const ReducerParams &params)
 
const Tensorinput () const
 
const Tensoraxes () const
 
Tensoroutput () const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::KernelWithParams< ReducerParams >
const ReducerParamsparams () const
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::KernelWithParams< ReducerParams >
 KernelWithParams (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, const ReducerParams &params)
 
- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::KernelWithParams< ReducerParams >
const ReducerParams _params
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 30 of file Mean.h.

Constructor & Destructor Documentation

◆ Mean()

luci_interpreter::kernels::Mean::Mean ( const Tensor input,
const Tensor axes,
Tensor output,
Tensor temp_index,
Tensor resolved_axes,
Tensor temp_sum,
const ReducerParams params 
)

Definition at line 126 of file Mean.cpp.

128 : KernelWithParams<ReducerParams>({input, axes}, {output, temp_index, resolved_axes, temp_sum},
129 params)
130{
131}
const ReducerParams & params() const
Definition Kernel.h:67
const Tensor * axes() const
Definition Mean.h:37
Tensor * output() const
Definition Mean.h:38
const Tensor * input() const
Definition Mean.h:36

References axes(), and input().

Member Function Documentation

◆ axes()

const Tensor * luci_interpreter::kernels::Mean::axes ( ) const
inline

Definition at line 37 of file Mean.h.

37{ return _inputs[1]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and Mean().

◆ configure()

void luci_interpreter::kernels::Mean::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 133 of file Mean.cpp.

134{
135 LUCI_INTERPRETER_CHECK(input()->element_type() == output()->element_type());
136 LUCI_INTERPRETER_CHECK(axes()->element_type() == DataType::S32);
137 if (input()->element_type() == DataType::S16)
138 {
139 LUCI_INTERPRETER_CHECK(input()->zero_point() == 0 && output()->zero_point() == 0);
140 }
141
142 const Shape &input_shape = input()->shape();
143 int input_num_dims = input_shape.num_dims();
144
145 const auto *axes_data = getTensorData<int32_t>(axes());
146 int num_axes = axes()->shape().num_elements();
147 assert(num_axes <= 4);
148
149 Shape output_shape = getOutputShape(input_shape, axes_data, num_axes, _params.keep_dims);
151
152 tflite::MeanParams params{};
153 resolveAxes(axes_data, num_axes, &params);
154 _need_temporaries = !(
155 _params.keep_dims && input_num_dims == 4 && params.axis_count == 2 &&
156 ((params.axis[0] == 1 && params.axis[1] == 2) || (params.axis[0] == 2 && params.axis[1] == 1)));
157 if (_need_temporaries)
158 {
159 auto temp_index = getOutputTensors()[1];
160 auto resolved_axes = getOutputTensors()[2];
161 auto temp_sum = getOutputTensors()[3];
162
163 temp_index->resize(Shape(input_num_dims));
164 resolved_axes->resize(Shape(num_axes));
165 temp_sum->resize(output()->shape());
166 }
167 else
168 {
169 auto temp_index = getOutputTensors()[1];
170 auto resolved_axes = getOutputTensors()[2];
171 auto temp_sum = getOutputTensors()[3];
172
173 temp_index->set_allocatable(false);
174 resolved_axes->set_allocatable(false);
175 temp_sum->set_allocatable(false);
176 }
177}
const std::vector< Tensor * > & getOutputTensors() const
Definition Kernel.h:40
int32_t num_elements() const
Definition Tensor.h:53
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
const luci_interpreter::RuntimeShape output_shape
Definition Shape.h:28

References luci_interpreter::KernelWithParams< ReducerParams >::_params, axes(), luci_interpreter::Kernel::getOutputTensors(), input(), luci_interpreter::ReducerParams::keep_dims, LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_dims(), luci_interpreter::Shape::num_elements(), output(), output_shape, luci_interpreter::KernelWithParams< ReducerParams >::params(), luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ execute()

void luci_interpreter::kernels::Mean::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 179 of file Mean.cpp.

180{
181 switch (input()->element_type())
182 {
183 case DataType::FLOAT32:
184 evalFloat();
185 break;
186 case DataType::U8:
187 evalQuantized();
188 break;
189 case DataType::S16:
190 evalQuantizedS16();
191 break;
192 default:
193 throw std::runtime_error("luci-intp Mean Unsupported type.");
194 }
195}

References input().

◆ input()

const Tensor * luci_interpreter::kernels::Mean::input ( ) const
inline

Definition at line 36 of file Mean.h.

36{ return _inputs[0]; }

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), execute(), and Mean().

◆ output()

Tensor * luci_interpreter::kernels::Mean::output ( ) const
inline

Definition at line 38 of file Mean.h.

38{ return _outputs[0]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by configure().


The documentation for this class was generated from the following files: