ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::While Class Reference

#include <While.h>

Collaboration diagram for luci_interpreter::kernels::While:

Public Member Functions

 While (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs, RuntimeGraph *cond_graph, RuntimeGraph *body_graph)
 
const Tensorinput (int index) const
 
Tensoroutput (int index) const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file While.h.

Constructor & Destructor Documentation

◆ While()

luci_interpreter::kernels::While::While ( std::vector< const Tensor * >  inputs,
std::vector< Tensor * >  outputs,
RuntimeGraph cond_graph,
RuntimeGraph body_graph 
)

Definition at line 61 of file While.cpp.

63 : Kernel(std::move(inputs), std::move(outputs)), _cond_graph(cond_graph), _body_graph(body_graph)
64{
65}
Kernel(std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
Definition Kernel.h:31

Member Function Documentation

◆ configure()

void luci_interpreter::kernels::While::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 67 of file While.cpp.

68{
69 LUCI_INTERPRETER_CHECK(_body_graph->getInputTensors().size() == getInputTensors().size());
71 LUCI_INTERPRETER_CHECK(_body_graph->getOutputTensors().size() == getInputTensors().size());
72
73 LUCI_INTERPRETER_CHECK(_cond_graph->getInputTensors().size() == getInputTensors().size());
74
75 const auto &cond_outputs = _cond_graph->getOutputTensors();
76 LUCI_INTERPRETER_CHECK(cond_outputs.size() == 1)
77 LUCI_INTERPRETER_CHECK(cond_outputs[0]->element_type() == DataType::BOOL);
78}
const std::vector< Tensor * > & getOutputTensors() const
Definition Kernel.h:40
const std::vector< const Tensor * > & getInputTensors() const
Definition Kernel.h:39
const std::vector< Tensor * > & getOutputTensors() const
const std::vector< Tensor * > & getInputTensors() const
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
DataType
"scalar" value type
Definition DataType.h:27
int32_t size[5]
Definition Slice.cpp:35

References luci_interpreter::Kernel::getInputTensors(), luci_interpreter::RuntimeGraph::getInputTensors(), luci_interpreter::Kernel::getOutputTensors(), luci_interpreter::RuntimeGraph::getOutputTensors(), LUCI_INTERPRETER_CHECK, and size.

◆ execute()

void luci_interpreter::kernels::While::execute ( ) const
overridevirtual
Note
Dynamic shape such as {1, 0, 8} may fail in tensor->data()

Implements luci_interpreter::Kernel.

Definition at line 83 of file While.cpp.

84{
85 const auto &cond_inputs = _cond_graph->getInputTensors();
86 const auto &cond_outputs = _cond_graph->getOutputTensors();
87
88 configureTensorsAllocations(cond_inputs, _cond_graph);
89
90 copy(getInputTensors(), cond_inputs);
91
92 const auto &body_inputs = _body_graph->getInputTensors();
93 const auto &body_outputs = _body_graph->getOutputTensors();
94
95 configureTensorsAllocations(body_inputs, _body_graph);
96
97 while (true)
98 {
99 _cond_graph->execute();
100
101 bool cond_value = cond_outputs[0]->data<bool>()[0];
102 if (!cond_value)
103 break;
104
105 copy(cond_inputs, body_inputs);
106
107 _body_graph->execute();
108
109 copy(body_outputs, cond_inputs);
110 }
111
112 copy(cond_inputs, getOutputTensors());
113}

References luci_interpreter::RuntimeGraph::execute(), luci_interpreter::Kernel::getInputTensors(), luci_interpreter::RuntimeGraph::getInputTensors(), luci_interpreter::Kernel::getOutputTensors(), and luci_interpreter::RuntimeGraph::getOutputTensors().

◆ input()

const Tensor * luci_interpreter::kernels::While::input ( int  index) const
inline

Definition at line 34 of file While.h.

34{ return _inputs[index]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54

References luci_interpreter::Kernel::_inputs.

◆ output()

Tensor * luci_interpreter::kernels::While::output ( int  index) const
inline

Definition at line 35 of file While.h.

35{ return _outputs[index]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.


The documentation for this class was generated from the following files: