Skip to content

Commit 400fc6c

Browse files
committed
fix: model election algorithm
override -> command options -> defaults Signed-off-by: blob42 <contact@blob42.xyz>
1 parent 61ac567 commit 400fc6c

File tree

7 files changed

+98
-41
lines changed

7 files changed

+98
-41
lines changed

lua/codegpt/commands.lua

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@ local Utils = require("codegpt.utils")
22
local Ui = require("codegpt.ui")
33
local Providers = require("codegpt.providers")
44
local Api = require("codegpt.api")
5-
local config = require("codegpt.config")
5+
local Config = require("codegpt.config")
66
local models = require("codegpt.models")
77

88
local M = {}
99

1010
local text_popup_stream = function(stream, bufnr, start_row, start_col, end_row, end_col)
11-
local popup_filetype = config.opts.ui.text_popup_filetype
11+
local popup_filetype = Config.opts.ui.text_popup_filetype
1212
Ui.popup_stream(stream, popup_filetype, bufnr, start_row, start_col, end_row, end_col)
1313
end
1414

1515
M.CallbackTypes = {
1616
["text_popup_stream"] = text_popup_stream,
1717
["text_popup"] = function(lines, bufnr, start_row, start_col, end_row, end_col)
18-
local popup_filetype = config.opts.ui.text_popup_filetype
18+
local popup_filetype = Config.opts.ui.text_popup_filetype
1919
Ui.popup(lines, popup_filetype, bufnr, start_row, start_col, end_row, end_col)
2020
end,
2121
["code_popup"] = function(lines, bufnr, start_row, start_col, end_row, end_col)
@@ -45,15 +45,21 @@ M.CallbackTypes = {
4545
---@return table opts parsed options
4646
---@return boolean is_stream streaming enabled
4747
local function get_cmd_opts(cmd)
48-
local opts = config.opts.commands[cmd]
49-
local cmd_defaults = config.opts.global_defaults
48+
local opts = Config.opts.commands[cmd]
49+
-- print(vim.inspect(opts))
50+
local cmd_defaults = Config.opts.global_defaults
5051
local is_stream = false
5152

5253
-- print(vim.inspect(cmd))
5354
-- print(vim.inspect(config.opts.commands))
5455
-- print(vim.inspect(opts))
5556

56-
local _, model = models.get_model()
57+
local model
58+
if opts.model then
59+
_, model = models.get_model_by_name(opts.model)
60+
else
61+
_, model = models.get_model()
62+
end
5763

5864
---@type codegpt.CommandOpts
5965
--- options priority heighest->lowest: cmd options, model options, global
@@ -63,7 +69,7 @@ local function get_cmd_opts(cmd)
6369
opts.callback = opts.callback_type
6470
else
6571
if
66-
(config.opts.ui.stream_output and opts.callback_type == "text_popup")
72+
(Config.opts.ui.stream_output and opts.callback_type == "text_popup")
6773
or opts.callback_type == "test_popup_stream"
6874
then
6975
opts.callback = text_popup_stream
@@ -72,6 +78,8 @@ local function get_cmd_opts(cmd)
7278
opts.callback = M.CallbackTypes[opts.callback_type]
7379
end
7480
end
81+
-- print(vim.inspect(opts))
82+
-- error(1)
7583

7684
return opts, is_stream
7785
end

lua/codegpt/config.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ local M = {}
4242
---@field temperature? number Custom temperature for this command
4343
---@field max_tokens? number Custom max_tokens for this command
4444
---@field append_string? string String to append to prompt -- ex: /no_think
45+
---@field model? string Model to always use with this command
4546

4647
---@type { [string]: codegpt.CommandOpts }
4748
local default_commands = {

lua/codegpt/models.lua

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,38 @@
1-
local config = require("codegpt.config")
2-
local providers = require("codegpt.providers")
1+
local Config = require("codegpt.config")
2+
local Providers = require("codegpt.providers")
33
local M = {}
44

5+
function M.get_model_by_name(name)
6+
---@type codegpt.Model
7+
8+
local provider_name = vim.fn.tolower(Config.opts.connection.api_provider)
9+
local provider_config = Config.opts.models[provider_name]
10+
11+
if type(provider_config) == "string" and #provider_config > 0 then
12+
return provider_config, nil
13+
end
14+
15+
assert(type(provider_config) == "table")
16+
17+
if provider_config == nil then
18+
error("no models defined for " .. provider_name)
19+
end
20+
21+
local selected = name
22+
local result = provider_config[selected]
23+
24+
if result == nil then
25+
for model_name, model in pairs(provider_config) do
26+
if model.alias == selected then
27+
result = model
28+
selected = model_name
29+
break
30+
end
31+
end
32+
end
33+
34+
return selected, result
35+
end
536
--- default model selection order from highest to lowest priority
637
--- 1. global model_override (manual selection, always temporary for an nvim session)
738
--- 2. provider default model
@@ -12,10 +43,10 @@ function M.get_model()
1243
---@type codegpt.Model
1344
local result
1445

15-
local provider_name = vim.fn.tolower(config.opts.connection.api_provider)
46+
local provider_name = vim.fn.tolower(Config.opts.connection.api_provider)
1647
-- local selected = config.model_override or config.opts.models[provider].default or config.opts.models.default
1748
local selected
18-
local provider_config = config.opts.models[provider_name]
49+
local provider_config = Config.opts.models[provider_name]
1950

2051
-- of provider config is a string, then it must be just a model name
2152
if type(provider_config) == "string" and #provider_config > 0 then
@@ -27,7 +58,7 @@ function M.get_model()
2758
error("no models defined for " .. provider_name)
2859
end
2960

30-
selected = config.model_override or provider_config.default or config.opts.models.default
61+
selected = Config.model_override or provider_config.default or Config.opts.models.default
3162
assert(type(selected) == "string")
3263

3364
result = provider_config[selected]
@@ -52,8 +83,8 @@ end
5283
function M.get_model_by_alias(alias)
5384
assert(alias and #alias > 0)
5485

55-
local provider_name = vim.fn.tolower(config.opts.connection.api_provider)
56-
local provider_config = config.opts.models[provider_name]
86+
local provider_name = vim.fn.tolower(Config.opts.connection.api_provider)
87+
local provider_config = Config.opts.models[provider_name]
5788
if provider_config == nil then
5889
error("no models defined for " .. provider_name)
5990
end
@@ -68,7 +99,7 @@ end
6899

69100
---@return table[]
70101
function M.get_remote_models()
71-
local models = providers.get_provider().get_models()
102+
local models = Providers.get_provider().get_models()
72103
if models ~= nil then
73104
models = vim.tbl_map(function(remote)
74105
remote.model_source = "remote"
@@ -83,7 +114,7 @@ end
83114
---@param provider string
84115
---@return table[] models list of locally defined models
85116
function M.get_local_models(provider)
86-
local provider_config = config.opts.models[provider]
117+
local provider_config = Config.opts.models[provider]
87118

88119
-- models defined by name only are skipped since they must be a remote one
89120
if type(provider_config) == "string" then
@@ -107,12 +138,12 @@ function M.get_local_models(provider)
107138
end
108139

109140
--- List available models
110-
function M.list_models()
141+
function M.select_model()
111142
local remote_models = M.get_remote_models()
112143
local models = vim.tbl_extend("force", {}, remote_models)
113144

114145
-- get local defined models
115-
local used_provider = vim.fn.tolower(config.opts.connection.api_provider)
146+
local used_provider = vim.fn.tolower(Config.opts.connection.api_provider)
116147
local local_models = M.get_local_models(used_provider)
117148
models = vim.tbl_extend("force", models, local_models)
118149

@@ -137,28 +168,26 @@ function M.list_models()
137168
}, function(choice)
138169
if choice ~= nil then
139170
if choice.name ~= nil and #choice.name > 0 then
140-
print("selected <" .. choice.name .. "> (" .. choice.model_source .. " defined)")
171+
Config.model_override = choice.name
172+
print("model override = <" .. choice.name .. "> (" .. choice.model_source .. " defined)")
141173
end
142174
end
143175
end)
144176
end
145177

146-
function M.select_model()
147-
local models = providers.get_provider().get_models()
148-
if models == nil then
149-
error("querying models")
178+
---@param cmd_opts table
179+
---@return string name
180+
---@return table? model
181+
function M.get_model_for_cmdopts(cmd_opts)
182+
local model_name, model
183+
if cmd_opts.model ~= nil and Config.model_override == nil then
184+
model_name, model = M.get_model_by_name(cmd_opts.model)
185+
else
186+
model_name, model = M.get_model()
150187
end
151-
vim.ui.select(models, {
152-
prompt = "ollama: available models",
153-
format_item = function(item)
154-
return item.name
155-
end,
156-
}, function(choice)
157-
if choice ~= nil and #choice.name > 0 then
158-
config.model_override = choice.name
159-
print("model override = <" .. choice.name .. ">")
160-
end
161-
end)
188+
assert(model_name and #model_name > 0, "undefined model")
189+
190+
return model_name, model
162191
end
163192

164193
return M

lua/codegpt/providers/ollama.lua

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,11 @@ end
4646
---@param is_stream? boolean
4747
function M.make_request(command, cmd_opts, command_args, text_selection, is_stream)
4848
local models = require("codegpt.models")
49-
5049
local messages = generate_messages(command, cmd_opts, command_args, text_selection)
5150

5251
-- max # of tokens to generate
5352
local max_tokens = get_max_tokens(cmd_opts.max_tokens, messages)
54-
55-
local model_name, model = models.get_model()
56-
assert(model_name and #model_name > 0, "undefined model")
53+
local model_name, model = models.get_model_for_cmdopts(cmd_opts)
5754

5855
local model_opts = {}
5956

lua/codegpt/providers/openai.lua

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ function M.make_request(command, cmd_opts, command_args, text_selection, is_stre
5858
max_tokens = get_max_tokens(cmd_opts.max_tokens, messages)
5959
end
6060

61-
local model_name, model = models.get_model()
62-
assert(model_name and #model_name > 0, "undefined model")
61+
local model_name, model = models.get_model_for_cmdopts(cmd_opts)
6362

6463
local request = {
6564
temperature = cmd_opts.temperature,

lua/codegpt/ui.lua

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ function M.popup_stream(stream, filetype, bufnr, start_row, start_col, end_row,
151151
if stream == nil and #buffer > 0 then
152152
table.insert(lines, buffer)
153153
buffer = ""
154-
print("trailing buffer !")
155154
streaming = false
156155
elseif stream == nil then
157156
streaming = false

tests/command_spec.lua

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,28 @@ describe("cmd opts", function()
8787
assert(opts.max_tokens == 4242)
8888
end)
8989
end)
90+
91+
it("should prioritize command model ", function()
92+
codegpt.setup({
93+
commands = {
94+
foocmd = {
95+
model = "foomodel",
96+
},
97+
},
98+
models = {
99+
openai = {
100+
default = "barmodel",
101+
barmodel = {
102+
alias = "llamabar",
103+
},
104+
foomodel = {
105+
alias = "llamafoo",
106+
},
107+
},
108+
},
109+
})
110+
local cmds = require("codegpt.commands")
111+
local opts = cmds.get_cmd_opts("foocmd")
112+
assert(opts.model == "foomodel")
113+
end)
90114
end)

0 commit comments

Comments
 (0)