17#include "../KernelGenerator.h"
18#include "../Validator.h"
25void Validator::visit(
const ir::operation::EmbeddingLookup &) {
_supported =
true; }
27void KernelGenerator::visit(
const ir::operation::EmbeddingLookup &node)
29 const auto output_index{node.getOutputs().at(0)};
34 auto lookups_tensor = _tensor_reg->getAclTensor(lookups_index);
35 auto values_tensor = _tensor_reg->getAclTensor(values_index);
37 size_t n = _ctx.
at(values_index).shape().rank();
38 assert(n == values_tensor->num_dimensions());
39 size_t k = _ctx.
at(lookups_index).shape().rank();
40 assert(k == lookups_tensor->num_dimensions());
45 if (n != values_tensor->info()->num_dimensions())
50 if (k != lookups_tensor->info()->num_dimensions())
56 auto fn = acl_common::generateLayer<arm_compute::CLGather>(
57 values_tensor->handle(), lookups_tensor->handle(),
output_tensor->handle(), axis);
60 if (values_tensor->dimension(0) == 1)
64 if (lookups_tensor->dimension(0) == 1)
uint32_t value(void) const
std::unique_ptr< exec::IFunction > _return_fn
const Object & at(const Index &index) const
Get the object that is associated with the given index.
ARMComputeAxis ToARMComputeAxis(uint32_t rank, uint32_t axis)
void enableDimCorrection(IACLTensor *tensor)
std::unique_ptr< AclFunction > asAclFunction(std::unique_ptr<::arm_compute::IFunction > &&layer)
void disableDimCorrection(IACLTensor *tensor)