Skip to content

Commit af1a409

Browse files
qtnxRobitx
authored andcommitted
feat: handle gpt o1-preview, o1-mini models
1 parent f4cbbf4 commit af1a409

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

lua/gp/dispatcher.lua

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,29 @@ D.prepare_payload = function(messages, model, provider)
165165
model.model = "gpt-4o-2024-05-13"
166166
end
167167

168-
return {
168+
local output = {
169169
model = model.model,
170170
stream = true,
171171
messages = messages,
172172
max_tokens = model.max_tokens or 4096,
173173
temperature = math.max(0, math.min(2, model.temperature or 1)),
174174
top_p = math.max(0, math.min(1, model.top_p or 1)),
175175
}
176+
177+
if provider == "openai" and model.model:sub(1, 2) == "o1" then
178+
for i = #messages, 1, -1 do
179+
if messages[i].role == "system" then
180+
table.remove(messages, i)
181+
end
182+
end
183+
-- remove max_tokens, top_p, temperature for o1 models. https://platform.openai.com/docs/guides/reasoning/beta-limitations
184+
output.max_tokens = nil
185+
output.temperature = nil
186+
output.top_p = nil
187+
output.stream = false
188+
end
189+
190+
return output
176191
end
177192

178193
-- gpt query
@@ -249,6 +264,7 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
249264
end
250265
end
251266

267+
252268
if content and type(content) == "string" then
253269
qt.response = qt.response .. content
254270
handler(qid, content)
@@ -282,6 +298,19 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
282298
if #buffer > 0 then
283299
process_lines(buffer)
284300
end
301+
local raw_response = qt.raw_response
302+
local content = qt.response
303+
if qt.provider == 'openai' and content == "" and raw_response:match('choices') and raw_response:match("content") then
304+
local response = vim.json.decode(raw_response)
305+
if response.choices and response.choices[1] and response.choices[1].message and response.choices[1].message.content then
306+
content = response.choices[1].message.content
307+
end
308+
if content and type(content) == "string" then
309+
qt.response = qt.response .. content
310+
handler(qid, content)
311+
end
312+
end
313+
285314

286315
if qt.response == "" then
287316
logger.error(qt.provider .. " response is empty: \n" .. vim.inspect(qt.raw_response))
@@ -363,7 +392,8 @@ local query = function(buf, provider, payload, handler, on_exit, callback)
363392
}
364393
end
365394

366-
local temp_file = D.query_dir .. "/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
395+
local temp_file = D.query_dir ..
396+
"/" .. logger.now() .. "." .. string.format("%x", math.random(0, 0xFFFFFF)) .. ".json"
367397
helpers.table_to_file(payload, temp_file)
368398

369399
local curl_params = vim.deepcopy(D.config.curl_params or {})

0 commit comments

Comments
 (0)