17#ifndef __ONERT_IR_TRAIN_INDEX_H__
18#define __ONERT_IR_TRAIN_INDEX_H__
67 const T &
index()
const {
return _index; }
81 bool valid()
const {
return _index.valid(); }
91 return (!_index.valid() && !other.
index().valid()) ||
108 return std::hash<TrainingIndex<T>>{}(*this) < std::hash<TrainingIndex<T>>{}(other);
154template <>
struct hash<
onert::ir::train::TrainingOperationIndex>
158 const auto &op_index =
index.index();
159 const bool is_forward =
index.is_forward();
161 assert(
sizeof(op_index) <= 4);
162 assert((op_index.undefined() || op_index.value() < (1 << 16)) &&
163 "TrainingOperationIndex's hash creation error, operand_index is too big");
165 sizeof(size_t) >=
sizeof(uint32_t),
166 "TrainingOperationIndex's hash creation error, size_t size is less than uint32_t");
168 return (
static_cast<size_t>(op_index.value())) << 16 |
static_cast<size_t>(is_forward);
180template <>
struct hash<
onert::ir::train::TrainingOperandIndex>
184 const auto &operand_index = index.index();
185 const bool &is_forward = index.is_forward();
187 assert(
sizeof(operand_index) <= 4);
188 assert((operand_index.undefined() || operand_index.value() < (1 << 16)) &&
189 "TrainingOperandIndex's hash creation error, operand_index is too big");
190 static_assert(
sizeof(size_t) >=
sizeof(uint32_t),
191 "TrainingOperandIndex's hash creation error, size_t size is less than uint32_t");
193 return (
static_cast<size_t>(operand_index.value())) << 16 |
static_cast<size_t>(is_forward);
Class that provides index of tensor for training.
TrainingIndex(const T &index, bool is_forward)
Construct TrainingOperationIndex object.
TrainingIndex()
Construct TrainingOperationIndex object.
bool operator<(const TrainingIndex &other) const
operator overloading function for <
bool is_forward() const
Get whether the tensor is forward tensor or not.
bool operator==(const TrainingIndex &other) const
operator overloading function for ==
const T & index() const
Get index.
bool operator!=(const TrainingIndex &other) const
operator overloading function for !=
bool valid() const
Check if the index is valid or not.
loco::GraphInputIndex index(const TFPlaceholder *node)
std::ostream & operator<<(std::ostream &o, const TrainingOperationIndex &i)
size_t operator()(const onert::ir::train::TrainingOperandIndex &index) const noexcept