ONE - On-device Neural Engine
Loading...
Searching...
No Matches
mir::ops::SqueezeOp Class Reference

#include <SqueezeOp.h>

Collaboration diagram for mir::ops::SqueezeOp:

Public Member Functions

 SqueezeOp (Output *arg, const std::vector< std::int32_t > &dims_to_squeeze)
 
OperationcopyWithInputs (const std::vector< Output * > &inputs) override
 
void inferOutputTypes ()
 
int32_t getNumSqueezeDims () const
 
const std::vector< int32_t > & getDimsToSqueeze () const
 
- Public Member Functions inherited from mir::Operation
virtual ~Operation ()=default
 
Type getType () const
 
std::size_t getId () const
 
void setId (std::size_t id)
 
std::size_t getNumInputs () const
 
std::size_t getNumOutputs () const
 
std::deque< Output * > & getInputs ()
 
const std::deque< Output * > & getInputs () const
 
std::deque< Output > & getOutputs ()
 
const std::deque< Output > & getOutputs () const
 
OutputgetInput (std::size_t index)
 
const OutputgetInput (std::size_t index) const
 
OutputgetOutput (std::size_t index)
 
const OutputgetOutput (std::size_t index) const
 
const ShapegetInputShape (std::size_t index) const
 
const ShapegetOutputShape (std::size_t index) const
 
void accept (IVisitor *v)
 

Additional Inherited Members

- Public Types inherited from mir::Operation
enum class  Type { HANDLE_OP }
 
- Protected Member Functions inherited from mir::Operation
 Operation (Type type, const std::vector< Output * > &inputs, std::size_t num_outputs=1)
 
void setOutputType (std::size_t index, const TensorType &type)
 

Detailed Description

Definition at line 28 of file SqueezeOp.h.

Constructor & Destructor Documentation

◆ SqueezeOp()

mir::ops::SqueezeOp::SqueezeOp ( Output arg,
const std::vector< std::int32_t > &  dims_to_squeeze 
)
inline

Definition at line 31 of file SqueezeOp.h.

32 : Operation(Type::squeeze, {arg}), _dims_to_squeeze(dims_to_squeeze)
33 {
34 // Infer output shape.
36 }

Member Function Documentation

◆ copyWithInputs()

Operation * mir::ops::SqueezeOp::copyWithInputs ( const std::vector< Output * > &  inputs)
inlineoverridevirtual

Implements mir::Operation.

Definition at line 38 of file SqueezeOp.h.

39 {
40 return new SqueezeOp(inputs[0], _dims_to_squeeze);
41 }
SqueezeOp(Output *arg, const std::vector< std::int32_t > &dims_to_squeeze)
Definition SqueezeOp.h:31

◆ getDimsToSqueeze()

const std::vector< int32_t > & mir::ops::SqueezeOp::getDimsToSqueeze ( ) const
inline

Definition at line 47 of file SqueezeOp.h.

47{ return _dims_to_squeeze; }

Referenced by inferOutputTypes(), and mir::DotNodeBuilder::visit().

◆ getNumSqueezeDims()

int32_t mir::ops::SqueezeOp::getNumSqueezeDims ( ) const
inline

Definition at line 45 of file SqueezeOp.h.

45{ return static_cast<int32_t>(_dims_to_squeeze.size()); }

Referenced by inferOutputTypes().

◆ inferOutputTypes()

void mir::ops::SqueezeOp::inferOutputTypes ( )

Definition at line 24 of file SqueezeOp.cpp.

25{
26 assert(getNumInputs() == 1);
27
28 const auto &input_shape = getInputShape(0);
29 auto dt = getInput(0)->getElementType();
30 int32_t input_rank = input_shape.rank();
31
32 std::vector<int32_t> dims_to_squeeze;
33
34 if (getNumSqueezeDims() == 0)
35 {
36 for (int32_t i = 0; i < input_rank; ++i)
37 {
38 if (input_shape.dim(i) == 1)
39 {
40 dims_to_squeeze.push_back(i);
41 }
42 }
43 }
44 else
45 {
46 dims_to_squeeze = getDimsToSqueeze();
47 sort(dims_to_squeeze.begin(), dims_to_squeeze.end());
48 dims_to_squeeze.erase(unique(dims_to_squeeze.begin(), dims_to_squeeze.end()),
49 dims_to_squeeze.end());
50 }
51
52 if (dims_to_squeeze.size() == static_cast<size_t>(input_rank))
53 {
54 // Input shape have 1s in all dimensions, output shape is (1,)
55 setOutputType(0, {dt, Shape{1}});
56 return;
57 }
58
59 int32_t output_rank = 0;
60 size_t squeezing_idx = 0;
61 Shape output_shape(input_rank - dims_to_squeeze.size());
62 for (int32_t i = 0; i < input_rank; ++i)
63 {
64 if (squeezing_idx < dims_to_squeeze.size() && i == dims_to_squeeze[squeezing_idx])
65 {
66 if (input_shape.dim(i) != 1)
67 throw std::invalid_argument("All squeezed dimensions should have size 1");
68
69 squeezing_idx++;
70 }
71 else
72 {
73 output_shape.dim(output_rank++) = input_shape.dim(i);
74 }
75 }
76
78}
DataType getElementType() const
Definition Operation.h:98
std::size_t getNumInputs() const
Definition Operation.h:128
Output * getInput(std::size_t index)
Definition Operation.h:137
const Shape & getInputShape(std::size_t index) const
Definition Operation.h:161
void setOutputType(std::size_t index, const TensorType &type)
Definition Operation.h:172
int32_t getNumSqueezeDims() const
Definition SqueezeOp.h:45
const std::vector< int32_t > & getDimsToSqueeze() const
Definition SqueezeOp.h:47
const luci_interpreter::RuntimeShape output_shape
Definition Shape.h:28

References getDimsToSqueeze(), mir::Operation::Output::getElementType(), mir::Operation::getInput(), mir::Operation::getInputShape(), mir::Operation::getNumInputs(), getNumSqueezeDims(), output_shape, and mir::Operation::setOutputType().


The documentation for this class was generated from the following files: