ONE - On-device Neural Engine
Loading...
Searching...
No Matches
record_hessian::HessianObserver Class Reference

#include <HessianObserver.h>

Collaboration diagram for record_hessian::HessianObserver:

Public Member Functions

 HessianObserver ()=default
 
void postTensorWrite (const luci::CircleNode *node, const luci_interpreter::Tensor *tensor) override
 
std::unique_ptr< HessianMaphessianData ()
 
- Public Member Functions inherited from luci_interpreter::ExecutionObserver
virtual ~ExecutionObserver ()
 
virtual void preOperatorExecute (const luci::CircleNode *node)
 
virtual void postOperatorExecute (const luci::CircleNode *node)
 

Detailed Description

Definition at line 30 of file HessianObserver.h.

Constructor & Destructor Documentation

◆ HessianObserver()

record_hessian::HessianObserver::HessianObserver ( )
default

Member Function Documentation

◆ hessianData()

std::unique_ptr< HessianMap > record_hessian::HessianObserver::hessianData ( )
inline

Definition at line 38 of file HessianObserver.h.

38{ return _hessian_computer.getMap(); }
std::unique_ptr< HessianMap > getMap()

References record_hessian::HessianComputer::getMap().

◆ postTensorWrite()

void record_hessian::HessianObserver::postTensorWrite ( const luci::CircleNode node,
const luci_interpreter::Tensor tensor 
)
overridevirtual

Reimplemented from luci_interpreter::ExecutionObserver.

Definition at line 22 of file HessianObserver.cpp.

24{
25 assert(node != nullptr);
26 assert(tensor != nullptr);
27
28 auto node_outputs = loco::succs(node);
29 for (auto node_output : node_outputs)
30 {
31 auto cur_node = dynamic_cast<luci::CircleNode *>(node_output);
32 if (cur_node == nullptr)
33 {
34 throw std::runtime_error("Record Hessian: node output shouldn't be null.");
35 }
36 // TODO : ADD TCONV/DepthCONV cases
37 if (cur_node->opcode() == luci::CircleOpcode::FULLY_CONNECTED ||
38 cur_node->opcode() == luci::CircleOpcode::CONV_2D)
39 {
40 _hessian_computer.recordHessian(cur_node, tensor);
41 }
42 }
43}
void recordHessian(const luci::CircleNode *node, const luci_interpreter::Tensor *input_tensor)
std::set< Node * > succs(const Node *node)
Enumerate all the successors of a given node.
Definition Node.cpp:46

References record_hessian::HessianComputer::recordHessian(), and loco::succs().


The documentation for this class was generated from the following files: