17#include "../KernelGenerator.h"
18#include "../Validator.h"
25void Validator::visit(
const ir::operation::L2Normalization &) {
_supported =
true; }
27void KernelGenerator::visit(
const ir::operation::L2Normalization &node)
29 const auto ofm_index{node.getOutputs().at(0)};
37 const auto &ifm_shape = _ctx.
at(ifm_index).shape();
39 const auto normalization_axis = _ctx.
at(ifm_index).shape().rank() - 1;
41 2 * ifm_shape.dim(normalization_axis) + 1;
46 auto ofm_tensor = _tensor_reg->getAclTensor(ofm_index);
47 auto ifm_tensor = _tensor_reg->getAclTensor(ifm_index);
49 const auto norm_info = ::arm_compute::NormalizationLayerInfo(::arm_compute::NormType::CROSS_MAP,
50 radius, alpha, beta, bias,
false);
52 auto fn = acl_common::generateLayer<arm_compute::CLNormalizationLayer>(
53 ifm_tensor->handle(), ofm_tensor->handle(), norm_info);
std::unique_ptr< exec::IFunction > _return_fn
const Object & at(const Index &index) const
Get the object that is associated with the given index.
std::unique_ptr< AclFunction > asAclFunction(std::unique_ptr<::arm_compute::IFunction > &&layer)