17#include "../KernelGenerator.h"
18#include "../Validator.h"
25void Validator::visit(
const ir::operation::DepthwiseConv2D &) {
_supported =
true; }
27void KernelGenerator::visit(
const ir::operation::DepthwiseConv2D &node)
29 using ir::operation::DepthwiseConv2D;
31 const auto ofm_index{node.getOutputs().at(0)};
32 const auto ifm_index{node.getInputs().at(DepthwiseConv2D::Input::INPUT)};
33 const auto ker_index{node.getInputs().at(DepthwiseConv2D::Input::KERNEL)};
34 const auto bias_index{node.getInputs().at(DepthwiseConv2D::Input::BIAS)};
36 const auto ifm_shape = _ctx.
at(ifm_index).shape().asFeature();
37 const auto ofm_shape = _ctx.
at(ofm_index).shape().asFeature();
39 const auto &ker_shape = _ctx.
at(ker_index).shape();
40 const auto ker_height = ker_shape.dim(1);
41 const auto ker_width = ker_shape.dim(2);
43 const auto stride = node.param().stride;
44 const auto dilation = node.param().dilation;
47 dilation.width_factor, dilation.height_factor);
48 const auto multiplier = node.param().multiplier;
49 const auto activation = node.param().activation;
51 auto ofm_tensor = _tensor_reg->getAclTensor(ofm_index);
52 auto ifm_tensor = _tensor_reg->getAclTensor(ifm_index);
53 auto ker_tensor = _tensor_reg->getAclTensor(ker_index);
54 auto bias_tensor = _tensor_reg->getAclTensor(bias_index);
60 auto fn = acl_common::generateLayer<arm_compute::CLDepthwiseConvolutionLayer>(
62 conv_info, multiplier, act_info, dilation_info);
std::unique_ptr< exec::IFunction > _return_fn
const Object & at(const Index &index) const
Get the object that is associated with the given index.
::arm_compute::ActivationLayerInfo asActivationLayerInfo(const ir::Activation act_code)
arm_compute::Size2D asDilation(uint32_t dilation_width, uint32_t dilation_height)
::arm_compute::PadStrideInfo asPadStrideInfo(const ir::ExplicitPadding &padding, const ir::Stride &stride)
std::unique_ptr< AclFunction > asAclFunction(std::unique_ptr<::arm_compute::IFunction > &&layer)
const ExplicitPadding calculatePadding(const Padding &padding, const FeatureShape &ifm_shape, const FeatureShape &ofm_shape, const Stride &stride, uint32_t kw, uint32_t kh, uint32_t dwf=1, uint32_t dhf=1)