ONE - On-device Neural Engine
Loading...
Searching...
No Matches
record_hessian::HessianComputer Class Reference

#include <HessianComputer.h>

Public Member Functions

void recordHessian (const luci::CircleNode *node, const luci_interpreter::Tensor *input_tensor)
 
std::unique_ptr< HessianMapgetMap ()
 

Detailed Description

Definition at line 38 of file HessianComputer.h.

Member Function Documentation

◆ getMap()

std::unique_ptr< HessianMap > record_hessian::HessianComputer::getMap ( )

Definition at line 202 of file HessianComputer.cpp.

203{
204 auto hessian_map = std::make_unique<HessianMap>();
205
206 for (auto item : _hessian_map)
207 {
208 auto &vec = (*hessian_map)[item.first];
209 vec = item.second.hessian;
210 }
211
212 return hessian_map;
213}

Referenced by record_hessian::HessianObserver::hessianData().

◆ recordHessian()

void record_hessian::HessianComputer::recordHessian ( const luci::CircleNode node,
const luci_interpreter::Tensor input_tensor 
)

Definition at line 178 of file HessianComputer.cpp.

180{
181 if (node == nullptr || input_tensor == nullptr)
182 throw std::invalid_argument("RecordHessian: node or input_tensor is null.");
183
184 if (input_tensor->element_type() != luci_interpreter::DataType::FLOAT32)
185 throw std::runtime_error("RecordHessian: Unsupported dtype: only FLOAT32 is supported.");
186
187 _input_tensor = input_tensor;
188
189 switch (node->opcode())
190 {
191 case luci::CircleOpcode::FULLY_CONNECTED:
192 recordHessianForFullyConnected(node);
193 break;
194 case luci::CircleOpcode::CONV_2D:
195 recordHessianForConv2D(node);
196 break;
197 default:
198 throw std::runtime_error("RecordHessian: " + node->name() + " is unsupported op.");
199 }
200}
DataType element_type() const
Definition Tensor.h:105
NodeName name(void) const
virtual CircleOpcode opcode(void) const =0

References luci_interpreter::Tensor::element_type(), luci::CircleNode::name(), and luci::CircleNode::opcode().

Referenced by record_hessian::HessianObserver::postTensorWrite().


The documentation for this class was generated from the following files: