ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Index.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __ONERT_IR_TRAIN_INDEX_H__
18#define __ONERT_IR_TRAIN_INDEX_H__
19
20#include "ir/Index.h"
21
22#include <cassert>
23#include <cstdint>
24#include <utility>
25
26namespace onert
27{
28namespace ir
29{
30namespace train
31{
32
37template <typename T> class TrainingIndex
38{
39public:
45 TrainingIndex() : _index{T{}}, _is_forward{true}
46 {
47 // DO NOTHING
48 }
49
56 TrainingIndex(const T &index, bool is_forward) : _index{index}, _is_forward{is_forward}
57 {
58 // DO NOTHING
59 }
60
61public:
67 const T &index() const { return _index; }
73 bool is_forward() const { return _is_forward; }
74
75public:
81 bool valid() const { return _index.valid(); }
82
83public:
89 bool operator==(const TrainingIndex &other) const
90 {
91 return (!_index.valid() && !other.index().valid()) ||
92 (_index == other.index() && _is_forward == other.is_forward());
93 }
99 bool operator!=(const TrainingIndex &other) const { return !(*this == other); }
100
106 bool operator<(const TrainingIndex &other) const
107 {
108 return std::hash<TrainingIndex<T>>{}(*this) < std::hash<TrainingIndex<T>>{}(other);
109 }
110
111private:
112 T _index;
113 bool _is_forward;
114};
115
124
133
134inline std::ostream &operator<<(std::ostream &o, const TrainingOperationIndex &i)
135{
136 return operator<<(o, i.index());
137}
138
139inline std::ostream &operator<<(std::ostream &o, const TrainingOperandIndex &i)
140{
141 return operator<<(o, i.index());
142}
143
144} // namespace train
145} // namespace ir
146} // namespace onert
147
148namespace std
149{
150
154template <> struct hash<onert::ir::train::TrainingOperationIndex>
155{
156 size_t operator()(const onert::ir::train::TrainingOperationIndex &index) const noexcept
157 {
158 const auto &op_index = index.index();
159 const bool is_forward = index.is_forward();
160
161 assert(sizeof(op_index) <= 4);
162 assert((op_index.undefined() || op_index.value() < (1 << 16)) &&
163 "TrainingOperationIndex's hash creation error, operand_index is too big");
164 static_assert(
165 sizeof(size_t) >= sizeof(uint32_t),
166 "TrainingOperationIndex's hash creation error, size_t size is less than uint32_t");
167
168 return (static_cast<size_t>(op_index.value())) << 16 | static_cast<size_t>(is_forward);
169 }
170};
171
172} // namespace std
173
174namespace std
175{
176
180template <> struct hash<onert::ir::train::TrainingOperandIndex>
181{
182 size_t operator()(const onert::ir::train::TrainingOperandIndex &index) const noexcept
183 {
184 const auto &operand_index = index.index();
185 const bool &is_forward = index.is_forward();
186
187 assert(sizeof(operand_index) <= 4);
188 assert((operand_index.undefined() || operand_index.value() < (1 << 16)) &&
189 "TrainingOperandIndex's hash creation error, operand_index is too big");
190 static_assert(sizeof(size_t) >= sizeof(uint32_t),
191 "TrainingOperandIndex's hash creation error, size_t size is less than uint32_t");
192
193 return (static_cast<size_t>(operand_index.value())) << 16 | static_cast<size_t>(is_forward);
194 }
195};
196
197} // namespace std
198
199#endif // __ONERT_IR_TRAIN_INDEX_H__
Class that provides index of tensor for training.
Definition Index.h:38
TrainingIndex(const T &index, bool is_forward)
Construct TrainingOperationIndex object.
Definition Index.h:56
TrainingIndex()
Construct TrainingOperationIndex object.
Definition Index.h:45
bool operator<(const TrainingIndex &other) const
operator overloading function for <
Definition Index.h:106
bool is_forward() const
Get whether the tensor is forward tensor or not.
Definition Index.h:73
bool operator==(const TrainingIndex &other) const
operator overloading function for ==
Definition Index.h:89
const T & index() const
Get index.
Definition Index.h:67
bool operator!=(const TrainingIndex &other) const
operator overloading function for !=
Definition Index.h:99
bool valid() const
Check if the index is valid or not.
Definition Index.h:81
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
std::ostream & operator<<(std::ostream &o, const TrainingOperationIndex &i)
Definition Index.h:134
size_t operator()(const onert::ir::train::TrainingOperandIndex &index) const noexcept
Definition Index.h:182