1+ /* ***********************************************************************************
2+ * Copyright (c) 2023, xeus-cpp contributors *
3+ * Copyright (c) 2023, Johan Mabille, Loic Gouarin, Sylvain Corlay, Wolf Vollprecht *
4+ * *
5+ * Distributed under the terms of the BSD 3-Clause License. *
6+ * *
7+ * The full license is in the file LICENSE, distributed with this software. *
8+ ************************************************************************************/
9+ #include " xassist.hpp"
10+
11+ #include < curl/curl.h>
12+ #include < fstream>
13+ #include < iostream>
14+ #include < nlohmann/json.hpp>
15+ #include < string>
16+ #include < unordered_set>
17+
18+ using json = nlohmann::json;
19+
20+ namespace xcpp
21+ {
22+ class APIKeyManager
23+ {
24+ public:
25+
26+ static void saveApiKey (const std::string& model, const std::string& apiKey)
27+ {
28+ std::string apiKeyFilePath = model + " _api_key.txt" ;
29+ std::ofstream out (apiKeyFilePath);
30+ if (out)
31+ {
32+ out << apiKey;
33+ out.close ();
34+ std::cout << " API key saved for model " << model << std::endl;
35+ }
36+ else
37+ {
38+ std::cerr << " Failed to open file for writing API key for model " << model << std::endl;
39+ }
40+ }
41+
42+ // Method to load the API key for a specific model
43+ static std::string loadApiKey (const std::string& model)
44+ {
45+ std::string apiKeyFilePath = model + " _api_key.txt" ;
46+ std::ifstream in (apiKeyFilePath);
47+ std::string apiKey;
48+ if (in)
49+ {
50+ std::getline (in, apiKey);
51+ in.close ();
52+ return apiKey;
53+ }
54+
55+ std::cerr << " Failed to open file for reading API key for model " << model << std::endl;
56+ return " " ;
57+ }
58+ };
59+
60+ class CurlHelper
61+ {
62+ private:
63+
64+ CURL* m_curl;
65+ curl_slist* m_headers;
66+
67+ public:
68+
69+ CurlHelper ()
70+ : m_curl(curl_easy_init())
71+ , m_headers(curl_slist_append(nullptr , " Content-Type: application/json" ))
72+ {
73+ }
74+
75+ ~CurlHelper ()
76+ {
77+ if (m_curl)
78+ {
79+ curl_easy_cleanup (m_curl);
80+ }
81+ if (m_headers)
82+ {
83+ curl_slist_free_all (m_headers);
84+ }
85+ }
86+
87+ // Delete copy constructor and copy assignment operator
88+ CurlHelper (const CurlHelper&) = delete ;
89+ CurlHelper& operator =(const CurlHelper&) = delete ;
90+
91+ // Delete move constructor and move assignment operator
92+ CurlHelper (CurlHelper&&) = delete ;
93+ CurlHelper& operator =(CurlHelper&&) = delete ;
94+
95+ std::string
96+ performRequest (const std::string& url, const std::string& postData, const std::string& authHeader = " " )
97+ {
98+ if (!authHeader.empty ())
99+ {
100+ m_headers = curl_slist_append (m_headers, authHeader.c_str ());
101+ }
102+
103+ curl_easy_setopt (m_curl, CURLOPT_URL, url.c_str ());
104+ curl_easy_setopt (m_curl, CURLOPT_HTTPHEADER, m_headers);
105+ curl_easy_setopt (m_curl, CURLOPT_POSTFIELDS, postData.c_str ());
106+
107+ std::string response;
108+ curl_easy_setopt (
109+ m_curl,
110+ CURLOPT_WRITEFUNCTION,
111+ +[](const char * in, size_t size, size_t num, std::string* out)
112+ {
113+ const size_t totalBytes (size * num);
114+ out->append (in, totalBytes);
115+ return totalBytes;
116+ }
117+ );
118+ curl_easy_setopt (m_curl, CURLOPT_WRITEDATA, &response);
119+
120+ CURLcode res = curl_easy_perform (m_curl);
121+ if (res != CURLE_OK)
122+ {
123+ std::cerr << " CURL request failed: " << curl_easy_strerror (res) << std::endl;
124+ return " " ;
125+ }
126+
127+ return response;
128+ }
129+ };
130+
131+ std::string gemini (const std::string& cell, const std::string& key)
132+ {
133+ CurlHelper curlHelper;
134+ const std::string url = " https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key="
135+ + key;
136+ const std::string postData = R"( {"contents": [{"parts":[{"text": ")" + cell + R"( "}]}]})" ;
137+
138+ std::string response = curlHelper.performRequest (url, postData);
139+
140+ json j = json::parse (response);
141+ if (j.find (" error" ) != j.end ())
142+ {
143+ std::cerr << " Error: " << j[" error" ][" message" ] << std::endl;
144+ return " " ;
145+ }
146+
147+ return j[" candidates" ][0 ][" content" ][" parts" ][0 ][" text" ];
148+ }
149+
150+ std::string openai (const std::string& cell, const std::string& key)
151+ {
152+ CurlHelper curlHelper;
153+ const std::string url = " https://api.openai.com/v1/chat/completions" ;
154+ const std::string postData = R"( {
155+ "model": "gpt-3.5-turbo-16k",
156+ "messages": [{"role": "user", "content": ")"
157+ + cell + R"( "}],
158+ "temperature": 0.7
159+ })" ;
160+ std::string authHeader = " Authorization: Bearer " + key;
161+
162+ std::string response = curlHelper.performRequest (url, postData, authHeader);
163+
164+ json j = json::parse (response);
165+
166+ if (j.find (" error" ) != j.end ())
167+ {
168+ std::cerr << " Error: " << j[" error" ][" message" ] << std::endl;
169+ return " " ;
170+ }
171+
172+ return j[" choices" ][0 ][" message" ][" content" ];
173+ }
174+
175+ void xassist::operator ()(const std::string& line, const std::string& cell)
176+ {
177+ try
178+ {
179+ std::istringstream iss (line);
180+ std::vector<std::string> tokens (
181+ std::istream_iterator<std::string>{iss},
182+ std::istream_iterator<std::string>()
183+ );
184+
185+ std::vector<std::string> models = {" gemini" , " openai" };
186+ std::string model = tokens[1 ];
187+
188+ if (std::find (models.begin (), models.end (), model) == models.end ())
189+ {
190+ std::cerr << " Model not found." << std::endl;
191+ return ;
192+ }
193+
194+ APIKeyManager api;
195+ if (tokens[2 ] == " --save-key" )
196+ {
197+ xcpp::APIKeyManager::saveApiKey (model, cell);
198+ return ;
199+ }
200+
201+ std::string key = xcpp::APIKeyManager::loadApiKey (model);
202+ if (key.empty ())
203+ {
204+ std::cerr << " API key for model " << model << " is not available." << std::endl;
205+ return ;
206+ }
207+
208+ std::string response;
209+ if (model == " gemini" )
210+ {
211+ response = gemini (cell, key);
212+ }
213+ else if (model == " openai" )
214+ {
215+ response = openai (cell, key);
216+ }
217+
218+ std::cout << response;
219+ }
220+ catch (const std::runtime_error& e)
221+ {
222+ std::cerr << " Caught an exception: " << e.what () << std::endl;
223+ }
224+ catch (...)
225+ {
226+ std::cerr << " Caught an unknown exception" << std::endl;
227+ }
228+ }
229+ } // namespace xcpp
0 commit comments