17#include "../KernelGenerator.h"
18#include "../Validator.h"
25void Validator::visit(
const ir::operation::SplitV &) {
_supported =
true; }
27void KernelGenerator::visit(
const ir::operation::SplitV &node)
33 assert(node.param().num_splits ==
static_cast<int>(node.getOutputs().size()));
34 if (!_ctx.
at(split_dim_index).isConstant() || !_ctx.
at(size_split_index).isConstant())
36 throw std::runtime_error(
37 "Non-constant split_dim and size_split is not supported in acl_cl backend");
42 const size_t ifm_rank = _ctx.
at(ifm_index).shape().rank();
43 std::vector<ir::OperandIndex> output_indexes;
44 for (
const auto &output : node.getOutputs())
45 output_indexes.emplace_back(
output);
47 auto ifm_tensor = _tensor_reg->getAclTensor(ifm_index);
48 std::vector<arm_compute::ICLTensor *> output_tensors;
49 for (
const auto &ofm_ind : output_indexes)
50 output_tensors.emplace_back(_tensor_reg->getAclTensor(ofm_ind)->handle());
52 auto axis = _ctx.
at(split_dim_index).asScalar<int32_t>();
58 acl_common::generateLayer<arm_compute::CLSplit>(ifm_tensor->handle(), output_tensors, axis);
uint32_t value(void) const
std::unique_ptr< exec::IFunction > _return_fn
const Object & at(const Index &index) const
Get the object that is associated with the given index.
ARMComputeAxis ToARMComputeAxis(uint32_t rank, uint32_t axis)
std::unique_ptr< AclFunction > asAclFunction(std::unique_ptr<::arm_compute::IFunction > &&layer)