ONE - On-device Neural Engine
Loading...
Searching...
No Matches
mir_interpreter::DepthwiseConv2DImpl< T > Struct Template Reference

Static Public Member Functions

static void run (const mir::ops::DepthwiseConv2DOp &op, const mir::TensorVariant &inputv, const mir::TensorVariant &kernelv, const mir::TensorVariant *biasv, mir::TensorVariant &output)
 

Detailed Description

template<typename T>
struct mir_interpreter::DepthwiseConv2DImpl< T >

Definition at line 32 of file DepthwiseConv2D.cpp.

Member Function Documentation

◆ run()

template<typename T >
void mir_interpreter::DepthwiseConv2DImpl< T >::run ( const mir::ops::DepthwiseConv2DOp op,
const mir::TensorVariant inputv,
const mir::TensorVariant kernelv,
const mir::TensorVariant biasv,
mir::TensorVariant output 
)
static

Definition at line 40 of file DepthwiseConv2D.cpp.

44{
45 const Shape &in_shape = op.getInputShape(0);
46 const Shape &kernel_shape = op.getInputShape(1);
47 const Shape &out_shape = op.getOutputShape(0);
48 const auto &strides = op.getStrides();
49 const std::vector<int32_t> &pads = op.getPaddingBefore();
50
51 assert(in_shape.rank() == 4);
52 assert(kernel_shape.rank() == 4);
53 assert(kernel_shape.dim(2) == in_shape.dim(3));
54 assert(in_shape.dim(3) * kernel_shape.dim(3) == out_shape.dim(3));
55 assert(strides.size() == 2);
56 assert(pads.size() == 2);
57
58 int32_t channel_multiplier = kernel_shape.dim(3);
59
60 Tensor<T> res_accessor(output);
61 Tensor<T> input(inputv);
62 Tensor<T> bias(*biasv);
63 Tensor<T> kernel(kernelv);
64
65 ShapeRange in_range(in_shape);
66 ShapeRange kernel_range(kernel_shape);
67 ShapeRange out_range(Shape{out_shape.dim(0), out_shape.dim(1), out_shape.dim(2), 1});
68
69 Index in_index;
70 in_index.resize(4);
71
72 erase<T>(output);
73
74 for (const auto &out_index : out_range)
75 {
76 Index out_index_k = out_index;
77 for (const auto &kernel_index : kernel_range)
78 {
79 in_index.at(0) = out_index.at(0);
80 for (int i = 0; i < 2; ++i)
81 in_index.at(1 + i) = out_index.at(1 + i) * strides[i] + kernel_index.at(i) - pads[i];
82 in_index.at(3) = kernel_index.at(2);
83
84 if (in_range.contains(in_index))
85 {
86 out_index_k.at(3) = kernel_index.at(2) * channel_multiplier + kernel_index.at(3);
87 res_accessor.at(out_index_k) += input.at(in_index) * kernel.at(kernel_index);
88 }
89 }
90 }
91}
const Shape & getInputShape(std::size_t index) const
Definition Operation.h:161
const Shape & getOutputShape(std::size_t index) const
Definition Operation.h:163
const std::vector< std::int32_t > & getStrides() const
const std::vector< std::int32_t > & getPaddingBefore() const
Definition Shape.h:28

References mir::Tensor< T >::at(), mir::Index::at(), mir::ShapeRange::contains(), mir::Shape::dim(), mir::Operation::getInputShape(), mir::Operation::getOutputShape(), mir::ops::DepthwiseConv2DOp::getPaddingBefore(), mir::ops::DepthwiseConv2DOp::getStrides(), mir::Shape::rank(), and mir::Index::resize().

Referenced by package.infer.session::inference().


The documentation for this struct was generated from the following file: