17#include "../KernelGenerator.h"
18#include "../Validator.h"
25void Validator::visit(
const ir::operation::TransposeConv &) {
_supported =
true; }
27void KernelGenerator::visit(
const ir::operation::TransposeConv &node)
29 const auto ofm_index{node.getOutputs().at(0)};
33 const auto ofm_shape = _ctx.
at(ofm_index).shape().asFeature();
34 const auto ifm_shape = _ctx.
at(ifm_index).shape().asFeature();
35 const auto ker_shape = _ctx.
at(ker_index).shape().asFeature();
37 const auto stride = node.param().stride;
42 ker_shape.W, ker_shape.H);
43 uint32_t invalid_horizontal = 0;
44 uint32_t invalid_vertical = 0;
48 ofm_shape.W - (1 + (ifm_shape.W - 1) * stride.horizontal) - (ker_shape.W - 1);
49 invalid_vertical = ofm_shape.H - (1 + (ifm_shape.H - 1) * stride.vertical) - (ker_shape.H - 1);
52 auto ofm_tensor = _tensor_reg->getAclTensor(ofm_index);
53 auto ifm_tensor = _tensor_reg->getAclTensor(ifm_index);
54 auto ker_tensor = _tensor_reg->getAclTensor(ker_index);
58 auto fn = acl_common::generateLayer<arm_compute::CLTransposeConvLayer>(
59 _tensor_builder->acl_tensor_manager()->internal_buffer_manager(), ifm_tensor->handle(),
60 ker_tensor->handle(),
nullptr, ofm_tensor->handle(), tconv_info, invalid_horizontal,
std::unique_ptr< exec::IFunction > _return_fn
const Object & at(const Index &index) const
Get the object that is associated with the given index.
::arm_compute::PadStrideInfo asPadStrideInfo(const ir::ExplicitPadding &padding, const ir::Stride &stride)
std::unique_ptr< AclFunction > asAclFunction(std::unique_ptr<::arm_compute::IFunction > &&layer)
const ExplicitPadding calculatePadding(const Padding &padding, const FeatureShape &ifm_shape, const FeatureShape &ofm_shape, const Stride &stride, uint32_t kw, uint32_t kh, uint32_t dwf=1, uint32_t dhf=1)