59 const auto &lhs_type = lhs.
getType();
60 const auto &rhs_type = rhs.
getType();
61 const auto &res_type = res.
getType();
63 assert(lhs_type.isQuantized());
64 assert(rhs_type.isQuantized());
65 assert(res_type.isQuantized());
67 int32_t lhs_offset = -lhs_type.getQuantization().getZeroPoint();
68 int32_t rhs_offset = -rhs_type.getQuantization().getZeroPoint();
69 int32_t output_offset = res_type.getQuantization().getZeroPoint();
71 double lhs_scale = lhs_type.getQuantization().getScale();
72 double rhs_scale = rhs_type.getQuantization().getScale();
73 double output_scale = res_type.getQuantization().getScale();
76 const double twice_max_input_scale = 2 * std::max(lhs_scale, rhs_scale);
77 const double real_lhs_multiplier = lhs_scale / twice_max_input_scale;
78 const double real_rhs_multiplier = rhs_scale / twice_max_input_scale;
79 const double real_output_multiplier = twice_max_input_scale / ((1 << left_shift) * output_scale);
81 int32_t lhs_multiplier = 0;
82 int32_t rhs_multiplier = 0;
83 int32_t output_multiplier = 0;
99 int32_t output_min = std::numeric_limits<uint8_t>::min();
100 int32_t output_max = std::numeric_limits<uint8_t>::max();
102 for (
const auto &index :
ShapeRange(res_type.getShape()))
104 const int32_t lhs_val = lhs_accessor.
at(index) + lhs_offset;
105 const int32_t rhs_val = rhs_accessor.
at(index) + rhs_offset;
106 const int32_t shifted_lhs_val = lhs_val * (1 << left_shift);
107 const int32_t shifted_rhs_val = rhs_val * (1 << left_shift);
108 const int32_t scaled_lhs_val =
110 const int32_t scaled_rhs_val =
112 const int32_t raw_sum = scaled_lhs_val + scaled_rhs_val;
113 const int32_t raw_output =
116 const int32_t clamped_output = std::min(output_max, std::max(output_min, raw_output));
117 res_accessor.
at(index) =
static_cast<uint8_t
>(clamped_output);