ONE - On-device Neural Engine
Loading...
Searching...
No Matches
ggma::Context Class Reference

#include <Context.h>

Public Member Functions

 Context (const char *package_path)
 
GGMAConfig load_config (const std::string &package_path)
 
void prefill (ggma_token *tokens, size_t n_tokens, std::vector< uint8_t > &hidden_state)
 
void unemb (std::vector< uint8_t > &hidden_state, size_t n_tokens, std::vector< float > &logits)
 
ggma_token sample (const std::vector< float > &logits)
 
void decode (ggma_token token_id, std::vector< uint8_t > &hidden_state)
 
void decode (ggma_token token_id, std::vector< float > &logits)
 
 ~Context ()=default
 
GGMA_STATUS generate (ggma_token *tokens, size_t n_tokens, size_t n_tokens_max, size_t *n_predict)
 

Detailed Description

Definition at line 32 of file Context.h.

Constructor & Destructor Documentation

◆ Context()

ggma::Context::Context ( const char *  package_path)

Definition at line 78 of file Context.cc.

78 : _package_path(package_path)
79{
80 _cfg = load_config(_package_path);
81 _cache.init(_cfg, _cfg.cache_size);
82}
GGMAConfig load_config(const std::string &package_path)
Definition Context.cc:84
void init(const ggma::GGMAConfig &cfg, int cache_size)
Definition KVCache.cc:100

References ggma::GGMAConfig::cache_size, ggma::KVCache::init(), and load_config().

◆ ~Context()

ggma::Context::~Context ( )
default

Member Function Documentation

◆ decode() [1/2]

void ggma::Context::decode ( ggma_token  token_id,
std::vector< float > &  logits 
)

Definition at line 281 of file Context.cc.

282{
283 decode_impl<true, std::vector<float>>(token_id, logits);
284}

◆ decode() [2/2]

void ggma::Context::decode ( ggma_token  token_id,
std::vector< uint8_t > &  hidden_state 
)

Definition at line 276 of file Context.cc.

277{
278 decode_impl<false, std::vector<uint8_t>>(token_id, hidden_state);
279}

Referenced by generate().

◆ generate()

GGMA_STATUS ggma::Context::generate ( ggma_token tokens,
size_t  n_tokens,
size_t  n_tokens_max,
size_t *  n_predict 
)

Definition at line 39 of file Generate.cc.

41{
42 try
43 {
44 _cache.reset_pos();
45
46 std::vector<uint8_t> hidden;
47 std::vector<float> logits;
48 ggma_token new_token;
49
50 // 1. Prefill: run the model on the initial prompt to obtain the initial hidden state.
51 prefill(tokens, n_tokens, hidden); // hidden = prefill(tokens)
52
53 // 2. Set cache position to the length of the prompt.
54 _cache.set_pos(n_tokens);
55
56 // 3. Transpose KV caches to the layout expected by the decoder.
57 _cache.transpose(true /* k */, "0213", _cfg.model.num_attention_heads, _cfg.cache_size,
59 _cache.transpose(false /* v */, "0213", _cfg.model.num_attention_heads, _cfg.cache_size,
61
62 // 4. Unembed: obtain logits from the hidden state.
63 unemb(hidden, n_tokens, logits); // logits = unemb(hidden)
64
65 // 5. Determine how many tokens we can actually generate.
66 size_t n_possible = n_tokens_max - n_tokens;
67 if (*n_predict > n_possible)
68 *n_predict = n_possible;
69
70 auto is_end_token = [this](ggma_token token) {
71 return token == _cfg.model.eos_token_id.value_or(-1) || token == 0;
72 };
73
74 // 6. Autoregressive generation loop.
75 while ((_cache.pos() - n_tokens) < *n_predict)
76 {
77 // Sample the most probable token from the logits of the last position.
78 new_token = sample(logits);
79 tokens[n_tokens + (_cache.pos() - n_tokens)] = new_token;
80
81 // Stop if we hit an EOS or padding token.
82 if (is_end_token(new_token))
83 break;
84
85 // Decode: run the model for the newly generated token to update hidden state.
86 decode(new_token, hidden); // hidden = decode(new_token)
87
88 // Unembed: get logits for the next step.
89 unemb(hidden, 1, logits); // logits = unemb(hidden)
90 }
91
92 // Report how many tokens were actually generated.
93 *n_predict = _cache.pos() - n_tokens;
94 }
95 catch (const std::exception &e)
96 {
97 std::cerr << "Error in generate: " << e.what() << std::endl;
98 return GGMA_STATUS_ERROR;
99 }
101}
void prefill(ggma_token *tokens, size_t n_tokens, std::vector< uint8_t > &hidden_state)
Definition Context.cc:95
void decode(ggma_token token_id, std::vector< uint8_t > &hidden_state)
Definition Context.cc:276
ggma_token sample(const std::vector< float > &logits)
Definition Context.cc:294
void unemb(std::vector< uint8_t > &hidden_state, size_t n_tokens, std::vector< float > &logits)
Definition Context.cc:161
@ GGMA_STATUS_NO_ERROR
Definition ggma_types.h:37
@ GGMA_STATUS_ERROR
Definition ggma_types.h:42
int32_t ggma_token
Definition ggma_types.h:53
ModelConfig model
Definition Config.h:67
void set_pos(int pos)
Definition KVCache.h:93
void transpose(bool is_k_cache, const char *perm, size_t seq_len, size_t num_heads, size_t head_dim)
Transpose cache with "0213" permutation [0,2,1,3].
Definition KVCache.cc:68
void reset_pos()
Definition KVCache.h:94
int64_t pos() const
Definition KVCache.h:92
std::optional< int > eos_token_id
Definition Config.h:45
int num_attention_heads
Definition Config.h:41

References ggma::GGMAConfig::cache_size, decode(), ggma::ModelConfig::eos_token_id, GGMA_STATUS_ERROR, GGMA_STATUS_NO_ERROR, ggma::ModelConfig::hidden_size, ggma::GGMAConfig::model, ggma::ModelConfig::num_attention_heads, ggma::KVCache::pos(), prefill(), ggma::KVCache::reset_pos(), sample(), ggma::KVCache::set_pos(), ggma::KVCache::transpose(), and unemb().

◆ load_config()

ggma::GGMAConfig ggma::Context::load_config ( const std::string &  package_path)

Definition at line 84 of file Context.cc.

85{
86 GGMAConfig config;
87
88 // Load config from package path/config.json
89 std::filesystem::path config_path = std::filesystem::path(package_path) / "config.json";
90 config.model.load_from_file(config_path.string());
91
92 return config;
93}

Referenced by Context().

◆ prefill()

void ggma::Context::prefill ( ggma_token tokens,
size_t  n_tokens,
std::vector< uint8_t > &  hidden_state 
)

Definition at line 95 of file Context.cc.

96{
97 std::filesystem::path nnpkg_path = std::filesystem::path(_package_path) / "prefill";
98 nnfw_session *session = create_and_prepare_session(nnpkg_path.string());
99
101
102 // Input 0: token_id
103 // shape = [n_batch, n_seq]
104 // n_batch = 1
106 if (ti.rank != 2 || ti.dims[0] != 1)
107 throw std::runtime_error("prefill : invalid input shape");
108
109 // TODO: Check ubatch from model is same to runtime config
110 int ubatch = ti.dims[1]; // Number of tokens after padding to align to 32 multiples
111 // Use tokens as input without copying (zero-copy)
112 NNFW_ENSURE_STATUS(nnfw_set_input(session, 0, ti.dtype, tokens, ubatch * sizeof(ggma_token)));
113
114 // Expected Output:
115 //
116 // Index | Name | Description
117 // ------|----------|---------------------------
118 // 0 | k0 | key cache for layer 0
119 // 1 | v0 | value cache for layer 0
120 // ... | ... | ...
121 // 2n-2 | k{n-1} | key cache for layer n-1
122 // 2n-1 | v{n-1} | value cache for layer n-1
123 // 2n | hidden | hidden state
124 //
125 // where n = number of layers
126
127 uint32_t num_outputs;
128 NNFW_ENSURE_STATUS(nnfw_output_size(session, &num_outputs));
129 if (num_outputs != _cfg.model.n_layers * 2 + 1)
130 throw std::runtime_error("prefill : number of outputs mismatch");
131
132 // Output 0~2n-1: KV caches
133 for (int i = 0; i < _cfg.model.n_layers; ++i)
134 {
135 if (!_cache.v[i].empty())
136 NNFW_ENSURE_STATUS(nnfw_set_output(session, 2 * i, _cache.to_nnfw_type(), _cache.v[i].data(),
137 _cache.v[i].size()));
138 if (!_cache.k[i].empty())
139 NNFW_ENSURE_STATUS(nnfw_set_output(session, 2 * i + 1, _cache.to_nnfw_type(),
140 _cache.k[i].data(), _cache.k[i].size()));
141 }
142
143 // Output 2n: hidden_state
144 // shape = [n_batch, n_seq, n_emb]
145 // n_batch = 1
146
148 if (ti.rank != 3 || ti.dims[0] != 1)
149 throw std::runtime_error("prefill : invalid hidden shape");
150
151 // Allocate output buffer
152 hidden_state.resize(bufsize_for(&ti), 0);
153 // Output buffer setup - use externally allocated hidden_state (single output for single token)
155 nnfw_set_output(session, num_outputs - 1, ti.dtype, hidden_state.data(), hidden_state.size()));
156
158 nnfw_close_session(session);
159}
#define NNFW_ENSURE_STATUS(a)
Definition Context.cc:35
SessionID session(const coco::Module *m)
Definition Session.cpp:48
uint64_t bufsize_for(const nnfw_tensorinfo *ti)
Definition Context.cc:63
nnfw_session * create_and_prepare_session(const std::string &model_path)
Definition Context.cc:45
NNFW_STATUS nnfw_set_input(nnfw_session *session, uint32_t index, NNFW_TYPE type, const void *buffer, size_t length)
Set input buffer.
Definition APIImpl.cc:102
NNFW_STATUS nnfw_output_tensorinfo(nnfw_session *session, uint32_t index, nnfw_tensorinfo *tensor_info)
Get i-th output tensor info.
Definition APIImpl.cc:159
NNFW_STATUS nnfw_input_tensorinfo(nnfw_session *session, uint32_t index, nnfw_tensorinfo *tensor_info)
Get i-th input tensor info.
Definition APIImpl.cc:152
NNFW_STATUS nnfw_output_size(nnfw_session *session, uint32_t *number)
Get the number of outputs.
Definition APIImpl.cc:122
NNFW_STATUS nnfw_run(nnfw_session *session)
Run inference.
Definition APIImpl.cc:84
NNFW_STATUS nnfw_set_output(nnfw_session *session, uint32_t index, NNFW_TYPE type, void *buffer, size_t length)
Set output buffer.
Definition APIImpl.cc:109
NNFW_STATUS nnfw_close_session(nnfw_session *session)
Close a session instance.
Definition APIImpl.cc:66
std::vector< std::vector< uint8_t > > k
Definition KVCache.h:46
std::vector< std::vector< uint8_t > > v
Definition KVCache.h:47
NNFW_TYPE to_nnfw_type() const
Definition KVCache.h:65
tensor info describes the type and shape of tensors
NNFW_TYPE dtype
int32_t dims[NNFW_MAX_RANK]

References ggma::bufsize_for(), ggma::create_and_prepare_session(), nnfw_tensorinfo::dims, nnfw_tensorinfo::dtype, ggma::KVCache::k, ggma::GGMAConfig::model, ggma::ModelConfig::n_layers, nnfw_close_session(), NNFW_ENSURE_STATUS, nnfw_input_tensorinfo(), nnfw_output_size(), nnfw_output_tensorinfo(), nnfw_run(), nnfw_set_input(), nnfw_set_output(), nnfw_tensorinfo::rank, ggma::KVCache::to_nnfw_type(), and ggma::KVCache::v.

Referenced by generate().

◆ sample()

ggma_token ggma::Context::sample ( const std::vector< float > &  logits)

Definition at line 294 of file Context.cc.

295{
296 if (logits.empty())
297 throw std::runtime_error("Empty logits tensor");
298
299 // Calculate total number of float elements in logits tensor
300 size_t total_elements = logits.size();
301
302 if (total_elements % _cfg.model.vocab_size != 0)
303 throw std::runtime_error("Invalid sequence length in logits tensor");
304
305 const float *last_logits = logits.data() + (total_elements - _cfg.model.vocab_size);
306
307 // Find the token with maximum logit value from the last token's logits
308 const float *max_elem_iter = std::max_element(last_logits, last_logits + _cfg.model.vocab_size);
309
310 return std::distance(last_logits, max_elem_iter);
311}

References ggma::GGMAConfig::model, and ggma::ModelConfig::vocab_size.

Referenced by generate().

◆ unemb()

void ggma::Context::unemb ( std::vector< uint8_t > &  hidden_state,
size_t  n_tokens,
std::vector< float > &  logits 
)

Definition at line 161 of file Context.cc.

162{
163 std::filesystem::path nnpkg_path = std::filesystem::path(_package_path) / "unemb";
164 nnfw_session *session = create_and_prepare_session(nnpkg_path.string());
165
166 // Input buffer setup - use externally allocated hidden_state
169 // ti[0] : n_batch
170 // ti[1] : n_seq = ubatch if padded
171 // = n_tokens if not padded
172 if (ti.rank != 3 || ti.dims[0] != 1)
173 throw std::runtime_error("unemb : invalid input shape");
174 assert(ti.dims[1] == _cfg.ubatch); // Previously, it was padded to ubatch.
175 // Handle effective (actual) tokens only.
176 ti.dims[1] = n_tokens;
177 // Update buffer and nnfw input tensor info as sequence length is adjusted.
178 hidden_state.resize(bufsize_for(&ti), 0);
181 nnfw_set_input(session, 0, ti.dtype, hidden_state.data(), hidden_state.size()));
182
183 // Output buffer setup - use externally allocated logits
185 // Check if output data type is float
187 throw std::runtime_error("unemb: output tensor must be float type");
188 // Allocate output buffer
189 // ti[0] : n_batch
190 // ti[1] : n_seq = ubatch if padded
191 // = n_tokens if not padded
192 if (ti.rank != 3 || ti.dims[0] != 1)
193 throw std::runtime_error("unemb : invalid output shape");
194 assert(ti.dims[1] == _cfg.ubatch); // Previously, it was padded to ubatch.
195 // Handle effective (actual) tokens only.
196 ti.dims[1] = n_tokens;
197 logits.resize(num_elems(&ti), 0);
199 nnfw_set_output(session, 0, ti.dtype, logits.data(), logits.size() * sizeof(logits[0])));
200
202 nnfw_close_session(session);
203}
uint64_t num_elems(const nnfw_tensorinfo *tensor_info)
Definition Context.cc:55
NNFW_STATUS nnfw_set_input_tensorinfo(nnfw_session *session, uint32_t index, const nnfw_tensorinfo *tensor_info)
Set input model's tensor info for resizing.
Definition APIImpl.cc:178
@ NNFW_TYPE_TENSOR_FLOAT32
Definition onert-micro.h:77

References ggma::bufsize_for(), ggma::create_and_prepare_session(), nnfw_tensorinfo::dims, nnfw_tensorinfo::dtype, nnfw_close_session(), NNFW_ENSURE_STATUS, nnfw_input_tensorinfo(), nnfw_output_tensorinfo(), nnfw_run(), nnfw_set_input(), nnfw_set_input_tensorinfo(), nnfw_set_output(), NNFW_TYPE_TENSOR_FLOAT32, ggma::num_elems(), nnfw_tensorinfo::rank, and ggma::GGMAConfig::ubatch.

Referenced by generate().


The documentation for this class was generated from the following files: