ONE - On-device Neural Engine
Loading...
Searching...
No Matches
ggma::KVCache Struct Reference

#include <KVCache.h>

Public Member Functions

size_t element_size () const
 
NNFW_TYPE to_nnfw_type () const
 
bool is_valid () const
 
int64_t pos () const
 
void set_pos (int pos)
 
void reset_pos ()
 
void advance_pos ()
 
void init (const ggma::GGMAConfig &cfg, int cache_size)
 
void transpose (bool is_k_cache, const char *perm, size_t seq_len, size_t num_heads, size_t head_dim)
 Transpose cache with "0213" permutation [0,2,1,3].
 

Data Fields

KVCacheDataType data_type
 
std::vector< std::vector< uint8_t > > k
 
std::vector< std::vector< uint8_t > > v
 
int64_t _pos = 0
 

Detailed Description

Definition at line 43 of file KVCache.h.

Member Function Documentation

◆ advance_pos()

void ggma::KVCache::advance_pos ( )
inline

Definition at line 95 of file KVCache.h.

95{ _pos++; }
int64_t _pos
Definition KVCache.h:48

References _pos.

◆ element_size()

size_t ggma::KVCache::element_size ( ) const
inline

Definition at line 51 of file KVCache.h.

52 {
53 switch (data_type)
54 {
56 return sizeof(float);
58 return 1;
59 default:
60 return sizeof(float);
61 }
62 }
KVCacheDataType data_type
Definition KVCache.h:45

References data_type, ggma::FLOAT32, and ggma::UINT8.

Referenced by init(), and transpose().

◆ init()

void ggma::KVCache::init ( const ggma::GGMAConfig cfg,
int  cache_size 
)

Definition at line 100 of file KVCache.cc.

101{
102 if (cfg.model.n_layers <= 0)
103 throw std::runtime_error("n_layers not properly initialized");
104
105 // Set KV cache data type from config
107
108 // Allocate space for K and V caches for each layer
109 // Total: n_layers * 2 vectors (K and V for each layer)
110 k.resize(cfg.model.n_layers);
111 v.resize(cfg.model.n_layers);
112
113 for (int i = 0; i < cfg.model.n_layers; ++i)
114 {
115 size_t buffer_size = cfg.model.hidden_size * cache_size * element_size();
116 k[i].resize(buffer_size, 0);
117 v[i].resize(buffer_size, 0);
118 }
119}
ModelConfig model
Definition Config.h:67
KVCacheDataType kv_cache_type
Definition Config.h:70
size_t element_size() const
Definition KVCache.h:51
std::vector< std::vector< uint8_t > > k
Definition KVCache.h:46
std::vector< std::vector< uint8_t > > v
Definition KVCache.h:47

References data_type, element_size(), ggma::ModelConfig::hidden_size, k, ggma::GGMAConfig::kv_cache_type, ggma::GGMAConfig::model, ggma::ModelConfig::n_layers, and v.

Referenced by ggma::Context::Context().

◆ is_valid()

bool ggma::KVCache::is_valid ( ) const
inline

Definition at line 79 of file KVCache.h.

80 {
81 if (k.size() != v.size())
82 return false;
83
84 for (size_t i = 0; i < k.size(); ++i)
85 if (k[i].size() != v[i].size())
86 return false;
87
88 return true;
89 }
int32_t size[5]
Definition Slice.cpp:35

References k, size, and v.

◆ pos()

int64_t ggma::KVCache::pos ( ) const
inline

Definition at line 92 of file KVCache.h.

92{ return _pos; }

References _pos.

Referenced by ggma::Context::generate(), and set_pos().

◆ reset_pos()

void ggma::KVCache::reset_pos ( )
inline

Definition at line 94 of file KVCache.h.

94{ _pos = 0; }

References _pos.

Referenced by ggma::Context::generate().

◆ set_pos()

void ggma::KVCache::set_pos ( int  pos)
inline

Definition at line 93 of file KVCache.h.

93{ _pos = pos; }
int64_t pos() const
Definition KVCache.h:92

References _pos, and pos().

Referenced by ggma::Context::generate().

◆ to_nnfw_type()

NNFW_TYPE ggma::KVCache::to_nnfw_type ( ) const
inline

Definition at line 65 of file KVCache.h.

66 {
67 switch (data_type)
68 {
73 default:
75 }
76 }
@ NNFW_TYPE_TENSOR_UINT8
Definition nnfw.h:83
@ NNFW_TYPE_TENSOR_FLOAT32
Definition onert-micro.h:77

References data_type, ggma::FLOAT32, NNFW_TYPE_TENSOR_FLOAT32, NNFW_TYPE_TENSOR_UINT8, and ggma::UINT8.

Referenced by ggma::Context::prefill().

◆ transpose()

void ggma::KVCache::transpose ( bool  is_k_cache,
const char *  perm,
size_t  seq_len,
size_t  num_heads,
size_t  head_dim 
)

Transpose cache with "0213" permutation [0,2,1,3].

Parameters
is_k_cachetrue for K cache, false for V cache
permPermutation string (must be "0213")
seq_lenSequence length dimension
num_headsNumber of attention heads
head_dimHead dimension

Definition at line 68 of file KVCache.cc.

70{
71 if (perm == nullptr || strcmp(perm, "0213") != 0)
72 throw std::runtime_error("Only \"0213\" permutation is supported");
73
74 std::vector<std::vector<uint8_t>> &cache_vector = is_k_cache ? k : v;
75 const size_t element_bytes = element_size();
76 const size_t head_bytes = head_dim * element_bytes;
77
78 for (size_t i = 0; i < cache_vector.size(); ++i)
79 {
80 std::vector<uint8_t> transposed_cache = cache_vector[i];
81 uint8_t *input_data = cache_vector[i].data();
82 uint8_t *output_data = transposed_cache.data();
83
84 for (size_t s = 0; s < seq_len; ++s) // seq_len
85 {
86 for (size_t h = 0; h < num_heads; ++h) // num_heads
87 {
88 // source offset: s * (num_heads * head_bytes) + h * head_bytes
89 // target offset: h * (seq_len * head_bytes) + s * head_bytes
90 uint8_t *src_ptr = input_data + s * (num_heads * head_bytes) + h * head_bytes;
91 uint8_t *dst_ptr = output_data + h * (seq_len * head_bytes) + s * head_bytes;
92 memcpy(dst_ptr, src_ptr, head_bytes);
93 }
94 }
95
96 cache_vector[i] = std::move(transposed_cache);
97 }
98}
list input_data
Definition infer.py:29

References element_size(), k, and v.

Referenced by ggma::Context::generate().

Field Documentation

◆ _pos

int64_t ggma::KVCache::_pos = 0

Definition at line 48 of file KVCache.h.

Referenced by advance_pos(), pos(), reset_pos(), and set_pos().

◆ data_type

KVCacheDataType ggma::KVCache::data_type

Definition at line 45 of file KVCache.h.

Referenced by element_size(), init(), and to_nnfw_type().

◆ k

std::vector<std::vector<uint8_t> > ggma::KVCache::k

Definition at line 46 of file KVCache.h.

Referenced by init(), is_valid(), ggma::Context::prefill(), and transpose().

◆ v

std::vector<std::vector<uint8_t> > ggma::KVCache::v

Definition at line 47 of file KVCache.h.

Referenced by init(), is_valid(), ggma::Context::prefill(), and transpose().


The documentation for this struct was generated from the following files: