ONE - On-device Neural Engine
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
Index.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __ONERT_IR_TRAIN_INDEX_H__
18#define __ONERT_IR_TRAIN_INDEX_H__
19
20#include "ir/Index.h"
21
22#include <cassert>
23#include <cstdint>
24#include <utility>
25
26namespace onert::ir::train
27{
28
33template <typename T> class TrainingIndex
34{
35public:
41 TrainingIndex() : _index{T{}}, _is_forward{true}
42 {
43 // DO NOTHING
44 }
45
52 TrainingIndex(const T &index, bool is_forward) : _index{index}, _is_forward{is_forward}
53 {
54 // DO NOTHING
55 }
56
57public:
63 const T &index() const { return _index; }
69 bool is_forward() const { return _is_forward; }
70
71public:
77 bool valid() const { return _index.valid(); }
78
79public:
85 bool operator==(const TrainingIndex &other) const
86 {
87 return (!_index.valid() && !other.index().valid()) ||
88 (_index == other.index() && _is_forward == other.is_forward());
89 }
95 bool operator!=(const TrainingIndex &other) const { return !(*this == other); }
96
102 bool operator<(const TrainingIndex &other) const
103 {
104 return std::hash<TrainingIndex<T>>{}(*this) < std::hash<TrainingIndex<T>>{}(other);
105 }
106
107private:
108 T _index;
109 bool _is_forward;
110};
111
120
129
130inline std::ostream &operator<<(std::ostream &o, const TrainingOperationIndex &i)
131{
132 return operator<<(o, i.index());
133}
134
135inline std::ostream &operator<<(std::ostream &o, const TrainingOperandIndex &i)
136{
137 return operator<<(o, i.index());
138}
139
140} // namespace onert::ir::train
141
142namespace std
143{
144
148template <> struct hash<onert::ir::train::TrainingOperationIndex>
149{
150 size_t operator()(const onert::ir::train::TrainingOperationIndex &index) const noexcept
151 {
152 const auto &op_index = index.index();
153 const bool is_forward = index.is_forward();
154
155 assert(sizeof(op_index) <= 4);
156 assert((op_index.undefined() || op_index.value() < (1 << 16)) &&
157 "TrainingOperationIndex's hash creation error, operand_index is too big");
158 static_assert(
159 sizeof(size_t) >= sizeof(uint32_t),
160 "TrainingOperationIndex's hash creation error, size_t size is less than uint32_t");
161
162 return (static_cast<size_t>(op_index.value())) << 16 | static_cast<size_t>(is_forward);
163 }
164};
165
166} // namespace std
167
168namespace std
169{
170
174template <> struct hash<onert::ir::train::TrainingOperandIndex>
175{
176 size_t operator()(const onert::ir::train::TrainingOperandIndex &index) const noexcept
177 {
178 const auto &operand_index = index.index();
179 const bool &is_forward = index.is_forward();
180
181 assert(sizeof(operand_index) <= 4);
182 assert((operand_index.undefined() || operand_index.value() < (1 << 16)) &&
183 "TrainingOperandIndex's hash creation error, operand_index is too big");
184 static_assert(sizeof(size_t) >= sizeof(uint32_t),
185 "TrainingOperandIndex's hash creation error, size_t size is less than uint32_t");
186
187 return (static_cast<size_t>(operand_index.value())) << 16 | static_cast<size_t>(is_forward);
188 }
189};
190
191} // namespace std
192
193#endif // __ONERT_IR_TRAIN_INDEX_H__
Class that provides index of tensor for training.
Definition Index.h:34
TrainingIndex(const T &index, bool is_forward)
Construct TrainingOperationIndex object.
Definition Index.h:52
TrainingIndex()
Construct TrainingOperationIndex object.
Definition Index.h:41
bool operator<(const TrainingIndex &other) const
operator overloading function for <
Definition Index.h:102
bool is_forward() const
Get whether the tensor is forward tensor or not.
Definition Index.h:69
bool operator==(const TrainingIndex &other) const
operator overloading function for ==
Definition Index.h:85
const T & index() const
Get index.
Definition Index.h:63
bool operator!=(const TrainingIndex &other) const
operator overloading function for !=
Definition Index.h:95
bool valid() const
Check if the index is valid or not.
Definition Index.h:77
std::ostream & operator<<(std::ostream &o, const TrainingOperationIndex &i)
Definition Index.h:130
size_t operator()(const onert::ir::train::TrainingOperandIndex &index) const noexcept
Definition Index.h:176