19#include "kernels/Utils.h"
21#include "PALTransposeConv.h"
33constexpr int kFilterTensor = 1;
34constexpr int kInputTensor = 2;
35constexpr int kBiasTensor = 3;
36constexpr int kOutputTensor = 0;
39int32_t compute_padding_h(
const circle::Tensor *input,
const circle::Tensor *filter,
40 const circle::TransposeConvOptions *options)
44 const int32_t filter_height =
Tensor::dim(filter, 1);
48 const auto padding_height =
50 return padding_height;
53int32_t compute_padding_w(
const circle::Tensor *input,
const circle::Tensor *filter,
54 const circle::TransposeConvOptions *options)
58 const int32_t filter_width =
Tensor::dim(filter, 2);
62 const auto padding_width =
70void evalFloat(
const circle::Tensor *input,
const circle::Tensor *filter,
71 const circle::Tensor *bias,
const circle::Tensor *output,
72 const circle::TransposeConvOptions *options,
BaseRuntimeGraph *runtime_graph)
74 float activation_min{};
75 float activation_max{};
80 params.padding_values.width = compute_padding_w(input, filter, options);
81 params.stride_height =
options->stride_h();
82 params.stride_width =
options->stride_w();
83 params.dilation_height_factor = 1;
84 params.dilation_width_factor = 1;
85 params.float_activation_min = activation_min;
86 params.float_activation_max = activation_max;
88 auto *
input_data = runtime_graph->getDataByTensor(input);
89 auto *
output_data = runtime_graph->getDataByTensor(output);
91 auto *filter_data = runtime_graph->getConstDataByTensor(filter);
92 auto *bias_data = runtime_graph->getConstDataByTensor(bias);
96 kernels::getTensorData<float>(input_data),
100 kernels::getTensorData<float>(output_data));
110 const auto input_index = cur_op->inputs()->operator[](kInputTensor);
111 const auto filter_index = cur_op->inputs()->operator[](kFilterTensor);
112 const auto output_index = cur_op->outputs()->operator[](kOutputTensor);
114 assert(input_index != -1);
115 assert(filter_index != -1);
116 assert(output_index != -1);
122 assert(input !=
nullptr);
123 assert(filter !=
nullptr);
127 assert(filter_data !=
nullptr);
138 const auto input_index = cur_op->inputs()->operator[](kInputTensor);
139 const auto weight_index = cur_op->inputs()->operator[](kFilterTensor);
140 const auto bias_index =
141 cur_op->inputs()->size() == 4 ? cur_op->inputs()->operator[](kBiasTensor) : -1;
142 const auto output_index = cur_op->outputs()->operator[](kOutputTensor);
144 assert(input_index != -1);
145 assert(weight_index != -1);
146 assert(output_index != -1);
153 assert(input !=
nullptr);
154 assert(weights !=
nullptr);
155 assert(output !=
nullptr);
157 const auto *options = cur_op->builtin_options_as_TransposeConvOptions();
159 const auto type = Tensor::element_type(input);
163 case DataType::FLOAT32:
164 if (Tensor::element_type(weights) == DataType::FLOAT32)
166 evalFloat(input, weights, bias, output, options, runtime_graph);
171 assert(
false &&
"Unsupported type.");
uint8_t * getConstDataByTensor(const circle::Tensor *raw_tensor)
const circle::Tensor * getCircleTensorByIndex(int32_t index)
#define LUCI_INTERPRETER_CHECK(cond)
int32_t computePadding(int32_t stride, int32_t dilation_rate, int32_t in_size, int32_t filter_size, int32_t out_size)
void calculateActivationRange(Activation activation, T *activation_min, T *activation_max)
int32_t computeOutputSize(Padding padding, int32_t image_size, int32_t filter_size, int32_t stride, int32_t dilation_rate=1)
luci_interpreter::RuntimeShape getTensorRuntimeShape(const circle::Tensor *circle_tensor, BaseRuntimeGraph *runtime_graph)
void TransposeConv(const ConvParams ¶ms, const luci_interpreter::RuntimeShape &input_shape, const float *input_data, const luci_interpreter::RuntimeShape &filter_shape, const float *filter_data, const luci_interpreter::RuntimeShape &bias_shape, const float *bias_data, const luci_interpreter::RuntimeShape &output_shape, float *output_data)
RuntimeGraph BaseRuntimeGraph
void execute_kernel_CircleTransposeConv(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
void configure_kernel_CircleTransposeConv(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
Padding luci_padding(const circle::Padding padding)
const loco::Dimension & dim(uint32_t axis) const
PaddingValues padding_values