ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Shape.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef __NNFW_CKER_SHAPE_H__
19#define __NNFW_CKER_SHAPE_H__
20
21#include <algorithm>
22#include <cstring>
23#include <cassert>
24#include <vector>
25
26namespace nnfw
27{
28namespace cker
29{
30
31class Shape
32{
33public:
34 // Shapes with dimensions up to 5 are stored directly in the structure, while
35 // larger shapes are separately allocated.
36 static constexpr int kMaxSmallSize = 5;
37
38 Shape &operator=(Shape const &) = delete;
39
40 Shape() : _size(0) {}
41
42 explicit Shape(int dimensions_count) : _size(dimensions_count)
43 {
44 if (dimensions_count > kMaxSmallSize)
45 {
46 _dims_pointer = new int32_t[dimensions_count];
47 }
48 }
49
50 Shape(int shape_size, int32_t value) : _size(0)
51 {
52 Resize(shape_size);
53 for (int i = 0; i < shape_size; ++i)
54 {
55 SetDim(i, value);
56 }
57 }
58
59 Shape(int dimensions_count, const int32_t *dims_data) : _size(0)
60 {
61 ReplaceWith(dimensions_count, dims_data);
62 }
63
64 Shape(const std::initializer_list<int> init_list) : _size(0) { BuildFrom(init_list); }
65
66 // Avoid using this constructor. We should be able to delete it when C++17
67 // rolls out.
68 Shape(Shape const &other) : _size(other.DimensionsCount())
69 {
70 if (_size > kMaxSmallSize)
71 {
72 _dims_pointer = new int32_t[_size];
73 }
74 std::memcpy(DimsData(), other.DimsData(), sizeof(int32_t) * _size);
75 }
76
77 bool operator==(const Shape &comp) const
78 {
79 return this->_size == comp._size &&
80 std::memcmp(DimsData(), comp.DimsData(), _size * sizeof(int32_t)) == 0;
81 }
82
84 {
85 if (_size > kMaxSmallSize)
86 {
87 delete[] _dims_pointer;
88 }
89 }
90
91 inline int32_t DimensionsCount() const { return _size; }
92 inline int32_t Dims(int i) const
93 {
94 assert(i >= 0);
95 assert(i < _size);
96 return _size > kMaxSmallSize ? _dims_pointer[i] : _dims[i];
97 }
98 inline void SetDim(int i, int32_t val)
99 {
100 assert(i >= 0);
101 assert(i < _size);
102 if (_size > kMaxSmallSize)
103 {
104 _dims_pointer[i] = val;
105 }
106 else
107 {
108 _dims[i] = val;
109 }
110 }
111
112 inline int32_t *DimsData() { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
113 inline const int32_t *DimsData() const { return _size > kMaxSmallSize ? _dims_pointer : _dims; }
114 // The caller must ensure that the shape is no bigger than 4-D.
115 inline const int32_t *DimsDataUpTo4D() const { return _dims; }
116
117 inline void Resize(int dimensions_count)
118 {
119 if (_size > kMaxSmallSize)
120 {
121 delete[] _dims_pointer;
122 }
123 _size = dimensions_count;
124 if (dimensions_count > kMaxSmallSize)
125 {
126 _dims_pointer = new int32_t[dimensions_count];
127 }
128 }
129
130 inline void ReplaceWith(int dimensions_count, const int32_t *dims_data)
131 {
132 Resize(dimensions_count);
133 int32_t *dst_dims = DimsData();
134 std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32_t));
135 }
136
137 inline void ReplaceWith(const Shape &other)
138 {
139 ReplaceWith(other.DimensionsCount(), other.DimsData());
140 }
141
142 inline void ReplaceWith(Shape &&other)
143 {
144 Resize(0);
145 std::swap(_size, other._size);
146 if (_size <= kMaxSmallSize)
147 std::copy(other._dims, other._dims + kMaxSmallSize, _dims);
148 else
149 _dims_pointer = other._dims_pointer;
150 }
151
152 template <typename T> inline void BuildFrom(const T &src_iterable)
153 {
154 const int dimensions_count = std::distance(src_iterable.begin(), src_iterable.end());
155 Resize(dimensions_count);
156 int32_t *data = DimsData();
157 for (auto &&it : src_iterable)
158 {
159 *data = it;
160 ++data;
161 }
162 }
163
164 // This will probably be factored out. Old code made substantial use of 4-D
165 // shapes, and so this function is used to extend smaller shapes. Note that
166 // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
167 // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
168 // inputs should already be 4-D, so this function should not be needed.
169 inline static Shape ExtendedShape(int new_shape_size, const Shape &shape)
170 {
171 return Shape(new_shape_size, shape, 1);
172 }
173
174 inline void BuildFrom(const std::initializer_list<int> init_list)
175 {
176 BuildFrom<const std::initializer_list<int>>(init_list);
177 }
178
179 // Returns the total count of elements, that is the size when flattened into a
180 // vector.
181 inline int FlatSize() const
182 {
183 int buffer_size = 1;
184 const int *dims_data = DimsData();
185 for (int i = 0; i < _size; i++)
186 {
187 const int dim = dims_data[i];
188 buffer_size *= dim;
189 }
190 return buffer_size;
191 }
192
193 bool operator!=(const Shape &comp) const { return !((*this) == comp); }
194
195private:
196 // For use only by ExtendedShape(), written to guarantee (return-value) copy
197 // elision in C++17.
198 // This creates a shape padded to the desired size with the specified value.
199 Shape(int new_shape_size, const Shape &shape, int pad_value) : _size(0)
200 {
201 assert(new_shape_size >= shape.DimensionsCount());
202 assert(new_shape_size <= kMaxSmallSize);
203 Resize(new_shape_size);
204 const int size_increase = new_shape_size - shape.DimensionsCount();
205 for (int i = 0; i < size_increase; ++i)
206 {
207 SetDim(i, pad_value);
208 }
209 std::memcpy(DimsData() + size_increase, shape.DimsData(),
210 sizeof(int32_t) * shape.DimensionsCount());
211 }
212
213 int32_t _size;
214 union {
216 int32_t *_dims_pointer{nullptr};
217 };
218};
219
220inline int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &shape2,
221 [[maybe_unused]] int index2)
222{
223 assert(shape1.Dims(index1) == shape2.Dims(index2));
224 return shape1.Dims(index1);
225}
226
227template <typename... Args>
228int MatchingDim(const Shape &shape1, int index1, [[maybe_unused]] const Shape &shape2,
229 [[maybe_unused]] int index2, Args... args)
230{
231 assert(shape1.Dims(index1) == shape2.Dims(index2));
232 return MatchingDim(shape1, index1, args...);
233}
234
235inline Shape GetShape(const std::vector<int32_t> &data) { return Shape(data.size(), data.data()); }
236
237inline int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
238{
239 assert(shape.DimensionsCount() == 4);
240 const int *dims_data = shape.DimsDataUpTo4D();
241 assert(i0 >= 0 && i0 < dims_data[0]);
242 assert(i1 >= 0 && i1 < dims_data[1]);
243 assert(i2 >= 0 && i2 < dims_data[2]);
244 assert(i3 >= 0 && i3 < dims_data[3]);
245 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
246}
247
248inline int Offset(const Shape &shape, int *index)
249{
250 return Offset(shape, index[0], index[1], index[2], index[3]);
251}
252
253inline int FlatSizeSkipDim(const Shape &shape, int skip_dim)
254{
255 const int dims_count = shape.DimensionsCount();
256 assert(skip_dim >= 0 && skip_dim < dims_count);
257 const auto *dims_data = shape.DimsData();
258 int flat_size = 1;
259 for (int i = 0; i < dims_count; ++i)
260 {
261 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
262 }
263 return flat_size;
264}
265
266// Flat size calculation, checking that dimensions match with one or more other
267// arrays.
268template <typename... Ts> inline bool checkMatching(const Shape &shape, Ts... check_shapes)
269{
270 const Shape check_shapes_array[sizeof...(Ts)] = {std::forward<Ts>(check_shapes)...};
271 for (const auto &check_shape : check_shapes_array)
272 {
273 // Check matching of shapes except the case of that two shapes can be scalar
274 if (shape.DimensionsCount() > 1 || check_shape.DimensionsCount() > 1 || shape.FlatSize() != 1 ||
275 check_shape.FlatSize() != 1)
276 {
277 if (shape.DimensionsCount() != check_shape.DimensionsCount())
278 {
279 return false;
280 }
281 for (int i = 0; i < shape.DimensionsCount(); ++i)
282 {
283 if (shape.Dims(i) != check_shape.Dims(i))
284 {
285 return false;
286 }
287 }
288 }
289 }
290 return true;
291}
292
294{
295 template <typename... Args> UNUSED_ALL(Args const &...) {}
296};
297template <typename... Ts> inline int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
298{
299 UNUSED_ALL{check_shapes...};
300 assert(checkMatching(shape, std::forward<Ts>(check_shapes)...));
301 return shape.FlatSize();
302}
303
304inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim,
305 [[maybe_unused]] const Shape &check_shape_0)
306{
307 const int dims_count = shape.DimensionsCount();
308 for (int i = 0; i < dims_count; ++i)
309 {
310 if (i != skip_dim)
311 {
312 assert(shape.Dims(i) == check_shape_0.Dims(i));
313 }
314 }
315 return FlatSizeSkipDim(shape, skip_dim);
316}
317
318inline int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim,
319 [[maybe_unused]] const Shape &check_shape_0,
320 const Shape &check_shape_1)
321{
322 const int dims_count = shape.DimensionsCount();
323 for (int i = 0; i < dims_count; ++i)
324 {
325 if (i != skip_dim)
326 {
327 assert(shape.Dims(i) == check_shape_0.Dims(i));
328 }
329 }
330 return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
331}
332
333inline int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0,
334 const Shape &check_shape_1)
335{
336 const int size_1 = shape.FlatSize();
337 [[maybe_unused]] const int size_2 = check_shape_0.FlatSize();
338 [[maybe_unused]] const int size_3 = check_shape_1.FlatSize();
339 assert(size_1 == size_2);
340 assert(size_2 == size_3);
341 return size_1;
342}
343
344} // namespace cker
345} // namespace nnfw
346
347#endif // __NNFW_CKER_SHAPE_H__
int32_t DimensionsCount() const
Definition Shape.h:91
Shape(const std::initializer_list< int > init_list)
Definition Shape.h:64
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
Definition Shape.h:130
Shape(int shape_size, int32_t value)
Definition Shape.h:50
int32_t Dims(int i) const
Definition Shape.h:92
void ReplaceWith(const Shape &other)
Definition Shape.h:137
void BuildFrom(const T &src_iterable)
Definition Shape.h:152
void ReplaceWith(Shape &&other)
Definition Shape.h:142
static constexpr int kMaxSmallSize
Definition Shape.h:36
Shape(int dimensions_count)
Definition Shape.h:42
bool operator==(const Shape &comp) const
Definition Shape.h:77
Shape(Shape const &other)
Definition Shape.h:68
void BuildFrom(const std::initializer_list< int > init_list)
Definition Shape.h:174
bool operator!=(const Shape &comp) const
Definition Shape.h:193
int FlatSize() const
Definition Shape.h:181
Shape(int dimensions_count, const int32_t *dims_data)
Definition Shape.h:59
int32_t * _dims_pointer
Definition Shape.h:216
void Resize(int dimensions_count)
Definition Shape.h:117
void SetDim(int i, int32_t val)
Definition Shape.h:98
int32_t * DimsData()
Definition Shape.h:112
static Shape ExtendedShape(int new_shape_size, const Shape &shape)
Definition Shape.h:169
const int32_t * DimsData() const
Definition Shape.h:113
Shape & operator=(Shape const &)=delete
const int32_t * DimsDataUpTo4D() const
Definition Shape.h:115
int32_t _dims[kMaxSmallSize]
Definition Shape.h:215
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
Definition Shape.h:220
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
Definition Shape.h:237
int FlatSizeSkipDim(const Shape &shape, int skip_dim)
Definition Shape.h:253
Shape GetShape(const std::vector< int32_t > &data)
Definition Shape.h:235
int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
Definition Shape.h:304
int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0, const Shape &check_shape_1)
Definition Shape.h:333
int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
Definition Shape.h:297
bool checkMatching(const Shape &shape, Ts... check_shapes)
Definition Shape.h:268
Definition topk_v2.h:30
UNUSED_ALL(Args const &...)
Definition Shape.h:295