84{
85 const auto &input_type =
input.getType();
87
88 assert(input_type.isQuantized());
90
91 const auto input_shape = input_type.getShape();
92
93 assert(input_type.getElementType() == mir::DataType::UINT8);
94 assert(axis == input_shape.rank() - 1);
95 (void)axis;
96
97 double input_scale = input_type.getQuantization().getScale();
98 double output_scale =
output_type.getQuantization().getScale();
99
100 const int trailing_dim = input_shape.rank() - 1;
101 int excluding_last_dim = 1;
102 for (int32_t i = 0; i < input_shape.rank() - 1; i++)
103 {
104 excluding_last_dim *= input_shape.dim(i);
105 }
106 const int last_dim = input_shape.dim(trailing_dim);
107
108 const int32_t clamp_max = std::numeric_limits<uint8_t>::max();
109 const int32_t clamp_min = std::numeric_limits<uint8_t>::min();
110
112
113 float table[256];
115
117
118 for (int i = 0; i < excluding_last_dim; ++i)
119 {
120 int32_t max_val = std::numeric_limits<uint8_t>::min();
121
122 for (int j = 0; j < last_dim; ++j)
123 {
124 max_val = std::max(max_val, static_cast<int32_t>(input_data[j]));
125 }
126
127 float sum_exp = 0.0f;
128 const int32_t max_uint8 = std::numeric_limits<uint8_t>::max();
129 const float *table_offset = &table[max_uint8 - max_val];
130
131 for (int j = 0; j < last_dim; ++j)
132 {
134 }
135
136 const float inv_sum_exp = 1.0f / (sum_exp * output_scale);
137
138 for (int j = 0; j < last_dim; ++j)
139 {
140 const float prob_rescaled = table_offset[
input_data[j]] * inv_sum_exp;
141 const int32_t prob_quantized = static_cast<int32_t>(prob_rescaled + 0.5);
143 static_cast<uint8_t>(std::max(std::min(clamp_max, prob_quantized), clamp_min));
144 }
147 }
148}
void PopulateSoftmaxLookupTable(float *table, float input_scale, float beta)