ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert_micro::train::pal Namespace Reference

Functions

void Conv2DInputGrad (const core::FloatConv2D &params, const core::OMRuntimeShape &weight_shape, const float *weight_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_dinput_shape, float *dloss_dinput_data)
 
void Conv2DBiasGrad (const core::OMRuntimeShape &dloss_doutput_shape, const float *dloss_doutput_data, float *dloss_dbias_data)
 
void Conv2DWeightGrad (const core::FloatConv2D &params, const core::OMRuntimeShape &input_shape, const float *input_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_dweight_shape, float *dloss_dweight_data, core::OpTrainableRankType rank)
 
void FullyConnectedInputGrad (const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *weight_data, const core::OMRuntimeShape &weight_shape, float *dloss_dinput_data)
 
void FullyConnectedWeightGrad (const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *input_data, const core::OMRuntimeShape &input_shape, float *dloss_dweight_data, const core::OMRuntimeShape &weight_shape, core::OpTrainableRankType rank)
 
OMStatus GRUWeightGrads (const float *output_grad_data, const float *weight_input_data, float *weight_input_grad_data, const float *weight_hidden_data, float *weight_hidden_grad_data, const float *bias_input_data, float *bias_input_grad_data, const float *bias_hidden_data, float *bias_hidden_grad_data, const float *input_data, float *input_grad_data, float *state_grad_data, const core::OMRuntimeShape &input_shape, const core::OMRuntimeShape &output_shape, const core::OMRuntimeShape &weight_input_shape, const core::OMRuntimeShape &weight_hidden_shape, const core::OMRuntimeShape &output_shape_fc, float *intermediate_buffer, float *left_fc_output_grad_buffer, float *right_fc_output_grad_buffer)
 
void MaxPool2D (const core::Pool2DParams &params, const core::OMRuntimeShape &input_shape, const float *input_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *dloss_doutput_data, float *dloss_dinput_data)
 
void ReluInputGrad (const float *input_relu_data, float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape)
 
void SoftmaxInputGrad (const float *dloss_doutput_data, const core::OMRuntimeShape &dloss_doutput_shape, const float *calculated_data, float *jacobian_row_data, float *dloss_dinput_data)
 

Function Documentation

◆ Conv2DBiasGrad()

void onert_micro::train::pal::Conv2DBiasGrad ( const core::OMRuntimeShape dloss_doutput_shape,
const float *  dloss_doutput_data,
float *  dloss_dbias_data 
)

Definition at line 31 of file PALConv2DWeightGrad.h.

33{
34 assert(dloss_doutput_shape.dimensionsCount() == 4);
35 assert(dloss_doutput_shape.dims(0) == 1);
36 const int dloss_doutput_h = dloss_doutput_shape.dims(1);
37 const int dloss_doutput_w = dloss_doutput_shape.dims(2);
38 const int dloss_doutput_d = dloss_doutput_shape.dims(3);
39
40 // Reduce sum over last dim
41 for (uint32_t oc = 0; oc < dloss_doutput_d; ++oc)
42 {
43 float total = 0.f;
44 for (uint32_t h = 0; h < dloss_doutput_h; ++h)
45 {
46 for (uint32_t w = 0; w < dloss_doutput_w; ++w)
47 {
48 uint32_t offset = oc + w * dloss_doutput_d + h * dloss_doutput_w * dloss_doutput_d;
49 assert(offset < dloss_doutput_shape.flatSize());
50 total +=
51 dloss_doutput_data[oc + w * dloss_doutput_d + h * dloss_doutput_w * dloss_doutput_d];
52 }
53 }
54 dloss_dbias_data[oc] = total;
55 }
56}
__global uchar * offset(const Image *img, int x, int y)
Definition helpers.h:540

References onert_micro::core::OMRuntimeShape::dimensionsCount(), onert_micro::core::OMRuntimeShape::dims(), onert_micro::core::OMRuntimeShape::flatSize(), and offset().

◆ Conv2DInputGrad()

void onert_micro::train::pal::Conv2DInputGrad ( const core::FloatConv2D params,
const core::OMRuntimeShape weight_shape,
const float *  weight_data,
const core::OMRuntimeShape dloss_doutput_shape,
const float *  dloss_doutput_data,
const core::OMRuntimeShape dloss_dinput_shape,
float *  dloss_dinput_data 
)

Definition at line 60 of file PALConv2DInputGrad.h.

64{
65 const int stride_width = params.stride_w;
66 const int stride_height = params.stride_h;
67 const int dilation_width_factor = params.dilation_width_factor;
68 const int dilation_height_factor = params.dilation_height_factor;
69 const int pad_width = 0;
70 const int pad_height = 0;
71
72 const int weight_h = weight_shape.dims(1);
73 const int weight_w = weight_shape.dims(2);
74 const int weight_d = weight_shape.dims(3);
75 const int dloss_doutput_h = dloss_doutput_shape.dims(1);
76 const int dloss_doutput_w = dloss_doutput_shape.dims(2);
77 const int dloss_doutput_d = dloss_doutput_shape.dims(3);
78 const int dloss_dinput_h = dloss_dinput_shape.dims(1);
79 const int dloss_dinput_w = dloss_dinput_shape.dims(2);
80 const int dloss_dinput_d = dloss_dinput_shape.dims(3);
81
82 auto *n_c_weight_data = const_cast<float *>(weight_data);
83
84 for (uint32_t oc = 0; oc < dloss_dinput_d; ++oc)
85 {
86 for (uint32_t ic = 0; ic < dloss_doutput_d; ++ic)
87 {
88 rotate_180(n_c_weight_data, weight_h, weight_w, ic, oc, dloss_dinput_d);
89 for (int out_y = 0; out_y < dloss_dinput_h; ++out_y)
90 {
91 for (int out_x = 0; out_x < dloss_dinput_w; ++out_x)
92 {
93 const int in_x_origin = (out_x * stride_width) - pad_width;
94 const int in_y_origin = (out_y * stride_height) - pad_height;
95 float total = 0.f;
96
97 for (int filter_y = 0; filter_y < weight_h; ++filter_y)
98 {
99 for (int filter_x = 0; filter_x < weight_w; ++filter_x)
100 {
101 const int in_x = in_x_origin + dilation_width_factor * filter_x;
102 const int in_y = in_y_origin + dilation_height_factor * filter_y;
103 // If the location is outside the bounds of the input image,
104 // use zero as a default value.
105 if ((in_x >= 0) && (in_x < dloss_doutput_w) && (in_y >= 0) &&
106 (in_y < dloss_doutput_h))
107 {
108 uint32_t input_offset =
109 in_x * dloss_doutput_d + in_y * dloss_doutput_w * dloss_doutput_d + ic;
110 uint32_t filter_offset = oc + filter_x * dloss_dinput_d +
111 filter_y * weight_w * dloss_dinput_d +
112 ic * weight_w * dloss_dinput_d * weight_h;
113 assert(input_offset < dloss_doutput_shape.flatSize());
114 float input_value = dloss_doutput_data[input_offset];
115 float filter_value = n_c_weight_data[filter_offset];
116 total += (input_value * filter_value);
117 }
118 }
119 }
120 uint32_t output_offset =
121 oc + dloss_dinput_d * out_x + out_y * dloss_dinput_d * dloss_dinput_w;
122 assert(output_offset < dloss_dinput_shape.flatSize());
123 dloss_dinput_data[output_offset] = total;
124 }
125 }
126 // Rotate back
127 rotate_180(n_c_weight_data, weight_h, weight_w, ic, oc, dloss_dinput_d);
128 }
129 }
130}

References onert_micro::core::FloatConv2D::dilation_height_factor, onert_micro::core::FloatConv2D::dilation_width_factor, onert_micro::core::OMRuntimeShape::dims(), onert_micro::core::OMRuntimeShape::flatSize(), onert_micro::core::FloatConv2D::stride_h, and onert_micro::core::FloatConv2D::stride_w.

◆ Conv2DWeightGrad()

void onert_micro::train::pal::Conv2DWeightGrad ( const core::FloatConv2D params,
const core::OMRuntimeShape input_shape,
const float *  input_data,
const core::OMRuntimeShape dloss_doutput_shape,
const float *  dloss_doutput_data,
const core::OMRuntimeShape dloss_dweight_shape,
float *  dloss_dweight_data,
core::OpTrainableRankType  rank 
)

Definition at line 58 of file PALConv2DWeightGrad.h.

63{
64 const int stride_width = params.stride_w;
65 const int stride_height = params.stride_h;
66 const int dilation_width_factor = params.dilation_width_factor;
67 const int dilation_height_factor = params.dilation_height_factor;
68 const int pad_width = 0;
69 const int pad_height = 0;
70
71 const int batches = dloss_doutput_shape.dims(0);
72 const int input_h = input_shape.dims(1);
73 const int input_w = input_shape.dims(2);
74 const int input_d = input_shape.dims(3);
75 const int dloss_doutput_h = dloss_doutput_shape.dims(1);
76 const int dloss_doutput_w = dloss_doutput_shape.dims(2);
77 const int dloss_doutput_d = dloss_doutput_shape.dims(3);
78 const int dloss_dweight_h = dloss_dweight_shape.dims(1);
79 const int dloss_dweight_w = dloss_dweight_shape.dims(2);
80 const int dloss_dweight_d = dloss_dweight_shape.dims(3);
81 const int dloss_dweight_o = dloss_dweight_shape.dims(0);
82
83 auto depth_bounds = execute::pal::getUpLowerWeightTensorDepth(rank, dloss_doutput_d);
84
85 for (uint32_t oc = 0; oc < dloss_dweight_o; ++oc)
86 {
87 for (uint32_t ic = 0; ic < input_d; ++ic)
88 {
89 for (int out_y = 0; out_y < dloss_dweight_h; ++out_y)
90 {
91 for (int out_x = 0; out_x < dloss_dweight_w; ++out_x)
92 {
93 const int in_x_origin = (out_x * stride_width) - pad_width;
94 const int in_y_origin = (out_y * stride_height) - pad_height;
95 float total = 0.f;
96
97 for (int filter_y = 0; filter_y < dloss_doutput_h; ++filter_y)
98 {
99 for (int filter_x = 0; filter_x < dloss_doutput_w; ++filter_x)
100 {
101 const int in_x = in_x_origin + dilation_width_factor * filter_x;
102 const int in_y = in_y_origin + dilation_height_factor * filter_y;
103 // If the location is outside the bounds of the input image,
104 // use zero as a default value.
105 if ((in_x >= 0) && (in_x < input_w) && (in_y >= 0) && (in_y < input_h))
106 {
107 uint32_t input_offset = in_x * input_d + in_y * input_w * input_d + ic;
108 uint32_t filter_offset =
109 oc + filter_x * dloss_doutput_d + filter_y * dloss_doutput_w * dloss_doutput_d;
110 assert(input_offset < input_shape.flatSize());
111 assert(filter_offset < dloss_doutput_shape.flatSize());
112 float input_value = input_data[input_offset];
113 float filter_value = dloss_doutput_data[filter_offset];
114 total += (input_value * filter_value);
115 }
116 }
117 }
118 uint32_t output_offset = ic + input_d * out_x + input_d * dloss_dweight_w * out_y +
119 input_d * dloss_dweight_w * dloss_dweight_h * oc;
120 assert(output_offset < dloss_dweight_shape.flatSize());
121 dloss_dweight_data[output_offset] = total;
122 }
123 }
124 }
125 }
126}

References onert_micro::core::FloatConv2D::dilation_height_factor, onert_micro::core::FloatConv2D::dilation_width_factor, onert_micro::core::OMRuntimeShape::dims(), onert_micro::core::OMRuntimeShape::flatSize(), onert_micro::execute::pal::getUpLowerWeightTensorDepth(), onert_micro::core::FloatConv2D::stride_h, and onert_micro::core::FloatConv2D::stride_w.

◆ FullyConnectedInputGrad()

void onert_micro::train::pal::FullyConnectedInputGrad ( const float *  dloss_doutput_data,
const core::OMRuntimeShape dloss_doutput_shape,
const float *  weight_data,
const core::OMRuntimeShape weight_shape,
float *  dloss_dinput_data 
)
inline

Definition at line 33 of file PALFullyConnectedInputGrad.h.

38{
39 const uint32_t input_rows = dloss_doutput_shape.dims(0);
40 const uint32_t input_col = weight_shape.dims(1);
41 const uint32_t output_cols = dloss_doutput_shape.dims(1);
42
43 for (uint32_t i = 0; i < input_rows; ++i)
44 {
45 for (uint32_t j = 0; j < input_col; ++j)
46 {
47 float total = 0.f;
48 for (uint32_t o = 0; o < output_cols; ++o)
49 {
50 total += weight_data[o * input_col + j] * dloss_doutput_data[o + i * output_cols];
51 }
52 dloss_dinput_data[j + i * input_col] = total;
53 }
54 }
55}

References onert_micro::core::OMRuntimeShape::dims().

◆ FullyConnectedWeightGrad()

void onert_micro::train::pal::FullyConnectedWeightGrad ( const float *  dloss_doutput_data,
const core::OMRuntimeShape dloss_doutput_shape,
const float *  input_data,
const core::OMRuntimeShape input_shape,
float *  dloss_dweight_data,
const core::OMRuntimeShape weight_shape,
core::OpTrainableRankType  rank 
)
inline

Definition at line 34 of file PALFullyConnectedWeightGrad.h.

38{
39 const uint32_t batches = input_shape.dims(0);
40 const uint32_t output_depth = dloss_doutput_shape.dims(1);
41 const uint32_t accum_depth = input_shape.dims(1);
42
43 auto depth_bounds = execute::pal::getUpLowerWeightTensorDepth(rank, output_depth);
44
45 auto weight_depth = weight_shape.dims(0);
46
47 for (uint32_t o = 0; o < weight_depth; ++o)
48 {
49 float cur_dloss_doutput = dloss_doutput_data[o + depth_bounds.first];
50 for (uint32_t i = 0; i < accum_depth; ++i)
51 {
52 dloss_dweight_data[i + o * accum_depth] += cur_dloss_doutput * input_data[i];
53 }
54 }
55
56 for (int b = 1; b < batches; ++b)
57 {
58 for (uint32_t o = depth_bounds.first; o < depth_bounds.second; ++o)
59 {
60 float cur_dloss_doutput = dloss_doutput_data[o + b * output_depth];
61 for (uint32_t i = 0; i < accum_depth; ++i)
62 {
63 dloss_dweight_data[i + o * accum_depth] +=
64 cur_dloss_doutput * input_data[i + b * accum_depth];
65 }
66 }
67 }
68}

References onert_micro::core::OMRuntimeShape::dims(), and onert_micro::execute::pal::getUpLowerWeightTensorDepth().

◆ GRUWeightGrads()

OMStatus onert_micro::train::pal::GRUWeightGrads ( const float *  output_grad_data,
const float *  weight_input_data,
float *  weight_input_grad_data,
const float *  weight_hidden_data,
float *  weight_hidden_grad_data,
const float *  bias_input_data,
float *  bias_input_grad_data,
const float *  bias_hidden_data,
float *  bias_hidden_grad_data,
const float *  input_data,
float *  input_grad_data,
float *  state_grad_data,
const core::OMRuntimeShape input_shape,
const core::OMRuntimeShape output_shape,
const core::OMRuntimeShape weight_input_shape,
const core::OMRuntimeShape weight_hidden_shape,
const core::OMRuntimeShape output_shape_fc,
float *  intermediate_buffer,
float *  left_fc_output_grad_buffer,
float *  right_fc_output_grad_buffer 
)

Definition at line 130 of file PALGRUWeightGrad.h.

139{
140 const int32_t time = input_shape.dims(0);
141
142 // Init pointers to intermediate values
143 size_t offset = output_shape.flatSize();
144
145 size_t data_type_size = sizeof(float);
146 const int32_t num_of_intermediate_tensors = 9;
147 size_t time_offset = num_of_intermediate_tensors * offset;
148
149 core::OMRuntimeShape two_dim_input_shape(2);
150 auto dim_count = input_shape.dimensionsCount();
151 if (dim_count < 2)
152 return UnsupportedType;
153 two_dim_input_shape.setDim(0, input_shape.dims(dim_count - 2));
154 two_dim_input_shape.setDim(1, input_shape.dims(dim_count - 1));
155
156 core::OMRuntimeShape two_dim_output_shape(2);
157 dim_count = output_shape.dimensionsCount();
158 if (dim_count < 2)
159 return UnsupportedType;
160 two_dim_output_shape.setDim(0, output_shape.dims(dim_count - 2));
161 two_dim_output_shape.setDim(1, output_shape.dims(dim_count - 1));
162
163 std::memset(weight_input_grad_data, 0, weight_input_shape.flatSize() * sizeof(float));
164 std::memset(weight_hidden_grad_data, 0, weight_hidden_shape.flatSize() * sizeof(float));
165
166 for (int i = 0; i < time; ++i)
167 {
168 float *output_data = intermediate_buffer;
169 float *left_logistic_data = output_data + offset;
170 float *left_mul_data = left_logistic_data + offset;
171 float *right_logistic_data = left_mul_data + offset;
172 float *right_mul_left_input_data = right_logistic_data + offset;
173 float *right_mul_right_input_data = right_mul_left_input_data + offset;
174 float *tanh_data = right_mul_right_input_data + offset;
175 float *middle_mul_left_input = tanh_data + offset;
176 float *middle_mul_right_input = middle_mul_left_input + offset;
177
178 calculateGRUWeightGrads(
179 output_grad_data, weight_input_data, weight_input_grad_data, weight_hidden_data,
180 weight_hidden_grad_data, bias_input_data, bias_input_grad_data, bias_hidden_data,
181 bias_hidden_grad_data, input_data, input_grad_data, state_grad_data, two_dim_input_shape,
182 output_shape_fc, two_dim_output_shape, weight_input_shape, weight_hidden_shape, output_data,
183 left_logistic_data, left_mul_data, right_logistic_data, right_mul_left_input_data,
184 right_mul_right_input_data, tanh_data, middle_mul_left_input, middle_mul_right_input,
185 left_fc_output_grad_buffer, right_fc_output_grad_buffer);
186 input_data += input_shape.dims(2);
187 intermediate_buffer += time_offset;
188 }
189 return Ok;
190}
int32_t dimensionsCount() const
Definition Tensor.h:106
int32_t dims(int i) const
Definition Tensor.h:108
const luci_interpreter::RuntimeShape output_shape
@ UnsupportedType
Definition OMStatus.h:26

References luci_interpreter::RuntimeShape::dimensionsCount(), onert_micro::core::OMRuntimeShape::dimensionsCount(), luci_interpreter::RuntimeShape::dims(), onert_micro::core::OMRuntimeShape::dims(), luci_interpreter::RuntimeShape::flatSize(), onert_micro::core::OMRuntimeShape::flatSize(), offset(), onert_micro::Ok, output_shape, onert_micro::core::OMRuntimeShape::setDim(), and onert_micro::UnsupportedType.

◆ MaxPool2D()

void onert_micro::train::pal::MaxPool2D ( const core::Pool2DParams params,
const core::OMRuntimeShape input_shape,
const float *  input_data,
const core::OMRuntimeShape dloss_doutput_shape,
const float *  dloss_doutput_data,
float *  dloss_dinput_data 
)
inline

Definition at line 33 of file PALMaxPool2DInputGrad.h.

36{
37 const int32_t batches = input_shape.dims(0);
38 const int32_t depth = dloss_doutput_shape.dims(3);
39 const int32_t input_height = input_shape.dims(1);
40 const int32_t input_width = input_shape.dims(2);
41 const int32_t output_height = dloss_doutput_shape.dims(1);
42 const int32_t output_width = dloss_doutput_shape.dims(2);
43 const int32_t stride_height = params.stride_h;
44 const int32_t stride_width = params.stride_w;
45 for (int batch = 0; batch < batches; ++batch)
46 {
47 for (int out_y = 0; out_y < output_height; ++out_y)
48 {
49 for (int out_x = 0; out_x < output_width; ++out_x)
50 {
51 for (int channel = 0; channel < depth; ++channel)
52 {
53 const int in_x_origin = (out_x * stride_width) - params.pad_w;
54 const int in_y_origin = (out_y * stride_height) - params.pad_h;
55 // Compute the boundaries of the filter region clamped so as to
56 // ensure that the filter window fits in the input array.
57 const int filter_x_start = std::max(0, -in_x_origin);
58 const int filter_x_end = std::min(params.filter_w, input_width - in_x_origin);
59 const int filter_y_start = std::max(0, -in_y_origin);
60 const int filter_y_end = std::min(params.filter_h, input_height - in_y_origin);
61
62 const int output_data_offset =
63 ((batch * output_height + out_y) * output_width + out_x) * depth + channel;
64
65 float max = std::numeric_limits<float>::lowest();
66 int max_index = 0;
67
68 for (int filter_y = filter_y_start; filter_y < filter_y_end; ++filter_y)
69 {
70 for (int filter_x = filter_x_start; filter_x < filter_x_end; ++filter_x)
71 {
72 const int in_x = in_x_origin + filter_x;
73 const int in_y = in_y_origin + filter_y;
74
75 const int input_data_offset =
76 ((batch * input_height + in_y) * input_width + in_x) * depth + channel;
77
78 if (input_data[input_data_offset] > max)
79 {
80 max = input_data[input_data_offset];
81 max_index = input_data_offset;
82 }
83 }
84 }
85 dloss_dinput_data[max_index] = dloss_doutput_data[output_data_offset];
86 }
87 }
88 }
89 }
90}

References onert_micro::core::OMRuntimeShape::dims(), onert_micro::core::Pool2DParams::filter_h, onert_micro::core::Pool2DParams::filter_w, onert_micro::core::Pool2DParams::pad_h, onert_micro::core::Pool2DParams::pad_w, onert_micro::core::Pool2DParams::stride_h, and onert_micro::core::Pool2DParams::stride_w.

◆ ReluInputGrad()

void onert_micro::train::pal::ReluInputGrad ( const float *  input_relu_data,
float *  dloss_doutput_data,
const core::OMRuntimeShape dloss_doutput_shape 
)
inline

Definition at line 34 of file PALReluInputGrad.h.

36{
37 const uint32_t flat_size = dloss_doutput_shape.flatSize();
38
39 for (uint32_t i = 0; i < flat_size; ++i)
40 {
41 dloss_doutput_data[i] = input_relu_data[i] > 0 ? dloss_doutput_data[i] : 0.f;
42 }
43}

References onert_micro::core::OMRuntimeShape::flatSize().

◆ SoftmaxInputGrad()

void onert_micro::train::pal::SoftmaxInputGrad ( const float *  dloss_doutput_data,
const core::OMRuntimeShape dloss_doutput_shape,
const float *  calculated_data,
float *  jacobian_row_data,
float *  dloss_dinput_data 
)
inline

Definition at line 33 of file PALSoftmaxInputGrad.h.

37{
38 assert(dloss_doutput_shape.dimensionsCount() == 2);
39 assert(dloss_doutput_shape.dims(0) == 1);
40 const uint32_t width = dloss_doutput_shape.dims(dloss_doutput_shape.dimensionsCount() - 1);
41 for (int w1 = 0; w1 < width; ++w1)
42 {
43 float sum = 0.0f;
44 for (int w2 = 0; w2 < width; ++w2)
45 {
46 float val;
47 if (w1 == w2)
48 {
49 val = calculated_data[w2] * (1.f - calculated_data[w2]);
50 }
51 else
52 {
53 val = -calculated_data[w2] * calculated_data[w1];
54 }
55 val *= dloss_doutput_data[w2];
56 sum += val;
57 }
58 dloss_dinput_data[w1] = sum;
59 }
60}

References onert_micro::core::OMRuntimeShape::dimensionsCount(), and onert_micro::core::OMRuntimeShape::dims().