17#ifndef LUCI_INTERPRETER_PAL_CONV2D_H
18#define LUCI_INTERPRETER_PAL_CONV2D_H
20#include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
21#include <tensorflow/lite/kernels/internal/reference/integer_ops/conv.h>
25static inline void Conv(
const tflite::ConvParams ¶ms,
const tflite::RuntimeShape &input_shape,
26 const float *input_data,
const tflite::RuntimeShape &filter_shape,
27 const float *filter_data,
const tflite::RuntimeShape &bias_shape,
28 const float *bias_data,
const tflite::RuntimeShape &
output_shape,
29 float *output_data,
const tflite::RuntimeShape &scratchpad_shape,
30 float *scratchpad_data)
32 (void)scratchpad_shape;
34 const int32_t batches = tflite::MatchingDim(input_shape, 0,
output_shape, 0);
35 const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
38 const int32_t filter_height = filter_shape.Dims(1);
39 const int32_t filter_width = filter_shape.Dims(2);
41 int64_t im2col_flat_size = 1;
42 im2col_flat_size *= batches;
43 im2col_flat_size *= output_height;
44 im2col_flat_size *= output_width;
45 im2col_flat_size *= input_depth;
46 im2col_flat_size *= filter_height;
47 im2col_flat_size *= filter_width;
53 bool opt_kernel_overflow = im2col_flat_size > std::numeric_limits<int32_t>::max();
55 if (scratchpad_data and not opt_kernel_overflow)
57 tflite::RuntimeShape im2col_shape{batches, output_height, output_width,
58 input_depth * filter_height * filter_width};
60 tflite::optimized_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
61 bias_shape, bias_data,
output_shape, output_data, im2col_shape,
65 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
67 tflite::RuntimeShape(),
nullptr);
70static inline void Conv(
const tflite::ConvParams ¶ms,
const tflite::RuntimeShape &input_shape,
71 const uint8 *input_data,
const tflite::RuntimeShape &filter_shape,
72 const uint8 *filter_data,
const tflite::RuntimeShape &bias_shape,
74 uint8 *output_data,
const tflite::RuntimeShape &scratchpad_shape,
75 uint8 *scratchpad_data)
79 auto gemmlowp_context = std::make_unique<gemmlowp::GemmContext>();
80 gemmlowp_context->set_max_num_threads(
static_cast<int>(std::thread::hardware_concurrency()));
82 tflite::reference_ops::Conv(params, input_shape, input_data, filter_shape, filter_data,
83 bias_shape, bias_data,
output_shape, output_data, scratchpad_shape,
84 scratchpad_data, gemmlowp_context.get());
87static inline void ConvPerChannel(
const tflite::ConvParams ¶ms,
const int32_t *mult,
88 const int32_t *shifts,
const tflite::RuntimeShape &input_shape,
89 const int8 *input_data,
const tflite::RuntimeShape &filter_shape,
90 const int8 *filter_data,
const tflite::RuntimeShape &bias_shape,
92 int8 *output_data,
const tflite::RuntimeShape &scratchpad_shape,
93 int8 *scratchpad_data)
95 (void)scratchpad_shape;
96 (void)scratchpad_data;
98 tflite::reference_integer_ops::ConvPerChannel(params, mult, shifts, input_shape, input_data,
99 filter_shape, filter_data, bias_shape, bias_data,
105 const tflite::ConvParams ¶ms,
106 const tflite::RuntimeShape &input_shape,
107 const tflite::RuntimeShape &filter_shape,
110 const int32_t filter_height = filter_shape.Dims(1);
111 const int32_t filter_width = filter_shape.Dims(2);
115 const bool need_dilated_scratchpad =
116 params.dilation_height_factor != 1 || params.dilation_width_factor != 1;
117 const bool need_non_dilated_scratchpad = params.stride_height != 1 || params.stride_width != 1 ||
118 filter_height != 1 || filter_width != 1;
119 auto _need_scratchpad = input_data_type != luci_interpreter::DataType::S16 &&
120 (need_dilated_scratchpad || need_non_dilated_scratchpad);
122 if (_need_scratchpad)
124 const int32_t batches = tflite::MatchingDim(input_shape, 0,
output_shape, 0);
125 const int32_t input_depth = tflite::MatchingDim(input_shape, 3, filter_shape, 3);
133 input_depth * filter_height * filter_width,
135 scratchpad->
resize(scratchpad_shape);
void set_allocatable(bool value)
void resize(const Shape &new_shape)
const luci_interpreter::RuntimeShape output_shape
size_t getDataTypeSize(DataType data_type)
DataType
"scalar" value type
void Conv(const ConvParams ¶ms, const Shape &input_shape, const uint8_t *input_data, const Shape &filter_shape, const uint8_t *filter_data, const Shape &bias_shape, const int32_t *bias_data, const Shape &output_shape, uint8_t *output_data, const Shape &im2col_shape, uint8_t *im2col_data)