ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnfw::cker::BCastList< N > Class Template Reference

#include <BCast.h>

Collaboration diagram for nnfw::cker::BCastList< N >:

Public Types

typedef std::vector< int32_t > Vec
 

Public Member Functions

 BCastList (const Vec(&x)[N], const bool fewer_dims_optimization=true, const bool return_flattened_batch_indices=false)
 
 ~BCastList ()
 
bool IsValid () const
 
bool IsBroadcastingRequired () const
 
const Vecreshape (int i) const
 
const Vecbcast (int i) const
 
const Vecresult_shape () const
 
const Vecoutput_shape () const
 
const Vecgrad_reduce_idx (int i) const
 
int32_t output_batch_size () const
 
const std::vector< int32_t > & batch_indices (int i) const
 

Static Protected Member Functions

static void Reverse (Vec *shape)
 

Protected Attributes

bool valid_ = true
 
bool broadcasting_required_ = true
 
Vec reshape_ [N]
 
Vec bcast_ [N]
 
Vec result_
 
Vec output_
 
Vec grad_reduce_idx_ [N]
 
int32_t output_batch_size_
 
std::vector< int32_t > batch_indices_ [N]
 

Detailed Description

template<int N>
class nnfw::cker::BCastList< N >

Definition at line 69 of file BCast.h.

Member Typedef Documentation

◆ Vec

template<int N>
typedef std::vector<int32_t> nnfw::cker::BCastList< N >::Vec

Definition at line 76 of file BCast.h.

Constructor & Destructor Documentation

◆ BCastList()

template<int N>
nnfw::cker::BCastList< N >::BCastList ( const Vec(&)  x[N],
const bool  fewer_dims_optimization = true,
const bool  return_flattened_batch_indices = false 
)
explicit

Definition at line 133 of file BCast.h.

135{
136 typedef BCastList::Vec Vec;
137 bool all_equal = true;
138 size_t largest_rank = 0;
140 for (int i = 0; i < N; ++i)
141 {
142 if (x[i] != x[0])
143 {
144 all_equal = false;
145 }
146 if (x[i].size() > largest_rank)
147 {
148 largest_rank = x[i].size();
149 }
150 }
151 if (all_equal)
152 {
154 }
155 if (all_equal && fewer_dims_optimization)
156 {
157 // Fast path for common case of identical shapes.
158 int32_t elements = 1;
159 const int rank = x[0].size();
160 output_.resize(rank);
161 for (int i = 0; i < rank; i++)
162 {
163 const int32_t dim = x[0][i];
164 elements *= dim;
165 output_[i] = dim;
166 }
167 result_.push_back(elements);
168 output_batch_size_ = elements;
169 for (int i = 0; i < N; ++i)
170 {
171 reshape_[i].push_back(elements);
172 bcast_[i].push_back(1);
173 }
174 // grad_reduce_ is left as empty
175 return;
176 }
177
178 // Reverse all the shapes for convenience
179 // After the reverse, 0-th is the inner-most dimension.
180 Vec copy[N];
181 for (int i = 0; i < N; ++i)
182 {
183 copy[i] = x[i];
184 Reverse(&copy[i]);
185 }
186
187 // 1-extend and align all vectors.
188 for (int i = 0; i < N; ++i)
189 {
190 if (copy[i].size() < largest_rank)
191 {
192 copy[i].resize(largest_rank, 1);
193 }
194 }
195 // Going through each dimension starting from the inner-most
196 // dimension, compares dimension of x and y. They are compatible if
197 // they are equal or either is 1.
198
199 // indices of j-th component of each input.
200 bool prev_is_one[N];
201 bool current_is_one[N];
202 for (int i = 0; i < N; ++i)
203 {
204 prev_is_one[i] = false;
205 current_is_one[i] = false;
206 }
207 Vec output;
208 bool output_dim_set = false;
209 int output_dim = -1;
210 bool none_is_one = true;
211 bool set_one = false;
212 for (size_t j = 0; j < largest_rank; ++j)
213 {
214 output_dim = -1;
215 output_dim_set = false;
216 none_is_one = true;
217 // Find which indices are 1.
218 for (int i = 0; i < N; ++i)
219 {
220 // Keep track of which indices are 1.
221 if (copy[i][j] == 1)
222 {
223 current_is_one[i] = true;
224 none_is_one = false;
225 }
226 else
227 {
228 current_is_one[i] = false;
229 if (!output_dim_set || copy[i][j] == output_dim)
230 {
231 output_dim = copy[i][j];
232 output_dim_set = true;
233 }
234 else
235 {
236 valid_ = false;
237 return;
238 }
239 }
240 }
241 output_.push_back(output_dim_set ? output_dim : 1);
242 output_batch_size_ *= output_.back();
243 // All dimensions are 1.
244 if (!output_dim_set)
245 {
246 if (!fewer_dims_optimization)
247 {
248 for (int i = 0; i < N; ++i)
249 {
250 bcast_[i].push_back(1);
251 reshape_[i].push_back(1);
252 }
253 result_.push_back(1);
254 }
255 for (int i = 0; i < N; ++i)
256 {
257 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
258 }
259 // This will skip updating the previous state to the current one. We'll
260 // explain why this is safe below.
261 // Consider the previous state P, current state C and the next state N.
262 // In the case where N also is all ones (N == C), we'll do the same
263 // optimization here (push back one dimensions if we need to), which is
264 // safe and is expected.
265 //
266 // When N != C, we'll continue as usual. However, we might trigger the
267 // next block if N == P (because we didn't update the previous state).
268 // We trigger the next block if `fewer_dims_optimization` is true.
269 // This means that we did not modify and broadcast / reshapes in this
270 // block (we skipped updating, since the one dimensions can be ignored).
271 // In essence, we only need to check whether the previous non-one state is
272 // equal to the current non-one state.
273
274 continue;
275 }
276 else if ((fewer_dims_optimization) &&
277 std::equal(current_is_one, current_is_one + N, prev_is_one) && set_one)
278 {
279 // It is a run of the same broadcasting case as last time.
280 // We can reshape the input so that fewer dimensions
281 // are involved in the intermediate computation.
282 result_.back() *= output_dim;
283 for (int i = 0; i < N; ++i)
284 {
285 reshape_[i].back() *= copy[i][j];
286 bcast_[i].back() *= current_is_one[i] ? output_dim : 1;
287 if (current_is_one[i] && !none_is_one)
288 {
289 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
290 }
291 }
292 }
293 else
294 {
295 result_.push_back(output_dim);
296 for (int i = 0; i < N; ++i)
297 {
298 reshape_[i].push_back(copy[i][j]);
299 bcast_[i].push_back(current_is_one[i] ? output_dim : 1);
300 if (current_is_one[i] && !none_is_one)
301 {
302 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
303 }
304 }
305 }
306 set_one = true;
307 for (int i = 0; i < N; ++i)
308 {
309 prev_is_one[i] = current_is_one[i];
310 }
311 }
312 if (result_.empty())
313 {
314 result_.push_back(1);
315 for (int i = 0; i < N; ++i)
316 {
317 reshape_[i].push_back(1);
318 bcast_[i].push_back(1);
319 }
320 }
321 // Do something about batches.
322 for (int i = 0; i < N; ++i)
323 {
324 Reverse(&reshape_[i]);
325 Reverse(&bcast_[i]);
327 }
330 // Only compute batch indices when we need broadcasting, and we aren't doing
331 // needless work (when the output size is 0 or the
332 // return_flattened_batch_indices isn't enabled).
333 if (return_flattened_batch_indices && broadcasting_required_ && output_batch_size_ > 0)
334 {
335 for (int i = 0; i < N; ++i)
336 {
338 }
339 }
340}
bool broadcasting_required_
Definition BCast.h:119
static void Reverse(Vec *shape)
Definition BCast.h:129
std::vector< int32_t > Vec
Definition BCast.h:76
int32_t output_batch_size_
Definition BCast.h:126
Vec grad_reduce_idx_[N]
Definition BCast.h:124
std::vector< int32_t > batch_indices_[N]
Definition BCast.h:127
void ComputeBatchIndices(const int32_t output_batch_size, const std::vector< int32_t > &reshape, const std::vector< int32_t > &bcast, std::vector< int32_t > *out_indices)
Definition BCast.h:40
int32_t size[5]
Definition Slice.cpp:35

References nnfw::cker::ComputeBatchIndices(), nnfw::cker::Reverse(), and size.

◆ ~BCastList()

template<int N>
nnfw::cker::BCastList< N >::~BCastList ( )
inline

Definition at line 92 of file BCast.h.

92{}

Member Function Documentation

◆ batch_indices()

template<int N>
const std::vector< int32_t > & nnfw::cker::BCastList< N >::batch_indices ( int  i) const
inline

Definition at line 115 of file BCast.h.

115{ return batch_indices_[i]; }

References nnfw::cker::BCastList< N >::batch_indices_.

◆ bcast()

template<int N>
const Vec & nnfw::cker::BCastList< N >::bcast ( int  i) const
inline

Definition at line 103 of file BCast.h.

103{ return bcast_[i]; }

References nnfw::cker::BCastList< N >::bcast_.

◆ grad_reduce_idx()

template<int N>
const Vec & nnfw::cker::BCastList< N >::grad_reduce_idx ( int  i) const
inline

Definition at line 106 of file BCast.h.

106{ return grad_reduce_idx_[i]; }

References nnfw::cker::BCastList< N >::grad_reduce_idx_.

◆ IsBroadcastingRequired()

template<int N>
bool nnfw::cker::BCastList< N >::IsBroadcastingRequired ( ) const
inline

Definition at line 97 of file BCast.h.

97{ return broadcasting_required_; }

References nnfw::cker::BCastList< N >::broadcasting_required_.

◆ IsValid()

template<int N>
bool nnfw::cker::BCastList< N >::IsValid ( ) const
inline

Definition at line 96 of file BCast.h.

96{ return valid_; }

References nnfw::cker::BCastList< N >::valid_.

Referenced by nnfw::cker::BroadcastTo().

◆ output_batch_size()

template<int N>
int32_t nnfw::cker::BCastList< N >::output_batch_size ( ) const
inline

Definition at line 107 of file BCast.h.

107{ return output_batch_size_; }

References nnfw::cker::BCastList< N >::output_batch_size_.

◆ output_shape()

template<int N>
const Vec & nnfw::cker::BCastList< N >::output_shape ( ) const
inline

Definition at line 105 of file BCast.h.

105{ return output_; }

References nnfw::cker::BCastList< N >::output_.

◆ reshape()

template<int N>
const Vec & nnfw::cker::BCastList< N >::reshape ( int  i) const
inline

Definition at line 102 of file BCast.h.

102{ return reshape_[i]; }

References nnfw::cker::BCastList< N >::reshape_.

◆ result_shape()

template<int N>
const Vec & nnfw::cker::BCastList< N >::result_shape ( ) const
inline

Definition at line 104 of file BCast.h.

104{ return result_; }

References nnfw::cker::BCastList< N >::result_.

◆ Reverse()

template<int N>
static void nnfw::cker::BCastList< N >::Reverse ( Vec shape)
inlinestaticprotected

Definition at line 129 of file BCast.h.

129{ std::reverse(shape->begin(), shape->end()); }

Field Documentation

◆ batch_indices_

template<int N>
std::vector<int32_t> nnfw::cker::BCastList< N >::batch_indices_[N]
protected

Definition at line 127 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::batch_indices().

◆ bcast_

template<int N>
Vec nnfw::cker::BCastList< N >::bcast_[N]
protected

Definition at line 121 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::bcast().

◆ broadcasting_required_

template<int N>
bool nnfw::cker::BCastList< N >::broadcasting_required_ = true
protected

Definition at line 119 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::IsBroadcastingRequired().

◆ grad_reduce_idx_

template<int N>
Vec nnfw::cker::BCastList< N >::grad_reduce_idx_[N]
protected

Definition at line 124 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::grad_reduce_idx().

◆ output_

template<int N>
Vec nnfw::cker::BCastList< N >::output_
protected

Definition at line 123 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::output_shape().

◆ output_batch_size_

template<int N>
int32_t nnfw::cker::BCastList< N >::output_batch_size_
protected

Definition at line 126 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::output_batch_size().

◆ reshape_

template<int N>
Vec nnfw::cker::BCastList< N >::reshape_[N]
protected

Definition at line 120 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::reshape().

◆ result_

template<int N>
Vec nnfw::cker::BCastList< N >::result_
protected

Definition at line 122 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::result_shape().

◆ valid_

template<int N>
bool nnfw::cker::BCastList< N >::valid_ = true
protected

Definition at line 118 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::IsValid().


The documentation for this class was generated from the following file: