ONE - On-device Neural Engine
Loading...
Searching...
No Matches
nnfw::cker::BCastList< N > Class Template Reference

#include <BCast.h>

Collaboration diagram for nnfw::cker::BCastList< N >:

Public Types

typedef std::vector< int32_t > Vec
 

Public Member Functions

 BCastList (const Vec(&x)[N], const bool fewer_dims_optimization=true, const bool return_flattened_batch_indices=false)
 
 ~BCastList ()
 
bool IsValid () const
 
bool IsBroadcastingRequired () const
 
const Vecreshape (int i) const
 
const Vecbcast (int i) const
 
const Vecresult_shape () const
 
const Vecoutput_shape () const
 
const Vecgrad_reduce_idx (int i) const
 
int32_t output_batch_size () const
 
const std::vector< int32_t > & batch_indices (int i) const
 

Static Protected Member Functions

static void Reverse (Vec *shape)
 

Protected Attributes

bool valid_ = true
 
bool broadcasting_required_ = true
 
Vec reshape_ [N]
 
Vec bcast_ [N]
 
Vec result_
 
Vec output_
 
Vec grad_reduce_idx_ [N]
 
int32_t output_batch_size_
 
std::vector< int32_t > batch_indices_ [N]
 

Detailed Description

template<int N>
class nnfw::cker::BCastList< N >

Definition at line 63 of file BCast.h.

Member Typedef Documentation

◆ Vec

template<int N>
typedef std::vector<int32_t> nnfw::cker::BCastList< N >::Vec

Definition at line 70 of file BCast.h.

Constructor & Destructor Documentation

◆ BCastList()

template<int N>
nnfw::cker::BCastList< N >::BCastList ( const Vec(&)  x[N],
const bool  fewer_dims_optimization = true,
const bool  return_flattened_batch_indices = false 
)
explicit

Definition at line 127 of file BCast.h.

129{
130 typedef BCastList::Vec Vec;
131 bool all_equal = true;
132 size_t largest_rank = 0;
134 for (int i = 0; i < N; ++i)
135 {
136 if (x[i] != x[0])
137 {
138 all_equal = false;
139 }
140 if (x[i].size() > largest_rank)
141 {
142 largest_rank = x[i].size();
143 }
144 }
145 if (all_equal)
146 {
148 }
149 if (all_equal && fewer_dims_optimization)
150 {
151 // Fast path for common case of identical shapes.
152 int32_t elements = 1;
153 const int rank = x[0].size();
154 output_.resize(rank);
155 for (int i = 0; i < rank; i++)
156 {
157 const int32_t dim = x[0][i];
158 elements *= dim;
159 output_[i] = dim;
160 }
161 result_.push_back(elements);
162 output_batch_size_ = elements;
163 for (int i = 0; i < N; ++i)
164 {
165 reshape_[i].push_back(elements);
166 bcast_[i].push_back(1);
167 }
168 // grad_reduce_ is left as empty
169 return;
170 }
171
172 // Reverse all the shapes for convenience
173 // After the reverse, 0-th is the inner-most dimension.
174 Vec copy[N];
175 for (int i = 0; i < N; ++i)
176 {
177 copy[i] = x[i];
178 Reverse(&copy[i]);
179 }
180
181 // 1-extend and align all vectors.
182 for (int i = 0; i < N; ++i)
183 {
184 if (copy[i].size() < largest_rank)
185 {
186 copy[i].resize(largest_rank, 1);
187 }
188 }
189 // Going through each dimension starting from the inner-most
190 // dimension, compares dimension of x and y. They are compatible if
191 // they are equal or either is 1.
192
193 // indices of j-th component of each input.
194 bool prev_is_one[N];
195 bool current_is_one[N];
196 for (int i = 0; i < N; ++i)
197 {
198 prev_is_one[i] = false;
199 current_is_one[i] = false;
200 }
201 Vec output;
202 bool output_dim_set = false;
203 int output_dim = -1;
204 bool none_is_one = true;
205 bool set_one = false;
206 for (size_t j = 0; j < largest_rank; ++j)
207 {
208 output_dim = -1;
209 output_dim_set = false;
210 none_is_one = true;
211 // Find which indices are 1.
212 for (int i = 0; i < N; ++i)
213 {
214 // Keep track of which indices are 1.
215 if (copy[i][j] == 1)
216 {
217 current_is_one[i] = true;
218 none_is_one = false;
219 }
220 else
221 {
222 current_is_one[i] = false;
223 if (!output_dim_set || copy[i][j] == output_dim)
224 {
225 output_dim = copy[i][j];
226 output_dim_set = true;
227 }
228 else
229 {
230 valid_ = false;
231 return;
232 }
233 }
234 }
235 output_.push_back(output_dim_set ? output_dim : 1);
236 output_batch_size_ *= output_.back();
237 // All dimensions are 1.
238 if (!output_dim_set)
239 {
240 if (!fewer_dims_optimization)
241 {
242 for (int i = 0; i < N; ++i)
243 {
244 bcast_[i].push_back(1);
245 reshape_[i].push_back(1);
246 }
247 result_.push_back(1);
248 }
249 for (int i = 0; i < N; ++i)
250 {
251 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
252 }
253 // This will skip updating the previous state to the current one. We'll
254 // explain why this is safe below.
255 // Consider the previous state P, current state C and the next state N.
256 // In the case where N also is all ones (N == C), we'll do the same
257 // optimization here (push back one dimensions if we need to), which is
258 // safe and is expected.
259 //
260 // When N != C, we'll continue as usual. However, we might trigger the
261 // next block if N == P (because we didn't update the previous state).
262 // We trigger the next block if `fewer_dims_optimization` is true.
263 // This means that we did not modify and broadcast / reshapes in this
264 // block (we skipped updating, since the one dimensions can be ignored).
265 // In essence, we only need to check whether the previous non-one state is
266 // equal to the current non-one state.
267
268 continue;
269 }
270 else if ((fewer_dims_optimization) &&
271 std::equal(current_is_one, current_is_one + N, prev_is_one) && set_one)
272 {
273 // It is a run of the same broadcasting case as last time.
274 // We can reshape the input so that fewer dimensions
275 // are involved in the intermediate computation.
276 result_.back() *= output_dim;
277 for (int i = 0; i < N; ++i)
278 {
279 reshape_[i].back() *= copy[i][j];
280 bcast_[i].back() *= current_is_one[i] ? output_dim : 1;
281 if (current_is_one[i] && !none_is_one)
282 {
283 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
284 }
285 }
286 }
287 else
288 {
289 result_.push_back(output_dim);
290 for (int i = 0; i < N; ++i)
291 {
292 reshape_[i].push_back(copy[i][j]);
293 bcast_[i].push_back(current_is_one[i] ? output_dim : 1);
294 if (current_is_one[i] && !none_is_one)
295 {
296 grad_reduce_idx_[i].push_back(largest_rank - 1 - j);
297 }
298 }
299 }
300 set_one = true;
301 for (int i = 0; i < N; ++i)
302 {
303 prev_is_one[i] = current_is_one[i];
304 }
305 }
306 if (result_.empty())
307 {
308 result_.push_back(1);
309 for (int i = 0; i < N; ++i)
310 {
311 reshape_[i].push_back(1);
312 bcast_[i].push_back(1);
313 }
314 }
315 // Do something about batches.
316 for (int i = 0; i < N; ++i)
317 {
318 Reverse(&reshape_[i]);
319 Reverse(&bcast_[i]);
321 }
324 // Only compute batch indices when we need broadcasting, and we aren't doing
325 // needless work (when the output size is 0 or the
326 // return_flattened_batch_indices isn't enabled).
327 if (return_flattened_batch_indices && broadcasting_required_ && output_batch_size_ > 0)
328 {
329 for (int i = 0; i < N; ++i)
330 {
332 }
333 }
334}
bool broadcasting_required_
Definition BCast.h:113
static void Reverse(Vec *shape)
Definition BCast.h:123
std::vector< int32_t > Vec
Definition BCast.h:70
int32_t output_batch_size_
Definition BCast.h:120
Vec grad_reduce_idx_[N]
Definition BCast.h:118
std::vector< int32_t > batch_indices_[N]
Definition BCast.h:121
void ComputeBatchIndices(const int32_t output_batch_size, const std::vector< int32_t > &reshape, const std::vector< int32_t > &bcast, std::vector< int32_t > *out_indices)
Definition BCast.h:34
int32_t size[5]
Definition Slice.cpp:35

References nnfw::cker::ComputeBatchIndices(), nnfw::cker::Reverse(), and size.

◆ ~BCastList()

template<int N>
nnfw::cker::BCastList< N >::~BCastList ( )
inline

Definition at line 86 of file BCast.h.

86{}

Member Function Documentation

◆ batch_indices()

template<int N>
const std::vector< int32_t > & nnfw::cker::BCastList< N >::batch_indices ( int  i) const
inline

Definition at line 109 of file BCast.h.

109{ return batch_indices_[i]; }

References nnfw::cker::BCastList< N >::batch_indices_.

◆ bcast()

template<int N>
const Vec & nnfw::cker::BCastList< N >::bcast ( int  i) const
inline

Definition at line 97 of file BCast.h.

97{ return bcast_[i]; }

References nnfw::cker::BCastList< N >::bcast_.

◆ grad_reduce_idx()

template<int N>
const Vec & nnfw::cker::BCastList< N >::grad_reduce_idx ( int  i) const
inline

Definition at line 100 of file BCast.h.

100{ return grad_reduce_idx_[i]; }

References nnfw::cker::BCastList< N >::grad_reduce_idx_.

◆ IsBroadcastingRequired()

template<int N>
bool nnfw::cker::BCastList< N >::IsBroadcastingRequired ( ) const
inline

Definition at line 91 of file BCast.h.

91{ return broadcasting_required_; }

References nnfw::cker::BCastList< N >::broadcasting_required_.

◆ IsValid()

template<int N>
bool nnfw::cker::BCastList< N >::IsValid ( ) const
inline

Definition at line 90 of file BCast.h.

90{ return valid_; }

References nnfw::cker::BCastList< N >::valid_.

Referenced by nnfw::cker::BroadcastTo().

◆ output_batch_size()

template<int N>
int32_t nnfw::cker::BCastList< N >::output_batch_size ( ) const
inline

Definition at line 101 of file BCast.h.

101{ return output_batch_size_; }

References nnfw::cker::BCastList< N >::output_batch_size_.

◆ output_shape()

template<int N>
const Vec & nnfw::cker::BCastList< N >::output_shape ( ) const
inline

Definition at line 99 of file BCast.h.

99{ return output_; }

References nnfw::cker::BCastList< N >::output_.

◆ reshape()

template<int N>
const Vec & nnfw::cker::BCastList< N >::reshape ( int  i) const
inline

Definition at line 96 of file BCast.h.

96{ return reshape_[i]; }

References nnfw::cker::BCastList< N >::reshape_.

◆ result_shape()

template<int N>
const Vec & nnfw::cker::BCastList< N >::result_shape ( ) const
inline

Definition at line 98 of file BCast.h.

98{ return result_; }

References nnfw::cker::BCastList< N >::result_.

◆ Reverse()

template<int N>
static void nnfw::cker::BCastList< N >::Reverse ( Vec shape)
inlinestaticprotected

Definition at line 123 of file BCast.h.

123{ std::reverse(shape->begin(), shape->end()); }

Field Documentation

◆ batch_indices_

template<int N>
std::vector<int32_t> nnfw::cker::BCastList< N >::batch_indices_[N]
protected

Definition at line 121 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::batch_indices().

◆ bcast_

template<int N>
Vec nnfw::cker::BCastList< N >::bcast_[N]
protected

Definition at line 115 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::bcast().

◆ broadcasting_required_

template<int N>
bool nnfw::cker::BCastList< N >::broadcasting_required_ = true
protected

Definition at line 113 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::IsBroadcastingRequired().

◆ grad_reduce_idx_

template<int N>
Vec nnfw::cker::BCastList< N >::grad_reduce_idx_[N]
protected

Definition at line 118 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::grad_reduce_idx().

◆ output_

template<int N>
Vec nnfw::cker::BCastList< N >::output_
protected

Definition at line 117 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::output_shape().

◆ output_batch_size_

template<int N>
int32_t nnfw::cker::BCastList< N >::output_batch_size_
protected

Definition at line 120 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::output_batch_size().

◆ reshape_

template<int N>
Vec nnfw::cker::BCastList< N >::reshape_[N]
protected

Definition at line 114 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::reshape().

◆ result_

template<int N>
Vec nnfw::cker::BCastList< N >::result_
protected

Definition at line 116 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::result_shape().

◆ valid_

template<int N>
bool nnfw::cker::BCastList< N >::valid_ = true
protected

Definition at line 112 of file BCast.h.

Referenced by nnfw::cker::BCastList< N >::IsValid().


The documentation for this class was generated from the following file: