ONE - On-device Neural Engine
Loading...
Searching...
No Matches
locoex::COpShapeInferenceRule Struct Referencefinal

Shape inference rule for COpDialect. More...

#include <COpShapeInferenceRule.h>

Collaboration diagram for locoex::COpShapeInferenceRule:

Public Member Functions

bool recognize (const loco::Dialect *) const final
 Return true if this rule recognizes a given dialect.
 
bool infer (const loco::Node *, loco::NodeShape &) const final
 Infer node's shape.
 
- Public Member Functions inherited from loco::ShapeInferenceRule
virtual ~ShapeInferenceRule ()=default
 
virtual bool support (const API &api) const
 Check whether a given API is available or not.
 
virtual void infer (const Context *, const Node *, Sink *) const
 

Additional Inherited Members

- Public Types inherited from loco::ShapeInferenceRule
enum class  API { V1 , V2 }
 

Detailed Description

Shape inference rule for COpDialect.

Note
the shape of inputs and output of CopCall must belong to loco::Domain::Tensor

Definition at line 33 of file COpShapeInferenceRule.h.

Member Function Documentation

◆ infer()

bool locoex::COpShapeInferenceRule::infer ( const loco::Node ,
loco::NodeShape  
) const
finalvirtual

Infer node's shape.

WARNING!!

Implementation SHOULD return true only when it succeeds in inference!

Implements loco::ShapeInferenceRule.

Definition at line 35 of file COpShapeInferenceRule.cpp.

36{
37 assert(node->dialect() == COpDialect::get());
38 assert(dynamic_cast<const COpNode *>(node) != nullptr);
39
40 auto cop_call = loco::must_cast<const COpCall *>(node);
41
42 // Note that the shape of custom op is considered as TensorShape
43 // TODO Decide how to deal with this shape error cases
44 for (uint32_t n = 0; n < cop_call->arity(); n++)
45 if (loco::shape_get(cop_call->input(n)).domain() != loco::Domain::Tensor)
46 throw std::runtime_error("Input of custom op must belong to Tensor domain.");
47
48 loco::TensorShape out_shape;
49
50 out_shape.rank(cop_call->rank());
51 for (uint32_t d = 0; d < cop_call->rank(); d++)
52 out_shape.dim(d) = cop_call->dim(d);
53
54 shape.set(out_shape);
55
56 return true;
57}
void set(uint32_t value)
Definition Dimension.h:53
const Domain & domain(void) const
Definition NodeShape.h:48
const Dimension & dim(uint32_t axis) const
Definition TensorShape.h:38
uint32_t rank(void) const
Definition TensorShape.h:35
static loco::Dialect * get(void)
NodeShape shape_get(const Node *node)

References loco::Node::dialect(), loco::TensorShape::dim(), loco::NodeShape::domain(), locoex::COpDialect::get(), loco::TensorShape::rank(), loco::NodeShape::set(), loco::shape_get(), and loco::Tensor.

◆ recognize()

bool locoex::COpShapeInferenceRule::recognize ( const loco::Dialect ) const
finalvirtual

Return true if this rule recognizes a given dialect.

Implements loco::ShapeInferenceRule.

Definition at line 30 of file COpShapeInferenceRule.cpp.

31{
32 return COpDialect::get() == d;
33}

References locoex::COpDialect::get().


The documentation for this struct was generated from the following files: