ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci::DynamicBatchToSingleBatchPass Class Reference

Pass to convert dynamic batch to single batch. More...

#include <DynamicBatchToSingleBatchPass.h>

Collaboration diagram for luci::DynamicBatchToSingleBatchPass:

Public Member Functions

virtual const char * name (void) const
 
bool run (loco::Graph *graph)
 Run the pass.
 
- Public Member Functions inherited from logo::Pass
virtual ~Pass ()=default
 

Detailed Description

Pass to convert dynamic batch to single batch.

Definition at line 28 of file DynamicBatchToSingleBatchPass.h.

Member Function Documentation

◆ name()

virtual const char * luci::DynamicBatchToSingleBatchPass::name ( void  ) const
inlinevirtual

Reimplemented from logo::Pass.

Definition at line 31 of file DynamicBatchToSingleBatchPass.h.

31{ return "luci::DynamicBatchToSingleBatchPass"; }

◆ run()

bool luci::DynamicBatchToSingleBatchPass::run ( loco::Graph graph)
virtual

Run the pass.

Returns
false if there was nothing changed

Implements logo::Pass.

Definition at line 25 of file DynamicBatchToSingleBatchPass.cpp.

26{
27 assert(g); // FIX CALLER UNLESS
28
29 bool changed = false;
30
31 auto graph_inputs = g->inputs();
32
33 // Assume the first dimension is batch dimension
34 const uint32_t BATCH_DIM = 0;
35
36 for (auto node : loco::input_nodes(g))
37 {
38 auto input_node = loco::must_cast<luci::CircleInput *>(node);
39
40 if (input_node->rank() == 0)
41 continue;
42
43 // Skip if batch dimension is known
44 if (input_node->dim(BATCH_DIM).known())
45 continue;
46
47 if (input_node->rank() != 4)
48 {
49 // Limit use only for rank 4 inputs (for NHWC and NCHW)
50 // TODO Enable this if necessary
51 throw std::runtime_error("First dimension of input is unknown, but its rank is not 4.");
52 }
53
54 // 'set' will make the dimension known
55 input_node->dim(BATCH_DIM).set(1);
56
57 // Update graph input
58 auto graph_input = graph_inputs->at(input_node->index());
59 auto graph_input_shape = graph_input->shape();
60 auto tensor_shape = std::make_unique<loco::TensorShape>();
61 {
62 tensor_shape->rank(graph_input_shape->rank());
63 for (uint32_t i = 0; i < tensor_shape->rank(); i++)
64 {
65 tensor_shape->dim(i) = graph_input_shape->dim(i);
66 }
67 tensor_shape->dim(BATCH_DIM).set(1);
68 }
69
70 graph_input->shape(std::move(tensor_shape));
71
72 changed = true;
73 }
74
75 return changed;
76}
void set(uint32_t value)
Definition Dimension.h:53
const Dimension & dim(uint32_t axis) const
Definition TensorShape.h:38
uint32_t rank(void) const
Definition TensorShape.h:35
void index(const loco::GraphInputIndex &index)
loco::Use * at(uint32_t n) const
std::vector< Node * > input_nodes(const Graph *)
Definition Graph.cpp:71
void graph_input_shape(luci::CircleInput *input)
This will set GraphInput shape from CircleInput shape.
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.

References luci::FixedArityNode< N, Base >::at(), loco::TensorShape::dim(), luci::CircleInput::index(), luci::input_node(), loco::input_nodes(), loco::TensorShape::rank(), and loco::Dimension::set().

Referenced by package.infer.session::inference().


The documentation for this class was generated from the following files: