ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Shape.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_RUY_SHAPE_H__
19#define __NNFW_RUY_SHAPE_H__
20
21#include <algorithm>
22#include <cstring>
23#include <cassert>
24#include <vector>
25
26namespace nnfw
27{
28namespace ruy
29{
30
31class Shape
32{
33public:
34 // Shapes with dimensions up to 5 are stored directly in the structure, while
35 // larger shapes are separately allocated.
36 static constexpr int kMaxSmallSize = 5;
37
38 Shape &operator=(Shape const &) = delete;
39
40 Shape() : _size(0) {}
41
42 explicit Shape(int dimensions_count) : _size(dimensions_count)
43 {
44 if (dimensions_count > kMaxSmallSize)
45 {
46 _dims_pointer = new int32_t[dimensions_count];
47 }
48 }
49
50 Shape(int shape_size, int32_t value) : _size(0)
51 {
52 Resize(shape_size);
53 for (int i = 0; i < shape_size; ++i)
54 {
55 SetDim(i, value);
56 }
57 }
58
59 Shape(int dimensions_count, const int32_t *dims_data) : _size(0)
60 {
61 ReplaceWith(dimensions_count, dims_data);
62 }
63
64 Shape(const std::initializer_list<int> init_list) : _size(0) { BuildFrom(init_list); }
65
66 // Avoid using this constructor. We should be able to delete it when C++17
67 // rolls out.
68 Shape(Shape const &other) : _size(other.DimensionsCount())
69 {
70 if (_size > kMaxSmallSize)
71 {
72 _dims_pointer = new int32_t[_size];
73 }
74 std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * _size);
75 }
76
77 bool operator==(const Shape &comp) const
78 {
79 return this->_size == comp._size &&
80 std::memcmp(DimsData(), comp.DimsData(), _size * sizeof(int32_t)) == 0;
81 }
82
84 {
85 if (_size > kMaxSmallSize)
86 {
87 delete[] _dims_pointer;
88 }
89 }
90
91 inline int32_t DimensionsCount() const { return _size; }
92 inline int32_t Dims(int i) const
93 {
94 assert(i >= 0);
95 assert(i < _size);
96 return _size > kMaxSmallSize ? _dims_pointer[i] : _dims[i];
97 }
98 inline void SetDim(int i, int32_t val)
99 {
100 assert(i >= 0);
101 assert(i < _size);
102 if (_size > kMaxSmallSize)
103 {
104 _dims_pointer[i] = val;
105 }
106 else
107 {
108 _dims[i] = val;
109 }
110 }
111
112 inline int32_t *DimsData() { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
113 inline const int32_t *DimsData() const { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
114 // The caller must ensure that the shape is no bigger than 4-D.
115 inline const int32_t *DimsDataUpTo4D() const { return _dims; }
116
117 inline void Resize(int dimensions_count)
118 {
119 if (_size > kMaxSmallSize)
120 {
121 delete[] _dims_pointer;
122 }
123 _size = dimensions_count;
124 if (dimensions_count > kMaxSmallSize)
125 {
126 _dims_pointer = new int32_t[dimensions_count];
127 }
128 }
129
130 inline void ReplaceWith(int dimensions_count, const int32_t *dims_data)
131 {
132 Resize(dimensions_count);
133 int32_t *dst_dims = DimsData();
134 std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
135 }
136
137 inline void ReplaceWith(const Shape &other)
138 {
139 ReplaceWith(other.DimensionsCount(), other.DimsData());
140 }
141
142 inline void ReplaceWith(Shape &&other)
143 {
144 Resize(0);
145 std::swap(_size, other._size);
146 if (_size <= kMaxSmallSize)
147 std::copy(other._dims, other._dims + kMaxSmallSize, _dims);
148 else
149 _dims_pointer = other._dims_pointer;
150 }
151
152 template <typename T> inline void BuildFrom(const T &src_iterable)
153 {
154 const int dimensions_count = std::distance(src_iterable.begin(), src_iterable.end());
155 Resize(dimensions_count);
156 int32_t *data = DimsData();
157 for (auto &&it : src_iterable)
158 {
159 *data = it;
160 ++data;
161 }
162 }
163
164 // This will probably be factored out. Old code made substantial use of 4-D
165 // shapes, and so this function is used to extend smaller shapes. Note that
166 // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
167 // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
168 // inputs should already be 4-D, so this function should not be needed.
169 inline static Shape ExtendedShape(int new_shape_size, const Shape &shape)
170 {
171 return Shape(new_shape_size, shape, 1);
172 }
173
174 inline void BuildFrom(const std::initializer_list<int> init_list)
175 {
176 BuildFrom<const std::initializer_list<int>>(init_list);
177 }
178
179 // Returns the total count of elements, that is the size when flattened into a
180 // vector.
181 inline int FlatSize() const
182 {
183 int buffer_size = 1;
184 const int *dims_data = DimsData();
185 for (int i = 0; i < _size; i++)
186 {
187 const int dim = dims_data[i];
188 assert(dim >= 1);
189 buffer_size *= dim;
190 }
191 return buffer_size;
192 }
193
194 bool operator!=(const Shape &comp) const { return !((*this) == comp); }
195
196private:
197 // For use only by ExtendedShape(), written to guarantee (return-value) copy
198 // elision in C++17.
199 // This creates a shape padded to the desired size with the specified value.
200 Shape(int new_shape_size, const Shape &shape, int pad_value) : _size(0)
201 {
202 assert(new_shape_size >= shape.DimensionsCount());
203 assert(new_shape_size <= kMaxSmallSize);
204 Resize(new_shape_size);
205 const int size_increase = new_shape_size - shape.DimensionsCount();
206 for (int i = 0; i < size_increase; ++i)
207 {
208 SetDim(i, pad_value);
209 }
210 std::memcpy(DimsData() + size_increase, shape.DimsData(),
211 sizeof(int32_t) * shape.DimensionsCount());
212 }
213
214 int32_t _size;
215 union {
217 int32_t *_dims_pointer{nullptr};
218 };
219};
220
221inline int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &shape2,
222 [[maybe_unused]] int index2)
223{
224 assert(shape1.Dims(index1) == shape2.Dims(index2));
225 return shape1.Dims(index1);
226}
227
228template <typename... Args>
229int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &shape2,
230 [[maybe_unused]] int index2, Args... args)
231{
232 assert(shape1.Dims(index1) == shape2.Dims(index2));
233 return MatchingDim(shape1, index1, args...);
234}
235
236inline Shape GetShape(const std::vector<int32_t> &data) { return Shape(data.size(), data.data()); }
237
238inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
239{
240 assert(shape.DimensionsCount() == 4);
241 const int *dims_data = shape.DimsDataUpTo4D();
242 assert(i0 >= 0 && i0 < dims_data[0]);
243 assert(i1 >= 0 && i1 < dims_data[1]);
244 assert(i2 >= 0 && i2 < dims_data[2]);
245 assert(i3 >= 0 && i3 < dims_data[3]);
246 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
247}
248
249inline int Offset(const Shape &shape, int *index)
250{
251 return Offset(shape, index[0], index[1], index[2], index[3]);
252}
253
254inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
255{
256 const int dims_count = shape.DimensionsCount();
257 assert(skip_dim >= 0 && skip_dim < dims_count);
258 const auto *dims_data = shape.DimsData();
259 int flat_size = 1;
260 for (int i = 0; i < dims_count; ++i)
261 {
262 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
263 }
264 return flat_size;
265}
266
267// Flat size calculation, checking that dimensions match with one or more other
268// arrays.
269template <typename... Ts> inline bool checkMatching(const Shape &shape, Ts... check_shapes)
270{
271 const Shape check_shapes_array[sizeof...(Ts)] = {std::forward<Ts>(check_shapes)...};
272 for (const auto &check_shape : check_shapes_array)
273 {
274 // Check matching of shapes except the case of that two shapes can be scalar
275 if (shape.DimensionsCount() > 1 || check_shape.DimensionsCount() > 1 || shape.FlatSize() != 1 ||
276 check_shape.FlatSize() != 1)
277 {
278 if (shape.DimensionsCount() != check_shape.DimensionsCount())
279 {
280 return false;
281 }
282 for (int i = 0; i < shape.DimensionsCount(); ++i)
283 {
284 if (shape.Dims(i) != check_shape.Dims(i))
285 {
286 return false;
287 }
288 }
289 }
290 }
291 return true;
292}
293
295{
296 template <typename... Args> UNUSED_ALL(Args const &...) {}
297};
298template <typename... Ts> inline int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
299{
300 UNUSED_ALL{check_shapes...};
301 assert(checkMatching(shape, std::forward<Ts>(check_shapes)...));
302 return shape.FlatSize();
303}
304
305inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim,
306 [[maybe_unused]] const Shape &check_shape_0)
307{
308 const int dims_count = shape.DimensionsCount();
309 for (int i = 0; i < dims_count; ++i)
310 {
311 if (i != skip_dim)
312 {
313 assert(shape.Dims(i) == check_shape_0.Dims(i));
314 }
315 }
316 return FlatSizeSkipDim(shape, skip_dim);
317}
318
319inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim,
320 [[maybe_unused]] const Shape &check_shape_0,
321 const Shape &check_shape_1)
322{
323 const int dims_count = shape.DimensionsCount();
324 for (int i = 0; i < dims_count; ++i)
325 {
326 if (i != skip_dim)
327 {
328 assert(shape.Dims(i) == check_shape_0.Dims(i));
329 }
330 }
331 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
332}
333
334inline int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0,
335 const Shape &check_shape_1)
336{
337 const int size_1 = shape.FlatSize();
338 [[maybe_unused]] const int size_2 = check_shape_0.FlatSize();
339 [[maybe_unused]] const int size_3 = check_shape_1.FlatSize();
340 assert(size_1 == size_2);
341 assert(size_2 == size_3);
342 return size_1;
343}
344
345} // namespace ruy
346} // namespace nnfw
347
348#endif // __NNFW_RUY_SHAPE_H__
int32_t * DimsData()
Definition Shape.h:112
int FlatSize() const
Definition Shape.h:181
int32_t _dims[kMaxSmallSize]
Definition Shape.h:216
bool operator==(const Shape &comp) const
Definition Shape.h:77
Shape(int dimensions_count, const int32_t *dims_data)
Definition Shape.h:59
void Resize(int dimensions_count)
Definition Shape.h:117
void BuildFrom(const T &src_iterable)
Definition Shape.h:152
void BuildFrom(const std::initializer_list< int > init_list)
Definition Shape.h:174
Shape(Shape const &other)
Definition Shape.h:68
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
Definition Shape.h:130
void ReplaceWith(const Shape &other)
Definition Shape.h:137
Shape(int shape_size, int32_t value)
Definition Shape.h:50
Shape(const std::initializer_list< int > init_list)
Definition Shape.h:64
const int32_t * DimsDataUpTo4D() const
Definition Shape.h:115
int32_t DimensionsCount() const
Definition Shape.h:91
Shape & operator=(Shape const &)=delete
int32_t Dims(int i) const
Definition Shape.h:92
int32_t * _dims_pointer
Definition Shape.h:217
void ReplaceWith(Shape &&other)
Definition Shape.h:142
static Shape ExtendedShape(int new_shape_size, const Shape &shape)
Definition Shape.h:169
Shape(int dimensions_count)
Definition Shape.h:42
const int32_t * DimsData() const
Definition Shape.h:113
bool operator!=(const Shape &comp) const
Definition Shape.h:194
static constexpr int kMaxSmallSize
Definition Shape.h:36
void SetDim(int i, int32_t val)
Definition Shape.h:98
Shape GetShape(const std::vector< int32_t > &data)
Definition Shape.h:236
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
Definition Shape.h:238
int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
Definition Shape.h:305
int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0, const Shape &check_shape_1)
Definition Shape.h:334
int FlatSizeSkipDim(const Shape &shape, int skip_dim)
Definition Shape.h:254
bool checkMatching(const Shape &shape, Ts... check_shapes)
Definition Shape.h:269
int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
Definition Shape.h:298
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
Definition Shape.h:221
Definition topk_v2.h:30
UNUSED_ALL(Args const &...)
Definition Shape.h:296