17#ifndef ONERT_MICRO_TRAIN_TRAIN_OPTIMIZERS_ADAM_H
18#define ONERT_MICRO_TRAIN_TRAIN_OPTIMIZERS_ADAM_H
26#include <unordered_map>
43 std::unordered_map<uint16_t, uint8_t *> _tensor_to_exponent_avg_squares;
45 std::unordered_map<uint16_t, uint8_t *> _tensor_to_exponent_avg;
47 std::unordered_map<uint16_t, uint8_t *> _tensor_index_to_gradient;
57#ifdef OM_MEMORY_ESTIMATE
75 return _tensor_to_exponent_avg_squares.empty() or _tensor_to_exponent_avg.empty();
92 std::unordered_map<uint16_t, core::OpTrainableRankType> &);
Adam & operator=(const Adam &)=delete
uint8_t * getExponentAvgSquaresDataByTensorIndex(uint16_t tensor_index)
void setExponentAvgDataByTensorIndex(uint16_t tensor_index, uint8_t *data)
OMStatus handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context, core::OMRuntimeStorage &storage)
void setExponentAvgSquaresDataByTensorIndex(uint16_t tensor_index, uint8_t *data)
Adam(const Adam &)=delete
OMStatus updateWeights(const OMTrainingContext &training_config, core::OMRuntimeContext &context, core::OMRuntimeStorage &storage, std::unordered_map< uint16_t, core::OpTrainableRankType > &)
uint8_t * getExponentAvgDataByTensorIndex(uint16_t tensor_index)
Adam && operator=(const Adam &&)=delete