ONE - On-device Neural Engine
Loading...
Searching...
No Matches
SGD.h
Go to the documentation of this file.
1
/*
2
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3
*
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
* you may not use this file except in compliance with the License.
6
* You may obtain a copy of the License at
7
*
8
* http://www.apache.org/licenses/LICENSE-2.0
9
*
10
* Unless required by applicable law or agreed to in writing, software
11
* distributed under the License is distributed on an "AS IS" BASIS,
12
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
* See the License for the specific language governing permissions and
14
* limitations under the License.
15
*/
16
17
#ifndef ONERT_MICRO_TRAIN_TRAIN_OPTIMIZERS_SGD_H
18
#define ONERT_MICRO_TRAIN_TRAIN_OPTIMIZERS_SGD_H
19
20
#include "
OMStatus.h
"
21
#include "
core/OMRuntimeStorage.h
"
22
#include "
core/OMRuntimeContext.h
"
23
24
#include <cstdint>
25
#include <unordered_map>
26
27
namespace
onert_micro
28
{
29
namespace
train
30
{
31
namespace
optimizers
32
{
33
34
/*
35
* Class to handle Adam optimizer
36
*/
37
class
SGD
38
{
39
private
:
40
// Save mapping between tensor index and internal state data with calculated gradients
41
std::unordered_map<uint16_t, uint8_t *> _tensor_index_to_gradient;
42
43
public
:
44
SGD
() =
default
;
45
SGD
(
const
SGD
&) =
delete
;
46
SGD
(
SGD
&&) =
delete
;
47
SGD
&
operator=
(
const
SGD
&) =
delete
;
48
SGD
&&
operator=
(
const
SGD
&&) =
delete
;
49
~SGD
() {
reset
(); }
50
51
#ifdef OM_MEMORY_ESTIMATE
52
// Reset and deallocate all internal states
53
void
reset
(
core::OMRuntimeContext
&context,
core::OMRuntimeStorage
&storage);
54
#endif
// OM_MEMORY_ESTIMATE
55
56
// Reset and deallocate all internal states
57
void
reset
();
58
59
// Update internal states according to SGD theory
60
OMStatus
handle
(
core::OMRuntimeStorage
&backward_storage,
core::OMRuntimeContext
&context,
61
core::OMRuntimeStorage
&storage);
62
63
// Update weights according to SGD theory
64
OMStatus
updateWeights
(
65
const
OMTrainingContext
&training_config,
core::OMRuntimeContext
&context,
66
core::OMRuntimeStorage
&storage,
67
std::unordered_map<uint16_t, core::OpTrainableRankType> &tensor_index_to_rank_type_map);
68
};
69
70
}
// namespace optimizers
71
}
// namespace train
72
}
// namespace onert_micro
73
74
#endif
// ONERT_MICRO_TRAIN_TRAIN_OPTIMIZERS_SGD_H
OMRuntimeContext.h
OMRuntimeStorage.h
OMStatus.h
onert_micro::core::OMRuntimeContext
Definition
OMRuntimeContext.h:37
onert_micro::core::OMRuntimeStorage
Definition
OMRuntimeStorage.h:34
onert_micro::train::optimizers::SGD
Definition
SGD.h:38
onert_micro::train::optimizers::SGD::handle
OMStatus handle(core::OMRuntimeStorage &backward_storage, core::OMRuntimeContext &context, core::OMRuntimeStorage &storage)
Definition
SGD.cpp:94
onert_micro::train::optimizers::SGD::~SGD
~SGD()
Definition
SGD.h:49
onert_micro::train::optimizers::SGD::SGD
SGD(const SGD &)=delete
onert_micro::train::optimizers::SGD::operator=
SGD && operator=(const SGD &&)=delete
onert_micro::train::optimizers::SGD::reset
void reset()
Definition
SGD.cpp:79
onert_micro::train::optimizers::SGD::updateWeights
OMStatus updateWeights(const OMTrainingContext &training_config, core::OMRuntimeContext &context, core::OMRuntimeStorage &storage, std::unordered_map< uint16_t, core::OpTrainableRankType > &tensor_index_to_rank_type_map)
Definition
SGD.cpp:147
onert_micro::train::optimizers::SGD::SGD
SGD(SGD &&)=delete
onert_micro::train::optimizers::SGD::SGD
SGD()=default
onert_micro::train::optimizers::SGD::operator=
SGD & operator=(const SGD &)=delete
onert_micro
Definition
OMMemoryManager.h:26
onert_micro::OMStatus
OMStatus
Definition
OMStatus.h:24
onert_micro::OMTrainingContext
Definition
OMConfig.h:75
onert-micro
onert-micro
include
train
train_optimizers
SGD.h
Generated by
1.9.8