ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Shape.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_SHAPE_H__
19#define __NNFW_CKER_SHAPE_H__
20
21#include <algorithm>
22#include <array>
23#include <cassert>
24#include <cstring>
25#include <iterator>
26#include <variant>
27#include <vector>
28
29namespace nnfw
30{
31namespace cker
32{
33
34class Shape
35{
36public:
37 // Shapes with dimensions up to 6 are stored directly in the structure, while
38 // larger shapes are separately allocated.
39 static constexpr int kMaxSmallSize = 6;
40
41 // Delete copy assignment operator.
42 Shape &operator=(Shape const &) = delete;
43
44 // Default constructor: initializes an empty shape (size = 0) with small storage.
45 Shape() : _size(0), dims_(std::array<int32_t, kMaxSmallSize>{}) {}
46
47 // Constructor that takes a dimension count.
48 // If dimensions_count <= kMaxSmallSize, it uses a fixed-size array.
49 // Otherwise, it uses a dynamic vector.
50 explicit Shape(int dimensions_count) : _size(dimensions_count) { initStorage(dimensions_count); }
51
52 // Constructor that creates a shape of given size and fills all dimensions with "value".
53 Shape(int shape_size, int32_t value) : _size(shape_size)
54 {
55 initStorage(shape_size);
56 for (int i = 0; i < shape_size; ++i)
57 {
58 SetDim(i, value);
59 }
60 }
61
62 // Constructor that creates a shape from an array of dimension data.
63 Shape(int dimensions_count, const int32_t *dims_data) : _size(dimensions_count)
64 {
65 initStorage(dimensions_count);
66 ReplaceWith(dimensions_count, dims_data);
67 }
68
69 // Initializer list constructor.
70 // Marked explicit to avoid unintended overload resolution.
71 Shape(const std::initializer_list<int> init_list) : _size(0)
72 {
73 const auto size = static_cast<int>(std::distance(init_list.begin(), init_list.end()));
74 initStorage(size);
75 BuildFrom(init_list);
76 }
77
78 // Copy constructor
79 Shape(const Shape &other) : _size(other._size)
80 {
81 if (_size <= kMaxSmallSize)
82 {
83 // When the number of dimensions is small, copy the fixed array.
84 dims_ = std::get<std::array<int32_t, kMaxSmallSize>>(other.dims_);
85 }
86 else
87 {
88 // Otherwise, copy the dynamically allocated vector.
89 dims_ = std::get<std::vector<int32_t>>(other.dims_);
90 }
91 }
92 Shape(Shape &&other) = default;
93
94 bool operator==(const Shape &comp) const
95 {
96 return this->_size == comp._size &&
97 std::memcmp(DimsData(), comp.DimsData(), _size * sizeof(int32_t)) == 0;
98 }
99
100 ~Shape() = default;
101
102 // Returns the number of dimensions.
103 inline int32_t DimensionsCount() const { return _size; }
104
105 // Returns the dimension size at index i.
106 inline int32_t Dims(int i) const
107 {
108 assert(i >= 0 && i < _size);
109 if (_size <= kMaxSmallSize)
110 {
111 return std::get<std::array<int32_t, kMaxSmallSize>>(dims_)[i];
112 }
113 else
114 {
115 return std::get<std::vector<int32_t>>(dims_)[i];
116 }
117 }
118
119 // Sets the dimension at index i.
120 inline void SetDim(int i, int32_t val)
121 {
122 assert(i >= 0 && i < _size);
123 if (_size <= kMaxSmallSize)
124 {
125 std::get<std::array<int32_t, kMaxSmallSize>>(dims_)[i] = val;
126 }
127 else
128 {
129 std::get<std::vector<int32_t>>(dims_)[i] = val;
130 }
131 }
132
133 // Returns a pointer to the dimension data (mutable).
134 inline int32_t *DimsData()
135 {
136 if (_size <= kMaxSmallSize)
137 {
138 return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
139 }
140 else
141 {
142 return std::get<std::vector<int32_t>>(dims_).data();
143 }
144 }
145
146 // Returns a pointer to the dimension data (const).
147 inline const int32_t *DimsData() const
148 {
149 if (_size <= kMaxSmallSize)
150 {
151 return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
152 }
153 else
154 {
155 return std::get<std::vector<int32_t>>(dims_).data();
156 }
157 }
158
159 // The caller must ensure that the shape is no larger than 6D.
160 inline const int32_t *DimsDataUpTo6D() const
161 {
162 return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
163 }
164
165 // Resizes the shape to dimensions_count while preserving existing data.
166 inline void Resize(int dimensions_count)
167 {
168 // If dims_ is in a valueless state (i.e. not yet initialized or lost due to an exception),
169 // initialize dims_ explicitly based on dimensions_count to ensure it is in a valid state.
170 if (dims_.valueless_by_exception())
171 {
172 initStorage(dimensions_count);
173 }
174
175 std::vector<int32_t> oldDims;
176 oldDims.reserve(_size);
177 if (_size <= kMaxSmallSize)
178 {
179 const auto &arr = std::get<std::array<int32_t, kMaxSmallSize>>(dims_);
180 oldDims.assign(arr.begin(), arr.begin() + _size);
181 }
182 else
183 {
184 oldDims = std::get<std::vector<int32_t>>(dims_);
185 }
186
187 int count = std::min(_size, dimensions_count);
188
189 if (dimensions_count <= kMaxSmallSize)
190 {
191 std::array<int32_t, kMaxSmallSize> dims = {};
192 std::copy_n(oldDims.begin(), count, dims.begin());
193 dims_ = dims;
194 }
195 else
196 {
197 std::vector<int32_t> dims(dimensions_count, 0);
198 std::copy_n(oldDims.begin(), count, dims.begin());
199 dims_ = dims;
200 }
201
202 _size = dimensions_count;
203 }
204
205 // Replaces the current shape with a new one defined by dimensions_count and dims_data.
206 inline void ReplaceWith(int dimensions_count, const int32_t *dims_data)
207 {
208 // Allow dims_data to be nullptr when dimensions_count is 0,
209 // because there are no dimensions to copy. For any non-zero dimensions_count,
210 // dims_data must not be nullptr to ensure valid shape data is provided.
211 assert(dimensions_count == 0 || dims_data != nullptr);
212 Resize(dimensions_count);
213 std::memcpy(DimsData(), dims_data, dimensions_count * sizeof(int32_t));
214 }
215
216 // Replaces the current shape with another shape.
217 inline void ReplaceWith(const Shape &other)
218 {
219 ReplaceWith(other.DimensionsCount(), other.DimsData());
220 }
221
222 // Replaces the current shape with another shape using move semantics.
223 inline void ReplaceWith(Shape &&other)
224 {
225 std::swap(_size, other._size);
226 dims_ = std::move(other.dims_);
227 }
228
229 // Builds the shape from an iterable sequence.
230 template <typename Iterable> inline void BuildFrom(const Iterable &src_iterable)
231 {
232 const int dimensions_count =
233 static_cast<int>(std::distance(src_iterable.begin(), src_iterable.end()));
234 Resize(dimensions_count);
235 int32_t *data = DimsData();
236 for (auto it = src_iterable.begin(); it != src_iterable.end(); ++it)
237 {
238 *data++ = static_cast<int32_t>(*it);
239 }
240 }
241
242 // Returns the total count of elements, that is the size when flattened into a
243 // vector.
244 inline static Shape ExtendedShape(int new_shape_size, const Shape &shape)
245 {
246 return Shape(new_shape_size, shape, 1);
247 }
248
249 // Overload for initializer list building.
250 inline void BuildFrom(const std::initializer_list<int> init_list)
251 {
252 BuildFrom<const std::initializer_list<int>>(init_list);
253 }
254
255 // Returns the total count of elements (flattened size).
256 inline int FlatSize() const
257 {
258 int buffer_size = 1;
259 const int *dims_data = DimsData();
260 for (int i = 0; i < _size; i++)
261 {
262 buffer_size *= dims_data[i];
263 }
264 return buffer_size;
265 }
266
267 bool operator!=(const Shape &comp) const { return !((*this) == comp); }
268
269private:
270 // Helper function: initialize dims_ storage based on the number of dimensions.
271 inline void initStorage(int dimensions_count)
272 {
273 assert(dimensions_count >= 0);
274 if (dimensions_count <= kMaxSmallSize)
275 dims_ = std::array<int32_t, kMaxSmallSize>{};
276 else
277 dims_ = std::vector<int32_t>(dimensions_count);
278 }
279
280 // For use only by ExtendedShape(), written to guarantee (return-value) copy
281 // elision in C++17.
282 // This creates a shape padded to the desired size with the specified value.
283 Shape(int new_shape_size, const Shape &shape, int pad_value) : _size(new_shape_size)
284 {
285 assert(new_shape_size >= shape.DimensionsCount());
286 assert(new_shape_size <= kMaxSmallSize);
287 Resize(new_shape_size);
288 const int size_increase = new_shape_size - shape.DimensionsCount();
289 for (int i = 0; i < size_increase; ++i)
290 {
291 SetDim(i, pad_value);
292 }
293 std::memcpy(DimsData() + size_increase, shape.DimsData(),
294 sizeof(int32_t) * shape.DimensionsCount());
295 }
296
297 int32_t _size;
298 // Internal storage: use std::array for shapes with dimensions up to kMaxSmallSize,
299 // and std::vector for larger shapes.
300 std::variant<std::array<int32_t, kMaxSmallSize>, std::vector<int32_t>> dims_;
301};
302
303// Utility functions below.
304
305inline int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &shape2,
306 [[maybe_unused]] int index2)
307{
308 assert(shape1.Dims(index1) == shape2.Dims(index2));
309 return shape1.Dims(index1);
310}
311
312template <typename... Args>
313int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &shape2,
314 [[maybe_unused]] int index2, Args... args)
315{
316 assert(shape1.Dims(index1) == shape2.Dims(index2));
317 return MatchingDim(shape1, index1, args...);
318}
319
320inline Shape GetShape(const std::vector<int32_t> &data)
321{
322 return Shape(static_cast<int>(data.size()), data.data());
323}
324
325inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
326{
327 assert(shape.DimensionsCount() == 4);
328 const int *dims_data = shape.DimsDataUpTo6D();
329 assert(i0 >= 0 && i0 < dims_data[0]);
330 assert(i1 >= 0 && i1 < dims_data[1]);
331 assert(i2 >= 0 && i2 < dims_data[2]);
332 assert(i3 >= 0 && i3 < dims_data[3]);
333 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
334}
335
336inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3, int i4)
337{
338 assert(shape.DimensionsCount() == 5);
339 const int *dim = shape.DimsDataUpTo6D();
340 assert(i0 >= 0 && i0 < dim[0]);
341 assert(i1 >= 0 && i1 < dim[1]);
342 assert(i2 >= 0 && i2 < dim[2]);
343 assert(i3 >= 0 && i3 < dim[3]);
344 assert(i4 >= 0 && i4 < dim[4]);
345 return ((((i0 * dim[1] + i1) * dim[2] + i2) * dim[3] + i3) * dim[4]) + i4;
346}
347
348inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3, int i4, int i5)
349{
350 assert(shape.DimensionsCount() == 6);
351 const int *dim = shape.DimsDataUpTo6D();
352 assert(i0 >= 0 && i0 < dim[0]);
353 assert(i1 >= 0 && i1 < dim[1]);
354 assert(i2 >= 0 && i2 < dim[2]);
355 assert(i3 >= 0 && i3 < dim[3]);
356 assert(i4 >= 0 && i4 < dim[4]);
357 assert(i5 >= 0 && i5 < dim[5]);
358 // clang format off
359 return (((((i0 * dim[1] + i1) * dim[2] + i2) * dim[3] + i3) * dim[4]) + i4) * dim[5] + i5;
360 // clang format on
361}
362
363inline int Offset(const Shape &shape, int *index)
364{
365 return Offset(shape, index[0], index[1], index[2], index[3], index[4], index[5]);
366}
367
368inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
369{
370 const int dims_count = shape.DimensionsCount();
371 assert(skip_dim >= 0 && skip_dim < dims_count);
372 const auto *dims_data = shape.DimsData();
373 int flat_size = 1;
374 for (int i = 0; i < dims_count; ++i)
375 {
376 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
377 }
378 return flat_size;
379}
380
381// Flat size calculation, checking that dimensions match with one or more other shapes.
382template <typename... Ts> inline bool checkMatching(const Shape &shape, Ts... check_shapes)
383{
384 auto match = [&shape](const Shape &s) -> bool {
385 // Check matching of shapes except the case that both shapes are scalars.
386 if (shape.DimensionsCount() > 1 || s.DimensionsCount() > 1 || shape.FlatSize() != 1 ||
387 s.FlatSize() != 1)
388 {
389 if (shape.DimensionsCount() != s.DimensionsCount())
390 {
391 return false;
392 }
393 for (int i = 0; i < shape.DimensionsCount(); ++i)
394 {
395 if (shape.Dims(i) != s.Dims(i))
396 {
397 return false;
398 }
399 }
400 }
401 return true;
402 };
403
404 // Apply the lambda to each check shape and combine with &&
405 return (match(check_shapes) && ...);
406}
407
409{
410 template <typename... Args> UNUSED_ALL(Args const &...) {}
411};
412template <typename... Ts> inline int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
413{
414 UNUSED_ALL{check_shapes...};
415 assert(checkMatching(shape, std::forward<Ts>(check_shapes)...));
416 return shape.FlatSize();
417}
418
419inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim,
420 [[maybe_unused]] const Shape &check_shape_0)
421{
422 const int dims_count = shape.DimensionsCount();
423 for (int i = 0; i < dims_count; ++i)
424 {
425 if (i != skip_dim)
426 {
427 assert(shape.Dims(i) == check_shape_0.Dims(i));
428 }
429 }
430 return FlatSizeSkipDim(shape, skip_dim);
431}
432
433inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim,
434 [[maybe_unused]] const Shape &check_shape_0,
435 const Shape &check_shape_1)
436{
437 const int dims_count = shape.DimensionsCount();
438 for (int i = 0; i < dims_count; ++i)
439 {
440 if (i != skip_dim)
441 {
442 assert(shape.Dims(i) == check_shape_0.Dims(i));
443 }
444 }
445 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
446}
447
448inline int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0,
449 const Shape &check_shape_1)
450{
451 const int size_1 = shape.FlatSize();
452 [[maybe_unused]] const int size_2 = check_shape_0.FlatSize();
453 [[maybe_unused]] const int size_3 = check_shape_1.FlatSize();
454 assert(size_1 == size_2);
455 assert(size_2 == size_3);
456 return size_1;
457}
458
459} // namespace cker
460} // namespace nnfw
461
462#endif // __NNFW_CKER_SHAPE_H__
int32_t DimensionsCount() const
Definition Shape.h:103
Shape(const std::initializer_list< int > init_list)
Definition Shape.h:71
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
Definition Shape.h:206
Shape(int shape_size, int32_t value)
Definition Shape.h:53
int32_t Dims(int i) const
Definition Shape.h:106
const int32_t * DimsDataUpTo6D() const
Definition Shape.h:160
void ReplaceWith(const Shape &other)
Definition Shape.h:217
void BuildFrom(const Iterable &src_iterable)
Definition Shape.h:230
Shape(const Shape &other)
Definition Shape.h:79
void ReplaceWith(Shape &&other)
Definition Shape.h:223
Shape(Shape &&other)=default
static constexpr int kMaxSmallSize
Definition Shape.h:39
Shape(int dimensions_count)
Definition Shape.h:50
bool operator==(const Shape &comp) const
Definition Shape.h:94
void BuildFrom(const std::initializer_list< int > init_list)
Definition Shape.h:250
bool operator!=(const Shape &comp) const
Definition Shape.h:267
int FlatSize() const
Definition Shape.h:256
Shape(int dimensions_count, const int32_t *dims_data)
Definition Shape.h:63
void Resize(int dimensions_count)
Definition Shape.h:166
void SetDim(int i, int32_t val)
Definition Shape.h:120
int32_t * DimsData()
Definition Shape.h:134
static Shape ExtendedShape(int new_shape_size, const Shape &shape)
Definition Shape.h:244
const int32_t * DimsData() const
Definition Shape.h:147
Shape & operator=(Shape const &)=delete
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
Definition Shape.h:305
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
Definition Shape.h:325
int FlatSizeSkipDim(const Shape &shape, int skip_dim)
Definition Shape.h:368
Shape GetShape(const std::vector< int32_t > &data)
Definition Shape.h:320
int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
Definition Shape.h:419
int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0, const Shape &check_shape_1)
Definition Shape.h:448
int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
Definition Shape.h:412
bool checkMatching(const Shape &shape, Ts... check_shapes)
Definition Shape.h:382
Definition topk_v2.h:30
int32_t size[5]
Definition Slice.cpp:35
Definition Shape.h:28
UNUSED_ALL(Args const &...)
Definition Shape.h:410