17#ifndef __ONERT_IR_TRAIN_INDEX_H__
18#define __ONERT_IR_TRAIN_INDEX_H__
63 const T &
index()
const {
return _index; }
77 bool valid()
const {
return _index.valid(); }
87 return (!_index.valid() && !other.
index().valid()) ||
104 return std::hash<TrainingIndex<T>>{}(*this) < std::hash<TrainingIndex<T>>{}(other);
152 const auto &op_index = index.index();
153 const bool is_forward = index.is_forward();
155 assert(
sizeof(op_index) <= 4);
156 assert((op_index.undefined() || op_index.value() < (1 << 16)) &&
157 "TrainingOperationIndex's hash creation error, operand_index is too big");
159 sizeof(size_t) >=
sizeof(uint32_t),
160 "TrainingOperationIndex's hash creation error, size_t size is less than uint32_t");
162 return (
static_cast<size_t>(op_index.value())) << 16 |
static_cast<size_t>(is_forward);
174template <>
struct hash<
onert::ir::train::TrainingOperandIndex>
178 const auto &operand_index = index.index();
179 const bool &is_forward = index.is_forward();
181 assert(
sizeof(operand_index) <= 4);
182 assert((operand_index.undefined() || operand_index.value() < (1 << 16)) &&
183 "TrainingOperandIndex's hash creation error, operand_index is too big");
184 static_assert(
sizeof(size_t) >=
sizeof(uint32_t),
185 "TrainingOperandIndex's hash creation error, size_t size is less than uint32_t");
187 return (
static_cast<size_t>(operand_index.value())) << 16 |
static_cast<size_t>(is_forward);
174template <>
struct hash<
onert::ir::train::TrainingOperandIndex> {
…};
Class that provides index of tensor for training.
TrainingIndex(const T &index, bool is_forward)
Construct TrainingOperationIndex object.
TrainingIndex()
Construct TrainingOperationIndex object.
bool operator<(const TrainingIndex &other) const
operator overloading function for <
bool is_forward() const
Get whether the tensor is forward tensor or not.
bool operator==(const TrainingIndex &other) const
operator overloading function for ==
const T & index() const
Get index.
bool operator!=(const TrainingIndex &other) const
operator overloading function for !=
bool valid() const
Check if the index is valid or not.
std::ostream & operator<<(std::ostream &o, const TrainingOperationIndex &i)
size_t operator()(const onert::ir::train::TrainingOperandIndex &index) const noexcept