ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Context.h
Go to the documentation of this file.
1/*
2 * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef __GGMA_CONTEXT_H__
18#define __GGMA_CONTEXT_H__
19
20#include "ggma_types.h"
21#include "Config.h"
22#include "KVCache.h"
23
24#include <cstdint>
25#include <memory>
26#include <string>
27#include <vector>
28
29namespace ggma
30{
31
33{
34public:
35 Context(const char *package_path);
36 GGMAConfig load_config(const std::string &package_path);
37
38 void prefill(ggma_token *tokens, size_t n_tokens, std::vector<uint8_t> &hidden_state);
39 void unemb(std::vector<uint8_t> &hidden_state, size_t n_tokens, std::vector<float> &logits);
40 ggma_token sample(const std::vector<float> &logits);
41 void decode(ggma_token token_id, std::vector<uint8_t> &hidden_state);
42 void decode(ggma_token token_id, std::vector<float> &logits);
43
44private:
45 // Template implementation to eliminate code duplication
46 template <bool ReturnLogits, typename OutputType>
47 void decode_impl(ggma_token token_id, OutputType &output);
48 void init_kv_cache();
49
50public:
51 ~Context() = default;
52
53 GGMA_STATUS generate(ggma_token *tokens, size_t n_tokens, size_t n_tokens_max, size_t *n_predict);
54
55private:
56 std::string _package_path;
58 ggma::KVCache _cache;
59};
60
61} // namespace ggma
62
63#endif // __GGMA_CONTEXT_H__
GGMAConfig load_config(const std::string &package_path)
Definition Context.cc:84
void prefill(ggma_token *tokens, size_t n_tokens, std::vector< uint8_t > &hidden_state)
Definition Context.cc:95
GGMA_STATUS generate(ggma_token *tokens, size_t n_tokens, size_t n_tokens_max, size_t *n_predict)
Definition Generate.cc:39
void decode(ggma_token token_id, std::vector< uint8_t > &hidden_state)
Definition Context.cc:276
ggma_token sample(const std::vector< float > &logits)
Definition Context.cc:294
void unemb(std::vector< uint8_t > &hidden_state, size_t n_tokens, std::vector< float > &logits)
Definition Context.cc:161
~Context()=default
This file defines the core types and status codes for GGMA API.
GGMA_STATUS
Enumeration of status codes returned by GGMA API functions.
Definition ggma_types.h:35
int32_t ggma_token
Definition ggma_types.h:53
Definition Config.cc:24