ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci::ConstantFoldingTestGraph Class Referenceabstract

#include <PassTestGraphs.h>

Collaboration diagram for luci::ConstantFoldingTestGraph:

Public Member Functions

 ConstantFoldingTestGraph (std::vector< uint32_t > input_shape, loco::DataType input_dtype)
 
virtual void init ()=0
 
virtual ~ConstantFoldingTestGraph ()=default
 
virtual loco::NodecreateFoldedPattern ()=0
 
virtual luci::CircleConstgetFoldedPattern ()=0
 
loco::Graphgraph ()
 

Protected Attributes

loco::Graph _g
 
luci::CircleInput_input = nullptr
 
luci::CircleOutput_output = nullptr
 

Detailed Description

ConstantFoldingTestGraph is a base class for testing constant folding passes. It creates Input and Output in the below graph. Child classes must implement Connector and Folded pattern.

[Input]   [Folded pattern] (Implemented by child class)
     \    /
   [Connector] (Implemented by child class)
        |
     [Output]

Connector should satisfy the below conditions

  • Input type == Output type == Folded pattern type
  • Input shape == Output shape == Folded pattern shape

For example, Add, Mul, Sub, .. can be a Connector

Definition at line 47 of file PassTestGraphs.h.

Constructor & Destructor Documentation

◆ ConstantFoldingTestGraph()

luci::ConstantFoldingTestGraph::ConstantFoldingTestGraph ( std::vector< uint32_t >  input_shape,
loco::DataType  input_dtype 
)
inline

Definition at line 50 of file PassTestGraphs.h.

51 {
54
55 auto graph_input = _g.inputs()->create();
56 _input->index(graph_input->index());
57 auto graph_output = _g.outputs()->create();
58 _output->index(graph_output->index());
59
60 graph_input->dtype(input_dtype);
61 graph_output->dtype(input_dtype);
62 _input->dtype(input_dtype);
63 _output->dtype(input_dtype);
64
65 auto input_tensor_shape = std::make_unique<loco::TensorShape>();
66 input_tensor_shape->rank(input_shape.size());
67 for (int i = 0; i < input_shape.size(); i++)
68 input_tensor_shape->dim(i).set(input_shape[i]);
69 graph_input->shape(std::move(input_tensor_shape));
70
71 auto output_tensor_shape = std::make_unique<loco::TensorShape>();
72 output_tensor_shape->rank(input_shape.size());
73 for (int i = 0; i < input_shape.size(); i++)
74 output_tensor_shape->dim(i).set(input_shape[i]);
75 graph_output->shape(std::move(output_tensor_shape));
76
77 _input->rank(input_shape.size());
78 for (int i = 0; i < input_shape.size(); i++)
79 _input->dim(i).set(input_shape[i]);
80
81 _output->rank(input_shape.size());
82 for (int i = 0; i < input_shape.size(); i++)
83 _output->dim(i).set(input_shape[i]);
84
85 _input->name("input");
86 _output->name("output");
87 }
InputContext * inputs(void)
Definition Graph.h:220
NodeContext * nodes(void)
Definition Graph.h:218
OutputContext * outputs(void)
Definition Graph.h:222
Derived * create(Args &&...args)
Definition NodePool.h:37
CircleNode used for Input of the Graph.
Definition CircleInput.h:36
void index(const loco::GraphInputIndex &index)
CircleNode for Output of the Graph.
void index(const loco::GraphOutputIndex &index)
GraphInput * create(void)
Definition Graph.cpp:52
GraphOutput * create(void)
Definition Graph.cpp:54
NodeName name(void) const

References _g, _input, _output, loco::NodePool::create(), loco::Graph::InputContext::create(), loco::Graph::OutputContext::create(), luci::CircleInput::index(), luci::CircleOutput::index(), loco::Graph::inputs(), luci::CircleNode::name(), loco::Graph::nodes(), and loco::Graph::outputs().

◆ ~ConstantFoldingTestGraph()

virtual luci::ConstantFoldingTestGraph::~ConstantFoldingTestGraph ( )
virtualdefault

Member Function Documentation

◆ createFoldedPattern()

virtual loco::Node * luci::ConstantFoldingTestGraph::createFoldedPattern ( )
pure virtual

◆ getFoldedPattern()

virtual luci::CircleConst * luci::ConstantFoldingTestGraph::getFoldedPattern ( )
pure virtual

◆ graph()

loco::Graph * luci::ConstantFoldingTestGraph::graph ( )
inline

◆ init()

virtual void luci::ConstantFoldingTestGraph::init ( )
pure virtual

Field Documentation

◆ _g

loco::Graph luci::ConstantFoldingTestGraph::_g
protected

◆ _input

luci::CircleInput* luci::ConstantFoldingTestGraph::_input = nullptr
protected

◆ _output

luci::CircleOutput* luci::ConstantFoldingTestGraph::_output = nullptr
protected

The documentation for this class was generated from the following file: