17#include "luci_interpreter/Interpreter.h"
20#include "loader/ModuleLoader.h"
30class EventNotifierImpl final :
public EventNotifier
33 EventNotifierImpl(
const RuntimeToIR &runtime_to_ir,
34 const std::vector<ExecutionObserver *> &observers)
35 : _runtime_to_ir(runtime_to_ir), _observers(observers)
39 void postTensorWrite(
const Tensor *tensor)
override
41 assert(tensor !=
nullptr);
42 for (
const auto &observer : _observers)
44 observer->postTensorWrite(_runtime_to_ir.tensor_to_node.at(tensor), tensor);
48 void preOperatorExecute(
const Kernel *kernel)
override
50 assert(kernel !=
nullptr);
51 for (
const auto &observer : _observers)
53 observer->preOperatorExecute(_runtime_to_ir.kernel_to_node.at(kernel));
57 void postOperatorExecute(
const Kernel *kernel)
override
59 assert(kernel !=
nullptr);
60 for (
const auto &observer : _observers)
62 observer->postOperatorExecute(_runtime_to_ir.kernel_to_node.at(kernel));
67 const RuntimeToIR &_runtime_to_ir;
68 const std::vector<ExecutionObserver *> &_observers;
75 _runtime_to_ir = std::make_unique<RuntimeToIR>();
76 _event_notifier = std::make_unique<EventNotifierImpl>(*_runtime_to_ir, _observers);
77 _runtime_module = std::make_unique<RuntimeModule>(_event_notifier.get());
79 _default_memory_manager = std::make_unique<SimpleMemoryManager>();
81 ModuleLoader loader(module, _runtime_module.get(), *_runtime_to_ir, _node_to_tensor,
82 _default_memory_manager.get());
89 assert(memory_manager &&
"Use Interpreter::Interpreter(module) constructor instead");
91 _runtime_to_ir = std::make_unique<RuntimeToIR>();
92 _event_notifier = std::make_unique<EventNotifierImpl>(*_runtime_to_ir, _observers);
93 _runtime_module = std::make_unique<RuntimeModule>(_event_notifier.get());
95 ModuleLoader loader(module, _runtime_module.get(), *_runtime_to_ir, _node_to_tensor,
106 if (tensor ==
nullptr)
109 throw std::runtime_error(
"Cannot find tensor for input node named \"" + name +
"\".");
112 tensor->writeData(
data, data_size);
119 if (tensor ==
nullptr)
122 throw std::runtime_error(
"Cannot find tensor for output node named \"" + name +
"\".");
125 tensor->readData(
data, data_size);
131 if (tensor ==
nullptr)
134 throw std::runtime_error(
"Cannot find tensor size for output node named \"" + name +
"\".");
138 tensor_size *= tensor->shape().num_elements();
146 if (std::find(_observers.cbegin(), _observers.cend(), observer) != _observers.cend())
147 throw std::runtime_error(
"Observer is already attached.");
148 _observers.push_back(observer);
CircleNode for Output of the Graph.
void index(const loco::GraphOutputIndex &index)
Collection of 'loco::Graph's.
virtual void postOperatorExecute(const luci::CircleNode *node)
virtual ~ExecutionObserver()
virtual void preOperatorExecute(const luci::CircleNode *node)
virtual void postTensorWrite(const luci::CircleNode *node, const Tensor *tensor)
void attachObserver(ExecutionObserver *observer)
size_t getOutputTensorSize(const luci::CircleOutput *output_node)
void writeInputTensor(const luci::CircleInput *input_node, const void *data, size_t data_size)
Interpreter(const luci::Module *module)
void readOutputTensor(const luci::CircleOutput *output_node, void *data, size_t data_size)
const T * data(const std::vector< T, Alloc > &v)
size_t getDataTypeSize(DataType data_type)
CircleOutput * output_node(loco::Graph *g, const loco::GraphOutputIndex &index)
Find a CircleOutput node with a given output index.
CircleInput * input_node(loco::Graph *g, const loco::GraphInputIndex &index)
Find a Pull node with a given input index.
NodeName name(void) const