ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Array.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef _NDARRAY_ARRAY_H_
18#define _NDARRAY_ARRAY_H_
19
20#include "Common.h"
21
22#include "ContiguousSpan.h"
23#include "Shape.h"
24
25#include <algorithm>
26#include <array>
27#include <cassert>
28#include <cstddef>
29#include <tuple>
30#include <type_traits>
31#include <utility>
32
33namespace ndarray
34{
35
36template <size_t... Nums> using index_sequence = std::index_sequence<Nums...>;
37
38template <size_t _Num> using make_index_sequence = std::make_index_sequence<_Num>;
39
40struct Strides
41{
42 explicit Strides(Shape s) : _strides{} { fillStrides(s); }
43
44 int operator[](size_t idx) const noexcept { return _strides[idx]; }
45
46 // since we don't have c++14 fold expression
47 template <typename Seq, typename... Ts> struct _calc_offset;
48
49 template <size_t Num, size_t... Nums, typename T, typename... Ts>
50 struct _calc_offset<index_sequence<Num, Nums...>, T, Ts...>
51 {
52 static constexpr size_t get(const std::array<int, 8> &strides, int x, Ts... xs)
53 {
54 return _calc_offset<index_sequence<Nums...>, Ts...>::get(strides, xs...) +
55 x * std::get<Num>(strides);
56 }
57 };
58
59 template <size_t Num, typename T> struct _calc_offset<index_sequence<Num>, T>
60 {
61 static constexpr size_t get(const std::array<int, 8> &strides, int x)
62 {
63 return x * std::get<Num>(strides);
64 }
65 };
66
67 template <typename Seq, typename... Ts> constexpr size_t offset(Seq, Ts... x) const noexcept
68 {
69 // return ( 0 + ... + (std::get<Nums>(_strides) * x)); in c++14
70 return _calc_offset<Seq, Ts...>::get(_strides, x...);
71 }
72
73private:
74 void fillStrides(const Shape &s) noexcept
75 {
76 int rank = s.rank();
77 _strides[rank - 1] = 1;
78 for (int d = rank - 2; d >= 0; --d)
79 {
80 _strides[d] = _strides[d + 1] * s.dim(d + 1);
81 }
82 }
83
84 std::array<int, NDARRAY_MAX_DIMENSION_COUNT> _strides;
85};
86
87template <typename T> class Array
88{
89public:
90 Array(T *data, Shape shape) noexcept : _data(data), _shape(shape), _strides(shape) {}
91
92 Array(const Array &) = delete;
93
94 Array(Array &&a) noexcept : _data(a._data), _shape(a._shape), _strides(a._strides)
95 {
96 a._data = nullptr;
97 }
98
99 template <typename... Ts> T &at(Ts... x) const noexcept { return _at(static_cast<size_t>(x)...); }
100
106 template <typename... Ts> ContiguousSpan<T, std::is_const<T>::value> slice(Ts... x) noexcept
107 {
108 assert(sizeof...(Ts) == _shape.rank() - 1);
109 return {&at(x..., 0ul), _shape.dim(_shape.rank() - 1)};
110 }
111
117 template <typename... Ts> ContiguousSpan<T, true> slice(Ts... x) const noexcept
118 {
119 assert(sizeof...(Ts) == _shape.rank() - 1);
120 return {&at(x..., 0ul), _shape.dim(_shape.rank() - 1)};
121 }
122
124 {
125 return {_data, _shape.element_count()};
126 }
127
128 ContiguousSpan<T, true> flat() const noexcept { return {_data, _shape.element_count()}; }
129
130 const Shape &shape() const noexcept { return _shape; }
131
132private:
133 template <typename... Ts> T &_at(Ts... x) const noexcept
134 {
135 assert(sizeof...(x) == _shape.rank());
136 using Indices = make_index_sequence<sizeof...(Ts)>;
137 return _data[offset(Indices{}, x...)];
138 }
139
140 template <typename... Ts, size_t... Nums>
141 size_t offset(index_sequence<Nums...> seq, Ts... x) const noexcept
142 {
143 static_assert(
144 sizeof...(Ts) == sizeof...(Nums),
145 "Sanity check failed. Generated index sequence size is not equal to argument count");
146
147 return _strides.offset(seq, x...);
148 }
149
150 T *_data;
151 Shape _shape;
152 Strides _strides;
153};
154
155template <typename To, typename From> Array<To> array_cast(Array<From> &&from, Shape newShape)
156{
157 assert(from.shape().element_count() / (sizeof(To) / sizeof(From)) == newShape.element_count());
158 return Array<To>(reinterpret_cast<To *>(from.flat().data()), newShape);
159}
160
161template <typename To, typename From>
163{
164 assert(from.shape().element_count() / (sizeof(To) / sizeof(From)) == newShape.element_count());
165 return Array<To>(reinterpret_cast<const To *>(from.flat().data()), newShape);
166}
167
168#ifndef NDARRAY_INLINE_TEMPLATES
169
170extern template class Array<float>;
171extern template class Array<int32_t>;
172extern template class Array<uint32_t>;
173extern template class Array<uint8_t>;
174
175#endif // NDARRAY_INLINE_TEMPLATES
176
177} // namespace ndarray
178
179#endif //_NDARRAY_ARRAY_H_
T & at(Ts... x) const noexcept
Definition Array.h:99
const Shape & shape() const noexcept
Definition Array.h:130
ContiguousSpan< T, true > flat() const noexcept
Definition Array.h:128
ContiguousSpan< T, std::is_const< T >::value > slice(Ts... x) noexcept
returns last dimension as ContigniousSpan
Definition Array.h:106
ContiguousSpan< T, true > slice(Ts... x) const noexcept
returns last dimension as ContigniousSpan
Definition Array.h:117
Array(T *data, Shape shape) noexcept
Definition Array.h:90
Array(const Array &)=delete
Array(Array &&a) noexcept
Definition Array.h:94
ContiguousSpan< T, std::is_const< T >::value > flat() noexcept
Definition Array.h:123
size_t dim(int i) const noexcept
Definition Shape.h:44
size_t rank() const noexcept
Definition Shape.h:57
size_t element_count() const noexcept
Definition Shape.h:48
Array< To > array_cast(Array< From > &&from, Shape newShape)
Definition Array.h:155
std::make_index_sequence< _Num > make_index_sequence
Definition Array.h:38
std::index_sequence< Nums... > index_sequence
Definition Array.h:36
Definition Shape.h:28
static constexpr size_t get(const std::array< int, 8 > &strides, int x, Ts... xs)
Definition Array.h:52
static constexpr size_t get(const std::array< int, 8 > &strides, int x)
Definition Array.h:61
Strides(Shape s)
Definition Array.h:42
int operator[](size_t idx) const noexcept
Definition Array.h:44
constexpr size_t offset(Seq, Ts... x) const noexcept
Definition Array.h:67