ONE - On-device Neural Engine
Loading...
Searching...
No Matches
mir_interpreter::AvgPool2DImpl< T > Class Template Reference

Static Public Member Functions

static void run (const mir::ops::AvgPool2DOp &op, const mir::TensorVariant &input_var, mir::TensorVariant &output)
 

Detailed Description

template<typename T>
class mir_interpreter::AvgPool2DImpl< T >

Definition at line 29 of file AvgPool2D.cpp.

Member Function Documentation

◆ run()

template<typename T >
void mir_interpreter::AvgPool2DImpl< T >::run ( const mir::ops::AvgPool2DOp op,
const mir::TensorVariant input_var,
mir::TensorVariant output 
)
static

Definition at line 37 of file AvgPool2D.cpp.

39{
40 const auto &input_shape = op.getInputShape(0);
41 const auto &output_shape = op.getOutputShape(0);
42 const auto &window_size = op.getWindowSize();
43 const auto &strides = op.getStrides();
44 const auto &padding_before = op.getPaddingBefore();
45 const auto &padding_after = op.getPaddingAfter();
46 (void)padding_after;
47
48 constexpr int num_spatial_dims = 2;
49 assert(input_var.getShape().rank() == 4);
50 assert(window_size.size() == num_spatial_dims);
51 assert(strides.size() == num_spatial_dims);
52 assert(padding_before.size() == num_spatial_dims);
53 assert(padding_after.size() == num_spatial_dims);
54
55 Tensor<T> res_accessor(output);
56 Tensor<T> input(input_var);
57
58 ShapeRange in_range(input_shape);
59 Index in_index(input_shape.rank());
60
61 for (const auto &out_index : ShapeRange(output_shape))
62 {
63 T result = 0;
64 size_t num_elements = 0;
65
66 // Assuming NHWC format.
67 in_index.at(0) = out_index.at(0);
68 in_index.at(3) = out_index.at(3);
69
70 for (const auto &window_index : ShapeRange(Shape(window_size)))
71 {
72 // Assuming NHWC format.
73 for (int i = 0; i < num_spatial_dims; ++i)
74 in_index.at(1 + i) =
75 out_index.at(1 + i) * strides[i] + window_index.at(i) - padding_before[i];
76
77 if (in_range.contains(in_index))
78 {
80 result += input.at(in_index);
81 }
82 else if (op.getIncludePad())
83 {
85 }
86 }
87
89 res_accessor.at(out_index) = result;
90 }
91}
const Shape & getInputShape(std::size_t index) const
Definition Operation.h:161
const Shape & getOutputShape(std::size_t index) const
Definition Operation.h:163
int32_t rank() const
Definition Shape.h:43
const Shape & getShape() const
const std::vector< std::int32_t > & getWindowSize() const
Definition AvgPool2DOp.h:45
const std::vector< std::int32_t > & getPaddingBefore() const
Definition AvgPool2DOp.h:49
const std::vector< std::int32_t > & getPaddingAfter() const
Definition AvgPool2DOp.h:51
bool getIncludePad() const
Definition AvgPool2DOp.h:53
const std::vector< std::int32_t > & getStrides() const
Definition AvgPool2DOp.h:47
const luci_interpreter::RuntimeShape output_shape
result
Definition infer.py:103
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59
Definition Shape.h:28

References mir::Tensor< T >::at(), mir::Index::at(), mir::ShapeRange::contains(), mir::ops::AvgPool2DOp::getIncludePad(), mir::Operation::getInputShape(), mir::Operation::getOutputShape(), mir::ops::AvgPool2DOp::getPaddingAfter(), mir::ops::AvgPool2DOp::getPaddingBefore(), mir::TensorVariant::getShape(), mir::ops::AvgPool2DOp::getStrides(), mir::ops::AvgPool2DOp::getWindowSize(), output_shape, and mir::Shape::rank().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following file: