ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALMulCommon.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2021 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef ONERT_MICRO_EXECUTE_PAL_MUL_COMMON_H
19#define ONERT_MICRO_EXECUTE_PAL_MUL_COMMON_H
20
22
23namespace onert_micro
24{
25namespace execute
26{
27namespace pal
28{
29namespace
30{
31// Maximum dimension supported by the broadcast mul operation.
32constexpr int kMaxMulBroadcastDim = 6;
33} // namespace
34
35template <typename T>
36OMStatus Mul(const core::BinaryArithmeticBroadcastParams &params, const int flat_size,
37 const T *input1_data, const T *input2_data, T *output_data)
38{
39 ArithmeticOp<T, MulFn<T>>(params, flat_size, input1_data, input2_data, output_data);
40 return Ok;
41}
42
43template <typename T>
45 const core::OMRuntimeShape &input1_shape, const T *input1_data,
46 const core::OMRuntimeShape &input2_shape, const T *input2_data,
47 const core::OMRuntimeShape &output_shape, T *output_data)
48{
49 BroadcastArithmeticOp4DSlow<T, MulFn<T>>(params, input1_shape, input1_data, input2_shape,
50 input2_data, output_shape, output_data);
51 return Ok;
52}
53
54template <typename T>
56 const core::OMRuntimeShape &input1_shape, const T *input1_data,
57 const core::OMRuntimeShape &input2_shape, const T *input2_data,
58 const core::OMRuntimeShape &output_shape, T *output_data)
59{
62 // The input shapes are extended as part of NdArrayDesc initialization.
63 NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2);
64 const core::OMRuntimeShape extended_output_shape =
66 // Cache output shape dimensions.
67 int32_t extended_output_shape_dims[kMaxMulBroadcastDim];
68 std::memcpy(extended_output_shape_dims, extended_output_shape.dimsData(),
69 sizeof(extended_output_shape_dims));
70
71 size_t input1_offset_a = 0;
72 size_t input2_offset_a = 0;
73 size_t output_offset_a = 0;
74 for (int a = 0; a < extended_output_shape_dims[0]; ++a)
75 {
76 size_t input1_offset_d = input1_offset_a;
77 size_t input2_offset_d = input2_offset_a;
78 size_t output_offset_d = output_offset_a;
79 for (int d = 0; d < extended_output_shape_dims[1]; ++d)
80 {
81 size_t input1_offset_b = input1_offset_d;
82 size_t input2_offset_b = input2_offset_d;
83 size_t output_offset_b = output_offset_d;
84 for (int b = 0; b < extended_output_shape_dims[2]; ++b)
85 {
86 size_t input1_offset_y = input1_offset_b;
87 size_t input2_offset_y = input2_offset_b;
88 size_t output_offset_y = output_offset_b;
89 for (int y = 0; y < extended_output_shape_dims[3]; ++y)
90 {
91 size_t input1_offset_x = input1_offset_y;
92 size_t input2_offset_x = input2_offset_y;
93 size_t output_offset_x = output_offset_y;
94 for (int x = 0; x < extended_output_shape_dims[4]; ++x)
95 {
96 size_t input1_offset_c = input1_offset_x;
97 size_t input2_offset_c = input2_offset_x;
98 size_t output_offset_c = output_offset_x;
99 for (int c = 0; c < extended_output_shape_dims[5]; ++c)
100 {
101 const int32_t input1_val = params.input1_offset + input1_data[input1_offset_c];
102 const int32_t input2_val = params.input2_offset + input2_data[input2_offset_c];
103 const int32_t unclamped_result =
104 params.output_offset + multiplyByQuantizedMultiplier(input1_val * input2_val,
105 params.output_multiplier,
106 params.output_shift);
107 const int32_t clamped_output =
108 std::min(params.quantized_activation_max,
109 std::max(params.quantized_activation_min, unclamped_result));
110 output_data[output_offset_c] = static_cast<T>(clamped_output);
111 input1_offset_c += desc1.strides[5];
112 input2_offset_c += desc2.strides[5];
113 ++output_offset_c;
114 }
115 input1_offset_x += desc1.strides[4];
116 input2_offset_x += desc2.strides[4];
117 output_offset_x += extended_output_shape_dims[5];
118 }
119 input1_offset_y += desc1.strides[3];
120 input2_offset_y += desc2.strides[3];
121 output_offset_y += extended_output_shape_dims[4] * extended_output_shape_dims[5];
122 }
123 input1_offset_b += desc1.strides[2];
124 input2_offset_b += desc2.strides[2];
125 output_offset_b += extended_output_shape_dims[3] * extended_output_shape_dims[4] *
126 extended_output_shape_dims[5];
127 }
128 input1_offset_d += desc1.strides[1];
129 input2_offset_d += desc2.strides[1];
130 output_offset_d += extended_output_shape_dims[2] * extended_output_shape_dims[3] *
131 extended_output_shape_dims[4] * extended_output_shape_dims[5];
132 }
133 input1_offset_a += desc1.strides[0];
134 input2_offset_a += desc2.strides[0];
135 output_offset_a += extended_output_shape_dims[1] * extended_output_shape_dims[2] *
136 extended_output_shape_dims[3] * extended_output_shape_dims[4] *
137 extended_output_shape_dims[5];
138 }
139 return Ok;
140}
141
142} // namespace pal
143} // namespace execute
144} // namespace onert_micro
145
146#endif // ONERT_MICRO_EXECUTE_PAL_MUL_COMMON_H
static OMRuntimeShape extendedShape(int new_shape_size, const OMRuntimeShape &shape)
NdArrayDesc< 4 > desc1
const luci_interpreter::RuntimeShape output_shape
NdArrayDesc< 4 > desc2
OMStatus BroadcastMul6DSlow(const core::ArithmeticQuantParams &params, const core::OMRuntimeShape &input1_shape, const T *input1_data, const core::OMRuntimeShape &input2_shape, const T *input2_data, const core::OMRuntimeShape &output_shape, T *output_data)
void NdArrayDescsForElementwiseBroadcast(const core::OMRuntimeShape &input0_shape, const core::OMRuntimeShape &input1_shape, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
OMStatus Mul(const core::ArithmeticQuantParams &params, const uint32_t flat_size, const int8_t *input1_data, const int8_t *input2_data, int8_t *output_data)
Definition PALMul.h:33
OMStatus BroadcastMul4DSlow(const core::BinaryArithmeticBroadcastParams &params, const core::OMRuntimeShape &input1_shape, const T *input1_data, const core::OMRuntimeShape &input2_shape, const T *input2_data, const core::OMRuntimeShape &output_shape, T *output_data)
int32_t multiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
Definition PALUtils.h:104