ONE - On-device Neural Engine
Loading...
Searching...
No Matches
Config.cc
Go to the documentation of this file.
1/*
2 * Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Config.h"
18
19#include <fstream>
20#include <json/json.h>
21#include <sstream>
22
23namespace ggma
24{
25
26// Helper functions for config loading with template specialization
27template <typename T>
28void load_config_field(const Json::Value &root, const std::string &field_name, T &target,
29 bool is_optional = false)
30{
31 if (root.isMember(field_name))
32 target = root[field_name].asInt();
33 else if (!is_optional)
34 throw std::runtime_error(field_name + " not found in config.json");
35}
36
37// Template specialization for bool type
38template <>
39void load_config_field<bool>(const Json::Value &root, const std::string &field_name, bool &target,
40 bool is_optional)
41{
42 if (root.isMember(field_name))
43 target = root[field_name].asBool();
44 else if (!is_optional)
45 throw std::runtime_error(field_name + " not found in config.json");
46}
47
48// Template specialization for std::optional<int> type
49template <>
50void load_config_field<std::optional<int>>(const Json::Value &root, const std::string &field_name,
51 std::optional<int> &target, bool is_optional)
52{
53 if (root.isMember(field_name))
54 target = root[field_name].asInt();
55}
56
57// Constructor with default values
59
60// Load configuration from JSON file
61void ModelConfig::load_from_file(const std::string &config_path)
62{
63 std::ifstream config_file(config_path);
64
65 if (!config_file.is_open())
66 throw std::runtime_error(
67 "config.json not found in package. This file is required for ggma_context.");
68
69 try
70 {
71 Json::Value root;
72 Json::Reader reader;
73
74 if (!reader.parse(config_file, root, false))
75 throw std::runtime_error("Failed to parse JSON: " + reader.getFormattedErrorMessages());
76
77 load_from_json(root);
78 }
79 catch (const std::exception &e)
80 {
81 // Re-throw the exception to ensure session creation fails
82 throw std::runtime_error("Failed to load config.json: " + std::string(e.what()));
83 }
84}
85
86// Load configuration from JSON value
87void ModelConfig::load_from_json(const Json::Value &root)
88{
89 // Load model configuration from Hugging Face config.json
90 load_config_field(root, "num_hidden_layers", n_layers);
91 load_config_field(root, "hidden_size", hidden_size);
92 load_config_field(root, "num_attention_heads", num_attention_heads);
93 load_config_field(root, "vocab_size", vocab_size);
94 load_config_field(root, "max_position_embeddings", max_position_embeddings);
95 load_config_field(root, "bos_token_id", bos_token_id);
96 load_config_field(root, "eos_token_id", eos_token_id);
97}
98
99// Validate configuration
101{
102 // Check required fields are positive
103 if (n_layers <= 0)
104 return false;
105 if (hidden_size <= 0)
106 return false;
107 if (num_attention_heads <= 0)
108 return false;
109 if (vocab_size <= 0)
110 return false;
112 return false;
113
114 // Check token IDs are non-negative (only if they have values)
115 if (bos_token_id.has_value() && bos_token_id.value() < 0)
116 return false;
117 if (eos_token_id.has_value() && eos_token_id.value() < 0)
118 return false;
119
120 return true;
121}
122
123// Get configuration as string (for debugging)
124std::string ModelConfig::to_string() const
125{
126 std::ostringstream oss;
127 oss << "ModelConfig {\n";
128 oss << " n_layers: " << n_layers << "\n";
129 oss << " hidden_size: " << hidden_size << "\n";
130 oss << " num_attention_heads: " << num_attention_heads << "\n";
131 oss << " vocab_size: " << vocab_size << "\n";
132 oss << " max_position_embeddings: " << max_position_embeddings << "\n";
133 oss << " bos_token_id: "
134 << (bos_token_id.has_value() ? std::to_string(bos_token_id.value()) : "undefined") << "\n";
135 oss << " eos_token_id: "
136 << (eos_token_id.has_value() ? std::to_string(eos_token_id.value()) : "undefined") << "\n";
137 oss << "}";
138 return oss.str();
139}
140
141// Utility functions for ModelConfig
142bool validate_model_config(const ModelConfig &config) { return config.is_valid(); }
143
144std::string to_string(const ModelConfig &config) { return config.to_string(); }
145
146} // namespace ggma
Definition Config.cc:24
void load_config_field(const Json::Value &root, const std::string &field_name, T &target, bool is_optional=false)
Definition Config.cc:28
std::string to_string(const ModelConfig &config)
Definition Config.cc:144
void load_config_field< bool >(const Json::Value &root, const std::string &field_name, bool &target, bool is_optional)
Definition Config.cc:39
bool validate_model_config(const ModelConfig &config)
Definition Config.cc:142
std::optional< int > eos_token_id
Definition Config.h:45
int max_position_embeddings
Definition Config.h:43
int num_attention_heads
Definition Config.h:41
bool is_valid() const
Definition Config.cc:100
void load_from_json(const Json::Value &root)
Definition Config.cc:87
void load_from_file(const std::string &config_path)
Definition Config.cc:61
std::optional< int > bos_token_id
Definition Config.h:44
std::string to_string() const
Definition Config.cc:124