ONE - On-device Neural Engine
Loading...
Searching...
No Matches
PALreference_ops.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
3 * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18#ifndef LUCI_INTERPRETER_PAL_REFERENCE_OPS_H
19#define LUCI_INTERPRETER_PAL_REFERENCE_OPS_H
20
21#include <stdint.h>
22#include <sys/types.h>
23
24#include <algorithm>
25#include <cmath>
26#include <cstring>
27#include <functional>
28#include <limits>
29#include <memory>
30#include <type_traits>
31
32#include "third_party/eigen3/Eigen/Core"
33#include "fixedpoint/fixedpoint.h"
34#include "ruy/profiler/instrumentation.h" // from @ruy
35#include "tensorflow/lite/c/common.h"
36#include "tensorflow/lite/kernels/internal/common.h"
37#include "tensorflow/lite/kernels/internal/quantization_util.h"
38#include "tensorflow/lite/kernels/internal/reference/add.h"
39#include "tensorflow/lite/kernels/internal/reference/add_n.h"
40#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
41#include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
42#include "tensorflow/lite/kernels/internal/reference/batch_to_space_nd.h"
43#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
44#include "tensorflow/lite/kernels/internal/reference/cast.h"
45#include "tensorflow/lite/kernels/internal/reference/ceil.h"
46#include "tensorflow/lite/kernels/internal/reference/comparisons.h"
47#include "tensorflow/lite/kernels/internal/reference/concatenation.h"
48#include "tensorflow/lite/kernels/internal/reference/conv.h"
49#include "tensorflow/lite/kernels/internal/reference/depth_to_space.h"
50#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
51#include "tensorflow/lite/kernels/internal/reference/div.h"
52#include "tensorflow/lite/kernels/internal/reference/elu.h"
53#include "tensorflow/lite/kernels/internal/reference/exp.h"
54#include "tensorflow/lite/kernels/internal/reference/fill.h"
55#include "tensorflow/lite/kernels/internal/reference/floor.h"
56#include "tensorflow/lite/kernels/internal/reference/floor_div.h"
57#include "tensorflow/lite/kernels/internal/reference/floor_mod.h"
58#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
59#include "tensorflow/lite/kernels/internal/reference/gather.h"
60#include "tensorflow/lite/kernels/internal/reference/hard_swish.h"
61#include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
62#include "tensorflow/lite/kernels/internal/reference/leaky_relu.h"
63#include "tensorflow/lite/kernels/internal/reference/log_softmax.h"
64#include "tensorflow/lite/kernels/internal/reference/logistic.h"
65#include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
66#include "tensorflow/lite/kernels/internal/reference/mul.h"
67#include "tensorflow/lite/kernels/internal/reference/neg.h"
68#include "tensorflow/lite/kernels/internal/reference/pad.h"
69#include "tensorflow/lite/kernels/internal/reference/pooling.h"
70#include "tensorflow/lite/kernels/internal/reference/prelu.h"
71#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
72#include "tensorflow/lite/kernels/internal/reference/quantize.h"
73#include "tensorflow/lite/kernels/internal/reference/reduce.h"
74#include "tensorflow/lite/kernels/internal/reference/requantize.h"
75#include "tensorflow/lite/kernels/internal/reference/resize_bilinear.h"
76#include "tensorflow/lite/kernels/internal/reference/resize_nearest_neighbor.h"
77#include "tensorflow/lite/kernels/internal/reference/round.h"
78#include "tensorflow/lite/kernels/internal/reference/softmax.h"
79#include "tensorflow/lite/kernels/internal/reference/space_to_batch_nd.h"
80#include "tensorflow/lite/kernels/internal/reference/space_to_depth.h"
81#include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
82#include "tensorflow/lite/kernels/internal/reference/string_comparisons.h"
83#include "tensorflow/lite/kernels/internal/reference/sub.h"
84#include "tensorflow/lite/kernels/internal/reference/tanh.h"
85#include "tensorflow/lite/kernels/internal/reference/transpose.h"
86#include "tensorflow/lite/kernels/internal/reference/transpose_conv.h"
87#include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
88#include "tensorflow/lite/kernels/internal/tensor.h"
89#include "tensorflow/lite/kernels/internal/types.h"
90namespace tflite
91{
92
93namespace reference_ops
94{
95
96template <typename T>
97inline void Relu(const RuntimeShape &input_shape, const T *input_data,
98 const RuntimeShape &output_shape, T *output_data)
99{
100 const int flat_size = MatchingFlatSize(input_shape, output_shape);
101 for (int i = 0; i < flat_size; ++i)
102 {
103 const T val = input_data[i];
104 const T lower = 0;
105 const T clamped = val < lower ? lower : val;
106 output_data[i] = clamped;
107 }
108}
109
110template <typename T>
111inline void Relu1(const RuntimeShape &input_shape, const T *input_data,
112 const RuntimeShape &output_shape, T *output_data)
113{
114 ruy::profiler::ScopeLabel label("Relu1 (not fused)");
115 const int flat_size = MatchingFlatSize(input_shape, output_shape);
116 for (int i = 0; i < flat_size; ++i)
117 {
118 const T val = input_data[i];
119 const T upper = 1;
120 const T lower = -1;
121 const T clamped = val > upper ? upper : val < lower ? lower : val;
122 output_data[i] = clamped;
123 }
124}
125
126inline void Relu6(const RuntimeShape &input_shape, const float *input_data,
127 const RuntimeShape &output_shape, float *output_data)
128{
129 ruy::profiler::ScopeLabel label("Relu6 (not fused)");
130 const int flat_size = MatchingFlatSize(input_shape, output_shape);
131 for (int i = 0; i < flat_size; ++i)
132 {
133 const float val = input_data[i];
134 const float upper = 6;
135 const float lower = 0;
136 const float clamped = val > upper ? upper : val < lower ? lower : val;
137 output_data[i] = clamped;
138 }
139}
140
141template <typename T>
142inline void ReluX(const tflite::ReluParams &params, const RuntimeShape &input_shape,
143 const T *input_data, const RuntimeShape &output_shape, T *output_data)
144{
145 ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
146 const int flat_size = MatchingFlatSize(input_shape, output_shape);
147 for (int i = 0; i < flat_size; ++i)
148 {
149 const int32 val = static_cast<int32_t>(input_data[i]);
150 int32 clamped = params.output_offset + MultiplyByQuantizedMultiplier(val - params.input_offset,
151 params.output_multiplier,
152 params.output_shift);
153 clamped = std::max(params.quantized_activation_min, clamped);
154 clamped = std::min(params.quantized_activation_max, clamped);
155 output_data[i] = static_cast<T>(clamped);
156 }
157}
158
159template <typename T>
160inline void ReluX(const tflite::ActivationParams &params, const RuntimeShape &input_shape,
161 const T *input_data, const RuntimeShape &output_shape, T *output_data)
162{
163 ruy::profiler::ScopeLabel label("Quantized ReluX (not fused)");
164 const int flat_size = MatchingFlatSize(input_shape, output_shape);
165 const T max_value = params.quantized_activation_max;
166 const T min_value = params.quantized_activation_min;
167 for (int i = 0; i < flat_size; ++i)
168 {
169 const T val = input_data[i];
170 const T clamped = val > max_value ? max_value : val < min_value ? min_value : val;
171 output_data[i] = clamped;
172 }
173}
174
175// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
176// dimensionality if the runtime code does a single loop over one dimension
177// that handles broadcasting as the base case. The code generator would then
178// generate max(D1, D2) nested for loops.
179inline void BroadcastMulFivefold(const ArithmeticParams &unswitched_params,
180 const RuntimeShape &unswitched_input1_shape,
181 const uint8 *unswitched_input1_data,
182 const RuntimeShape &unswitched_input2_shape,
183 const uint8 *unswitched_input2_data,
184 const RuntimeShape &output_shape, uint8 *output_data)
185{
186 ArithmeticParams switched_params = unswitched_params;
187 switched_params.input1_offset = unswitched_params.input2_offset;
188 switched_params.input2_offset = unswitched_params.input1_offset;
189
190 const bool use_unswitched = unswitched_params.broadcast_category ==
191 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
192
193 const ArithmeticParams &params = use_unswitched ? unswitched_params : switched_params;
194 const uint8 *input1_data = use_unswitched ? unswitched_input1_data : unswitched_input2_data;
195 const uint8 *input2_data = use_unswitched ? unswitched_input2_data : unswitched_input1_data;
196
197 // Fivefold nested loops. The second input resets its position for each
198 // iteration of the second loop. The first input resets its position at the
199 // beginning of the fourth loop. The innermost loop is an elementwise Mul of
200 // sections of the arrays.
201 uint8 *output_data_ptr = output_data;
202 const uint8 *input1_data_ptr = input1_data;
203 const uint8 *input2_data_reset = input2_data;
204 int y0 = params.broadcast_shape[0];
205 int y1 = params.broadcast_shape[1];
206 int y2 = params.broadcast_shape[2];
207 int y3 = params.broadcast_shape[3];
208 int y4 = params.broadcast_shape[4];
209 for (int i0 = 0; i0 < y0; ++i0)
210 {
211 const uint8 *input2_data_ptr;
212 for (int i1 = 0; i1 < y1; ++i1)
213 {
214 input2_data_ptr = input2_data_reset;
215 for (int i2 = 0; i2 < y2; ++i2)
216 {
217 for (int i3 = 0; i3 < y3; ++i3)
218 {
219 MulElementwise(y4, params, input1_data_ptr, input2_data_ptr, output_data_ptr);
220 input2_data_ptr += y4;
221 output_data_ptr += y4;
222 }
223 input1_data_ptr += y4;
224 }
225 }
226 input2_data_reset = input2_data_ptr;
227 }
228}
229
230inline void Mul(const ArithmeticParams &params, const RuntimeShape &input1_shape,
231 const int16 *input1_data, const RuntimeShape &input2_shape,
232 const int16 *input2_data, const RuntimeShape &output_shape, int16 *output_data)
233{
234 ruy::profiler::ScopeLabel label("Mul/Int16");
235
236 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
237
238 for (int i = 0; i < flat_size; i++)
239 {
240 // F0 uses 0 integer bits, range [-1, 1].
241 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
242
243 F0 unclamped_result = F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
244 output_data[i] = unclamped_result.raw();
245 }
246}
247
248inline void Mul(const ArithmeticParams &params, const RuntimeShape &input1_shape,
249 const int16 *input1_data, const RuntimeShape &input2_shape,
250 const int16 *input2_data, const RuntimeShape &output_shape, uint8 *output_data)
251{
252 ruy::profiler::ScopeLabel label("Mul/Int16Uint8");
253 int32 output_offset = params.output_offset;
254 int32 output_activation_min = params.quantized_activation_min;
255 int32 output_activation_max = params.quantized_activation_max;
256 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
257
258 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
259
260 for (int i = 0; i < flat_size; i++)
261 {
262 // F0 uses 0 integer bits, range [-1, 1].
263 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
264
265 F0 unclamped_result = F0::FromRaw(input1_data[i]) * F0::FromRaw(input2_data[i]);
266 int16 rescaled_result = gemmlowp::RoundingDivideByPOT(unclamped_result.raw(), 8);
267 int16 clamped_result = std::min<int16>(output_activation_max - output_offset, rescaled_result);
268 clamped_result = std::max<int16>(output_activation_min - output_offset, clamped_result);
269 output_data[i] = output_offset + clamped_result;
270 }
271}
272
273inline void Sub16(const ArithmeticParams &params, const RuntimeShape &input1_shape,
274 const int16_t *input1_data, const RuntimeShape &input2_shape,
275 const int16_t *input2_data, const RuntimeShape &output_shape,
276 int16_t *output_data)
277{
278 ruy::profiler::ScopeLabel label("Sub/Int16");
279 const int input1_shift = params.input1_shift;
280 const int flat_size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
281 const int16 output_activation_min = params.quantized_activation_min;
282 const int16 output_activation_max = params.quantized_activation_max;
283
284 TFLITE_DCHECK(input1_shift == 0 || params.input2_shift == 0);
285 TFLITE_DCHECK_LE(input1_shift, 0);
286 TFLITE_DCHECK_LE(params.input2_shift, 0);
287 const int16 *not_shift_input = input1_shift == 0 ? input1_data : input2_data;
288 const int16 *shift_input = input1_shift == 0 ? input2_data : input1_data;
289 const int input_right_shift = input1_shift == 0 ? -params.input2_shift : -input1_shift;
290
291 if (input1_shift == 0)
292 {
293 // F0 uses 0 integer bits, range [-1, 1].
294 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
295 for (int i = 0; i < flat_size; ++i)
296 {
297 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
298 F0 scaled_input =
299 F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
300 F0 result = SaturatingSub(input_ready_scaled, scaled_input);
301 const int16 raw_output = result.raw();
302 const int16 clamped_output =
303 std::min(output_activation_max, std::max(output_activation_min, raw_output));
304 output_data[i] = clamped_output;
305 }
306 }
307 else
308 {
309 // F0 uses 0 integer bits, range [-1, 1].
310 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
311 for (int i = 0; i < flat_size; ++i)
312 {
313 F0 input_ready_scaled = F0::FromRaw(not_shift_input[i]);
314 F0 scaled_input =
315 F0::FromRaw(gemmlowp::RoundingDivideByPOT(shift_input[i], input_right_shift));
316 F0 result = SaturatingSub(scaled_input, input_ready_scaled);
317 const int16 raw_output = result.raw();
318 const int16 clamped_output =
319 std::min(output_activation_max, std::max(output_activation_min, raw_output));
320 output_data[i] = clamped_output;
321 }
322 }
323}
324
325template <typename Scalar>
326void Pack(const PackParams &params, const RuntimeShape *const *input_shapes,
327 const Scalar *const *input_data, const RuntimeShape &output_shape, Scalar *output_data)
328{
329 ruy::profiler::ScopeLabel label("Pack");
330 const int dimensions = output_shape.DimensionsCount();
331 int axis = params.axis;
332 int inputs_count = params.inputs_count;
333
334 int outer_size = 1;
335 for (int i = 0; i < axis; i++)
336 {
337 outer_size *= output_shape.Dims(i);
338 }
339 int copy_size = 1;
340 for (int i = params.axis + 1; i < dimensions; i++)
341 {
342 copy_size *= output_shape.Dims(i);
343 }
344 TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
345
346 for (int i = 0; i < inputs_count; ++i)
347 {
348 for (int k = 0; k < outer_size; k++)
349 {
350 const Scalar *input_ptr = input_data[i] + copy_size * k;
351 int loc = k * inputs_count * copy_size + i * copy_size;
352 memcpy(output_data + loc, input_ptr, copy_size * sizeof(Scalar));
353 }
354 }
355}
356
357template <typename Scalar>
358void Unpack(const UnpackParams &params, const RuntimeShape &input_shape, const Scalar *input_data,
359 const RuntimeShape &output_shape, Scalar *const *output_datas)
360{
361 ruy::profiler::ScopeLabel label("Unpack");
362 const int dimensions = input_shape.DimensionsCount();
363 const int outputs_count = params.num_split;
364
365 int outer_size = 1;
366 int axis = params.axis;
367 if (axis < 0)
368 {
369 axis += dimensions;
370 }
371 TFLITE_DCHECK_GE(axis, 0);
372 TFLITE_DCHECK_LT(axis, dimensions);
373 for (int i = 0; i < axis; ++i)
374 {
375 outer_size *= input_shape.Dims(i);
376 }
377 int copy_size = 1;
378 for (int i = axis + 1; i < dimensions; ++i)
379 {
380 copy_size *= input_shape.Dims(i);
381 }
382 TFLITE_DCHECK_EQ(output_shape.FlatSize(), copy_size * outer_size);
383
384 for (int i = 0; i < outputs_count; ++i)
385 {
386 for (int k = 0; k < outer_size; k++)
387 {
388 Scalar *output_ptr = output_datas[i] + copy_size * k;
389 int loc = k * outputs_count * copy_size + i * copy_size;
390 memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
391 }
392 }
393}
394
395template <typename Scalar>
396void PackWithScaling(const PackParams &params, const RuntimeShape *const *input_shapes,
397 const uint8 *const *input_data, const RuntimeShape &output_shape,
398 uint8 *output_data)
399{
400 ruy::profiler::ScopeLabel label("PackWithScaling");
401 const int dimensions = output_shape.DimensionsCount();
402 int axis = params.axis;
403 const int32 *input_zeropoint = params.input_zeropoint;
404 const float *input_scale = params.input_scale;
405 int inputs_count = params.inputs_count;
406 const int32 output_zeropoint = params.output_zeropoint;
407 const float output_scale = params.output_scale;
408
409 int outer_size = 1;
410 for (int i = 0; i < axis; i++)
411 {
412 outer_size *= output_shape.Dims(i);
413 }
414 int copy_size = 1;
415 for (int i = axis + 1; i < dimensions; i++)
416 {
417 copy_size *= output_shape.Dims(i);
418 }
419 TFLITE_DCHECK_EQ((**input_shapes).FlatSize(), copy_size * outer_size);
420
421 Scalar *output_ptr = output_data;
422 const float inverse_output_scale = 1.f / output_scale;
423 for (int k = 0; k < outer_size; k++)
424 {
425 for (int i = 0; i < inputs_count; ++i)
426 {
427 if (input_zeropoint[i] == output_zeropoint && input_scale[i] == output_scale)
428 {
429 memcpy(output_ptr, input_data[i] + k * copy_size, copy_size * sizeof(Scalar));
430 }
431 else
432 {
433 assert(false);
434 const float scale = input_scale[i] * inverse_output_scale;
435 const float bias = -input_zeropoint[i] * scale;
436 auto input_ptr = input_data[i];
437 for (int j = 0; j < copy_size; ++j)
438 {
439 const int value =
440 static_cast<int32_t>(std::round(input_ptr[j] * scale + bias)) + output_zeropoint;
441 output_ptr[j] = static_cast<uint8_t>(std::max(std::min(255, value), 0));
442 }
443 }
444 output_ptr += copy_size;
445 }
446 }
447}
448
449template <typename Scalar>
450void DepthConcatenation(const ConcatenationParams &params, const RuntimeShape *const *input_shapes,
451 const Scalar *const *input_data, const RuntimeShape &output_shape,
452 Scalar *output_data)
453{
454 ruy::profiler::ScopeLabel label("DepthConcatenation");
455 auto params_copy = params;
456 params_copy.axis = 3;
457 Concatenation(params_copy, input_shapes, input_data, output_shape, output_data);
458}
459
460inline void LstmCell(const LstmCellParams &params, const RuntimeShape &unextended_input_shape,
461 const float *input_data, const RuntimeShape &unextended_prev_activ_shape,
462 const float *prev_activ_data, const RuntimeShape &weights_shape,
463 const float *weights_data, const RuntimeShape &unextended_bias_shape,
464 const float *bias_data, const RuntimeShape &unextended_prev_state_shape,
465 const float *prev_state_data,
466 const RuntimeShape &unextended_output_state_shape, float *output_state_data,
467 const RuntimeShape &unextended_output_activ_shape, float *output_activ_data,
468 const RuntimeShape &unextended_concat_temp_shape, float *concat_temp_data,
469 const RuntimeShape &unextended_activ_temp_shape, float *activ_temp_data)
470{
471 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
472 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
473 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
474 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
475 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
476 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
477 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
478 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
479 const RuntimeShape input_shape = RuntimeShape::ExtendedShape(4, unextended_input_shape);
480 const RuntimeShape prev_activ_shape = RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
481 const RuntimeShape bias_shape = RuntimeShape::ExtendedShape(4, unextended_bias_shape);
482 const RuntimeShape prev_state_shape = RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
483 const RuntimeShape output_state_shape =
484 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
485 const RuntimeShape output_activ_shape =
486 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
487 const RuntimeShape concat_temp_shape =
488 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
489 const RuntimeShape activ_temp_shape = RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
490 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
491
492 const int weights_dim_count = weights_shape.DimensionsCount();
493 const int batches = MatchingDim(input_shape, 0, prev_activ_shape, 0, prev_state_shape, 0,
494 output_state_shape, 0, output_activ_shape, 0);
495 const int height = MatchingDim(input_shape, 1, prev_activ_shape, 1, prev_state_shape, 1,
496 output_state_shape, 1, output_activ_shape, 1);
497 const int width = MatchingDim(input_shape, 2, prev_activ_shape, 2, prev_state_shape, 2,
498 output_state_shape, 2, output_activ_shape, 2);
499 const int input_depth = input_shape.Dims(3);
500 const int prev_activ_depth = prev_activ_shape.Dims(3);
501 const int total_input_depth = prev_activ_depth + input_depth;
502 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1), total_input_depth);
503 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
504 const int intern_activ_depth = MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
505 TFLITE_DCHECK_EQ(weights_shape.FlatSize(), intern_activ_depth * total_input_depth);
506 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
507 const int output_depth = MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
508 3, output_activ_shape, 3);
509 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
510
511 // Concatenate prev_activ and input data together
512 std::vector<float const *> concat_input_arrays_data;
513 std::vector<RuntimeShape const *> concat_input_arrays_shapes;
514 concat_input_arrays_data.push_back(input_data);
515 concat_input_arrays_data.push_back(prev_activ_data);
516 concat_input_arrays_shapes.push_back(&input_shape);
517 concat_input_arrays_shapes.push_back(&prev_activ_shape);
518 tflite::ConcatenationParams concat_params;
519 concat_params.axis = 3;
520 concat_params.inputs_count = concat_input_arrays_data.size();
521 Concatenation(concat_params, &(concat_input_arrays_shapes[0]), &(concat_input_arrays_data[0]),
522 concat_temp_shape, concat_temp_data);
523
524 // Fully connected
525 tflite::FullyConnectedParams fc_params;
526 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
527 fc_params.float_activation_max = std::numeric_limits<float>::max();
528 FullyConnected(fc_params, concat_temp_shape, concat_temp_data, weights_shape, weights_data,
529 bias_shape, bias_data, activ_temp_shape, activ_temp_data);
530
531 // Memory state update (the LSTM "guts")
532 for (int b = 0; b < batches; ++b)
533 {
534 for (int w = 0; w < width; ++w)
535 {
536 for (int h = 0; h < height; ++h)
537 {
538 for (int c = 0; c < output_depth; ++c)
539 {
540 const float input_gate =
541 1.f /
542 (1.f +
543 std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w, 0 * output_depth + c)]));
544 const float new_input =
545 std::tanh(activ_temp_data[Offset(activ_temp_shape, b, h, w, 1 * output_depth + c)]);
546 const float forget_gate =
547 1.f /
548 (1.f +
549 std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w, 2 * output_depth + c)]));
550 const float output_gate =
551 1.f /
552 (1.f +
553 std::exp(-activ_temp_data[Offset(activ_temp_shape, b, h, w, 3 * output_depth + c)]));
554 const float new_state =
555 input_gate * new_input +
556 forget_gate * prev_state_data[Offset(prev_state_shape, b, h, w, c)];
557 output_state_data[Offset(output_state_shape, b, h, w, c)] = new_state;
558 output_activ_data[Offset(output_activ_shape, b, h, w, c)] =
559 output_gate * std::tanh(new_state);
560 }
561 }
562 }
563 }
564}
565
566// Quantized LSTM cell implementation.
567// The quantization of the input, output arrays is as follows:
568// - The input activations are quantized as uint8 on the interval
569// [-1, 127/128].
570// The rationale for that is that is the natural interval for output
571// activations (see next point) and these need to be concatenated together.
572// We could accommodate different ranges by re-scaling, but we empirically
573// found that setting the input activations range to be [-1, 127/128] in the
574// first place, removing the need for re-scaling, greatly improves accuracy.
575// - The output activations are quantized as uint8 on the interval
576// [-1, 127/128].
577// The rationale for that is that the definition of a LSTM cell makes them
578// intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128]
579// makes for simpler, more accurate fixed-point arithmetic.
580// - The output-at-previous-timestep state array is obviously quantized as
581// the output activations.
582// - The internal LSTM memory (not the output-at-previous-timestep, the other
583// internal state array) is int16-quantized and may use any power-of-two,
584// symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call
585// StateIntegerBits below, see the below discussion of that template
586// parameter ("The StateIntegerBits template parameter").
587// - The output of the internal fully-connected node is int16-quantized
588// on the interval [-8, 8 * 32767/32768], the rationale for which is
589// explained just below ("Why [-8, 8] for fully-connected output?").
590//
591//
592// === The StateIntegerBits template parameter ===
593//
594// The StateIntegerBits template parameter controls the fixed-point format used
595// to represent the internal memory of the LSTM cell (not the
596// output-at-previous-timestep, the other internal state array). It's currently
597// a template parameter so that the model can control that. The most typical
598// value for StateIntegerBits is 4. Other plausible values are anywhere between
599// 3 and 5. We might eventually standardize on a single supported value, e.g. 4,
600// and drop that template parameter. The reason why it can't be a runtime
601// parameter is that this controls the fixed-point format used, i.e. we need to
602// generate actually different code based on it. In particular, we generate code
603// for a fixed-point tanh() implementation for that format, which internally
604// uses a fixed-point exp() implementation, which internally uses a
605// barrel-shifter with a number of steps that depends on StateIntegerBits.
606// Another consequence of that is that a higher value of StateIntegerBits
607// results in a more expensive implementation (more barrel shifter steps
608// needed).
609//
610//
611// === Why [-8, 8] for fully-connected output? ===
612//
613// This array is only fed to Logistic and Tanh functions, for which
614// the quantized implementation will want to use fixed-point arithmetic,
615// requiring a power-of-two representation interval. Thus, we should right
616// away quantize this array to a power-of-two interval; otherwise,
617// implementation will need to rescale that, losing any benefit that a tighter
618// representation interval might otherwise yield, while introducing some
619// numerical error and computational overhead.
620//
621// Now, Logistic and Tanh
622// are nearly constant (nearly equal to their horizontal asymptotes)
623// outside of a small bounded interval around 0:
624//
625// Logistic(4) = 1 - 1.8e-2 Tanh(4) = 1 - 6.7e-4
626// Logistic(8) = 1 - 3.4e-4 Tanh(8) = 1 - 2.3e-7
627// Logistic(16) = 1 - 1.1e-7 Tanh(16) = 1 - 2.5e-14
628//
629// From this, we see that clamping to [-4, 4] would be too inaccurate
630// (the error of 1.8e-2 on Logistic would be felt even in 8bit precision)
631// while clamping to [-16, 16] would make no difference even in float32.
632// However, for a fixed-point implementation in 16-bit integers, using 5
633// integer bits to represent the [-16, 16] range would leave only 11
634// fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive
635// representable values. Notice that is higher than the
636// worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic.
637// Using [-8, 8] thus seems like the better compromise overall, enjoying
638// an increment of 2.4e-4 between representable values and a worst-case
639// clamping error of 3.4e-4, both better than the increment of 4.9e-4 with
640// [-16, 16].
641//
642// Moreover, all other things being equal, it is nice to choose the narrower
643// representation range, as that makes the implementation of fixed-point
644// math functions a little cheaper (each integer bit requires an additional
645// barrel-shifter atep in the implementation of exp(-x)). That is further
646// reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make
647// sense for 32-bit float or 32-bit fixed-point quantization, but we are
648// aiming for 16-bit fixed-point quantization of these internal nodes here.
649//
650template <int StateIntegerBits>
651inline void
652LstmCell(const LstmCellParams &params, const RuntimeShape &unextended_input_shape,
653 const uint8 *input_data_uint8, const RuntimeShape &unextended_prev_activ_shape,
654 const uint8 *prev_activ_data_uint8, const RuntimeShape &weights_shape,
655 const uint8 *weights_data_uint8, const RuntimeShape &unextended_bias_shape,
656 const int32 *bias_data_int32, const RuntimeShape &unextended_prev_state_shape,
657 const int16 *prev_state_data_int16, const RuntimeShape &unextended_output_state_shape,
658 int16 *output_state_data_int16, const RuntimeShape &unextended_output_activ_shape,
659 uint8 *output_activ_data_uint8, const RuntimeShape &unextended_concat_temp_shape,
660 uint8 *concat_temp_data_uint8, const RuntimeShape &unextended_activ_temp_shape,
661 int16 *activ_temp_data_int16, void *gemmlowp_context)
662{
663 (void)gemmlowp_context; // only used in optimized code.
664 int32 weights_zero_point = params.weights_zero_point;
665 int32 accum_multiplier = params.accum_multiplier;
666 int accum_shift = params.accum_shift;
667 TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
668 TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
669 TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
670 TFLITE_DCHECK_LE(unextended_prev_state_shape.DimensionsCount(), 4);
671 TFLITE_DCHECK_LE(unextended_output_state_shape.DimensionsCount(), 4);
672 TFLITE_DCHECK_LE(unextended_output_activ_shape.DimensionsCount(), 4);
673 TFLITE_DCHECK_LE(unextended_concat_temp_shape.DimensionsCount(), 4);
674 TFLITE_DCHECK_LE(unextended_activ_temp_shape.DimensionsCount(), 4);
675 const RuntimeShape input_shape = RuntimeShape::ExtendedShape(4, unextended_input_shape);
676 const RuntimeShape prev_activ_shape = RuntimeShape::ExtendedShape(4, unextended_prev_activ_shape);
677 const RuntimeShape bias_shape = RuntimeShape::ExtendedShape(4, unextended_bias_shape);
678 const RuntimeShape prev_state_shape = RuntimeShape::ExtendedShape(4, unextended_prev_state_shape);
679 const RuntimeShape output_state_shape =
680 RuntimeShape::ExtendedShape(4, unextended_output_state_shape);
681 const RuntimeShape output_activ_shape =
682 RuntimeShape::ExtendedShape(4, unextended_output_activ_shape);
683 const RuntimeShape concat_temp_shape =
684 RuntimeShape::ExtendedShape(4, unextended_concat_temp_shape);
685 const RuntimeShape activ_temp_shape = RuntimeShape::ExtendedShape(4, unextended_activ_temp_shape);
686 TFLITE_DCHECK_GE(weights_shape.DimensionsCount(), 2);
687
688 // Gather dimensions information, and perform consistency checks.
689 const int weights_dim_count = weights_shape.DimensionsCount();
690 const int outer_size = MatchingFlatSizeSkipDim(input_shape, 3, prev_activ_shape, prev_state_shape,
691 output_state_shape, output_activ_shape);
692 const int input_depth = input_shape.Dims(3);
693 const int prev_activ_depth = prev_activ_shape.Dims(3);
694 const int total_input_depth = prev_activ_depth + input_depth;
695 TFLITE_DCHECK_EQ(weights_shape.Dims(weights_dim_count - 1), total_input_depth);
696 const int intern_activ_depth = MatchingDim(weights_shape, weights_dim_count - 2, bias_shape, 3);
697 TFLITE_DCHECK_EQ(weights_shape.FlatSize(), intern_activ_depth * total_input_depth);
698 TFLITE_DCHECK_EQ(FlatSizeSkipDim(bias_shape, 3), 1);
699 TFLITE_DCHECK_EQ(intern_activ_depth % 4, 0);
700 const int output_depth = MatchingDim(prev_state_shape, 3, prev_activ_shape, 3, output_state_shape,
701 3, output_activ_shape, 3);
702 TFLITE_DCHECK_EQ(output_depth, intern_activ_depth / 4);
703 const int fc_batches = FlatSizeSkipDim(activ_temp_shape, 3);
704 const int fc_output_depth =
705 MatchingDim(weights_shape, weights_dim_count - 2, activ_temp_shape, 3);
706 const int fc_accum_depth = total_input_depth;
707 TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth);
708
709 // Depth-concatenate prev_activ and input data together.
710 uint8 const *concat_input_arrays_data[2] = {input_data_uint8, prev_activ_data_uint8};
711 const RuntimeShape *concat_input_arrays_shapes[2] = {&input_shape, &prev_activ_shape};
712 tflite::ConcatenationParams concat_params;
713 concat_params.axis = 3;
714 concat_params.inputs_count = 2;
715 Concatenation(concat_params, concat_input_arrays_shapes, concat_input_arrays_data,
716 concat_temp_shape, concat_temp_data_uint8);
717
718 // Implementation of the fully connected node inside the LSTM cell.
719 // The operands are 8-bit integers, the accumulators are internally 32bit
720 // integers, and the output is 16-bit fixed-point with 3 integer bits so
721 // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
722 // is explained in the function comment above.
723 for (int b = 0; b < fc_batches; ++b)
724 {
725 for (int out_c = 0; out_c < fc_output_depth; ++out_c)
726 {
727 // Internal accumulation.
728 // Initialize accumulator with the bias-value.
729 int32 accum = bias_data_int32[out_c];
730 // Accumulation loop.
731 for (int d = 0; d < fc_accum_depth; ++d)
732 {
733 int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
734 int16 weights_val = weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
735 accum += input_val * weights_val;
736 }
737 // Down-scale the final int32 accumulator to the scale used by our
738 // (16-bit, using 3 integer bits) fixed-point format. The quantized
739 // multiplier and shift here have been pre-computed offline
740 // (e.g. by toco).
741 accum = MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
742 // Saturate, cast to int16, and store to the temporary activations array.
743 accum = std::max(-32768, std::min(32767, static_cast<int>(accum)));
744 activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
745 }
746 }
747
748 // Rest of the LSTM cell: tanh and logistic math functions, and some adds
749 // and muls, all done in 16-bit fixed-point.
750 for (int b = 0; b < outer_size; ++b)
751 {
752 for (int c = 0; c < output_depth; ++c)
753 {
754 // Define the fixed-point data types that we will use here. All use
755 // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
756 // They only differ by the number of integral vs. fractional bits,
757 // determining the range of values that they can represent.
758 //
759 // F0 uses 0 integer bits, range [-1, 1].
760 // This is the return type of math functions such as tanh, logistic,
761 // whose range is in [-1, 1].
762 using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
763 // F3 uses 3 integer bits, range [-8, 8].
764 // This is the range of the previous fully-connected node's output,
765 // which is our input here.
766 using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
767 // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
768 // 2^StateIntegerBits]. It's used to represent the internal state, whose
769 // number of integer bits is currently dictated by the model. See comment
770 // on the StateIntegerBits template parameter above.
771 using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
772 // Implementation of input gate, using fixed-point logistic function.
773 F3 input_gate_input =
774 F3::FromRaw(activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
775 F0 input_gate_output = gemmlowp::logistic(input_gate_input);
776 // Implementation of input modulation gate, using fixed-point tanh
777 // function.
778 F3 input_modulation_gate_input =
779 F3::FromRaw(activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
780 F0 input_modulation_gate_output = gemmlowp::tanh(input_modulation_gate_input);
781 // Implementation of forget gate, using fixed-point logistic function.
782 F3 forget_gate_input =
783 F3::FromRaw(activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
784 F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
785 // Implementation of output gate, using fixed-point logistic function.
786 F3 output_gate_input =
787 F3::FromRaw(activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
788 F0 output_gate_output = gemmlowp::logistic(output_gate_input);
789 // Implementation of internal multiplication nodes, still in fixed-point.
790 F0 input_times_input_modulation = input_gate_output * input_modulation_gate_output;
791 FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]);
792 FS prev_state_times_forget_state = forget_gate_output * prev_state;
793 // Implementation of internal addition node, saturating.
794 FS new_state =
795 gemmlowp::SaturatingAdd(gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
796 prev_state_times_forget_state);
797 // Implementation of last internal Tanh node, still in fixed-point.
798 // Since a Tanh fixed-point implementation is specialized for a given
799 // number or integer bits, and each specialization can have a substantial
800 // code size, and we already used above a Tanh on an input with 3 integer
801 // bits, and per the table in the above function comment there is no
802 // significant accuracy to be lost by clamping to [-8, +8] for a
803 // 3-integer-bits representation, let us just do that. This helps people
804 // porting this to targets where code footprint must be minimized.
805 F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
806 F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
807 // Store the new internal state back to memory, as 16-bit integers.
808 // Note: here we store the original value with StateIntegerBits, not
809 // the rescaled 3-integer-bits value fed to tanh.
810 output_state_data_int16[b * output_depth + c] = new_state.raw();
811 // Down-scale the output activations to 8-bit integers, saturating,
812 // and store back to memory.
813 int16 rescaled_output_activ = gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
814 int16 clamped_output_activ =
815 std::max<int16>(-128, std::min<int16>(127, rescaled_output_activ));
816 output_activ_data_uint8[b * output_depth + c] = 128 + clamped_output_activ;
817 }
818 }
819}
820
821template <typename Scalar>
822void Split(const SplitParams &params, const RuntimeShape &input_shape, const Scalar *input_data,
823 const RuntimeShape *const *output_shapes, Scalar *const *output_data)
824{
825 ruy::profiler::ScopeLabel label("Split");
826 const int split_dimensions = input_shape.DimensionsCount();
827 int axis = params.axis < 0 ? params.axis + split_dimensions : params.axis;
828 int outputs_count = params.num_split;
829 TFLITE_DCHECK_LT(axis, split_dimensions);
830
831 int64_t split_size = 0;
832 for (int i = 0; i < outputs_count; i++)
833 {
834 TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), split_dimensions);
835 for (int j = 0; j < split_dimensions; j++)
836 {
837 if (j != axis)
838 {
839 MatchingDim(*output_shapes[i], j, input_shape, j);
840 }
841 }
842 split_size += output_shapes[i]->Dims(axis);
843 }
844 TFLITE_DCHECK_EQ(split_size, input_shape.Dims(axis));
845 int64_t outer_size = 1;
846 for (int i = 0; i < axis; ++i)
847 {
848 outer_size *= input_shape.Dims(i);
849 }
850 // For all output arrays,
851 // FlatSize() = outer_size * Dims(axis) * base_inner_size;
852 int64_t base_inner_size = 1;
853 for (int i = axis + 1; i < split_dimensions; ++i)
854 {
855 base_inner_size *= input_shape.Dims(i);
856 }
857
858 const Scalar *input_ptr = input_data;
859 for (int k = 0; k < outer_size; k++)
860 {
861 for (int i = 0; i < outputs_count; ++i)
862 {
863 const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
864 memcpy(output_data[i] + k * copy_size, input_ptr, copy_size * sizeof(Scalar));
865 input_ptr += copy_size;
866 }
867 }
868}
869
870inline int NodeOffset(int b, int h, int w, int height, int width)
871{
872 return (b * height + h) * width + w;
873}
874
875inline void LocalResponseNormalization(const tflite::LocalResponseNormalizationParams &op_params,
876 const RuntimeShape &input_shape, const float *input_data,
877 const RuntimeShape &output_shape, float *output_data)
878{
879 const int trailing_dim = input_shape.DimensionsCount() - 1;
880 const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
881 const int depth = MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
882
883 for (int i = 0; i < outer_size; ++i)
884 {
885 for (int c = 0; c < depth; ++c)
886 {
887 const int begin_input_c = std::max(0, static_cast<int>(c - op_params.range));
888 const int end_input_c = std::min(depth, static_cast<int>(c + op_params.range));
889 float accum = 0.f;
890 for (int input_c = begin_input_c; input_c < end_input_c; ++input_c)
891 {
892 const float input_val = input_data[i * depth + input_c];
893 accum += input_val * input_val;
894 }
895 const float multiplier = std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta);
896 output_data[i * depth + c] = input_data[i * depth + c] * multiplier;
897 }
898 }
899}
900
901inline void Dequantize(const RuntimeShape &input_shape, const Eigen::half *input_data,
902 const RuntimeShape &output_shape, float *output_data)
903{
904 const int flat_size = MatchingFlatSize(input_shape, output_shape);
905 for (int i = 0; i < flat_size; i++)
906 {
907 output_data[i] = static_cast<float>(input_data[i]);
908 }
909}
910
911inline void FakeQuant(const tflite::FakeQuantParams &op_params, const RuntimeShape &input_shape,
912 const float *input_data, const RuntimeShape &output_shape, float *output_data)
913{
914 ruy::profiler::ScopeLabel label("FakeQuant");
915 float rmin = op_params.minmax.min;
916 float rmax = op_params.minmax.max;
917 int num_bits = op_params.num_bits;
918 // 0 should always be a representable value. Let's assume that the initial
919 // min,max range contains 0.
920 TFLITE_DCHECK_LE(rmin, 0.0f);
921 TFLITE_DCHECK_GE(rmax, 0.0f);
922 TFLITE_DCHECK_LT(rmin, rmax);
923
924 // Code matches tensorflow's FakeQuantWithMinMaxArgsFunctor.
925 int quant_min = 0;
926 int quant_max = (1 << num_bits) - 1;
927 float nudged_min, nudged_max, nudged_scale;
928 NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min, &nudged_max, &nudged_scale);
929 const int flat_size = MatchingFlatSize(input_shape, output_shape);
930 FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data, output_data, flat_size);
931}
932
933// Common subroutine for both `GatherNd` and `GatherNdString`.
934struct GatherNdHelperResult
935{
936 int n_slices;
937 int slice_size;
938 int indices_nd;
939 std::vector<int> dims_to_count;
940};
941
942// Returns common values being used on both `GatherNd` and `GatherNdString`.
943inline GatherNdHelperResult GatherNdHelper(const RuntimeShape &params_shape,
944 const RuntimeShape &indices_shape)
945{
946 GatherNdHelperResult ret;
947 ret.n_slices = 1;
948 ret.slice_size = 1;
949 const int indices_dims = indices_shape.DimensionsCount();
950 ret.indices_nd = indices_shape.Dims(indices_dims - 1);
951 const int params_dims = params_shape.DimensionsCount();
952 for (int i = 0; i < indices_dims - 1; ++i)
953 {
954 ret.n_slices *= indices_shape.Dims(i);
955 }
956 for (int i = ret.indices_nd; i < params_dims; ++i)
957 {
958 ret.slice_size *= params_shape.Dims(i);
959 }
960
961 int remain_flat_size = params_shape.FlatSize();
962 ret.dims_to_count = std::vector<int>(ret.indices_nd, 0);
963 for (int i = 0; i < ret.indices_nd; ++i)
964 {
965 ret.dims_to_count[i] = remain_flat_size / params_shape.Dims(i);
966 remain_flat_size = ret.dims_to_count[i];
967 }
968
969 return ret;
970}
971
972template <typename ParamsT, typename IndicesT = int32>
973inline void GatherNd(const RuntimeShape &params_shape, const ParamsT *params_data,
974 const RuntimeShape &indices_shape, const IndicesT *indices_data,
975 const RuntimeShape &output_shape, ParamsT *output_data)
976{
977 ruy::profiler::ScopeLabel label("GatherNd");
978
979 const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
980 for (int i = 0; i < res.n_slices; ++i)
981 {
982 int from_pos = 0;
983 for (int j = 0; j < res.indices_nd; ++j)
984 {
985 from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
986 }
987 std::memcpy(output_data + i * res.slice_size, params_data + from_pos,
988 sizeof(ParamsT) * res.slice_size);
989 }
990}
991
992#ifndef TF_LITE_STATIC_MEMORY
993template <typename IndicesT = int32>
994inline void GatherNdString(const RuntimeShape &params_shape, const TfLiteTensor *params_data,
995 const RuntimeShape &indices_shape, const IndicesT *indices_data,
996 const RuntimeShape &output_shape, TfLiteTensor *output_data)
997{
998 ruy::profiler::ScopeLabel label("GatherNdString");
999
1000 const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
1001 DynamicBuffer buffer;
1002 for (int i = 0; i < res.n_slices; ++i)
1003 {
1004 int from_pos = 0;
1005 for (int j = 0; j < res.indices_nd; ++j)
1006 {
1007 from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
1008 }
1009 for (int j = 0; j < res.slice_size; ++j)
1010 {
1011 buffer.AddString(GetString(params_data, from_pos + j));
1012 }
1013 }
1014 buffer.WriteToTensor(output_data, /*new_shape=*/nullptr);
1015}
1016#endif
1017
1018template <typename IndicesT, typename UpdatesT>
1019inline void ScatterNd(const RuntimeShape &indices_shape, const IndicesT *indices_data,
1020 const RuntimeShape &updates_shape, const UpdatesT *updates_data,
1021 const RuntimeShape &output_shape, UpdatesT *output_data)
1022{
1023 ruy::profiler::ScopeLabel label("ScatterNd");
1024
1025 int n_slices = 1;
1026 int slice_size = 1;
1027 const int outer_dims = indices_shape.DimensionsCount() - 1;
1028 const int indices_nd = indices_shape.Dims(outer_dims);
1029 const int updates_dims = updates_shape.DimensionsCount();
1030 for (int i = 0; i < outer_dims; ++i)
1031 {
1032 n_slices *= indices_shape.Dims(i);
1033 }
1034 for (int i = outer_dims; i < updates_dims; ++i)
1035 {
1036 slice_size *= updates_shape.Dims(i);
1037 }
1038
1039 int output_flat_size = output_shape.FlatSize();
1040 int remain_flat_size = output_flat_size;
1041 std::vector<int> dims_to_count(indices_nd, 0);
1042 for (int i = 0; i < indices_nd; ++i)
1043 {
1044 dims_to_count[i] = remain_flat_size / output_shape.Dims(i);
1045 remain_flat_size = dims_to_count[i];
1046 }
1047
1048 memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
1049 for (int i = 0; i < n_slices; ++i)
1050 {
1051 int to_pos = 0;
1052 for (int j = 0; j < indices_nd; ++j)
1053 {
1054 IndicesT idx = indices_data[i * indices_nd + j];
1055 TFLITE_DCHECK(0 <= idx && idx < output_shape.Dims(j));
1056 to_pos += idx * dims_to_count[j];
1057 }
1058 for (int j = 0; j < slice_size; j++)
1059 {
1060 output_data[to_pos + j] += updates_data[i * slice_size + j];
1061 }
1062 }
1063}
1064
1065template <typename T>
1066inline void Slice(const tflite::SliceParams &op_params, const RuntimeShape &input_shape,
1067 const RuntimeShape &output_shape, SequentialTensorWriter<T> *writer)
1068{
1069 const RuntimeShape ext_shape = RuntimeShape::ExtendedShape(5, input_shape);
1070 TFLITE_DCHECK_LE(op_params.begin_count, 5);
1071 TFLITE_DCHECK_LE(op_params.size_count, 5);
1072 const int begin_count = op_params.begin_count;
1073 const int size_count = op_params.size_count;
1074 // We front-pad the begin and size vectors.
1075 std::array<int, 5> start;
1076 std::array<int, 5> stop;
1077 for (int i = 0; i < 5; ++i)
1078 {
1079 int padded_i = 5 - i;
1080 start[i] = begin_count < padded_i ? 0 : op_params.begin[begin_count - padded_i];
1081 stop[i] = (size_count < padded_i || op_params.size[size_count - padded_i] == -1)
1082 ? ext_shape.Dims(i)
1083 : start[i] + op_params.size[size_count - padded_i];
1084 }
1085
1086 for (int i0 = start[0]; i0 < stop[0]; ++i0)
1087 {
1088 for (int i1 = start[1]; i1 < stop[1]; ++i1)
1089 {
1090 for (int i2 = start[2]; i2 < stop[2]; ++i2)
1091 {
1092 for (int i3 = start[3]; i3 < stop[3]; ++i3)
1093 {
1094 for (int i4 = start[4]; i4 < stop[4]; ++i4)
1095 {
1096 writer->Write(Offset(ext_shape, i0, i1, i2, i3, i4));
1097 }
1098 }
1099 }
1100 }
1101 }
1102}
1103
1104template <typename T>
1105inline void Slice(const tflite::SliceParams &op_params, const RuntimeShape &input_shape,
1106 const T *input_data, const RuntimeShape &output_shape, T *output_data)
1107{
1108 SequentialTensorWriter<T> writer(input_data, output_data);
1109 return Slice(op_params, input_shape, output_shape, &writer);
1110}
1111
1112template <typename T>
1113inline void Slice(const tflite::SliceParams &op_params, const RuntimeShape &input_shape,
1114 const TfLiteTensor *input, const RuntimeShape &output_shape, TfLiteTensor *output)
1115{
1116 SequentialTensorWriter<T> writer(input, output);
1117 return Slice(op_params, input_shape, output_shape, &writer);
1118}
1119
1120template <typename T>
1121void Minimum(const RuntimeShape &input1_shape, const T *input1_data, const T *input2_data,
1122 const RuntimeShape &output_shape, T *output_data)
1123{
1124 const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1125
1126 auto min_value = input2_data[0];
1127 for (int i = 0; i < flat_size; i++)
1128 {
1129 output_data[i] = input1_data[i] > min_value ? min_value : input1_data[i];
1130 }
1131}
1132
1133// Convenience version that allows, for example, generated-code calls to be
1134// the same as other binary ops.
1135template <typename T>
1136inline void Minimum(const RuntimeShape &input1_shape, const T *input1_data, const RuntimeShape &,
1137 const T *input2_data, const RuntimeShape &output_shape, T *output_data)
1138{
1139 // Drop shape of second input: not needed.
1140 Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
1141}
1142
1143template <typename T>
1144void Maximum(const RuntimeShape &input1_shape, const T *input1_data, const T *input2_data,
1145 const RuntimeShape &output_shape, T *output_data)
1146{
1147 const int flat_size = MatchingFlatSize(input1_shape, output_shape);
1148
1149 auto max_value = input2_data[0];
1150 for (int i = 0; i < flat_size; i++)
1151 {
1152 output_data[i] = input1_data[i] < max_value ? max_value : input1_data[i];
1153 }
1154}
1155
1156// Convenience version that allows, for example, generated-code calls to be
1157// the same as other binary ops.
1158template <typename T>
1159inline void Maximum(const RuntimeShape &input1_shape, const T *input1_data, const RuntimeShape &,
1160 const T *input2_data, const RuntimeShape &output_shape, T *output_data)
1161{
1162 // Drop shape of second input: not needed.
1163 Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
1164}
1165
1166template <typename T1, typename T2, typename T3>
1167void ArgMax(const RuntimeShape &input1_shape, const T1 *input1_data, const T3 *input2_data,
1168 const RuntimeShape &output_shape, T2 *output_data)
1169{
1170 ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data, std::greater<T1>());
1171}
1172
1173// Convenience version that allows, for example, generated-code calls to be
1174// the same as other binary ops.
1175template <typename T1, typename T2, typename T3>
1176inline void ArgMax(const RuntimeShape &input1_shape, const T1 *input1_data,
1177 const RuntimeShape &input2_shape, const T3 *input2_data,
1178 const RuntimeShape &output_shape, T2 *output_data)
1179{
1180 // Drop shape of second input: not needed.
1181 ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
1182}
1183
1184template <typename D, typename T>
1185void Select(const RuntimeShape &input_condition_shape, const D *input_condition_data,
1186 const RuntimeShape &input_x_shape, const T *input_x_data,
1187 const RuntimeShape &input_y_shape, const T *input_y_data,
1188 const RuntimeShape &output_shape, T *output_data)
1189{
1190 int64_t flatsize;
1191 // Allow select operator executions on mixed scalar tensors and one element
1192 // tensors.
1193 if (input_condition_shape.FlatSize() == 1 && input_x_shape.FlatSize() == 1 &&
1194 input_y_shape.FlatSize() == 1 && output_shape.FlatSize() == 1)
1195 {
1196 flatsize = 1;
1197 }
1198 else
1199 {
1200 flatsize = MatchingFlatSize(input_condition_shape, input_x_shape, input_y_shape, output_shape);
1201 }
1202 for (int64_t i = 0; i < flatsize; ++i)
1203 {
1204 output_data[i] = input_condition_data[i] ? input_x_data[i] : input_y_data[i];
1205 }
1206}
1207
1208template <typename D, typename T>
1209void RankOneSelect(const RuntimeShape &input_condition_shape, const D *input_condition_data,
1210 const RuntimeShape &input_x_shape, const T *input_x_data,
1211 const RuntimeShape &input_y_shape, const T *input_y_data,
1212 const RuntimeShape &output_shape, T *output_data)
1213{
1214 const int64_t outer_size = input_condition_shape.FlatSize();
1215 int64_t inner_size;
1216 if (input_condition_shape.DimensionsCount() == 0)
1217 {
1218 inner_size = MatchingFlatSize(input_x_shape, input_y_shape, output_shape);
1219 }
1220 else
1221 {
1222 TFLITE_DCHECK_EQ(MatchingDim(input_x_shape, 0, input_y_shape, 0, output_shape, 0), outer_size);
1223 inner_size = MatchingFlatSizeSkipDim(input_x_shape, 0, input_y_shape, output_shape);
1224 }
1225
1226 int64_t offset = 0;
1227 for (int64_t i = 0; i < outer_size; i++)
1228 {
1229 const T *input_data = input_condition_data[i] ? input_x_data : input_y_data;
1230 memcpy(output_data + offset, input_data + offset, inner_size * sizeof(T));
1231 offset += inner_size;
1232 }
1233}
1234
1235template <typename D, typename T>
1236void BroadcastSelect4DSlow(const RuntimeShape &input_condition_shape, const D *input_condition_data,
1237 const RuntimeShape &input_x_shape, const T *input_x_data,
1238 const RuntimeShape &input_y_shape, const T *input_y_data,
1239 const RuntimeShape &output_shape, T *output_data)
1240{
1241 TFLITE_DCHECK_LE(input_condition_shape.DimensionsCount(), 4);
1242 TFLITE_DCHECK_LE(input_x_shape.DimensionsCount(), 4);
1243 TFLITE_DCHECK_LE(input_y_shape.DimensionsCount(), 4);
1244 TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
1245
1246 const RuntimeShape extended_output_shape = RuntimeShape::ExtendedShape(4, output_shape);
1247
1248 NdArrayDesc<4> desc_condition;
1249 NdArrayDesc<4> desc_x;
1250 NdArrayDesc<4> desc_y;
1251 NdArrayDescsForElementwiseBroadcast(input_condition_shape, input_x_shape, input_y_shape,
1252 &desc_condition, &desc_x, &desc_y);
1253
1254 // In Tensorflow, the dimensions are canonically named (batch_number, row,
1255 // col, channel), with extents (batches, height, width, depth), with the
1256 // trailing dimension changing most rapidly (channels has the smallest
1257 // stride, typically 1 element).
1258 //
1259 // In generated C code, we store arrays with the dimensions reversed. The
1260 // first dimension has smallest stride.
1261 //
1262 // We name our variables by their Tensorflow convention, but generate C code
1263 // nesting loops such that the innermost loop has the smallest stride for
1264 // the best cache behavior.
1265 for (int b = 0; b < extended_output_shape.Dims(0); ++b)
1266 {
1267 for (int y = 0; y < extended_output_shape.Dims(1); ++y)
1268 {
1269 for (int x = 0; x < extended_output_shape.Dims(2); ++x)
1270 {
1271 for (int c = 0; c < extended_output_shape.Dims(3); ++c)
1272 {
1273 const int condition_index = SubscriptToIndex(desc_condition, b, y, x, c);
1274 const int x_index = SubscriptToIndex(desc_x, b, y, x, c);
1275 const int y_index = SubscriptToIndex(desc_y, b, y, x, c);
1276 output_data[Offset(extended_output_shape, b, y, x, c)] =
1277 input_condition_data[condition_index] ? input_x_data[x_index] : input_y_data[y_index];
1278 }
1279 }
1280 }
1281 }
1282}
1283
1284template <typename D, typename T>
1285void SelectTrueCoords(const RuntimeShape &input_condition_shape, const D *input_condition_data,
1286 T *output_data)
1287{
1288 const size_t size = input_condition_shape.FlatSize();
1289 if (size == 0)
1290 {
1291 // Dimension is zero, in which case we don't need to output.
1292 return;
1293 }
1294 const size_t cond_rank = input_condition_shape.DimensionsCount();
1295
1296 std::vector<int> dims_to_count(cond_rank, 0);
1297 int cur_flat_size = size;
1298 for (int i = 0; i < cond_rank; ++i)
1299 {
1300 dims_to_count[i] = cur_flat_size / input_condition_shape.Dims(i);
1301 cur_flat_size = dims_to_count[i];
1302 }
1303
1304 int output_index = 0;
1305 for (int i = 0; i < size; ++i)
1306 {
1307 if (input_condition_data[i])
1308 {
1309 // Insert the coordinate of the current item (row major) into output.
1310 int flat_index = i;
1311 for (int j = 0; j < cond_rank; ++j)
1312 {
1313 int coord_j = flat_index / dims_to_count[j];
1314 output_data[output_index * cond_rank + j] = coord_j;
1315 flat_index %= dims_to_count[j];
1316 }
1317 output_index++;
1318 }
1319 }
1320}
1321
1322// For easy implementation, the indices is always a vector of size-4 vectors.
1323template <typename T, typename TI>
1324inline void SparseToDense(const std::vector<std::vector<TI>> &indices, const T *values,
1325 T default_value, bool value_is_scalar,
1326 const RuntimeShape &unextended_output_shape, T *output_data)
1327{
1328 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1329 const RuntimeShape output_shape = RuntimeShape::ExtendedShape(4, unextended_output_shape);
1330 const int value_count = indices.size();
1331
1332 // First fill the output_data with default value.
1333 const int num_elements = output_shape.FlatSize();
1334 for (int i = 0; i < num_elements; ++i)
1335 {
1336 output_data[i] = default_value;
1337 }
1338
1339 // Special handle for value is scalar case to avoid checking the boolean
1340 // condition within the loop every time.
1341 if (value_is_scalar)
1342 {
1343 for (int i = 0; i < value_count; ++i)
1344 {
1345 const std::vector<TI> &index = indices[i];
1346 TFLITE_DCHECK_EQ(index.size(), 4);
1347 const T value = *values; // just use the first value.
1348 output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] = value;
1349 }
1350 return;
1351 }
1352
1353 // Go through the values and indices to fill the sparse values.
1354 for (int i = 0; i < value_count; ++i)
1355 {
1356 const std::vector<TI> &index = indices[i];
1357 TFLITE_DCHECK_EQ(index.size(), 4);
1358 const T value = values[i];
1359 output_data[Offset(output_shape, index[0], index[1], index[2], index[3])] = value;
1360 }
1361}
1362
1363template <typename T>
1364inline void Pow(const RuntimeShape &input1_shape, const T *input1_data,
1365 const RuntimeShape &input2_shape, const T *input2_data,
1366 const RuntimeShape &output_shape, T *output_data)
1367{
1368 const int flat_size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
1369 for (int i = 0; i < flat_size; ++i)
1370 {
1371 output_data[i] = std::pow(input1_data[i], input2_data[i]);
1372 }
1373}
1374
1375template <typename T>
1376inline void BroadcastPow4DSlow(const RuntimeShape &unextended_input1_shape, const T *input1_data,
1377 const RuntimeShape &unextended_input2_shape, const T *input2_data,
1378 const RuntimeShape &unextended_output_shape, T *output_data)
1379{
1380 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
1381 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
1382 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
1383 const RuntimeShape output_shape = RuntimeShape::ExtendedShape(4, unextended_output_shape);
1384
1387 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, unextended_input2_shape, &desc1,
1388 &desc2);
1389
1390 for (int b = 0; b < output_shape.Dims(0); ++b)
1391 {
1392 for (int y = 0; y < output_shape.Dims(1); ++y)
1393 {
1394 for (int x = 0; x < output_shape.Dims(2); ++x)
1395 {
1396 for (int c = 0; c < output_shape.Dims(3); ++c)
1397 {
1398 auto out_idx = Offset(output_shape, b, y, x, c);
1399 auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
1400 auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
1401 auto in1_val = input1_data[in1_idx];
1402 auto in2_val = input2_data[in2_idx];
1403 output_data[out_idx] = std::pow(in1_val, in2_val);
1404 }
1405 }
1406 }
1407 }
1408}
1409
1410template <typename Scalar>
1411void Reverse(int axis, const RuntimeShape &input_shape, const Scalar *input_data,
1412 const RuntimeShape &output_shape, Scalar *output_data)
1413{
1414 ruy::profiler::ScopeLabel label("Reverse");
1415
1416 int outer_size = 1;
1417 for (int i = 0; i < axis; ++i)
1418 {
1419 outer_size *= input_shape.Dims(i);
1420 }
1421
1422 int copy_size = 1;
1423 for (int i = axis + 1; i < input_shape.DimensionsCount(); ++i)
1424 {
1425 copy_size *= input_shape.Dims(i);
1426 }
1427
1428 const int dims_at_axis = input_shape.Dims(axis);
1429 for (int i = 0; i < outer_size; ++i)
1430 {
1431 for (int j = 0; j < dims_at_axis; ++j)
1432 {
1433 const int start_pos = (i * dims_at_axis + j) * copy_size;
1434 Scalar *output_ptr = output_data + start_pos;
1435 int loc = (i * dims_at_axis + dims_at_axis - j - 1) * copy_size;
1436 memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
1437 }
1438 }
1439}
1440
1441template <typename Scalar, typename TS>
1442void ReverseSequence(const TS *seq_lengths, const int seq_dim, const int batch_dim,
1443 const RuntimeShape &input_shape, const Scalar *input_data,
1444 const RuntimeShape &output_shape, Scalar *output_data)
1445{
1446 ruy::profiler::ScopeLabel label("ReverseSequence");
1447
1448 int outer_size = 1;
1449 int outer_dim = std::min(batch_dim, seq_dim);
1450 int medium_dim = std::max(batch_dim, seq_dim);
1451 for (int i = 0; i < outer_dim; ++i)
1452 {
1453 outer_size *= input_shape.Dims(i);
1454 }
1455
1456 int medium_size = 1;
1457 for (int i = outer_dim + 1; i < medium_dim; ++i)
1458 {
1459 medium_size *= input_shape.Dims(i);
1460 }
1461
1462 int copy_size = 1;
1463 for (int i = medium_dim + 1; i < input_shape.DimensionsCount(); ++i)
1464 {
1465 copy_size *= input_shape.Dims(i);
1466 }
1467
1468 const int dims_at_outer_dim = input_shape.Dims(outer_dim);
1469 const int dims_at_medium_dim = input_shape.Dims(medium_dim);
1470
1471 Scalar *output_ptr;
1472 if (batch_dim > seq_dim)
1473 {
1474 for (int i = 0; i < outer_size; ++i)
1475 {
1476 for (int j = 0; j < dims_at_outer_dim; ++j)
1477 {
1478 const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1479 for (int p = 0; p < medium_size; ++p)
1480 {
1481 for (int q = 0; q < dims_at_medium_dim; ++q)
1482 {
1483 const int in_pos = ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1484 const Scalar *in_ptr = input_data + in_pos;
1485 int sl = seq_lengths[q] - 1;
1486 if (j > sl)
1487 {
1488 output_ptr = output_data + in_pos;
1489 }
1490 else
1491 {
1492 const int out_pos_base = (i * dims_at_outer_dim + sl - j) * medium_size;
1493 const int out_pos = ((out_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1494 output_ptr = output_data + out_pos;
1495 }
1496 memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1497 }
1498 }
1499 }
1500 }
1501 }
1502 else if (batch_dim < seq_dim)
1503 {
1504 for (int i = 0; i < outer_size; ++i)
1505 {
1506 for (int j = 0; j < dims_at_outer_dim; ++j)
1507 {
1508 const int in_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1509 int sl = seq_lengths[j] - 1;
1510 const int out_pos_base = (i * dims_at_outer_dim + j) * medium_size;
1511 for (int p = 0; p < medium_size; ++p)
1512 {
1513 for (int q = 0; q < dims_at_medium_dim; ++q)
1514 {
1515 const int in_pos = ((in_pos_base + p) * dims_at_medium_dim + q) * copy_size;
1516 const Scalar *in_ptr = input_data + in_pos;
1517 if (q > sl)
1518 {
1519 output_ptr = output_data + in_pos;
1520 }
1521 else
1522 {
1523 const int out_pos = ((out_pos_base + p) * dims_at_medium_dim + sl - q) * copy_size;
1524 output_ptr = output_data + out_pos;
1525 }
1526 memcpy(output_ptr, in_ptr, copy_size * sizeof(Scalar));
1527 }
1528 }
1529 }
1530 }
1531 }
1532}
1533
1534template <typename T>
1535inline void SegmentSum(const RuntimeShape &input_shape, const T *input_data,
1536 const RuntimeShape &segment_ids_shape, const int32_t *segment_ids_data,
1537 const RuntimeShape &output_shape, T *output_data)
1538{
1539 const int segment_flat_size = MatchingFlatSizeSkipDim(input_shape, 0, output_shape);
1540
1541 memset(output_data, 0, sizeof(T) * output_shape.FlatSize());
1542
1543 for (int i = 0; i < input_shape.Dims(0); i++)
1544 {
1545 int output_index = segment_ids_data[i];
1546 for (int j = 0; j < segment_flat_size; ++j)
1547 {
1548 output_data[output_index * segment_flat_size + j] += input_data[i * segment_flat_size + j];
1549 }
1550 }
1551}
1552
1553} // namespace reference_ops
1554} // namespace tflite
1555
1556#endif // LUCI_INTERPRETER_PAL_REFERENCE_OPS_H
void Concatenation(int concat_dim, const Scalar *const *input_data, const Dims< 4 > *const *input_dims, int inputs_count, Scalar *output_data, const Dims< 4 > &output_dims)
void FullyConnected(const float *input_data, const Dims< 4 > &input_dims, const float *weights_data, const Dims< 4 > &weights_dims, const float *bias_data, const Dims< 4 > &bias_dims, float *output_data, const Dims< 4 > &output_dims)
void NdArrayDescsForElementwiseBroadcast(const Dims< N > &input0_dims, const Dims< N > &input1_dims, NdArrayDesc< N > *desc0_out, NdArrayDesc< N > *desc1_out)
Definition NDArray.h:89
int SubscriptToIndex(const NdArrayDesc< 4 > &desc, int i0, int i1, int i2, int i3)
Definition NDArray.h:54
int Offset(const Dims< 4 > &dims, int i0, int i1, int i2, int i3)
Definition Dims.h:64
int MatchingFlatSize(const Dims< N > &dims, const Dims< N > &check_dims_0)
Definition Dims.h:108
std::uint8_t uint8
Definition Macro.h:52
std::int16_t int16
Definition Macro.h:53
__global uchar * offset(const Image *img, int x, int y)
Definition helpers.h:540
NdArrayDesc< 4 > desc1
const luci_interpreter::RuntimeShape output_shape
NdArrayDesc< 4 > desc2
result
Definition infer.py:103
list input_data
Definition infer.py:29
loco::GraphInputIndex index(const TFPlaceholder *node)
Definition TFNode.cpp:54
uint32_t num_elements(const Shape &shape)
The number of elements of a feature map of a given shape.
Definition Shape.h:59
void MulElementwise(int size, const BinaryArithmeticOpParam &params, const uint8_t *input1_data, const uint8_t *input2_data, uint8_t *output_data)
int MatchingDim(const Shape &shape1, int index1, const Shape &shape2, int index2)
Definition Shape.h:220
int FlatSizeSkipDim(const Shape &shape, int skip_dim)
Definition Shape.h:253
void ArgMinMax(const Shape &input1_shape, const T1 *input1_data, const Shape &output_shape, T2 *output_data, int32_t axis, const Cmp &cmp)
Definition ArgMinMax.h:29
int MatchingFlatSizeSkipDim(const Shape &shape, int skip_dim, const Shape &check_shape_0)
Definition Shape.h:304
int MatchingElementsSize(const Shape &shape, const Shape &check_shape_0, const Shape &check_shape_1)
Definition Shape.h:333
int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift)
Definition Utils.h:96
void GatherNd(const RuntimeShape &params_shape, const ParamsT *params_data, const RuntimeShape &indices_shape, const IndicesT *indices_data, const RuntimeShape &output_shape, ParamsT *output_data)
void LocalResponseNormalization(const tflite::LocalResponseNormalizationParams &op_params, const RuntimeShape &input_shape, const float *input_data, const RuntimeShape &output_shape, float *output_data)
void FakeQuant(const tflite::FakeQuantParams &op_params, const RuntimeShape &input_shape, const float *input_data, const RuntimeShape &output_shape, float *output_data)
void Sub16(const ArithmeticParams &params, const RuntimeShape &input1_shape, const int16_t *input1_data, const RuntimeShape &input2_shape, const int16_t *input2_data, const RuntimeShape &output_shape, int16_t *output_data)
void LstmCell(const LstmCellParams &params, const RuntimeShape &unextended_input_shape, const float *input_data, const RuntimeShape &unextended_prev_activ_shape, const float *prev_activ_data, const RuntimeShape &weights_shape, const float *weights_data, const RuntimeShape &unextended_bias_shape, const float *bias_data, const RuntimeShape &unextended_prev_state_shape, const float *prev_state_data, const RuntimeShape &unextended_output_state_shape, float *output_state_data, const RuntimeShape &unextended_output_activ_shape, float *output_activ_data, const RuntimeShape &unextended_concat_temp_shape, float *concat_temp_data, const RuntimeShape &unextended_activ_temp_shape, float *activ_temp_data)
void Relu1(const RuntimeShape &input_shape, const T *input_data, const RuntimeShape &output_shape, T *output_data)
void Dequantize(const RuntimeShape &input_shape, const Eigen::half *input_data, const RuntimeShape &output_shape, float *output_data)
GatherNdHelperResult GatherNdHelper(const RuntimeShape &params_shape, const RuntimeShape &indices_shape)
void GatherNdString(const RuntimeShape &params_shape, const TfLiteTensor *params_data, const RuntimeShape &indices_shape, const IndicesT *indices_data, const RuntimeShape &output_shape, TfLiteTensor *output_data)
void SparseToDense(const std::vector< std::vector< TI > > &indices, const T *values, T default_value, bool value_is_scalar, const RuntimeShape &unextended_output_shape, T *output_data)
void ArgMax(const RuntimeShape &input1_shape, const T1 *input1_data, const T3 *input2_data, const RuntimeShape &output_shape, T2 *output_data)
void ReluX(const tflite::ReluParams &params, const RuntimeShape &input_shape, const T *input_data, const RuntimeShape &output_shape, T *output_data)
void Select(const RuntimeShape &input_condition_shape, const D *input_condition_data, const RuntimeShape &input_x_shape, const T *input_x_data, const RuntimeShape &input_y_shape, const T *input_y_data, const RuntimeShape &output_shape, T *output_data)
void BroadcastPow4DSlow(const RuntimeShape &unextended_input1_shape, const T *input1_data, const RuntimeShape &unextended_input2_shape, const T *input2_data, const RuntimeShape &unextended_output_shape, T *output_data)
void ReverseSequence(const TS *seq_lengths, const int seq_dim, const int batch_dim, const RuntimeShape &input_shape, const Scalar *input_data, const RuntimeShape &output_shape, Scalar *output_data)
void BroadcastSelect4DSlow(const RuntimeShape &input_condition_shape, const D *input_condition_data, const RuntimeShape &input_x_shape, const T *input_x_data, const RuntimeShape &input_y_shape, const T *input_y_data, const RuntimeShape &output_shape, T *output_data)
void Minimum(const RuntimeShape &input1_shape, const T *input1_data, const T *input2_data, const RuntimeShape &output_shape, T *output_data)
void RankOneSelect(const RuntimeShape &input_condition_shape, const D *input_condition_data, const RuntimeShape &input_x_shape, const T *input_x_data, const RuntimeShape &input_y_shape, const T *input_y_data, const RuntimeShape &output_shape, T *output_data)
void Pow(const RuntimeShape &input1_shape, const T *input1_data, const RuntimeShape &input2_shape, const T *input2_data, const RuntimeShape &output_shape, T *output_data)
void Mul(const ArithmeticParams &params, const RuntimeShape &input1_shape, const int16 *input1_data, const RuntimeShape &input2_shape, const int16 *input2_data, const RuntimeShape &output_shape, int16 *output_data)
void SelectTrueCoords(const RuntimeShape &input_condition_shape, const D *input_condition_data, T *output_data)
void Slice(const tflite::SliceParams &op_params, const RuntimeShape &input_shape, const RuntimeShape &output_shape, SequentialTensorWriter< T > *writer)
void SegmentSum(const RuntimeShape &input_shape, const T *input_data, const RuntimeShape &segment_ids_shape, const int32_t *segment_ids_data, const RuntimeShape &output_shape, T *output_data)
void DepthConcatenation(const ConcatenationParams &params, const RuntimeShape *const *input_shapes, const Scalar *const *input_data, const RuntimeShape &output_shape, Scalar *output_data)
void PackWithScaling(const PackParams &params, const RuntimeShape *const *input_shapes, const uint8 *const *input_data, const RuntimeShape &output_shape, uint8 *output_data)
void Pack(const PackParams &params, const RuntimeShape *const *input_shapes, const Scalar *const *input_data, const RuntimeShape &output_shape, Scalar *output_data)
void BroadcastMulFivefold(const ArithmeticParams &unswitched_params, const RuntimeShape &unswitched_input1_shape, const uint8 *unswitched_input1_data, const RuntimeShape &unswitched_input2_shape, const uint8 *unswitched_input2_data, const RuntimeShape &output_shape, uint8 *output_data)
int NodeOffset(int b, int h, int w, int height, int width)
void Reverse(int axis, const RuntimeShape &input_shape, const Scalar *input_data, const RuntimeShape &output_shape, Scalar *output_data)
void Relu6(const RuntimeShape &input_shape, const float *input_data, const RuntimeShape &output_shape, float *output_data)
void Unpack(const UnpackParams &params, const RuntimeShape &input_shape, const Scalar *input_data, const RuntimeShape &output_shape, Scalar *const *output_datas)
void Split(const SplitParams &params, const RuntimeShape &input_shape, const Scalar *input_data, const RuntimeShape *const *output_shapes, Scalar *const *output_data)
void Maximum(const RuntimeShape &input1_shape, const T *input1_data, const T *input2_data, const RuntimeShape &output_shape, T *output_data)
void ScatterNd(const RuntimeShape &indices_shape, const IndicesT *indices_data, const RuntimeShape &updates_shape, const UpdatesT *updates_data, const RuntimeShape &output_shape, UpdatesT *output_data)
int32_t size[5]
Definition Slice.cpp:35
int8_t size_count
Definition Slice.cpp:34
int8_t begin_count
Definition Slice.cpp:32
int32_t int32
Definition topk_v2.h:27