41 if (tensor.element_type() == DataType::U8)
43 std::vector<uint8_t>
data = extractTensorData<uint8_t>(tensor);
46 if (tensor.element_type() == DataType::S8)
48 std::vector<int8_t>
data = extractTensorData<int8_t>(tensor);
51 else if (tensor.element_type() == DataType::S16)
54 for (
auto zp : tensor.zero_points())
60 std::vector<int16_t>
data = extractTensorData<int16_t>(tensor);
61 if (tensor.scales().size() == 1)
69 const Shape shape = tensor.shape();
70 const int32_t quantized_dimension = tensor.quantized_dimension();
71 assert(quantized_dimension < shape.
num_dims());
72 size_t outer_dims_size = 1;
73 int32_t quant_dim_size = shape.
dim(quantized_dimension);
74 size_t inner_dims_size = 1;
75 assert(quant_dim_size == tensor.scales().size());
77 for (
int i = 0; i < quantized_dimension; ++i)
78 outer_dims_size *= shape.
dim(i);
79 for (
int i = quantized_dimension + 1; i < shape.
num_dims(); ++i)
80 inner_dims_size *= shape.
dim(i);
82 assert(shape.
num_elements() == outer_dims_size * quant_dim_size * inner_dims_size);
84 std::vector<float> dequantized_data;
86 for (
size_t outer_it = 0; outer_it < outer_dims_size; ++outer_it)
87 for (int32_t channel = 0; channel < quant_dim_size; ++channel)
89 float scale = tensor.scales()[channel];
90 size_t offset = inner_dims_size * (quant_dim_size * outer_it + channel);
91 std::vector<float> part_dequantized_data =
93 dequantized_data.insert(dequantized_data.end(), part_dequantized_data.begin(),
94 part_dequantized_data.end());
96 return dequantized_data;
100 throw std::runtime_error(
"Unsupported type.");
104Matcher<std::vector<float>>
FloatArrayNear(
const std::vector<float> &values,
float max_abs_error)
106 std::vector<Matcher<float>> matchers;
107 matchers.reserve(values.size());
108 for (
const float v : values)
110 matchers.emplace_back(FloatNear(v, max_abs_error));
112 return ElementsAreArray(matchers);