Skip to content

Commit d7efa43

Browse files
committed
refactor(core): Standardize model representation as an object olimorris#1521
This commit refactors the handling of model identifiers to consistently support model objects. Previously, models were often treated as simple strings. This change ensures that if a model is represented as an object (e.g., `{ name = "model-id", ... }`), its `name` property is correctly used. Key changes: - In `http.lua`, the request building logic now explicitly extracts `model.name` if `model` is an object, ensuring the correct string identifier is sent in the API request. - Adapters for Copilot and OpenAI have been updated to access `model.name` for internal logic that relies on the model identifier, such as conditional parameter availability or message role transformations. This change improves the robustness and flexibility of model handling within the system, paving the way for more structured model metadata.
1 parent 4318576 commit d7efa43

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

lua/codecompanion/adapters/copilot.lua

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ return {
336336
if type(model) == "function" then
337337
model = model()
338338
end
339-
return not vim.startswith(model, "o1")
339+
return not vim.startswith(model.name, "o1")
340340
end,
341341
desc = "What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or top_p but not both.",
342342
},
@@ -358,7 +358,7 @@ return {
358358
if type(model) == "function" then
359359
model = model()
360360
end
361-
return not vim.startswith(model, "o1")
361+
return not vim.startswith(model.name, "o1")
362362
end,
363363
desc = "An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or temperature but not both.",
364364
},
@@ -373,7 +373,7 @@ return {
373373
if type(model) == "function" then
374374
model = model()
375375
end
376-
return not vim.startswith(model, "o1")
376+
return not vim.startswith(model.name, "o1")
377377
end,
378378
desc = "How many chat completions to generate for each prompt.",
379379
},

lua/codecompanion/adapters/openai.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ return {
8080
messages = vim
8181
.iter(messages)
8282
:map(function(m)
83-
if vim.startswith(model, "o1") and m.role == "system" then
83+
if vim.startswith(model.name, "o1") and m.role == "system" then
8484
m.role = self.roles.user
8585
end
8686

lua/codecompanion/http.lua

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,21 @@ function Client:request(payload, actions, opts)
8686

8787
adapter:get_env_vars()
8888

89-
local body = self.opts.encode(
90-
vim.tbl_extend(
91-
"keep",
92-
handlers.form_parameters
93-
and handlers.form_parameters(adapter, adapter:set_env_vars(adapter.parameters), payload.messages)
94-
or {},
95-
handlers.form_messages and handlers.form_messages(adapter, payload.messages) or {},
96-
handlers.form_tools and handlers.form_tools(adapter, payload.tools) or {},
97-
adapter.body and adapter.body or {},
98-
handlers.set_body and handlers.set_body(adapter, payload) or {}
99-
)
89+
local body_tbl = vim.tbl_extend(
90+
"keep",
91+
handlers.form_parameters
92+
and handlers.form_parameters(adapter, adapter:set_env_vars(adapter.parameters), payload.messages)
93+
or {},
94+
handlers.form_messages and handlers.form_messages(adapter, payload.messages) or {},
95+
handlers.form_tools and handlers.form_tools(adapter, payload.tools) or {},
96+
adapter.body and adapter.body or {},
97+
handlers.set_body and handlers.set_body(adapter, payload) or {}
10098
)
99+
if body_tbl.model and type(body_tbl.model) == "table" then
100+
body_tbl.model = body_tbl.model.name
101+
end
102+
103+
local body = self.opts.encode(body_tbl)
101104

102105
local body_file = Path.new(vim.fn.tempname() .. ".json")
103106
body_file:write(vim.split(body, "\n"), "w")

0 commit comments

Comments
 (0)