33constexpr uint32_t inputTensorIdx = 0;
34constexpr uint32_t weightTensorIdx = 1;
35constexpr uint32_t biasTensorIdx = 2;
41OMStatus onert_micro::import::configure_kernel_CircleConv2D(
const OMConfigureArgs &config_args)
44 uint16_t op_index = config_args.kernel_index;
47 runtime_kernel.
readKernel(op_index, runtime_context);
49 const circle::Tensor *
input = runtime_kernel.
inputs[inputTensorIdx];
50 const circle::Tensor *weight = runtime_kernel.
inputs[weightTensorIdx];
51 const circle::Tensor *
bias = runtime_kernel.
inputs[biasTensorIdx];
55 assert(input !=
nullptr);
56 assert(weight !=
nullptr);
58 assert(output !=
nullptr);
62 if ((
input->type() == circle::TensorType_FLOAT32 &&
63 weight->type() != circle::TensorType_FLOAT32) or
64 (
input->type() == circle::TensorType_INT8 && weight->type() != circle::TensorType_INT8) or
65 (
input->type() == circle::TensorType_INT16 && weight->type() != circle::TensorType_INT16))
75 status = utils::checkCondition(input_shape.dimensionsCount() == 4);
83 status = utils::checkCondition(input_shape.dimensionsCount() == weight_shape.dimensionsCount());
87 status = utils::checkCondition(bias ==
nullptr or weight_shape.dims(0) == bias_shape.flatSize());
89 if (
input->type() == circle::TensorType_FLOAT32)
92 auto input_quant =
input->quantization();
93 auto filter_quant = weight->quantization();
94 auto output_quant =
output->quantization();
96 status = utils::checkCondition(input_quant !=
nullptr and filter_quant !=
nullptr and
97 output_quant !=
nullptr);
101 auto input_scales = input_quant->scale();
102 auto filter_scales = filter_quant->scale();
103 auto output_scales = output_quant->scale();
105 status = utils::checkCondition(input_scales !=
nullptr and filter_scales !=
nullptr and
106 output_scales !=
nullptr);
111 status = utils::checkCondition(filter_scales->size() > 1);
int32_t dimensionsCount() const
OMStatus readKernel(uint16_t op_index, core::OMRuntimeContext &runtime_context)
const circle::Tensor * outputs[maxOutputSize]
const circle::Tensor * inputs[maxInputSize]
const luci_interpreter::RuntimeShape output_shape
constexpr uint32_t outputTensorIdx