18#ifndef __NNFW_RUY_SHAPE_H__
19#define __NNFW_RUY_SHAPE_H__
42 explicit Shape(
int dimensions_count) : _size(dimensions_count)
50 Shape(
int shape_size, int32_t value) : _size(0)
53 for (
int i = 0; i < shape_size; ++i)
59 Shape(
int dimensions_count,
const int32_t *dims_data) : _size(0)
64 Shape(
const std::initializer_list<int> init_list) : _size(0) {
BuildFrom(init_list); }
79 return this->_size == comp._size &&
92 inline int32_t
Dims(
int i)
const
98 inline void SetDim(
int i, int32_t val)
123 _size = dimensions_count;
130 inline void ReplaceWith(
int dimensions_count,
const int32_t *dims_data)
134 std::memcpy(dst_dims, dims_data, dimensions_count *
sizeof(int32_t));
145 std::swap(_size, other._size);
152 template <
typename T>
inline void BuildFrom(
const T &src_iterable)
154 const int dimensions_count = std::distance(src_iterable.begin(), src_iterable.end());
157 for (
auto &&it : src_iterable)
171 return Shape(new_shape_size, shape, 1);
174 inline void BuildFrom(
const std::initializer_list<int> init_list)
176 BuildFrom<const std::initializer_list<int>>(init_list);
185 for (
int i = 0; i < _size; i++)
187 const int dim = dims_data[i];
200 Shape(
int new_shape_size,
const Shape &shape,
int pad_value) : _size(0)
206 for (
int i = 0; i < size_increase; ++i)
222 [[maybe_unused]]
int index2)
224 assert(shape1.
Dims(index1) == shape2.Dims(index2));
225 return shape1.
Dims(index1);
228template <
typename... Args>
230 [[maybe_unused]]
int index2, Args... args)
232 assert(shape1.
Dims(index1) == shape2.Dims(index2));
238inline int Offset(
const Shape &shape,
int i0,
int i1,
int i2,
int i3)
242 assert(i0 >= 0 && i0 < dims_data[0]);
243 assert(i1 >= 0 && i1 < dims_data[1]);
244 assert(i2 >= 0 && i2 < dims_data[2]);
245 assert(i3 >= 0 && i3 < dims_data[3]);
246 return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
251 return Offset(shape, index[0], index[1], index[2], index[3]);
257 assert(skip_dim >= 0 && skip_dim < dims_count);
258 const auto *dims_data = shape.
DimsData();
260 for (
int i = 0; i < dims_count; ++i)
262 flat_size *= (i == skip_dim) ? 1 : dims_data[i];
271 const Shape check_shapes_array[
sizeof...(Ts)] = {std::forward<Ts>(check_shapes)...};
272 for (
const auto &check_shape : check_shapes_array)
276 check_shape.FlatSize() != 1)
284 if (shape.
Dims(i) != check_shape.Dims(i))
301 assert(
checkMatching(shape, std::forward<Ts>(check_shapes)...));
306 [[maybe_unused]]
const Shape &check_shape_0)
309 for (
int i = 0; i < dims_count; ++i)
313 assert(shape.
Dims(i) == check_shape_0.Dims(i));
320 [[maybe_unused]]
const Shape &check_shape_0,
321 const Shape &check_shape_1)
324 for (
int i = 0; i < dims_count; ++i)
328 assert(shape.
Dims(i) == check_shape_0.Dims(i));
335 const Shape &check_shape_1)
337 const int size_1 = shape.
FlatSize();
338 [[maybe_unused]]
const int size_2 = check_shape_0.
FlatSize();
339 [[maybe_unused]]
const int size_3 = check_shape_1.
FlatSize();
340 assert(size_1 == size_2);
341 assert(size_2 == size_3);
int32_t _dims[kMaxSmallSize]
bool operator==(const Shape &comp) const
Shape(int dimensions_count, const int32_t *dims_data)
void Resize(int dimensions_count)
void BuildFrom(const T &src_iterable)
void BuildFrom(const std::initializer_list< int > init_list)
Shape(Shape const &other)
void ReplaceWith(int dimensions_count, const int32_t *dims_data)
void ReplaceWith(const Shape &other)
Shape(int shape_size, int32_t value)
Shape(const std::initializer_list< int > init_list)
const int32_t * DimsDataUpTo4D() const
int32_t DimensionsCount() const
Shape & operator=(Shape const &)=delete
int32_t Dims(int i) const
void ReplaceWith(Shape &&other)
static Shape ExtendedShape(int new_shape_size, const Shape &shape)
Shape(int dimensions_count)
const int32_t * DimsData() const
bool operator!=(const Shape &comp) const
static constexpr int kMaxSmallSize
void SetDim(int i, int32_t val)
Shape GetShape(const std::vector< int32_t > &data)
int Offset(const Shape &shape, int i0, int i1, int i2, int i3)
int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0, const Shape &check_shape_1)
int FlatSizeSkipDim(const Shape &shape, int skip_dim)
bool checkMatching(const Shape &shape, Ts... check_shapes)
int MatchingFlatSize(const Shape &shape, Ts... check_shapes)
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
UNUSED_ALL(Args const &...)