95{
96 const size_t inputs_count =
inputs.size();
97 std::vector<int32_t> input_zeropoints(inputs_count);
98 std::vector<float> input_scales(inputs_count);
101 int64_t concat_size = 0;
102 for (size_t i = 0; i < inputs_count; i++)
103 {
104 const auto &input_type =
inputs[i].get().getType();
105 assert(input_type.isQuantized());
106 assert(input_type.getElementType() == mir::DataType::UINT8);
107 const auto &input_shape = input_type.getShape();
108 assert(input_shape.rank() == concat_dimensions);
109
110 for (int32_t j = 0; j < concat_dimensions; j++)
111 if (j != axis)
113
114 concat_size += input_shape.dim(axis);
115 input_zeropoints[i] = input_type.getQuantization().getZeroPoint();
116 input_scales[i] = input_type.getQuantization().getScale();
117 }
119
122 int32_t output_zeropoint =
output_type.getQuantization().getZeroPoint();
123 float output_scale =
output_type.getQuantization().getScale();
124
125
126 int32_t outer_size = 1;
127 for (int32_t i = 0; i < axis; i++)
129
130 int32_t base_inner_size = 1;
131 for (int32_t i = axis + 1; i < concat_dimensions; i++)
133
134
135 uint8_t *output_ptr =
reinterpret_cast<uint8_t *
>(
output.atOffset(0));
136
137 const float inverse_output_scale = 1.f / output_scale;
138 for (int k = 0; k < outer_size; k++)
139 {
140 for (size_t i = 0; i < inputs_count; ++i)
141 {
143 const int copy_size =
input.getShape().dim(axis) * base_inner_size;
145 const uint8_t *input_ptr =
reinterpret_cast<const uint8_t *
>(
input_data);
146 if (input_zeropoints[i] == output_zeropoint && input_scales[i] == output_scale)
147 {
148 std::memcpy(output_ptr, input_ptr, copy_size);
149 }
150 else
151 {
152 const float scale = input_scales[i] * inverse_output_scale;
153 const float bias = -input_zeropoints[i] *
scale;
154 for (int j = 0; j < copy_size; ++j)
155 {
156 const int32_t value =
157 static_cast<int32_t>(std::round(input_ptr[j] * scale + bias)) + output_zeropoint;
158 output_ptr[j] = static_cast<uint8_t>(std::max(std::min(255, value), 0));
159 }
160 }
161 output_ptr += copy_size;
162 }
163 }
164}
const luci_interpreter::RuntimeShape output_shape