18#ifndef __NNFW_CKER_SHAPE_H__
19#define __NNFW_CKER_SHAPE_H__
42 explicit Shape(
int dimensions_count) : _size(dimensions_count)
50 Shape(
int shape_size, int32_t value) : _size(0)
53 for (
int i = 0; i < shape_size; ++i)
59 Shape(
int dimensions_count,
const int32_t *dims_data) : _size(0)
64 Shape(
const std::initializer_list<int> init_list) : _size(0) {
BuildFrom(init_list); }
79 return this->_size == comp._size &&
92 inline int32_t
Dims(
int i)
const
98 inline void SetDim(
int i, int32_t val)
123 _size = dimensions_count;
130 inline void ReplaceWith(
int dimensions_count,
const int32_t *dims_data)
134 std::memcpy(dst_dims, dims_data, dimensions_count *
sizeof(int32_t));
145 std::swap(_size, other._size);
152 template <
typename T>
inline void BuildFrom(
const T &src_iterable)
154 const int dimensions_count = std::distance(src_iterable.begin(), src_iterable.end());
157 for (
auto &&it : src_iterable)
171 return Shape(new_shape_size, shape, 1);
174 inline void BuildFrom(
const std::initializer_list<int> init_list)
176 BuildFrom<const std::initializer_list<int>>(init_list);
185 for (
int i = 0; i < _size; i++)
187 const int dim = dims_data[i];
199 Shape(
int new_shape_size,
const Shape &shape,
int pad_value) : _size(0)
205 for (
int i = 0; i < size_increase; ++i)
221 [[maybe_unused]]
int index2)
223 assert(shape1.
Dims(index1) == shape2.Dims(index2));
224 return shape1.
Dims(index1);
227template <
typename... Args>
229 [[maybe_unused]]
int index2, Args... args)
231 assert(shape1.
Dims(index1) == shape2.Dims(index2));
237inline int Offset(
const Shape &shape,
int i0,
int i1,
int i2,
int i3)
241 assert(i0 >= 0 && i0 < dims_data[0]);
242 assert(i1 >= 0 && i1 < dims_data[1]);
243 assert(i2 >= 0 && i2 < dims_data[2]);
244 assert(i3 >= 0 && i3 < dims_data[3]);
245 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
250 return Offset(shape, index[0], index[1], index[2], index[3]);
256 assert(skip_dim >= 0 && skip_dim < dims_count);
257 const auto *dims_data = shape.
DimsData();
259 for (
int i = 0; i < dims_count; ++i)
261 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
270 const Shape check_shapes_array[
sizeof...(Ts)] = {std::forward<Ts>(check_shapes)...};
271 for (
const auto &check_shape : check_shapes_array)
275 check_shape.FlatSize() != 1)
283 if (shape.
Dims(i) != check_shape.Dims(i))
300 assert(
checkMatching(shape, std::forward<Ts>(check_shapes)...));
305 [[maybe_unused]]
const Shape &check_shape_0)
308 for (
int i = 0; i < dims_count; ++i)
312 assert(shape.
Dims(i) == check_shape_0.Dims(i));
319 [[maybe_unused]]
const Shape &check_shape_0,
320 const Shape &check_shape_1)
323 for (
int i = 0; i < dims_count; ++i)
327 assert(shape.
Dims(i) == check_shape_0.Dims(i));
334 const Shape &check_shape_1)
336 const int size_1 = shape.
FlatSize();
337 [[maybe_unused]]
const int size_2 = check_shape_0.
FlatSize();
338 [[maybe_unused]]
const int size_3 = check_shape_1.
FlatSize();
339 assert(size_1 == size_2);
340 assert(size_2 == size_3);
int32_t DimensionsCount() const
Shape(const std::initializer_list< int > init_list)
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
Shape(int shape_size, int32_t value)
int32_t Dims(int i) const
void ReplaceWith(const Shape &other)
void BuildFrom(const T &src_iterable)
void ReplaceWith(Shape &&other)
static constexpr int kMaxSmallSize
Shape(int dimensions_count)
bool operator==(const Shape &comp) const
Shape(Shape const &other)
void BuildFrom(const std::initializer_list< int > init_list)
bool operator!=(const Shape &comp) const
Shape(int dimensions_count, const int32_t *dims_data)
void Resize(int dimensions_count)
void SetDim(int i, int32_t val)
static Shape ExtendedShape(int new_shape_size, const Shape &shape)
const int32_t * DimsData() const
Shape & operator=(Shape const &)=delete
const int32_t * DimsDataUpTo4D() const
int32_t _dims[kMaxSmallSize]
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
int FlatSizeSkipDim(const Shape &shape, int skip_dim)
Shape GetShape(const std::vector< int32_t > &data)
int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0, const Shape &check_shape_1)
int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
bool checkMatching(const Shape &shape, Ts... check_shapes)
UNUSED_ALL(Args const &...)