1
+ /* ***********************************************************************************
2
+ * Copyright (c) 2023, xeus-cpp contributors *
3
+ * Copyright (c) 2023, Johan Mabille, Loic Gouarin, Sylvain Corlay, Wolf Vollprecht *
4
+ * *
5
+ * Distributed under the terms of the BSD 3-Clause License. *
6
+ * *
7
+ * The full license is in the file LICENSE, distributed with this software. *
8
+ ************************************************************************************/
9
+ #include " xassist.hpp"
10
+
11
+ #include < curl/curl.h>
12
+ #include < fstream>
13
+ #include < iostream>
14
+ #include < nlohmann/json.hpp>
15
+ #include < string>
16
+ #include < unordered_set>
17
+
18
+ using json = nlohmann::json;
19
+
20
+ namespace xcpp
21
+ {
22
+ class APIKeyManager
23
+ {
24
+ public:
25
+
26
+ static void saveApiKey (const std::string& model, const std::string& apiKey)
27
+ {
28
+ std::string apiKeyFilePath = model + " _api_key.txt" ;
29
+ std::ofstream out (apiKeyFilePath);
30
+ if (out)
31
+ {
32
+ out << apiKey;
33
+ out.close ();
34
+ std::cout << " API key saved for model " << model << std::endl;
35
+ }
36
+ else
37
+ {
38
+ std::cerr << " Failed to open file for writing API key for model " << model << std::endl;
39
+ }
40
+ }
41
+
42
+ // Method to load the API key for a specific model
43
+ static std::string loadApiKey (const std::string& model)
44
+ {
45
+ std::string apiKeyFilePath = model + " _api_key.txt" ;
46
+ std::ifstream in (apiKeyFilePath);
47
+ std::string apiKey;
48
+ if (in)
49
+ {
50
+ std::getline (in, apiKey);
51
+ in.close ();
52
+ return apiKey;
53
+ }
54
+
55
+ std::cerr << " Failed to open file for reading API key for model " << model << std::endl;
56
+ return " " ;
57
+ }
58
+ };
59
+
60
+ class CurlHelper
61
+ {
62
+ private:
63
+
64
+ CURL* m_curl;
65
+ curl_slist* m_headers;
66
+
67
+ public:
68
+
69
+ CurlHelper ()
70
+ : m_curl(curl_easy_init())
71
+ , m_headers(curl_slist_append(nullptr , " Content-Type: application/json" ))
72
+ {
73
+ }
74
+
75
+ ~CurlHelper ()
76
+ {
77
+ if (m_curl)
78
+ {
79
+ curl_easy_cleanup (m_curl);
80
+ }
81
+ if (m_headers)
82
+ {
83
+ curl_slist_free_all (m_headers);
84
+ }
85
+ }
86
+
87
+ // Delete copy constructor and copy assignment operator
88
+ CurlHelper (const CurlHelper&) = delete ;
89
+ CurlHelper& operator =(const CurlHelper&) = delete ;
90
+
91
+ // Delete move constructor and move assignment operator
92
+ CurlHelper (CurlHelper&&) = delete ;
93
+ CurlHelper& operator =(CurlHelper&&) = delete ;
94
+
95
+ std::string
96
+ performRequest (const std::string& url, const std::string& postData, const std::string& authHeader = " " )
97
+ {
98
+ if (!authHeader.empty ())
99
+ {
100
+ m_headers = curl_slist_append (m_headers, authHeader.c_str ());
101
+ }
102
+
103
+ curl_easy_setopt (m_curl, CURLOPT_URL, url.c_str ());
104
+ curl_easy_setopt (m_curl, CURLOPT_HTTPHEADER, m_headers);
105
+ curl_easy_setopt (m_curl, CURLOPT_POSTFIELDS, postData.c_str ());
106
+
107
+ std::string response;
108
+ curl_easy_setopt (
109
+ m_curl,
110
+ CURLOPT_WRITEFUNCTION,
111
+ +[](const char * in, size_t size, size_t num, std::string* out)
112
+ {
113
+ const size_t totalBytes (size * num);
114
+ out->append (in, totalBytes);
115
+ return totalBytes;
116
+ }
117
+ );
118
+ curl_easy_setopt (m_curl, CURLOPT_WRITEDATA, &response);
119
+
120
+ CURLcode res = curl_easy_perform (m_curl);
121
+ if (res != CURLE_OK)
122
+ {
123
+ std::cerr << " CURL request failed: " << curl_easy_strerror (res) << std::endl;
124
+ return " " ;
125
+ }
126
+
127
+ return response;
128
+ }
129
+ };
130
+
131
+ std::string gemini (const std::string& cell, const std::string& key)
132
+ {
133
+ CurlHelper curlHelper;
134
+ const std::string url = " https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key="
135
+ + key;
136
+ const std::string postData = R"( {"contents": [{"parts":[{"text": ")" + cell + R"( "}]}]})" ;
137
+
138
+ std::string response = curlHelper.performRequest (url, postData);
139
+
140
+ json j = json::parse (response);
141
+ if (j.find (" error" ) != j.end ())
142
+ {
143
+ std::cerr << " Error: " << j[" error" ][" message" ] << std::endl;
144
+ return " " ;
145
+ }
146
+
147
+ return j[" candidates" ][0 ][" content" ][" parts" ][0 ][" text" ];
148
+ }
149
+
150
+ std::string openai (const std::string& cell, const std::string& key)
151
+ {
152
+ CurlHelper curlHelper;
153
+ const std::string url = " https://api.openai.com/v1/chat/completions" ;
154
+ const std::string postData = R"( {
155
+ "model": "gpt-3.5-turbo-16k",
156
+ "messages": [{"role": "user", "content": ")"
157
+ + cell + R"( "}],
158
+ "temperature": 0.7
159
+ })" ;
160
+ std::string authHeader = " Authorization: Bearer " + key;
161
+
162
+ std::string response = curlHelper.performRequest (url, postData, authHeader);
163
+
164
+ json j = json::parse (response);
165
+
166
+ if (j.find (" error" ) != j.end ())
167
+ {
168
+ std::cerr << " Error: " << j[" error" ][" message" ] << std::endl;
169
+ return " " ;
170
+ }
171
+
172
+ return j[" choices" ][0 ][" message" ][" content" ];
173
+ }
174
+
175
+ void xassist::operator ()(const std::string& line, const std::string& cell)
176
+ {
177
+ try
178
+ {
179
+ std::istringstream iss (line);
180
+ std::vector<std::string> tokens (
181
+ std::istream_iterator<std::string>{iss},
182
+ std::istream_iterator<std::string>()
183
+ );
184
+
185
+ std::vector<std::string> models = {" gemini" , " openai" };
186
+ std::string model = tokens[1 ];
187
+
188
+ if (std::find (models.begin (), models.end (), model) == models.end ())
189
+ {
190
+ std::cerr << " Model not found." << std::endl;
191
+ return ;
192
+ }
193
+
194
+ APIKeyManager api;
195
+ if (tokens[2 ] == " --save-key" )
196
+ {
197
+ xcpp::APIKeyManager::saveApiKey (model, cell);
198
+ return ;
199
+ }
200
+
201
+ std::string key = xcpp::APIKeyManager::loadApiKey (model);
202
+ if (key.empty ())
203
+ {
204
+ std::cerr << " API key for model " << model << " is not available." << std::endl;
205
+ return ;
206
+ }
207
+
208
+ std::string response;
209
+ if (model == " gemini" )
210
+ {
211
+ response = gemini (cell, key);
212
+ }
213
+ else if (model == " openai" )
214
+ {
215
+ response = openai (cell, key);
216
+ }
217
+
218
+ std::cout << response;
219
+ }
220
+ catch (const std::runtime_error& e)
221
+ {
222
+ std::cerr << " Caught an exception: " << e.what () << std::endl;
223
+ }
224
+ catch (...)
225
+ {
226
+ std::cerr << " Caught an unknown exception" << std::endl;
227
+ }
228
+ }
229
+ } // namespace xcpp
0 commit comments