50 expsum_shape.
dim(axis) = 1;
59 std::int32_t axis_size = arg.getShape().dim(axis);
60 for (std::int32_t i = 0; i < axis_size; ++i)
62 arg_index.
at(axis) = i;
63 sum += std::exp(arg_accessor.
at(arg_index));
65 expsum_accessor.
at(expsum_index) = sum;
71 expsum_index.
at(axis) = 0;
72 res_accessor.
at(res_index) =
73 std::exp(arg_accessor.
at(res_index)) / expsum_accessor.
at(expsum_index);
85 const auto &input_type = input.getType();
86 const auto &output_type = output.getType();
88 assert(input_type.isQuantized());
89 assert(output_type.isQuantized());
91 const auto input_shape = input_type.getShape();
93 assert(input_type.getElementType() == mir::DataType::UINT8);
94 assert(axis == input_shape.rank() - 1);
97 double input_scale = input_type.getQuantization().getScale();
98 double output_scale = output_type.getQuantization().getScale();
100 const int trailing_dim = input_shape.rank() - 1;
101 int excluding_last_dim = 1;
102 for (int32_t i = 0; i < input_shape.rank() - 1; i++)
104 excluding_last_dim *= input_shape.dim(i);
106 const int last_dim = input_shape.dim(trailing_dim);
108 const int32_t clamp_max = std::numeric_limits<uint8_t>::max();
109 const int32_t clamp_min = std::numeric_limits<uint8_t>::min();
111 uint8_t *input_data =
reinterpret_cast<uint8_t *
>(input.atOffset(0));
114 PopulateSoftmaxLookupTable(table, input_scale, 1.f);
116 uint8_t *output_data =
reinterpret_cast<uint8_t *
>(output.atOffset(0));
118 for (
int i = 0; i < excluding_last_dim; ++i)
120 int32_t max_val = std::numeric_limits<uint8_t>::min();
122 for (
int j = 0; j < last_dim; ++j)
124 max_val = std::max(max_val,
static_cast<int32_t
>(input_data[j]));
127 float sum_exp = 0.0f;
128 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
129 const float *table_offset = &table[max_uint8 - max_val];
131 for (
int j = 0; j < last_dim; ++j)
133 sum_exp += table_offset[input_data[j]];
136 const float inv_sum_exp = 1.0f / (sum_exp * output_scale);
138 for (
int j = 0; j < last_dim; ++j)
140 const float prob_rescaled = table_offset[input_data[j]] * inv_sum_exp;
141 const int32_t prob_quantized =
static_cast<int32_t
>(prob_rescaled + 0.5);
143 static_cast<uint8_t
>(std::max(std::min(clamp_max, prob_quantized), clamp_min));
145 input_data += last_dim;
146 output_data += last_dim;