ONE - On-device Neural Engine
Loading...
Searching...
No Matches
IndexEnumerator.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
23#ifndef __NNFW_MISC_TENSOR_INDEX_ENUMERATOR_H__
24#define __NNFW_MISC_TENSOR_INDEX_ENUMERATOR_H__
25
26#include "misc/tensor/Shape.h"
27#include "misc/tensor/Index.h"
28
29namespace nnfw
30{
31namespace misc
32{
33namespace tensor
34{
40{
41public:
46 explicit IndexEnumerator(const Shape &shape) : _shape(shape), _cursor(0), _index(shape.rank())
47 {
48 const uint32_t rank = _shape.rank();
49
50 for (uint32_t axis = 0; axis < rank; ++axis)
51 {
52 _index.at(axis) = 0;
53 }
54
55 for (_cursor = 0; _cursor < rank; ++_cursor)
56 {
57 if (_index.at(_cursor) < _shape.dim(_cursor))
58 {
59 break;
60 }
61 }
62 }
63
64public:
73
74public:
79 bool valid(void) const { return _cursor < _shape.rank(); }
80
81public:
86 const Index &curr(void) const { return _index; }
87
88public:
92 void advance(void)
93 {
94 const uint32_t rank = _shape.rank();
95
96 // Find axis to be updated
97 while ((_cursor < rank) && !(_index.at(_cursor) + 1 < _shape.dim(_cursor)))
98 {
99 ++_cursor;
100 }
101
102 if (_cursor == rank)
103 {
104 return;
105 }
106
107 // Update index
108 _index.at(_cursor) += 1;
109
110 for (uint32_t axis = 0; axis < _cursor; ++axis)
111 {
112 _index.at(axis) = 0;
113 }
114
115 // Update cursor
116 _cursor = 0;
117 }
118
119public:
120 const Shape _shape;
121
122private:
123 uint32_t _cursor;
124 Index _index;
125};
126
127} // namespace tensor
128} // namespace misc
129} // namespace nnfw
130
131#endif // __NNFW_MISC_TENSOR_INDEX_ENUMERATOR_H__
Class to enumerate index of a tensor.
bool valid(void) const
Check if more enumeration is available.
IndexEnumerator(IndexEnumerator &&)=delete
Prevent constructing IndexEnumerator object by using R-value reference.
void advance(void)
Advance index by +1.
IndexEnumerator(const IndexEnumerator &)=delete
Prevent copy constructor.
const Index & curr(void) const
Get the current index to enumerate.
IndexEnumerator(const Shape &shape)
Construct a new IndexEnumerator object.
const Shape _shape
Shape to enumerate.
Class to represent shape of a tensor.
Definition Shape.h:45
int32_t dim(uint32_t n) const
Get specific dimension.
Definition Shape.h:100
uint32_t rank(void) const
Get the rank of this shape.
Definition Shape.h:92
Definition topk_v2.h:30
This file contains nnfw::misc::tensor::Index struct.
This file contains nnfw::misc::tensor::Shape class.
Struct to represent index of each dimension of a tensor.
Definition Index.h:42
int32_t at(uint32_t n) const
Get the index n'th dimension.
Definition Index.h:75