ONE - On-device Neural Engine
Loading...
Searching...
No Matches
mir_interpreter Namespace Reference

Data Structures

struct  AbsImpl
 
struct  AbsImpl< uint8_t >
 
struct  AddImpl
 
struct  AddImpl< uint8_t >
 
class  AvgPool2DImpl
 
struct  AvgPool2DImpl< uint8_t >
 
struct  CappedReLUImpl
 
struct  CappedReLUImpl< uint8_t >
 
struct  ConcatImpl
 
struct  ConcatImpl< uint8_t >
 
struct  Conv2DImpl
 
struct  Conv2DImpl< uint8_t >
 
struct  DeConv2DImpl
 
struct  DepthwiseConv2DImpl
 
struct  DepthwiseConv2DImpl< uint8_t >
 
struct  DivImpl
 
struct  DivImpl< uint8_t >
 
struct  ELUImpl
 
struct  EqualImpl
 
struct  FillImpl
 
struct  FullyConnectedImpl
 
struct  FullyConnectedImpl< uint8_t >
 
struct  GatherByT
 
struct  GatherImpl
 
struct  GreaterImpl
 
struct  HardSwishImpl
 
struct  HardSwishImpl< uint8_t >
 
struct  LeakyReLUImpl
 
struct  LessImpl
 
struct  MaxImpl
 
struct  MaxImpl< uint8_t >
 
struct  MaxPool2DImpl
 
struct  MaxPool2DImpl< uint8_t >
 
class  MIRInterpreter
 
struct  MulImpl
 
struct  MulImpl< uint8_t >
 
struct  PadImpl
 
struct  ReduceMeanImpl
 
struct  ReLUImpl
 
struct  ReLUImpl< uint8_t >
 
struct  SigmoidImpl
 
struct  SigmoidImpl< uint8_t >
 
struct  SliceImpl
 
struct  SoftmaxImpl
 
struct  SoftmaxImpl< uint8_t >
 
struct  SqrtImpl
 
struct  SqrtImpl< uint8_t >
 
struct  SubImpl
 
struct  SubImpl< uint8_t >
 
struct  TanhImpl
 
struct  TanhImpl< uint8_t >
 
struct  TransposeImpl
 

Functions

void Abs (const mir::TensorVariant &arg, mir::TensorVariant &result)
 
void Add (const TensorVariant &lhs, const TensorVariant &rhs, TensorVariant &res)
 
void AvgPool2D (const mir::ops::AvgPool2DOp &op, const mir::TensorVariant &input, mir::TensorVariant &output)
 
void CappedReLU (const mir::TensorVariant &arg, float cap, mir::TensorVariant &result)
 
Index shift (const Index &in_index, const Shape &shift_from)
 
template<template< typename > class F, typename... Args>
void dispatch (mir::DataType dt, Args &&...args)
 
template<typename T >
void erase (mir::TensorVariant &tv)
 
void Concat (const std::vector< std::reference_wrapper< const mir::TensorVariant > > &inputs, int axis, mir::TensorVariant &output)
 
void Conv2D (const mir::TensorVariant &input, const mir::TensorVariant &kernel, const mir::Conv2DOpAttributes &attributes, mir::TensorVariant &result, const mir::TensorVariant *fused_bias)
 
void DeConv2D (const mir::TensorVariant &input, const mir::TensorVariant &kernel, const mir::Deconv2DOpAttributes &attributes, mir::TensorVariant &output)
 Transposed convolution (or Deconvolution)
 
void DepthwiseConv2D (const mir::ops::DepthwiseConv2DOp &op, const mir::TensorVariant &input, const mir::TensorVariant &kernel, mir::TensorVariant &output, const mir::TensorVariant *bias)
 
void Div (const TensorVariant &lhs, const TensorVariant &rhs, TensorVariant &res)
 
void ELU (const mir::TensorVariant &arg, float alpha, mir::TensorVariant &result)
 
void Equal (const TensorVariant &lhs, const TensorVariant &rhs, TensorVariant &res)
 
template<typename F >
void Fill (mir::TensorVariant &t, F f)
 
void FullyConnected (const mir::TensorVariant &input, const mir::TensorVariant &weights, const mir::ops::FullyConnectedOp &op, mir::TensorVariant &res, const mir::TensorVariant *bias)
 
void Gather (const TensorVariant &data, const TensorVariant &indices, const ops::GatherOp &op, TensorVariant &res)
 
void Greater (const TensorVariant &lhs, const TensorVariant &rhs, TensorVariant &res)
 
void HardSwish (const mir::TensorVariant &input, mir::TensorVariant &result)
 
void LeakyReLU (const mir::TensorVariant &arg, float alpha, mir::TensorVariant &result)
 
void Less (const TensorVariant &lhs, const TensorVariant &rhs, TensorVariant &res)
 
void Max (const TensorVariant &lhs, const TensorVariant &rhs, TensorVariant &res)
 
void MaxPool2D (const mir::TensorVariant &input, const mir::ops::MaxPool2DOp &op, mir::TensorVariant &result)
 
void Mul (const TensorVariant &lhs, const TensorVariant &rhs, TensorVariant &res)
 
void Pad (const mir::TensorVariant &input, const mir::ops::PadOp &op, mir::TensorVariant &result)
 Implements PadOp for interpreter backend.
 
void Dequantize (const TensorVariant &input, TensorVariant &output)
 
void Quantize (const TensorVariant &input, TensorVariant &output)
 
void QuantizeMultiplier (double double_multiplier, int32_t *quantized_multiplier, int *shift)
 
void QuantizeMultiplierSmallerThanOneExp (double double_multiplier, int32_t *quantized_multiplier, int *left_shift)
 
int32_t MaskIfNonZero (int32_t a)
 
int32_t MaskIfZero (int32_t a)
 
int32_t MaskIfLessThan (int32_t a, int32_t b)
 
int32_t MaskIfGreaterThan (int32_t a, int32_t b)
 
int32_t RoundingDivideByPOT (int32_t x, int exponent)
 
std::int32_t SaturatingRoundingDoublingHighMul (std::int32_t a, std::int32_t b)
 
int32_t MultiplyByQuantizedMultiplier (int32_t x, int32_t quantized_multiplier, int shift)
 
int32_t MultiplyByQuantizedMultiplierSmallerThanOneExp (int32_t x, int32_t quantized_multiplier, int left_shift)
 
void ReduceMean (const mir::TensorVariant &input, const mir::ops::ReduceMeanOp &op, mir::TensorVariant &output)
 
void ReLU (const mir::TensorVariant &arg, mir::TensorVariant &result)
 
void Reshape (const mir::TensorVariant &input, mir::TensorVariant &output)
 
void Sigmoid (const mir::TensorVariant &arg, mir::TensorVariant &result)
 
void Slice (const mir::TensorVariant &arg, const mir::Shape &starts, mir::TensorVariant &res)
 
void Softmax (const mir::TensorVariant &arg, int axis, mir::TensorVariant &result)
 
void Sqrt (const mir::TensorVariant &arg, mir::TensorVariant &result)
 
void Sub (const TensorVariant &lhs, const TensorVariant &rhs, TensorVariant &res)
 
void Tanh (const mir::TensorVariant &arg, mir::TensorVariant &result)
 
void Transpose (const mir::TensorVariant &input, const mir::ops::TransposeOp &op, mir::TensorVariant &output)
 

Function Documentation

◆ Abs()

void mir_interpreter::Abs ( const mir::TensorVariant arg,
mir::TensorVariant result 
)

Definition at line 50 of file Abs.cpp.

51{
52 dispatch<AbsImpl>(arg.getElementType(), arg, result);
53};

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Add()

void mir_interpreter::Add ( const TensorVariant lhs,
const TensorVariant rhs,
TensorVariant res 
)

Definition at line 121 of file Add.cpp.

122{
123 if (lhs.getElementType() != rhs.getElementType())
124 {
125 throw std::runtime_error{"Add with different input types is unsupported"};
126 }
127 dispatch<AddImpl>(res.getElementType(), lhs, rhs, res);
128}
DataType getElementType() const

References mir::TensorVariant::getElementType().

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ AvgPool2D()

void mir_interpreter::AvgPool2D ( const mir::ops::AvgPool2DOp op,
const mir::TensorVariant input,
mir::TensorVariant output 
)

Definition at line 167 of file AvgPool2D.cpp.

169{
170 dispatch<AvgPool2DImpl>(output.getElementType(), op, input, output);
171}

◆ CappedReLU()

void mir_interpreter::CappedReLU ( const mir::TensorVariant arg,
float  cap,
mir::TensorVariant result 
)

Definition at line 77 of file CappedReLU.cpp.

78{
79 dispatch<CappedReLUImpl>(arg.getElementType(), arg, cap, result);
80}

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Concat()

void mir_interpreter::Concat ( const std::vector< std::reference_wrapper< const mir::TensorVariant > > &  inputs,
int  axis,
mir::TensorVariant output 
)

Definition at line 166 of file Concat.cpp.

168{
169 dispatch<ConcatImpl>(inputs[0].get().getElementType(), inputs, axis, output);
170}

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Conv2D()

void mir_interpreter::Conv2D ( const mir::TensorVariant input,
const mir::TensorVariant kernel,
const mir::Conv2DOpAttributes attributes,
mir::TensorVariant result,
const mir::TensorVariant fused_bias 
)

Definition at line 254 of file Conv2D.cpp.

257{
258 dispatch<Conv2DImpl>(result.getElementType(), input, kernel, attributes, result, fused_bias);
259}

References Conv2D().

Referenced by Conv2D(), and mir_interpreter::MIRInterpreter::visit().

◆ DeConv2D()

void mir_interpreter::DeConv2D ( const mir::TensorVariant input,
const mir::TensorVariant kernel,
const mir::Deconv2DOpAttributes attributes,
mir::TensorVariant output 
)

Transposed convolution (or Deconvolution)

Parameters
inputThe Input tensor
opThe DeConvolution operation object

This is basically the backward pass for the convolution operation, hence all the indexing can be deducted by expressing the input index of Conv in terms of it's output index.

Definition at line 116 of file DeConv2D.cpp.

118{
119 dispatch<DeConv2DImpl>(output.getElementType(), input, kernel, attributes, output);
120}

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ DepthwiseConv2D()

void mir_interpreter::DepthwiseConv2D ( const mir::ops::DepthwiseConv2DOp op,
const mir::TensorVariant input,
const mir::TensorVariant kernel,
mir::TensorVariant output,
const mir::TensorVariant bias 
)

Definition at line 218 of file DepthwiseConv2D.cpp.

221{
222 dispatch<DepthwiseConv2DImpl>(output.getElementType(), op, input, kernel, bias, output);
223}

◆ Dequantize()

void mir_interpreter::Dequantize ( const TensorVariant input,
TensorVariant output 
)

Definition at line 28 of file Quantization.cpp.

29{
30 const TensorType &input_type = input.getType();
31 assert(input_type.isQuantized());
32 assert(input_type.getElementType() == DataType::UINT8);
33
34 const float scale = input_type.getQuantization().getScale();
35 const int32_t zero_point = input_type.getQuantization().getZeroPoint();
36
37 Tensor<uint8_t> input_accessor(input);
38 Tensor<float> res_accessor(output);
39
40 for (const auto &index : ShapeRange(output.getShape()))
41 {
42 const int32_t value = input_accessor.at(index);
43 res_accessor.at(index) = scale * static_cast<float>(value - zero_point);
44 }
45}
const AffineQuantization & getQuantization() const
Definition TensorType.h:45
bool isQuantized() const
Definition TensorType.h:47
DataType getElementType() const
Definition TensorType.h:41

References mir::Tensor< T >::at(), mir::TensorType::getElementType(), mir::TensorType::getQuantization(), mir::AffineQuantization::getScale(), mir::AffineQuantization::getZeroPoint(), and mir::TensorType::isQuantized().

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ dispatch()

template<template< typename > class F, typename... Args>
void mir_interpreter::dispatch ( mir::DataType  dt,
Args &&...  args 
)

Definition at line 30 of file Common.h.

31{
32 switch (dt)
33 {
34 case mir::DataType::FLOAT32:
35 return F<float>::run(std::forward<Args>(args)...);
36 case mir::DataType::FLOAT64:
37 return F<double>::run(std::forward<Args>(args)...);
38 case mir::DataType::INT32:
39 return F<int32_t>::run(std::forward<Args>(args)...);
40 case mir::DataType::INT64:
41 return F<int64_t>::run(std::forward<Args>(args)...);
42 case mir::DataType::UINT8:
43 return F<uint8_t>::run(std::forward<Args>(args)...);
44 case mir::DataType::UNKNOWN:
45 throw std::runtime_error{"Unknown datatype met during operation execution"};
46 default:
47 throw std::runtime_error{"mir::DataType enum mismatch"};
48 }
49}

◆ Div()

void mir_interpreter::Div ( const TensorVariant lhs,
const TensorVariant rhs,
TensorVariant res 
)

Definition at line 57 of file Div.cpp.

58{
59 dispatch<DivImpl>(res.getElementType(), lhs, rhs, res);
60}

References mir::TensorVariant::getElementType().

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ ELU()

void mir_interpreter::ELU ( const mir::TensorVariant arg,
float  alpha,
mir::TensorVariant result 
)

Definition at line 46 of file ELU.cpp.

47{
48 dispatch<ELUImpl>(result.getElementType(), arg, alpha, result);
49}

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Equal()

void mir_interpreter::Equal ( const TensorVariant lhs,
const TensorVariant rhs,
TensorVariant res 
)

Definition at line 48 of file Equal.cpp.

49{
50 if (lhs.getElementType() != rhs.getElementType())
51 {
52 throw std::runtime_error{"Equal with different input types is unsupported"};
53 }
54
55 dispatch<EqualImpl>(lhs.getElementType(), lhs, rhs, res);
56}

References mir::TensorVariant::getElementType().

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ erase()

template<typename T >
void mir_interpreter::erase ( mir::TensorVariant tv)

Definition at line 51 of file Common.h.

52{
53 size_t element_count = tv.getShape().numElements();
54 for (size_t i = 0; i < element_count; ++i)
55 {
56 auto ptr = tv.atOffset(i);
57 *reinterpret_cast<T *>(ptr) = 0;
58 }
59}
int32_t numElements() const
Definition Shape.cpp:30
char * atOffset(int32_t offset) const
const Shape & getShape() const

References mir::TensorVariant::atOffset(), mir::TensorVariant::getShape(), and mir::Shape::numElements().

◆ Fill()

template<typename F >
void mir_interpreter::Fill ( mir::TensorVariant t,
f 
)

Definition at line 41 of file Fill.h.

42{
43 dispatch<FillImpl>(t.getElementType(), t, f);
44}

References mir::TensorVariant::getElementType().

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ FullyConnected()

void mir_interpreter::FullyConnected ( const mir::TensorVariant input,
const mir::TensorVariant weights,
const mir::ops::FullyConnectedOp op,
mir::TensorVariant res,
const mir::TensorVariant bias 
)

Definition at line 208 of file FullyConnected.cpp.

211{
212 dispatch<FullyConnectedImpl>(res.getElementType(), input, weights, op, res, bias);
213}

References mir::TensorVariant::getElementType().

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Gather()

void mir_interpreter::Gather ( const TensorVariant data,
const TensorVariant indices,
const ops::GatherOp op,
TensorVariant res 
)

Definition at line 86 of file Gather.cpp.

88{
89 dispatch<GatherByT>(data.getElementType(), data, indices, op, res);
90}

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Greater()

void mir_interpreter::Greater ( const TensorVariant lhs,
const TensorVariant rhs,
TensorVariant res 
)

Definition at line 48 of file Greater.cpp.

49{
50 if (lhs.getElementType() != rhs.getElementType())
51 {
52 throw std::runtime_error{"Greater with different input types is unsupported"};
53 }
54 dispatch<GreaterImpl>(lhs.getElementType(), lhs, rhs, res);
55}

References mir::TensorVariant::getElementType().

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ HardSwish()

void mir_interpreter::HardSwish ( const mir::TensorVariant input,
mir::TensorVariant result 
)

Definition at line 49 of file HardSwish.cpp.

50{
51 dispatch<HardSwishImpl>(input.getElementType(), input, result);
52}

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ LeakyReLU()

void mir_interpreter::LeakyReLU ( const mir::TensorVariant arg,
float  alpha,
mir::TensorVariant result 
)

Definition at line 44 of file LeakyReLU.cpp.

45{
46 dispatch<LeakyReLUImpl>(result.getElementType(), arg, alpha, result);
47}

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Less()

void mir_interpreter::Less ( const TensorVariant lhs,
const TensorVariant rhs,
TensorVariant res 
)

Definition at line 48 of file Less.cpp.

49{
50 if (lhs.getElementType() != rhs.getElementType())
51 {
52 throw std::runtime_error{"Less with different input types is unsupported"};
53 }
54 dispatch<LessImpl>(lhs.getElementType(), lhs, rhs, res);
55}

References mir::TensorVariant::getElementType().

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ MaskIfGreaterThan()

int32_t mir_interpreter::MaskIfGreaterThan ( int32_t  a,
int32_t  b 
)
inline

Definition at line 85 of file QuantizationHelpers.h.

85{ return MaskIfNonZero(a > b); }
int32_t MaskIfNonZero(int32_t a)

References MaskIfNonZero().

Referenced by RoundingDivideByPOT().

◆ MaskIfLessThan()

int32_t mir_interpreter::MaskIfLessThan ( int32_t  a,
int32_t  b 
)
inline

Definition at line 83 of file QuantizationHelpers.h.

83{ return MaskIfNonZero(a < b); }

References MaskIfNonZero().

Referenced by RoundingDivideByPOT().

◆ MaskIfNonZero()

int32_t mir_interpreter::MaskIfNonZero ( int32_t  a)
inline

Definition at line 75 of file QuantizationHelpers.h.

76{
77 static const int32_t zero = 0;
78 return a ? ~zero : zero;
79}

Referenced by MaskIfGreaterThan(), MaskIfLessThan(), and MaskIfZero().

◆ MaskIfZero()

int32_t mir_interpreter::MaskIfZero ( int32_t  a)
inline

Definition at line 81 of file QuantizationHelpers.h.

81{ return MaskIfNonZero(!a); }

References MaskIfNonZero().

◆ Max()

void mir_interpreter::Max ( const TensorVariant lhs,
const TensorVariant rhs,
TensorVariant res 
)

Definition at line 57 of file Max.cpp.

58{
59 if (lhs.getElementType() != rhs.getElementType())
60 {
61 throw std::runtime_error{"Max with different input types is unsupported"};
62 }
63 dispatch<MaxImpl>(lhs.getElementType(), lhs, rhs, res);
64}

References mir::TensorVariant::getElementType().

◆ MaxPool2D()

void mir_interpreter::MaxPool2D ( const mir::TensorVariant input,
const mir::ops::MaxPool2DOp op,
mir::TensorVariant result 
)

Definition at line 151 of file MaxPool2D.cpp.

153{
154 dispatch<MaxPool2DImpl>(input.getElementType(), input, op, result);
155};

◆ Mul()

void mir_interpreter::Mul ( const TensorVariant lhs,
const TensorVariant rhs,
TensorVariant res 
)

Definition at line 56 of file Mul.cpp.

57{
58 dispatch<MulImpl>(lhs.getElementType(), lhs, rhs, res);
59};

References mir::TensorVariant::getElementType().

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ MultiplyByQuantizedMultiplier()

int32_t mir_interpreter::MultiplyByQuantizedMultiplier ( int32_t  x,
int32_t  quantized_multiplier,
int  shift 
)
inline

Definition at line 108 of file QuantizationHelpers.h.

109{
110 int left_shift = shift > 0 ? shift : 0;
111 int right_shift = shift > 0 ? 0 : -shift;
112 return RoundingDivideByPOT(
113 SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier), right_shift);
114}
Index shift(const Index &in_index, const Shape &shift_from)
Definition Common.cpp:26
std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b)
int32_t RoundingDivideByPOT(int32_t x, int exponent)

References RoundingDivideByPOT(), SaturatingRoundingDoublingHighMul(), and shift().

Referenced by mir_interpreter::DepthwiseConv2DImpl< uint8_t >::run(), mir_interpreter::FullyConnectedImpl< uint8_t >::run(), and mir_interpreter::Conv2DImpl< uint8_t >::run().

◆ MultiplyByQuantizedMultiplierSmallerThanOneExp()

int32_t mir_interpreter::MultiplyByQuantizedMultiplierSmallerThanOneExp ( int32_t  x,
int32_t  quantized_multiplier,
int  left_shift 
)
inline

Definition at line 116 of file QuantizationHelpers.h.

119{
120 return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(x, quantized_multiplier),
121 -left_shift);
122}

References RoundingDivideByPOT(), and SaturatingRoundingDoublingHighMul().

Referenced by mir_interpreter::AddImpl< uint8_t >::run().

◆ Pad()

void mir_interpreter::Pad ( const mir::TensorVariant input,
const mir::ops::PadOp op,
mir::TensorVariant result 
)

Implements PadOp for interpreter backend.

This operation pads a tensor according to the paddings you specify. For each dimension of input add values before and after of contents.

Definition at line 79 of file Pad.cpp.

80{
81 dispatch<PadImpl>(input.getElementType(), input, op, result);
82};

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Quantize()

void mir_interpreter::Quantize ( const TensorVariant input,
TensorVariant output 
)

Definition at line 47 of file Quantization.cpp.

48{
49 const TensorType &output_type = output.getType();
50 assert(output_type.isQuantized());
51 assert(input.getElementType() == DataType::FLOAT32);
52
53 const float scale = output_type.getQuantization().getScale();
54 const int32_t zero_point = output_type.getQuantization().getZeroPoint();
55
56 const int32_t min_val = std::numeric_limits<uint8_t>::min();
57 const int32_t max_val = std::numeric_limits<uint8_t>::max();
58
59 Tensor<float> input_accessor(input);
60 Tensor<uint8_t> res_accessor(output);
61
62 for (const auto &index : ShapeRange(output.getShape()))
63 {
64 const float value = input_accessor.at(index);
65 int32_t unclamped = static_cast<int32_t>(std::round(value / scale)) + zero_point;
66 int32_t clamped = std::min(std::max(unclamped, min_val), max_val);
67 res_accessor.at(index) = static_cast<uint8_t>(clamped);
68 }
69}

References mir::Tensor< T >::at().

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ QuantizeMultiplier()

void mir_interpreter::QuantizeMultiplier ( double  double_multiplier,
int32_t *  quantized_multiplier,
int *  shift 
)
inline

Definition at line 27 of file QuantizationHelpers.h.

28{
29 if (double_multiplier == 0.)
30 {
31 *quantized_multiplier = 0;
32 *shift = 0;
33 return;
34 }
35
36 const double q = std::frexp(double_multiplier, shift);
37 auto q_fixed = static_cast<int64_t>(round(q * (1ll << 31)));
38
39 assert(q_fixed <= (1ll << 31));
40 if (q_fixed == (1ll << 31))
41 {
42 q_fixed /= 2;
43 ++*shift;
44 }
45 assert(q_fixed <= std::numeric_limits<int32_t>::max());
46 // A shift amount smaller than -31 would cause all bits to be shifted out
47 // and thus all results would be zero. We implement that instead with
48 // q_fixed==0, so as to avoid hitting issues with right-shift
49 // operations with shift amounts greater than 31. Note that this happens
50 // roughly when abs(double_multiplier) < 2^-31 and the present handling means
51 // that we're effectively flushing tiny double_multiplier's to zero.
52 // We could conceivably handle values in the range (roughly) [32, 63]
53 // as 'denormals' i.e. (shift==0, q_fixed < 2^30). In that point of view
54 // the present handling is just doing 'flush denormals to zero'. We could
55 // reconsider and actually generate nonzero denormals if a need arises.
56 if (*shift < -31)
57 {
58 *shift = 0;
59 q_fixed = 0;
60 }
61 *quantized_multiplier = static_cast<int32_t>(q_fixed);
62}

References shift().

Referenced by QuantizeMultiplierSmallerThanOneExp(), mir_interpreter::DepthwiseConv2DImpl< uint8_t >::run(), mir_interpreter::FullyConnectedImpl< uint8_t >::run(), and mir_interpreter::Conv2DImpl< uint8_t >::run().

◆ QuantizeMultiplierSmallerThanOneExp()

void mir_interpreter::QuantizeMultiplierSmallerThanOneExp ( double  double_multiplier,
int32_t *  quantized_multiplier,
int *  left_shift 
)
inline

Definition at line 64 of file QuantizationHelpers.h.

66{
67 assert(double_multiplier < 1.0);
68 assert(double_multiplier > 0.0);
69 int shift;
70 QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift);
71 assert(shift <= 0);
72 *left_shift = shift;
73}

References QuantizeMultiplier(), and shift().

Referenced by mir_interpreter::AddImpl< uint8_t >::run().

◆ ReduceMean()

void mir_interpreter::ReduceMean ( const mir::TensorVariant input,
const mir::ops::ReduceMeanOp op,
mir::TensorVariant output 
)

Definition at line 90 of file ReduceMean.cpp.

92{
93 dispatch<ReduceMeanImpl>(input.getElementType(), input, op, output);
94};

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ ReLU()

void mir_interpreter::ReLU ( const mir::TensorVariant arg,
mir::TensorVariant result 
)

Definition at line 53 of file ReLU.cpp.

54{
55 dispatch<ReLUImpl>(arg.getElementType(), arg, result);
56};

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Reshape()

void mir_interpreter::Reshape ( const mir::TensorVariant input,
mir::TensorVariant output 
)

Definition at line 26 of file Reshape.cpp.

27{
28 assert(input.getShape().numElements() == output.getShape().numElements());
29
30 mir::ShapeRange input_range(input.getShape());
31 auto in_iter = input_range.begin();
32 const size_t elem_size = input.getElementSize();
33
34 for (const auto &out_index : mir::ShapeRange(output.getShape()))
35 std::memcpy(output.at(out_index), input.at(*in_iter++), elem_size);
36}

References mir::ShapeRange::begin().

Referenced by mir_interpreter::MIRInterpreter::visit(), and mir_interpreter::MIRInterpreter::visit().

◆ RoundingDivideByPOT()

int32_t mir_interpreter::RoundingDivideByPOT ( int32_t  x,
int  exponent 
)
inline

Definition at line 87 of file QuantizationHelpers.h.

88{
89 assert(exponent >= 0);
90 assert(exponent <= 31);
91 const int32_t mask = (1ll << exponent) - 1;
92 const int32_t remainder = x & mask;
93 const int32_t threshold = (mask >> 1) + (MaskIfLessThan(x, 0) & 1);
94 return (x >> exponent) + (MaskIfGreaterThan(remainder, threshold) & 1);
95}
int32_t MaskIfGreaterThan(int32_t a, int32_t b)
int32_t MaskIfLessThan(int32_t a, int32_t b)

References MaskIfGreaterThan(), and MaskIfLessThan().

Referenced by MultiplyByQuantizedMultiplier(), and MultiplyByQuantizedMultiplierSmallerThanOneExp().

◆ SaturatingRoundingDoublingHighMul()

std::int32_t mir_interpreter::SaturatingRoundingDoublingHighMul ( std::int32_t  a,
std::int32_t  b 
)
inline

Definition at line 97 of file QuantizationHelpers.h.

98{
99 bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
100 std::int64_t a_64(a);
101 std::int64_t b_64(b);
102 std::int64_t ab_64 = a_64 * b_64;
103 std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
104 std::int32_t ab_x2_high32 = static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
105 return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
106}

Referenced by MultiplyByQuantizedMultiplier(), and MultiplyByQuantizedMultiplierSmallerThanOneExp().

◆ shift()

mir::Index mir_interpreter::shift ( const Index in_index,
const Shape shift_from 
)

Definition at line 26 of file Common.cpp.

27{
28 Index index = in_index;
29 assert(index.rank() == shift_from.rank());
30 for (int32_t d = 0; d < in_index.rank(); ++d)
31 {
32 index.at(d) = index.at(d) + shift_from.dim(d);
33 }
34 return index;
35}
int32_t rank() const
Definition Index.h:43
int32_t & dim(int32_t axis) noexcept
Definition Shape.h:47
int32_t rank() const
Definition Shape.h:43

References mir::Shape::dim(), mir::Index::rank(), and mir::Shape::rank().

Referenced by MultiplyByQuantizedMultiplier(), QuantizeMultiplier(), QuantizeMultiplierSmallerThanOneExp(), and mir_interpreter::SliceImpl< T >::run().

◆ Sigmoid()

void mir_interpreter::Sigmoid ( const mir::TensorVariant arg,
mir::TensorVariant result 
)

Definition at line 53 of file Sigmoid.cpp.

54{
55 dispatch<SigmoidImpl>(arg.getElementType(), arg, result);
56};

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Slice()

void mir_interpreter::Slice ( const mir::TensorVariant arg,
const mir::Shape starts,
mir::TensorVariant res 
)

Definition at line 47 of file Slice.cpp.

48{
49 dispatch<SliceImpl>(arg.getElementType(), arg, starts, res);
50}

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Softmax()

void mir_interpreter::Softmax ( const mir::TensorVariant arg,
int  axis,
mir::TensorVariant result 
)

Definition at line 150 of file Softmax.cpp.

151{
152 dispatch<SoftmaxImpl>(arg.getElementType(), arg, axis, result);
153};

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Sqrt()

void mir_interpreter::Sqrt ( const mir::TensorVariant arg,
mir::TensorVariant result 
)

Definition at line 53 of file Sqrt.cpp.

54{
55 dispatch<SqrtImpl>(arg.getElementType(), arg, result);
56};

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Sub()

void mir_interpreter::Sub ( const TensorVariant lhs,
const TensorVariant rhs,
TensorVariant res 
)

Definition at line 56 of file Sub.cpp.

57{
58 dispatch<SubImpl>(lhs.getElementType(), lhs, rhs, res);
59};

References mir::TensorVariant::getElementType().

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Tanh()

void mir_interpreter::Tanh ( const mir::TensorVariant arg,
mir::TensorVariant result 
)

Definition at line 53 of file Tanh.cpp.

54{
55 dispatch<TanhImpl>(arg.getElementType(), arg, result);
56};

Referenced by mir_interpreter::MIRInterpreter::visit().

◆ Transpose()

void mir_interpreter::Transpose ( const mir::TensorVariant input,
const mir::ops::TransposeOp op,
mir::TensorVariant output 
)

Definition at line 58 of file Transpose.cpp.

60{
61 dispatch<TransposeImpl>(input.getElementType(), input, op, output);
62}

Referenced by mir_interpreter::MIRInterpreter::visit().