ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert_micro::train::optimizers::Adam Class Reference

#include <Adam.h>

Public Member Functions

 Adam ()=default
 
 Adam (const Adam &)=delete
 
 Adam (Adam &&)=delete
 
Adamoperator= (const Adam &)=delete
 
Adam && operator= (const Adam &&)=delete
 
 ~Adam ()
 
void fullReset ()
 
void reset ()
 
bool isReset ()
 
uint8_t * getExponentAvgDataByTensorIndex (uint16_t tensor_index)
 
uint8_t * getExponentAvgSquaresDataByTensorIndex (uint16_t tensor_index)
 
void setExponentAvgDataByTensorIndex (uint16_t tensor_index, uint8_t *data)
 
void setExponentAvgSquaresDataByTensorIndex (uint16_t tensor_index, uint8_t *data)
 
OMStatus handle (core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context, core::OMRuntimeStorage &storage)
 
OMStatus updateWeights (const OMTrainingContext &training_config, core::OMRuntimeContext &context, core::OMRuntimeStorage &storage, std::unordered_map< uint16_t, core::OpTrainableRankType > &)
 

Detailed Description

Definition at line 38 of file Adam.h.

Constructor & Destructor Documentation

◆ Adam() [1/3]

onert_micro::train::optimizers::Adam::Adam ( )
default

◆ Adam() [2/3]

onert_micro::train::optimizers::Adam::Adam ( const Adam )
delete

◆ Adam() [3/3]

onert_micro::train::optimizers::Adam::Adam ( Adam &&  )
delete

◆ ~Adam()

onert_micro::train::optimizers::Adam::~Adam ( )
inline

Definition at line 55 of file Adam.h.

References fullReset().

Member Function Documentation

◆ fullReset()

void Adam::fullReset ( )

Definition at line 147 of file Adam.cpp.

148{
149 for (auto &cur_tensor_index_data : _tensor_to_exponent_avg)
150 {
151 uint8_t *allocated_data = cur_tensor_index_data.second;
152
154 }
155 _tensor_to_exponent_avg.clear();
156
157 for (auto &cur_tensor_index_data : _tensor_to_exponent_avg_squares)
158 {
159 uint8_t *allocated_data = cur_tensor_index_data.second;
160
162 }
163 _tensor_to_exponent_avg_squares.clear();
164
165 for (auto &cur_tensor_index_data : _tensor_index_to_gradient)
166 {
167 uint8_t *allocated_data = cur_tensor_index_data.second;
168
170 }
171 _tensor_index_to_gradient.clear();
172}
static OMStatus deallocateMemory(uint8_t *data)

References onert_micro::core::memory::OMMemoryManager::deallocateMemory().

Referenced by ~Adam().

◆ getExponentAvgDataByTensorIndex()

uint8_t * Adam::getExponentAvgDataByTensorIndex ( uint16_t  tensor_index)

Definition at line 185 of file Adam.cpp.

186{
187 auto it = _tensor_to_exponent_avg.find(tensor_index);
188 if (it == _tensor_to_exponent_avg.end())
189 return nullptr;
190
191 return it->second;
192}

◆ getExponentAvgSquaresDataByTensorIndex()

uint8_t * Adam::getExponentAvgSquaresDataByTensorIndex ( uint16_t  tensor_index)

Definition at line 194 of file Adam.cpp.

195{
196 auto it = _tensor_to_exponent_avg_squares.find(tensor_index);
197 if (it == _tensor_to_exponent_avg_squares.end())
198 return nullptr;
199
200 return it->second;
201}

◆ handle()

OMStatus Adam::handle ( core::OMRuntimeStorage backward_storage,
core::OMRuntimeContext context,
core::OMRuntimeStorage storage 
)

Definition at line 224 of file Adam.cpp.

226{
227 auto &backward_tensor_to_data = backward_storage.getTensorIndexToData();
228
229 // Check is allocated or not helper buffers
230 if (_tensor_to_exponent_avg_squares.empty())
231 {
232 // If not - let's allocate it
233 assert(_tensor_to_exponent_avg.empty() == true);
234 // Goes over all calculated gradients
235 // Warning: assume that backward storage at this moment contains only weighs gradients -
236 // This should be done due to execution plan work
237 for (auto &tensor_to_data : backward_tensor_to_data)
238 {
239 auto tensor_index = tensor_to_data.first;
240 auto tensor = context.getTensorByIndex(tensor_index);
242
243#ifndef DIS_DYN_SHAPES
244 int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_index).flatSize();
245 if (dynamic_tensor_size != 0)
246 num_elements = dynamic_tensor_size;
247#endif // DIS_DYN_SHAPES
248
249 auto tensor_size = num_elements * sizeof(core::OMDataType(tensor->type()));
250
251 // Allocate data for exponent calculation
252 uint8_t *exponent_data = nullptr;
253 OMStatus status = core::memory::OMMemoryManager::allocateMemory(tensor_size, &exponent_data);
254 assert(status == Ok);
255 if (status != Ok)
256 return UnknownError;
257 // Set to zeros
258 std::memset(exponent_data, 0, tensor_size);
259 _tensor_to_exponent_avg[tensor_to_data.first] = exponent_data;
260
261 // Allocate data for exponent square calculation
262 uint8_t *exponent_square_data = nullptr;
263 status = core::memory::OMMemoryManager::allocateMemory(tensor_size, &exponent_square_data);
264 assert(status == Ok);
265 if (status != Ok)
266 return UnknownError;
267 // Set to zeros
268 std::memset(exponent_square_data, 0, tensor_size);
269 _tensor_to_exponent_avg_squares[tensor_to_data.first] = exponent_square_data;
270 }
271 }
272
273 // Check is allocated or not helper buffer
274 if (_tensor_index_to_gradient.empty())
275 {
276 // If not - let's just move it with calculations
277 // Goes over all calculated gradients
278 // Warning: assume that backward storage at this moment contains only weights gradients -
279 // This should be done due to execution plan work
280 for (auto &tensor_to_data : backward_tensor_to_data)
281 {
282 // Move data
283 _tensor_index_to_gradient[tensor_to_data.first] = tensor_to_data.second;
284 tensor_to_data.second = nullptr;
285 }
286 backward_tensor_to_data.clear();
287 }
288 else
289 {
290 // Goes over all calculated gradients
291 // Warning: assume that backward storage at this moment contains only weighs gradients -
292 // This should be done due to execution plan work
293 for (auto &tensor_to_data : backward_tensor_to_data)
294 {
295 auto tensor = context.getTensorByIndex(tensor_to_data.first);
297
298#ifndef DIS_DYN_SHAPES
299 int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_to_data.first).flatSize();
300 if (dynamic_tensor_size != 0)
301 num_elements = dynamic_tensor_size;
302#endif // DIS_DYN_SHAPES
303
304 auto *grad_data = reinterpret_cast<float *>(_tensor_index_to_gradient[tensor_to_data.first]);
305 auto *calculated_data = reinterpret_cast<float *>(tensor_to_data.second);
306
307 for (uint32_t i = 0; i < num_elements; ++i)
308 {
309 grad_data[i] += calculated_data[i];
310 }
311 }
312 }
313
314 return Ok;
315}
const circle::Tensor * getTensorByIndex(int32_t tensor_index)
OMRuntimeShape getDynamicRuntimeShape(uint16_t tensor_index)
std::unordered_map< uint16_t, uint8_t * > & getTensorIndexToData()
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59
OMDataType
"scalar" value type
Definition OMDataType.h:35
static OMStatus allocateMemory(uint32_t size, uint8_t **data)

References onert_micro::core::memory::OMMemoryManager::allocateMemory(), onert_micro::core::OMRuntimeShape::flatSize(), onert_micro::core::OMRuntimeStorage::getDynamicRuntimeShape(), onert_micro::core::OMRuntimeContext::getTensorByIndex(), onert_micro::core::OMRuntimeStorage::getTensorIndexToData(), onert_micro::Ok, and onert_micro::UnknownError.

◆ isReset()

bool onert_micro::train::optimizers::Adam::isReset ( )
inline

Definition at line 73 of file Adam.h.

74 {
75 return _tensor_to_exponent_avg_squares.empty() or _tensor_to_exponent_avg.empty();
76 }

◆ operator=() [1/2]

Adam && onert_micro::train::optimizers::Adam::operator= ( const Adam &&  )
delete

◆ operator=() [2/2]

Adam & onert_micro::train::optimizers::Adam::operator= ( const Adam )
delete

◆ reset()

void Adam::reset ( )

Definition at line 174 of file Adam.cpp.

175{
176 for (auto &cur_tensor_index_data : _tensor_index_to_gradient)
177 {
178 uint8_t *allocated_data = cur_tensor_index_data.second;
179
181 }
182 _tensor_index_to_gradient.clear();
183}

References onert_micro::core::memory::OMMemoryManager::deallocateMemory().

◆ setExponentAvgDataByTensorIndex()

void Adam::setExponentAvgDataByTensorIndex ( uint16_t  tensor_index,
uint8_t *  data 
)

Definition at line 203 of file Adam.cpp.

204{
205 assert(_tensor_to_exponent_avg.find(tensor_index) == _tensor_to_exponent_avg.end());
206 assert(data != nullptr);
207
208 _tensor_to_exponent_avg[tensor_index] = data;
209}

◆ setExponentAvgSquaresDataByTensorIndex()

void Adam::setExponentAvgSquaresDataByTensorIndex ( uint16_t  tensor_index,
uint8_t *  data 
)

Definition at line 211 of file Adam.cpp.

212{
213 assert(_tensor_to_exponent_avg_squares.find(tensor_index) ==
214 _tensor_to_exponent_avg_squares.end());
215 assert(data != nullptr);
216
217 _tensor_to_exponent_avg_squares[tensor_index] = data;
218}

◆ updateWeights()

OMStatus Adam::updateWeights ( const OMTrainingContext training_config,
core::OMRuntimeContext context,
core::OMRuntimeStorage storage,
std::unordered_map< uint16_t, core::OpTrainableRankType > &  tensor_index_to_rank_type_map 
)

Definition at line 328 of file Adam.cpp.

332{
333 assert(!_tensor_index_to_gradient.empty());
334 for (auto &tensor_to_data : _tensor_index_to_gradient)
335 {
336 auto exponent_squares_it = _tensor_to_exponent_avg_squares.find(tensor_to_data.first);
337 if (exponent_squares_it == _tensor_to_exponent_avg_squares.end())
338 return UnknownError;
339
340 auto exponent_it = _tensor_to_exponent_avg.find(tensor_to_data.first);
341 if (exponent_it == _tensor_to_exponent_avg.end())
342 return UnknownError;
343
344 auto tensor = context.getTensorByIndex(tensor_to_data.first);
345 core::OMRuntimeShape shape(tensor);
346
347 auto original_d = shape.dims(0);
348
350
351#ifndef DIS_DYN_SHAPES
352 int32_t dynamic_tensor_size = storage.getDynamicRuntimeShape(tensor_to_data.first).flatSize();
353 if (dynamic_tensor_size != 0)
354 num_elements = dynamic_tensor_size;
355#endif // DIS_DYN_SHAPES
356
357 auto *exponent_data = reinterpret_cast<float *>(exponent_it->second);
358 auto *exponent_square_data = reinterpret_cast<float *>(exponent_squares_it->second);
359 auto *calculated_data = reinterpret_cast<float *>(tensor_to_data.second);
360 float beta = training_config.beta;
361 float beta_squares = training_config.beta_squares;
362 auto batches = static_cast<float>(training_config.batch_size);
363 for (uint32_t i = 0; i < num_elements; ++i)
364 {
365 const auto cur_val = calculated_data[i];
366 exponent_data[i] = beta * exponent_data[i] + (1 - beta) * cur_val;
367 exponent_square_data[i] =
368 beta_squares * exponent_square_data[i] + (1 - beta_squares) * cur_val * cur_val;
369 }
370
371 uint8_t *weight_data = nullptr;
372 if (context.getConstDataByTensorIndex(&weight_data, tensor_to_data.first) != Ok)
373 return UnknownError;
374
375 assert(weight_data != nullptr);
376 if (weight_data == nullptr)
377 return UnknownError;
378
379 auto *f_weight_data = reinterpret_cast<float *>(weight_data);
380 float lambda = training_config.learning_rate;
381 auto num_step = static_cast<float>(training_config.num_step);
382 float beta_in_pow_batch = std::pow(beta, num_step);
383 float beta_square_in_pow_batch = std::pow(beta_squares, num_step);
384 float epsilon = training_config.epsilon;
385
386 assert((1.f - beta_in_pow_batch) != 0);
387 assert((1.f - beta_square_in_pow_batch) != 0);
388 auto train_it = tensor_index_to_rank_type_map.find(tensor_to_data.first);
389 core::OpTrainableRankType rank = train_it == tensor_index_to_rank_type_map.end()
391 : core::OpTrainableRankType(train_it->second);
392 auto depth_bounds = getUpLowerWeightTensorDepth(rank, original_d);
393
394 for (uint32_t i = 0; i < num_elements; ++i)
395 {
396 float exponent_corrected = exponent_data[i] / (1.f - beta_in_pow_batch);
397 float exponent_square_corrected = exponent_square_data[i] / (1.f - beta_square_in_pow_batch);
398 f_weight_data[i + depth_bounds.first] -=
399 lambda * (exponent_corrected / (std::sqrt(exponent_square_corrected + epsilon)));
400 }
401 }
402
403 return Ok;
404}
OMStatus getConstDataByTensorIndex(uint8_t **data, uint16_t tensor_index)
std::pair< uint32_t, uint32_t > getUpLowerWeightTensorDepth(core::OpTrainableRankType rank, const uint32_t output_depth)
Definition PALUtils.h:30

References onert_micro::core::ALL, onert_micro::OMTrainingContext::batch_size, onert_micro::OMTrainingContext::beta, onert_micro::OMTrainingContext::beta_squares, onert_micro::core::OMRuntimeShape::dims(), onert_micro::OMTrainingContext::epsilon, onert_micro::core::OMRuntimeShape::flatSize(), onert_micro::core::OMRuntimeContext::getConstDataByTensorIndex(), onert_micro::core::OMRuntimeStorage::getDynamicRuntimeShape(), onert_micro::core::OMRuntimeContext::getTensorByIndex(), onert_micro::OMTrainingContext::learning_rate, onert_micro::OMTrainingContext::num_step, onert_micro::Ok, and onert_micro::UnknownError.


The documentation for this class was generated from the following files: