1
- local config = require (" codegpt.config" )
2
- local providers = require (" codegpt.providers" )
1
+ local Config = require (" codegpt.config" )
2
+ local Providers = require (" codegpt.providers" )
3
3
local M = {}
4
4
5
+ function M .get_model_by_name (name )
6
+ --- @type codegpt.Model
7
+
8
+ local provider_name = vim .fn .tolower (Config .opts .connection .api_provider )
9
+ local provider_config = Config .opts .models [provider_name ]
10
+
11
+ if type (provider_config ) == " string" and # provider_config > 0 then
12
+ return provider_config , nil
13
+ end
14
+
15
+ assert (type (provider_config ) == " table" )
16
+
17
+ if provider_config == nil then
18
+ error (" no models defined for " .. provider_name )
19
+ end
20
+
21
+ local selected = name
22
+ local result = provider_config [selected ]
23
+
24
+ if result == nil then
25
+ for model_name , model in pairs (provider_config ) do
26
+ if model .alias == selected then
27
+ result = model
28
+ selected = model_name
29
+ break
30
+ end
31
+ end
32
+ end
33
+
34
+ return selected , result
35
+ end
5
36
--- default model selection order from highest to lowest priority
6
37
--- 1. global model_override (manual selection, always temporary for an nvim session)
7
38
--- 2. provider default model
@@ -12,10 +43,10 @@ function M.get_model()
12
43
--- @type codegpt.Model
13
44
local result
14
45
15
- local provider_name = vim .fn .tolower (config .opts .connection .api_provider )
46
+ local provider_name = vim .fn .tolower (Config .opts .connection .api_provider )
16
47
-- local selected = config.model_override or config.opts.models[provider].default or config.opts.models.default
17
48
local selected
18
- local provider_config = config .opts .models [provider_name ]
49
+ local provider_config = Config .opts .models [provider_name ]
19
50
20
51
-- of provider config is a string, then it must be just a model name
21
52
if type (provider_config ) == " string" and # provider_config > 0 then
@@ -27,7 +58,7 @@ function M.get_model()
27
58
error (" no models defined for " .. provider_name )
28
59
end
29
60
30
- selected = config .model_override or provider_config .default or config .opts .models .default
61
+ selected = Config .model_override or provider_config .default or Config .opts .models .default
31
62
assert (type (selected ) == " string" )
32
63
33
64
result = provider_config [selected ]
52
83
function M .get_model_by_alias (alias )
53
84
assert (alias and # alias > 0 )
54
85
55
- local provider_name = vim .fn .tolower (config .opts .connection .api_provider )
56
- local provider_config = config .opts .models [provider_name ]
86
+ local provider_name = vim .fn .tolower (Config .opts .connection .api_provider )
87
+ local provider_config = Config .opts .models [provider_name ]
57
88
if provider_config == nil then
58
89
error (" no models defined for " .. provider_name )
59
90
end
68
99
69
100
--- @return table[]
70
101
function M .get_remote_models ()
71
- local models = providers .get_provider ().get_models ()
102
+ local models = Providers .get_provider ().get_models ()
72
103
if models ~= nil then
73
104
models = vim .tbl_map (function (remote )
74
105
remote .model_source = " remote"
83
114
--- @param provider string
84
115
--- @return table[] models list of locally defined models
85
116
function M .get_local_models (provider )
86
- local provider_config = config .opts .models [provider ]
117
+ local provider_config = Config .opts .models [provider ]
87
118
88
119
-- models defined by name only are skipped since they must be a remote one
89
120
if type (provider_config ) == " string" then
@@ -107,12 +138,12 @@ function M.get_local_models(provider)
107
138
end
108
139
109
140
--- List available models
110
- function M .list_models ()
141
+ function M .select_model ()
111
142
local remote_models = M .get_remote_models ()
112
143
local models = vim .tbl_extend (" force" , {}, remote_models )
113
144
114
145
-- get local defined models
115
- local used_provider = vim .fn .tolower (config .opts .connection .api_provider )
146
+ local used_provider = vim .fn .tolower (Config .opts .connection .api_provider )
116
147
local local_models = M .get_local_models (used_provider )
117
148
models = vim .tbl_extend (" force" , models , local_models )
118
149
@@ -137,28 +168,26 @@ function M.list_models()
137
168
}, function (choice )
138
169
if choice ~= nil then
139
170
if choice .name ~= nil and # choice .name > 0 then
140
- print (" selected <" .. choice .name .. " > (" .. choice .model_source .. " defined)" )
171
+ Config .model_override = choice .name
172
+ print (" model override = <" .. choice .name .. " > (" .. choice .model_source .. " defined)" )
141
173
end
142
174
end
143
175
end )
144
176
end
145
177
146
- function M .select_model ()
147
- local models = providers .get_provider ().get_models ()
148
- if models == nil then
149
- error (" querying models" )
178
+ --- @param cmd_opts table
179
+ --- @return string name
180
+ --- @return table ? model
181
+ function M .get_model_for_cmdopts (cmd_opts )
182
+ local model_name , model
183
+ if cmd_opts .model ~= nil and Config .model_override == nil then
184
+ model_name , model = M .get_model_by_name (cmd_opts .model )
185
+ else
186
+ model_name , model = M .get_model ()
150
187
end
151
- vim .ui .select (models , {
152
- prompt = " ollama: available models" ,
153
- format_item = function (item )
154
- return item .name
155
- end ,
156
- }, function (choice )
157
- if choice ~= nil and # choice .name > 0 then
158
- config .model_override = choice .name
159
- print (" model override = <" .. choice .name .. " >" )
160
- end
161
- end )
188
+ assert (model_name and # model_name > 0 , " undefined model" )
189
+
190
+ return model_name , model
162
191
end
163
192
164
193
return M
0 commit comments