59 const circle::Tensor *input;
60 const circle::Tensor *weight;
61 const circle::Tensor *output;
68 const circle::TransposeConvOptions *options;
76 input = runtime_kernel.
inputs[kInputTensorIdx];
77 weight = runtime_kernel.
inputs[kWeightTensorIdx];
78 output = runtime_kernel.
outputs[kOutputTensorIdx];
79 assert(input !=
nullptr);
80 assert(weight !=
nullptr);
82 assert(output !=
nullptr);
84 status = runtime_kernel.
getDataFromStorage(op_index, runtime_storage, runtime_context);
88 input_data = runtime_kernel.
inputs_data[kInputTensorIdx];
89 weight_data = runtime_kernel.
inputs_data[kWeightTensorIdx];
90 bias_data = runtime_kernel.
inputs_data[kBiasTensorIdx];
91 output_data = runtime_kernel.
outputs_data[kOutputTensorIdx];
92 assert(input_data !=
nullptr);
93 assert(weight_data !=
nullptr);
95 assert(output_data !=
nullptr);
97 options = runtime_kernel.
first_operator->builtin_options_as_TransposeConvOptions();
102 int32_t padding_h = 0;
103 int32_t padding_w = 0;
108 const int input_width = input_shape.
dims(2);
109 const int input_height = input_shape.
dims(1);
110 const int weight_width = weight_shape.
dims(2);
111 const int weight_height = weight_shape.
dims(1);
115 input_width, weight_height, weight_width, options->padding(),
116 &padding_h, &padding_w);
118 switch (input->type())
121 case circle::TensorType_FLOAT32:
126 ¶ms.activation_min, ¶ms.activation_max);
127 params.stride_w = options->stride_w();
128 params.stride_h = options->stride_h();
129 params.dilation_width_factor = 1;
130 params.dilation_height_factor = 1;
131 params.pad_h = padding_h;
132 params.pad_w = padding_w;
138 ¶ms, input_shape, core::utils::castInputData<float>(input_data), weight_shape,
139 core::utils::castInputData<float>(weight_data),
140 core::utils::castInputData<float>(bias_data),
OMRuntimeShape(output),
141 core::utils::castOutputData<float>(output_data));
142 assert(status ==
Ok);
149 assert(
false &&
"Unsupported type.");
OMStatus TransposeConv< float >(const core::FloatConv2D *params, const core::OMRuntimeShape &input_shape, const float *input_data, const core::OMRuntimeShape &filter_shape, const float *filter_data, const float *bias_data, const core::OMRuntimeShape &output_shape, float *output_data)
void computePaddingHeightWidth(int32_t stride_height, int32_t stride_width, int32_t dilation_rate_height, int32_t dilation_rate_width, int32_t in_height, int32_t in_width, int32_t filter_height, int32_t filter_width, circle::Padding padding, int32_t *padding_h, int32_t *padding_w)