18#ifndef __NNFW_CKER_SHAPE_H__
19#define __NNFW_CKER_SHAPE_H__
50 explicit Shape(
int dimensions_count) : _size(dimensions_count) { initStorage(dimensions_count); }
53 Shape(
int shape_size, int32_t value) : _size(shape_size)
55 initStorage(shape_size);
56 for (
int i = 0; i < shape_size; ++i)
63 Shape(
int dimensions_count,
const int32_t *dims_data) : _size(dimensions_count)
65 initStorage(dimensions_count);
71 Shape(
const std::initializer_list<int> init_list) : _size(0)
73 const auto size =
static_cast<int>(std::distance(init_list.begin(), init_list.end()));
84 dims_ = std::get<std::array<int32_t, kMaxSmallSize>>(other.dims_);
89 dims_ = std::get<std::vector<int32_t>>(other.dims_);
96 return this->_size == comp._size &&
106 inline int32_t
Dims(
int i)
const
108 assert(i >= 0 && i < _size);
111 return std::get<std::array<int32_t, kMaxSmallSize>>(dims_)[i];
115 return std::get<std::vector<int32_t>>(dims_)[i];
122 assert(i >= 0 && i < _size);
125 std::get<std::array<int32_t, kMaxSmallSize>>(dims_)[i] = val;
129 std::get<std::vector<int32_t>>(dims_)[i] = val;
138 return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
142 return std::get<std::vector<int32_t>>(dims_).data();
151 return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
155 return std::get<std::vector<int32_t>>(dims_).data();
162 return std::get<std::array<int32_t, kMaxSmallSize>>(dims_).data();
170 if (dims_.valueless_by_exception())
172 initStorage(dimensions_count);
175 std::vector<int32_t> oldDims;
176 oldDims.reserve(_size);
179 const auto &arr = std::get<std::array<int32_t, kMaxSmallSize>>(dims_);
180 oldDims.assign(arr.begin(), arr.begin() + _size);
184 oldDims = std::get<std::vector<int32_t>>(dims_);
187 int count = std::min(_size, dimensions_count);
191 std::array<int32_t, kMaxSmallSize> dims = {};
192 std::copy_n(oldDims.begin(), count, dims.begin());
197 std::vector<int32_t> dims(dimensions_count, 0);
198 std::copy_n(oldDims.begin(), count, dims.begin());
202 _size = dimensions_count;
206 inline void ReplaceWith(
int dimensions_count,
const int32_t *dims_data)
211 assert(dimensions_count == 0 || dims_data !=
nullptr);
213 std::memcpy(
DimsData(), dims_data, dimensions_count *
sizeof(int32_t));
225 std::swap(_size, other._size);
226 dims_ = std::move(other.dims_);
230 template <
typename Iterable>
inline void BuildFrom(
const Iterable &src_iterable)
232 const int dimensions_count =
233 static_cast<int>(std::distance(src_iterable.begin(), src_iterable.end()));
236 for (
auto it = src_iterable.begin(); it != src_iterable.end(); ++it)
238 *data++ =
static_cast<int32_t
>(*it);
246 return Shape(new_shape_size, shape, 1);
250 inline void BuildFrom(
const std::initializer_list<int> init_list)
252 BuildFrom<const std::initializer_list<int>>(init_list);
260 for (
int i = 0; i < _size; i++)
262 buffer_size *= dims_data[i];
271 inline void initStorage(
int dimensions_count)
273 assert(dimensions_count >= 0);
275 dims_ = std::array<int32_t, kMaxSmallSize>{};
277 dims_ = std::vector<int32_t>(dimensions_count);
283 Shape(
int new_shape_size,
const Shape &shape,
int pad_value) : _size(new_shape_size)
285 assert(new_shape_size >= shape.DimensionsCount());
288 const int size_increase = new_shape_size - shape.DimensionsCount();
289 for (
int i = 0; i < size_increase; ++i)
293 std::memcpy(
DimsData() + size_increase, shape.DimsData(),
294 sizeof(int32_t) * shape.DimensionsCount());
300 std::variant<std::array<int32_t, kMaxSmallSize>, std::vector<int32_t>> dims_;
306 [[maybe_unused]]
int index2)
308 assert(shape1.
Dims(index1) == shape2.Dims(index2));
309 return shape1.
Dims(index1);
312template <
typename... Args>
314 [[maybe_unused]]
int index2, Args... args)
316 assert(shape1.
Dims(index1) == shape2.Dims(index2));
322 return Shape(
static_cast<int>(data.size()), data.data());
325inline int Offset(
const Shape &shape,
int i0,
int i1,
int i2,
int i3)
329 assert(i0 >= 0 && i0 < dims_data[0]);
330 assert(i1 >= 0 && i1 < dims_data[1]);
331 assert(i2 >= 0 && i2 < dims_data[2]);
332 assert(i3 >= 0 && i3 < dims_data[3]);
333 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
336inline int Offset(
const Shape &shape,
int i0,
int i1,
int i2,
int i3,
int i4)
340 assert(i0 >= 0 && i0 < dim[0]);
341 assert(i1 >= 0 && i1 < dim[1]);
342 assert(i2 >= 0 && i2 < dim[2]);
343 assert(i3 >= 0 && i3 < dim[3]);
344 assert(i4 >= 0 && i4 < dim[4]);
345 return ((((i0 * dim[1] + i1) * dim[2] + i2) * dim[3] + i3) * dim[4]) + i4;
348inline int Offset(
const Shape &shape,
int i0,
int i1,
int i2,
int i3,
int i4,
int i5)
352 assert(i0 >= 0 && i0 < dim[0]);
353 assert(i1 >= 0 && i1 < dim[1]);
354 assert(i2 >= 0 && i2 < dim[2]);
355 assert(i3 >= 0 && i3 < dim[3]);
356 assert(i4 >= 0 && i4 < dim[4]);
357 assert(i5 >= 0 && i5 < dim[5]);
359 return (((((i0 * dim[1] + i1) * dim[2] + i2) * dim[3] + i3) * dim[4]) + i4) * dim[5] + i5;
365 return Offset(shape, index[0], index[1], index[2], index[3], index[4], index[5]);
371 assert(skip_dim >= 0 && skip_dim < dims_count);
372 const auto *dims_data = shape.
DimsData();
374 for (
int i = 0; i < dims_count; ++i)
376 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
384 auto match = [&shape](
const Shape &s) ->
bool {
395 if (shape.
Dims(i) != s.Dims(i))
405 return (match(check_shapes) && ...);
415 assert(
checkMatching(shape, std::forward<Ts>(check_shapes)...));
420 [[maybe_unused]]
const Shape &check_shape_0)
423 for (
int i = 0; i < dims_count; ++i)
427 assert(shape.
Dims(i) == check_shape_0.Dims(i));
434 [[maybe_unused]]
const Shape &check_shape_0,
435 const Shape &check_shape_1)
438 for (
int i = 0; i < dims_count; ++i)
442 assert(shape.
Dims(i) == check_shape_0.Dims(i));
449 const Shape &check_shape_1)
451 const int size_1 = shape.
FlatSize();
452 [[maybe_unused]]
const int size_2 = check_shape_0.
FlatSize();
453 [[maybe_unused]]
const int size_3 = check_shape_1.
FlatSize();
454 assert(size_1 == size_2);
455 assert(size_2 == size_3);
int32_t DimensionsCount() const
Shape(const std::initializer_list< int > init_list)
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
Shape(int shape_size, int32_t value)
int32_t Dims(int i) const
const int32_t * DimsDataUpTo6D() const
void ReplaceWith(const Shape &other)
void BuildFrom(const Iterable &src_iterable)
Shape(const Shape &other)
void ReplaceWith(Shape &&other)
Shape(Shape &&other)=default
static constexpr int kMaxSmallSize
Shape(int dimensions_count)
bool operator==(const Shape &comp) const
void BuildFrom(const std::initializer_list< int > init_list)
bool operator!=(const Shape &comp) const
Shape(int dimensions_count, const int32_t *dims_data)
void Resize(int dimensions_count)
void SetDim(int i, int32_t val)
static Shape ExtendedShape(int new_shape_size, const Shape &shape)
const int32_t * DimsData() const
Shape & operator=(Shape const &)=delete
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
int FlatSizeSkipDim(const Shape &shape, int skip_dim)
Shape GetShape(const std::vector< int32_t > &data)
int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0, const Shape &check_shape_1)
int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
bool checkMatching(const Shape &shape, Ts... check_shapes)
UNUSED_ALL(Args const &...)