17#include "../KernelGenerator.h"
18#include "../Validator.h"
25void Validator::visit(
const ir::operation::RNN &) {
_supported =
true; }
27void KernelGenerator::visit(
const ir::operation::RNN &node)
30 const auto hidden_state_out_index{
35 const auto recurrent_weights_index{
40 const auto activation = node.param().activation;
43 auto hidden_state_out_tensor = _tensor_reg->getAclTensor(hidden_state_out_index);
45 auto input_tensor = _tensor_reg->getAclTensor(input_index);
46 auto weights_tensor = _tensor_reg->getAclTensor(weights_index);
47 auto recurrent_weights_tensor = _tensor_reg->getAclTensor(recurrent_weights_index);
48 auto bias_tensor = _tensor_reg->getAclTensor(bias_index);
49 auto hidden_state_in_tensor = _tensor_reg->getAclTensor(hidden_state_in_index);
52 auto copy_layer = acl_common::generateLayer<arm_compute::NECopy>(
53 hidden_state_in_tensor->handle(), hidden_state_out_tensor->handle());
56 auto fn = acl_common::generateLayer<arm_compute::NERNNLayer>(
57 _tensor_builder->acl_tensor_manager()->internal_buffer_manager(), input_tensor->handle(),
58 weights_tensor->handle(), recurrent_weights_tensor->handle(),
bias_tensor->handle(),
59 hidden_state_out_tensor->handle(),
output_tensor->handle(), act_info);
std::unique_ptr< exec::IFunction > _return_fn
::arm_compute::ActivationLayerInfo asActivationLayerInfo(const ir::Activation act_code)
std::unique_ptr< AclFunction > asAclFunction(std::unique_ptr<::arm_compute::IFunction > &&layer)