ONE - On-device Neural Engine
Loading...
Searching...
No Matches
mir_interpreter::Conv2DImpl< T > Struct Template Reference

Static Public Member Functions

static void run (const TensorVariant &input, const TensorVariant &kernel, const Conv2DOpAttributes &attributes, TensorVariant &result, const TensorVariant *fused_bias)
 

Detailed Description

template<typename T>
struct mir_interpreter::Conv2DImpl< T >

Definition at line 37 of file Conv2D.cpp.

Member Function Documentation

◆ run()

template<typename T >
void mir_interpreter::Conv2DImpl< T >::run ( const TensorVariant input,
const TensorVariant kernel,
const Conv2DOpAttributes attributes,
TensorVariant result,
const TensorVariant fused_bias 
)
static

Definition at line 45 of file Conv2D.cpp.

48{
49 const auto *input_data = reinterpret_cast<const T *>(input.atOffset(0));
50 const auto *kernel_data = reinterpret_cast<const T *>(kernel.atOffset(0));
51 auto *result_data = reinterpret_cast<T *>(result.atOffset(0));
52
53 const Shape &input_shape = input.getShape();
54 const Shape &output_shape = result.getShape();
55 const Shape &kernel_shape = kernel.getShape();
56
57 const std::vector<std::int32_t> &strides = attributes.strides;
58 const std::vector<std::int32_t> &padding_before = attributes.padding_before;
59 const std::int32_t num_groups = attributes.num_groups;
60 assert(attributes.data_format == DataFormat::NHWC);
61
62 const std::int32_t batch_size = output_shape.dim(0);
63 const std::int32_t output_height = output_shape.dim(1);
64 const std::int32_t output_width = output_shape.dim(2);
65 const std::int32_t kernel_height = kernel_shape.dim(1);
66 const std::int32_t kernel_width = kernel_shape.dim(2);
67 const std::int32_t input_height = input_shape.dim(1);
68 const std::int32_t input_width = input_shape.dim(2);
69
70 const std::int32_t num_in_channels = input_shape.dim(3);
71 const std::int32_t num_out_channels = output_shape.dim(3);
72
73 assert(num_in_channels % num_groups == 0);
74 assert(num_out_channels % num_groups == 0);
75
76 const std::int32_t out_group_size = num_out_channels / num_groups;
77 const std::int32_t in_group_size = num_in_channels / num_groups;
78
79 assert(kernel_shape.dim(3) == in_group_size);
80 assert(kernel_shape.dim(0) == num_out_channels);
81
82 for (std::int32_t batch = 0; batch < batch_size; ++batch)
83 {
84 for (std::int32_t out_y = 0; out_y < output_height; ++out_y)
85 {
86 for (std::int32_t out_x = 0; out_x < output_width; ++out_x)
87 {
88 for (std::int32_t group = 0; group < num_groups; ++group)
89 {
90 const std::int32_t out_group_offset = group * out_group_size;
91 const std::int32_t in_group_offset = group * in_group_size;
92
93 for (std::int32_t out_c = 0; out_c < out_group_size; ++out_c)
94 {
95 const std::int32_t in_y_origin = (out_y * strides[0]) - padding_before[0];
96 const std::int32_t in_x_origin = (out_x * strides[1]) - padding_before[1];
97
98 T sum = 0.0f;
99
100 for (std::int32_t kernel_y = 0; kernel_y < kernel_height; ++kernel_y)
101 {
102 for (std::int32_t kernel_x = 0; kernel_x < kernel_width; ++kernel_x)
103 {
104 for (std::int32_t in_c = 0; in_c < in_group_size; ++in_c)
105 {
106 const std::int32_t in_y = in_y_origin + kernel_y;
107 const std::int32_t in_x = in_x_origin + kernel_x;
108
109 if ((in_y >= 0 && in_y < input_height) && (in_x >= 0 && in_x < input_width))
110 {
111 const std::int32_t in_offset =
112 calcOffset(input_shape, batch, in_y, in_x, in_group_offset + in_c);
113 const std::int32_t kernel_offset =
114 calcOffset(kernel_shape, out_group_offset + out_c, kernel_y, kernel_x, in_c);
115 const T input_val = input_data[in_offset];
116 const T kernel_val = kernel_data[kernel_offset];
117 sum += kernel_val * input_val;
118 }
119 }
120 }
121 }
122
123 const std::int32_t out_offset =
124 calcOffset(output_shape, batch, out_y, out_x, out_group_offset + out_c);
125 result_data[out_offset] = sum;
126 }
127 }
128 }
129 }
130 }
131}
const luci_interpreter::RuntimeShape output_shape
result
Definition infer.py:103
list input_data
Definition infer.py:29
int32_t calcOffset(const Shape &shape, int32_t d0, int32_t d1, int32_t d2, int32_t d3)
Definition Utils.h:75
Definition Shape.h:28

References mir::TensorVariant::atOffset(), mir::Conv2DOpAttributes::data_format, mir::Shape::dim(), mir::TensorVariant::getShape(), mir::Conv2DOpAttributes::num_groups, output_shape, mir::Conv2DOpAttributes::padding_before, and mir::Conv2DOpAttributes::strides.

Referenced by package.infer.session::inference().


The documentation for this struct was generated from the following file: