ONE - On-device Neural Engine
Loading...
Searching...
No Matches
dalgona::PythonHooks Class Reference

#include <PythonHooks.h>

Collaboration diagram for dalgona::PythonHooks:

Public Member Functions

 PythonHooks (luci_interpreter::Interpreter *interpreter)
 
void importAnalysis (const std::string &analysis_path, py::object &globals, const std::string &analysis_args)
 
void endAnalysis ()
 
void startNetworkExecution (loco::Graph *graph)
 
void endNetworkExecution (loco::Graph *graph)
 
void preOperatorExecute (const luci::CircleNode *node) override
 
void postOperatorExecute (const luci::CircleNode *node) override
 
- Public Member Functions inherited from luci_interpreter::ExecutionObserver
virtual ~ExecutionObserver ()
 
virtual void postTensorWrite (const luci::CircleNode *node, const Tensor *tensor)
 

Detailed Description

Definition at line 32 of file PythonHooks.h.

Constructor & Destructor Documentation

◆ PythonHooks()

dalgona::PythonHooks::PythonHooks ( luci_interpreter::Interpreter interpreter)
inline

Definition at line 35 of file PythonHooks.h.

35 : _interpreter(interpreter)
36 {
37 // Do nothing
38 }

Member Function Documentation

◆ endAnalysis()

void dalgona::PythonHooks::endAnalysis ( )

Definition at line 103 of file PythonHooks.cpp.

104{
105 if (py::hasattr(_analysis, "EndAnalysis"))
106 pySafeCall(_analysis.attr("EndAnalysis"));
107}
void pySafeCall(py::object func, Args... args)
Definition Utils.h:29

References dalgona::pySafeCall().

◆ endNetworkExecution()

void dalgona::PythonHooks::endNetworkExecution ( loco::Graph graph)

Definition at line 84 of file PythonHooks.cpp.

85{
86 if (!py::hasattr(_analysis, "EndNetworkExecution"))
87 return;
88
89 assert(graph != nullptr); // FIX_CALLER_UNLESS
90
91 const auto output_nodes = loco::output_nodes(graph);
92 py::list outputs;
93 // Assumption: output_nodes is iterated in the same order of model outputs
94 for (const auto output_node : output_nodes)
95 {
96 auto circle_node = loco::must_cast<luci::CircleOutput *>(output_node);
97 outputs.append(
98 outputPyArray(loco::must_cast<luci::CircleNode *>(circle_node->from()), _interpreter));
99 }
100 pySafeCall(_analysis.attr("EndNetworkExecution"), outputs);
101}
py::dict outputPyArray(const luci::CircleNode *node, luci_interpreter::Interpreter *interpreter)
Definition Utils.cpp:160
std::vector< Node * > output_nodes(Graph *)
Definition Graph.cpp:101

References loco::output_nodes(), dalgona::outputPyArray(), and dalgona::pySafeCall().

◆ importAnalysis()

void dalgona::PythonHooks::importAnalysis ( const std::string &  analysis_path,
py::object &  globals,
const std::string &  analysis_args 
)

Definition at line 39 of file PythonHooks.cpp.

41{
42 const auto base_filename = analysis_path.substr(analysis_path.find_last_of("/\\") + 1);
43 // module name must be the same with the python code
44 // ex: base_filename = MyAnalysis.py -> module_name = MyAnalysis
45 const auto module_name = base_filename.substr(0, base_filename.find_last_of('.'));
46
47 py::dict locals;
48 locals["path"] = py::cast(analysis_path);
49
50 py::eval<py::eval_statements>("import sys\n"
51 "import os\n"
52 "sys.path.append(os.path.dirname(path))\n"
53 "import " +
54 module_name +
55 "\n"
56 "analysis = " +
57 module_name + "." + module_name + "()",
58 globals, locals);
59
60 _analysis = locals["analysis"];
61
62 if (py::hasattr(_analysis, "StartAnalysis"))
63 pySafeCall(_analysis.attr("StartAnalysis"), analysis_args);
64}

References dalgona::pySafeCall().

◆ postOperatorExecute()

void dalgona::PythonHooks::postOperatorExecute ( const luci::CircleNode node)
overridevirtual

Reimplemented from luci_interpreter::ExecutionObserver.

Definition at line 33 of file PythonHooks.cpp.

34{
35 PostOperatorHook hook(_analysis, _interpreter);
36 node->accept(&hook);
37}
T accept(CircleNodeVisitorBase< T > *) const

References luci::CircleNode::accept().

◆ preOperatorExecute()

void dalgona::PythonHooks::preOperatorExecute ( const luci::CircleNode node)
overridevirtual

Reimplemented from luci_interpreter::ExecutionObserver.

Definition at line 27 of file PythonHooks.cpp.

28{
29 PreOperatorHook hook(_analysis, _interpreter);
30 node->accept(&hook);
31}

References luci::CircleNode::accept().

◆ startNetworkExecution()

void dalgona::PythonHooks::startNetworkExecution ( loco::Graph graph)

Definition at line 66 of file PythonHooks.cpp.

67{
68 if (!py::hasattr(_analysis, "StartNetworkExecution"))
69 return;
70
71 assert(graph != nullptr); // FIX_CALLER_UNLESS
72
73 const auto input_nodes = loco::input_nodes(graph);
74 py::list inputs;
75 // Assumption: input_nodes is iterated in the same order of model inputs
76 for (const auto input_node : input_nodes)
77 {
78 auto circle_node = loco::must_cast<luci::CircleInput *>(input_node);
79 inputs.append(outputPyArray(circle_node, _interpreter));
80 }
81 pySafeCall(_analysis.attr("StartNetworkExecution"), inputs);
82}
std::vector< Node * > input_nodes(const Graph *)
Definition Graph.cpp:71

References loco::input_nodes(), dalgona::outputPyArray(), and dalgona::pySafeCall().


The documentation for this class was generated from the following files: