ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert_micro::train::optimizers::SGD Class Reference

#include <SGD.h>

Public Member Functions

 SGD ()=default
 
 SGD (const SGD &)=delete
 
 SGD (SGD &&)=delete
 
SGDoperator= (const SGD &)=delete
 
SGD && operator= (const SGD &&)=delete
 
 ~SGD ()
 
void reset ()
 
OMStatus handle (core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context, core::OMRuntimeStorage &storage)
 
OMStatus updateWeights (const OMTrainingContext &training_config, core::OMRuntimeContext &context, core::OMRuntimeStorage &storage, std::unordered_map< uint16_t, core::OpTrainableRankType > &tensor_index_to_rank_type_map)
 

Detailed Description

Definition at line 37 of file SGD.h.

Constructor & Destructor Documentation

◆ SGD() [1/3]

onert_micro::train::optimizers::SGD::SGD ( )
default

◆ SGD() [2/3]

onert_micro::train::optimizers::SGD::SGD ( const SGD )
delete

◆ SGD() [3/3]

onert_micro::train::optimizers::SGD::SGD ( SGD &&  )
delete

◆ ~SGD()

onert_micro::train::optimizers::SGD::~SGD ( )
inline

Definition at line 49 of file SGD.h.

References reset().

Member Function Documentation

◆ handle()

OMStatus SGD::handle ( core::OMRuntimeStorage backward_storage,
core::OMRuntimeContext context,
core::OMRuntimeStorage storage 
)

Definition at line 94 of file SGD.cpp.

96{
97 auto &backward_tensor_to_data = backward_storage.getTensorIndexToData();
98 // Check is allocated or not helper buffers
99 if (_tensor_index_to_gradient.empty())
100 {
101 // If not - let's just move it with calculations
102 // Goes over all calculated gradients
103 // Warning: assume that backward storage at this moment contains only weigths gradients -
104 // This should be done due to execution plan work
105 for (auto &tensor_to_data : backward_tensor_to_data)
106 {
107 // Move data
108 _tensor_index_to_gradient[tensor_to_data.first] = tensor_to_data.second;
109 tensor_to_data.second = nullptr;
110 }
111 backward_tensor_to_data.clear();
112 }
113 else
114 {
115 // Goes over all calculated gradients
116 // Warning: assume that backward storage at this moment contains only weigths gradients -
117 // This should be done due to execution plan work
118 for (auto &tensor_to_data : backward_tensor_to_data)
119 {
120 auto tensor = context.getTensorByIndex(tensor_to_data.first);
122
123#ifndef DIS_DYN_SHAPES
124 int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_to_data.first).flatSize();
125 if (dynamic_tensor_size != 0)
126 num_elements = dynamic_tensor_size;
127#endif // DIS_DYN_SHAPES
128
129 auto *grad_data = reinterpret_cast<float *>(_tensor_index_to_gradient[tensor_to_data.first]);
130 auto *calculated_data = reinterpret_cast<float *>(tensor_to_data.second);
131
132 for (uint32_t i = 0; i < num_elements; ++i)
133 {
134 grad_data[i] += calculated_data[i];
135 }
136 }
137 }
138
139 return Ok;
140}
const circle::Tensor * getTensorByIndex(int32_t tensor_index)
OMRuntimeShape getDynamicRuntimeShape(uint16_t tensor_index)
std::unordered_map< uint16_t, uint8_t * > & getTensorIndexToData()
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59

References onert_micro::core::OMRuntimeShape::flatSize(), onert_micro::core::OMRuntimeStorage::getDynamicRuntimeShape(), onert_micro::core::OMRuntimeContext::getTensorByIndex(), onert_micro::core::OMRuntimeStorage::getTensorIndexToData(), and onert_micro::Ok.

◆ operator=() [1/2]

SGD && onert_micro::train::optimizers::SGD::operator= ( const SGD &&  )
delete

◆ operator=() [2/2]

SGD & onert_micro::train::optimizers::SGD::operator= ( const SGD )
delete

◆ reset()

void SGD::reset ( )

Definition at line 79 of file SGD.cpp.

80{
81 for (auto &cur_tensor_index_data : _tensor_index_to_gradient)
82 {
83 uint8_t *allocated_data = cur_tensor_index_data.second;
84
86 }
87 _tensor_index_to_gradient.clear();
88}
static OMStatus deallocateMemory(uint8_t *data)

References onert_micro::core::memory::OMMemoryManager::deallocateMemory().

Referenced by ~SGD().

◆ updateWeights()

OMStatus SGD::updateWeights ( const OMTrainingContext training_config,
core::OMRuntimeContext context,
core::OMRuntimeStorage storage,
std::unordered_map< uint16_t, core::OpTrainableRankType > &  tensor_index_to_rank_type_map 
)

Definition at line 147 of file SGD.cpp.

151{
152 assert(!_tensor_index_to_gradient.empty());
153 if (_tensor_index_to_gradient.empty())
154 return UnknownError;
155
156 for (auto &tensor_to_data : _tensor_index_to_gradient)
157 {
158 auto tensor = context.getTensorByIndex(tensor_to_data.first);
159 core::OMRuntimeShape shape(tensor);
160 auto num_elements = shape.flatSize();
161
162 auto original_d = shape.dims(0);
163
164#ifndef DIS_DYN_SHAPES
165 int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_to_data.first).flatSize();
166 if (dynamic_tensor_size != 0)
167 num_elements = dynamic_tensor_size;
168#endif // DIS_DYN_SHAPES
169
170 auto *grad_data = reinterpret_cast<float *>(tensor_to_data.second);
171 uint8_t *weight_data = nullptr;
172 if (context.getConstDataByTensorIndex(&weight_data, tensor_to_data.first) != Ok)
173 return UnknownError;
174
175 assert(weight_data != nullptr);
176 if (weight_data == nullptr)
177 return UnknownError;
178
179 auto *f_weight_data = reinterpret_cast<float *>(weight_data);
180 float lambda = training_config.learning_rate;
181 const uint32_t batch_size = training_config.batch_size;
182 auto train_it = tensor_index_to_rank_type_map.find(tensor_to_data.first);
183 core::OpTrainableRankType rank = train_it == tensor_index_to_rank_type_map.end()
185 : core::OpTrainableRankType(train_it->second);
186 auto depth_bounds = getUpLowerWeightTensorDepth(rank, original_d);
187
188 assert(batch_size != 0);
189
190 for (uint32_t i = 0; i < num_elements; ++i)
191 {
192 f_weight_data[i] -= (lambda * grad_data[i]) / (static_cast<float>(batch_size));
193 }
194 }
195 return Ok;
196}
OMStatus getConstDataByTensorIndex(uint8_t **data, uint16_t tensor_index)
std::pair< uint32_t, uint32_t > getUpLowerWeightTensorDepth(core::OpTrainableRankType rank, const uint32_t output_depth)
Definition PALUtils.h:30

References onert_micro::core::ALL, onert_micro::OMTrainingContext::batch_size, onert_micro::core::OMRuntimeShape::dims(), onert_micro::core::OMRuntimeShape::flatSize(), onert_micro::core::OMRuntimeContext::getConstDataByTensorIndex(), onert_micro::core::OMRuntimeStorage::getDynamicRuntimeShape(), onert_micro::core::OMRuntimeContext::getTensorByIndex(), onert_micro::OMTrainingContext::learning_rate, onert_micro::Ok, and onert_micro::UnknownError.


The documentation for this class was generated from the following files: