ONE - On-device Neural Engine
Loading...
Searching...
No Matches
luci_interpreter::kernels::If Class Reference

#include <If.h>

Collaboration diagram for luci_interpreter::kernels::If:

Public Member Functions

 If (const Tensor *cond, const std::vector< const Tensor * > &inputs, std::vector< Tensor * > outputs, RuntimeGraph *then_graph, RuntimeGraph *else_graph)
 
const Tensorcond () const
 
const Tensorinput (int index) const
 
Tensoroutput (int index) const
 
void configure () override
 
void execute () const override
 
- Public Member Functions inherited from luci_interpreter::Kernel
virtual ~Kernel ()=default
 
const std::vector< const Tensor * > & getInputTensors () const
 
const std::vector< Tensor * > & getOutputTensors () const
 

Additional Inherited Members

- Protected Member Functions inherited from luci_interpreter::Kernel
 Kernel (std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
 
- Protected Attributes inherited from luci_interpreter::Kernel
const std::vector< const Tensor * > _inputs
 
const std::vector< Tensor * > _outputs
 

Detailed Description

Definition at line 28 of file If.h.

Constructor & Destructor Documentation

◆ If()

luci_interpreter::kernels::If::If ( const Tensor cond,
const std::vector< const Tensor * > &  inputs,
std::vector< Tensor * >  outputs,
RuntimeGraph then_graph,
RuntimeGraph else_graph 
)

Definition at line 35 of file If.cpp.

37 : Kernel(joinInputs(cond, inputs), std::move(outputs)), _then_graph(then_graph),
38 _else_graph(else_graph)
39{
40}
Kernel(std::vector< const Tensor * > inputs, std::vector< Tensor * > outputs)
Definition Kernel.h:31
const Tensor * cond() const
Definition If.h:34

Member Function Documentation

◆ cond()

const Tensor * luci_interpreter::kernels::If::cond ( ) const
inline

Definition at line 34 of file If.h.

34{ return _inputs[0]; }
const std::vector< const Tensor * > _inputs
Definition Kernel.h:52

References luci_interpreter::Kernel::_inputs.

Referenced by configure(), and execute().

◆ configure()

void luci_interpreter::kernels::If::configure ( )
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 42 of file If.cpp.

43{
44 LUCI_INTERPRETER_CHECK(cond()->element_type() == DataType::BOOL);
45 LUCI_INTERPRETER_CHECK(cond()->shape().num_elements() == 1);
46
47 for (RuntimeGraph *graph : {_then_graph, _else_graph})
48 {
49 (void)graph;
50 LUCI_INTERPRETER_CHECK(graph->getInputTensors().size() == getInputTensors().size() - 1);
51 LUCI_INTERPRETER_CHECK(graph->getOutputTensors().size() == getOutputTensors().size());
52 }
53}
const std::vector< Tensor * > & getOutputTensors() const
Definition Kernel.h:40
const std::vector< const Tensor * > & getInputTensors() const
Definition Kernel.h:39
#define LUCI_INTERPRETER_CHECK(cond)
Definition Utils.h:36
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59
int32_t size[5]
Definition Slice.cpp:35

References cond(), luci_interpreter::Kernel::getInputTensors(), luci_interpreter::Kernel::getOutputTensors(), LUCI_INTERPRETER_CHECK, and size.

◆ execute()

void luci_interpreter::kernels::If::execute ( ) const
overridevirtual

Implements luci_interpreter::Kernel.

Definition at line 55 of file If.cpp.

56{
57 const bool cond_value = cond()->data<bool>()[0];
58
59 RuntimeGraph *active_graph = cond_value ? _then_graph : _else_graph;
60 const auto &graph_inputs = active_graph->getInputTensors();
61 const auto &graph_outputs = active_graph->getOutputTensors();
62
63 // Copy kernel inputs to active graph inputs.
64 for (size_t i = 0; i < getInputTensors().size() - 1; ++i)
65 {
66 LUCI_INTERPRETER_CHECK(graph_inputs[i]->element_type() == input(i)->element_type());
67 graph_inputs[i]->resize(input(i)->shape());
68
69 const int32_t num_elements = input(i)->shape().num_elements();
70 const std::size_t element_size = getDataTypeSize(input(i)->element_type());
71 // TODO: Think about how allocate memory for output in main graph
72 active_graph->configureAllocations(graph_inputs[i]);
73 std::memcpy(graph_inputs[i]->data<void>(), input(i)->data<void>(), num_elements * element_size);
74 }
75
76 active_graph->execute();
77
78 // Copy graph outputs to kernel outputs.
79 for (size_t i = 0; i < getOutputTensors().size(); ++i)
80 {
81 LUCI_INTERPRETER_CHECK(graph_outputs[i]->element_type() == output(i)->element_type());
82 output(i)->resize(graph_outputs[i]->shape());
83 // TODO: Think about how allocate memory for output in main graph
84 active_graph->configureAllocations(output(i));
85
86 const int32_t num_elements = output(i)->shape().num_elements();
87 const std::size_t element_size = getDataTypeSize(output(i)->element_type());
88 std::memcpy(output(i)->data<void>(), graph_outputs[i]->data<void>(),
89 num_elements * element_size);
90 }
91}
const std::vector< Tensor * > & getInputTensors() const
int32_t num_elements() const
Definition Tensor.h:53
void resize(const Shape &new_shape)
Definition Tensor.cpp:56
const Shape & shape() const
Definition Tensor.h:107
const T * data() const
Definition Tensor.h:127
Tensor * output(int index) const
Definition If.h:36
const Tensor * input(int index) const
Definition If.h:35
size_t getDataTypeSize(DataType data_type)
Definition DataType.h:33

References cond(), luci_interpreter::Tensor::data(), luci_interpreter::getDataTypeSize(), luci_interpreter::Kernel::getInputTensors(), luci_interpreter::RuntimeGraph::getInputTensors(), luci_interpreter::Kernel::getOutputTensors(), input(), LUCI_INTERPRETER_CHECK, luci_interpreter::Shape::num_elements(), output(), luci_interpreter::Tensor::resize(), and luci_interpreter::Tensor::shape().

◆ input()

const Tensor * luci_interpreter::kernels::If::input ( int  index) const
inline

Definition at line 35 of file If.h.

35{ return _inputs[1 + index]; }
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54

References luci_interpreter::Kernel::_inputs.

Referenced by execute().

◆ output()

Tensor * luci_interpreter::kernels::If::output ( int  index) const
inline

Definition at line 36 of file If.h.

36{ return _outputs[index]; }
const std::vector< Tensor * > _outputs
Definition Kernel.h:53

References luci_interpreter::Kernel::_outputs.

Referenced by execute().


The documentation for this class was generated from the following files: