17#include "../KernelGenerator.h"
18#include "../Validator.h"
25void Validator::visit(
const ir::operation::Reduce &) {
_supported =
true; }
27void KernelGenerator::visit(
const ir::operation::Reduce &node)
29 const auto output_index{node.getOutputs().at(0)};
34 auto input_tensor = _tensor_reg->getAclTensor(input_index);
37 const auto &axes = _ctx.
at(axes_index);
38 const auto input_rank = _ctx.
at(input_index).shape().rank();
40 const auto reduce_type = node.param().reduce_type;
41 const auto keep_dims = node.param().keep_dims;
43 std::unique_ptr<::arm_compute::IFunction> fn;
46 fn = acl_common::generateLayer<arm_compute::NEReduceMean>(input_tensor->handle(), reduce_axes,
51 fn = acl_common::generateLayer<arm_compute::NEReduceSum>(input_tensor->handle(), reduce_axes,
56 fn = acl_common::generateLayer<arm_compute::NEReduceOperation>(
57 input_tensor->handle(), reduce_axes, keep_dims,
output_tensor->handle(),
std::unique_ptr< exec::IFunction > _return_fn
const Object & at(const Index &index) const
Get the object that is associated with the given index.
arm_compute::ReductionOperation convertReduceType(ir::operation::Reduce::ReduceType reduce_type_ir)
arm_compute::Coordinates asCoordinates(const ir::Operand &operand, int32_t rank)
std::unique_ptr< AclFunction > asAclFunction(std::unique_ptr<::arm_compute::IFunction > &&layer)