ONE - On-device Neural Engine
Loading...
Searching...
No Matches
onert::backend::cpu::ops::AttentionLayer Class Reference

#include <AttentionLayer.h>

Collaboration diagram for onert::backend::cpu::ops::AttentionLayer:

Public Member Functions

 AttentionLayer ()
 
 ~AttentionLayer ()
 
void configure (const IPortableTensor *input, const IPortableTensor *wq, const IPortableTensor *wk, const IPortableTensor *wv, const IPortableTensor *wo, const IPortableTensor *cos, const IPortableTensor *sin, const IPortableTensor *mask, IPortableTensor *k_cache, IPortableTensor *v_cache, const IPortableTensor *pos, IPortableTensor *output)
 
void run () override
 
- Public Member Functions inherited from onert::exec::IFunction
virtual ~IFunction ()=default
 
virtual void prepare ()
 

Detailed Description

Definition at line 28 of file AttentionLayer.h.

Constructor & Destructor Documentation

◆ AttentionLayer()

onert::backend::cpu::ops::AttentionLayer::AttentionLayer ( )

Definition at line 84 of file AttentionLayer.cc.

85 : _input(nullptr), _wq(nullptr), _wk(nullptr), _wv(nullptr), _wo(nullptr), _cos(nullptr),
86 _sin(nullptr), _mask(nullptr), _k_cache(nullptr), _v_cache(nullptr), _cache_pos(nullptr),
87 _output(nullptr)
88{
89 // DO NOTHING
90}

◆ ~AttentionLayer()

onert::backend::cpu::ops::AttentionLayer::~AttentionLayer ( )
default

Member Function Documentation

◆ configure()

void onert::backend::cpu::ops::AttentionLayer::configure ( const IPortableTensor input,
const IPortableTensor wq,
const IPortableTensor wk,
const IPortableTensor wv,
const IPortableTensor wo,
const IPortableTensor cos,
const IPortableTensor sin,
const IPortableTensor mask,
IPortableTensor k_cache,
IPortableTensor v_cache,
const IPortableTensor pos,
IPortableTensor output 
)

Definition at line 94 of file AttentionLayer.cc.

100{
101 _input = input;
102 _wq = wq;
103 _wk = wk;
104 _wv = wv;
105 _wo = wo;
106 _cos = cos;
107 _sin = sin;
108 _mask = mask;
109 _k_cache = k_cache;
110 _v_cache = v_cache;
111 _cache_pos = pos;
112 _output = output;
113
114 // 0. Read and check inputs and params
115 const auto n_batch = getShape(_input).Dims(0);
116 assert(n_batch == 1); // Multi-batch is not supported.
117 const auto d_model = getShape(_input).Dims(2);
118
119 if (_cos == nullptr || _sin == nullptr || _cache_pos == nullptr)
120 throw std::runtime_error{"Attention: input tensors cannot be nullptr"};
121
122 const auto k_cache_shape = getShape(_k_cache);
123 if (k_cache_shape.DimensionsCount() != 4)
124 throw std::runtime_error{"K cache tensor must be 4D"};
125
126 // 0.1 Param - Read n_head from K cache 3rd dimension
127 const int32_t n_head = k_cache_shape.Dims(2);
128 if (d_model % n_head != 0)
129 throw std::runtime_error{"d_model must be divisible by n_head"};
130
131 const int32_t d_head = d_model / n_head;
132 const auto k_cache_dims = k_cache_shape.DimsData();
133 const int32_t k_cache_n_batch = k_cache_dims[0];
134 const int32_t k_cache_n_head = k_cache_dims[2];
135 const int32_t k_cache_d_head = k_cache_dims[3];
136
137 if (n_batch != k_cache_n_batch || n_head != k_cache_n_head || d_head != k_cache_d_head)
138 throw std::runtime_error{"Attention: shape mismatch between inputs"};
139}
int32_t Dims(int i) const
Definition Shape.h:106
IntervalMask mask(uint32_t s, uint32_t e)
Definition IntervalSet.h:34
nnfw::cker::Shape getShape(const IPortableTensor *tensor)

References nnfw::cker::Shape::Dims(), and onert::backend::cpu::ops::getShape().

◆ run()

void onert::backend::cpu::ops::AttentionLayer::run ( )
overridevirtual

Implements onert::exec::IFunction.

Definition at line 388 of file AttentionLayer.cc.

389{
390 if (_input->data_type() == OperandType::FLOAT32)
391 attentionFloat32();
392 else
393 throw std::runtime_error{"AttentionLayer: unsupported input data type"};
394}
ir::DataType data_type() const override final

References onert::backend::IPortableTensor::data_type().


The documentation for this class was generated from the following files: