Skip to content

Commit 0d65515

Browse files
committed
refact: reduce code duplication on providers
1 parent 2a4f0ba commit 0d65515

File tree

7 files changed

+36
-68
lines changed

7 files changed

+36
-68
lines changed

lua/codegpt/providers/anthropic.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
-- TODO: review and update to NG
12
local curl = require("plenary.curl")
23
local Render = require("codegpt.template_render")
34
local Utils = require("codegpt.utils")

lua/codegpt/providers/azure.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
-- TODO: review and update to NG
12
local curl = require("plenary.curl")
23
local Render = require("codegpt.template_render")
34
local Utils = require("codegpt.utils")

lua/codegpt/providers/groq.lua

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,12 @@
11
local curl = require("plenary.curl")
2-
local Render = require("codegpt.template_render")
32
local Utils = require("codegpt.utils")
43
local Api = require("codegpt.api")
54
local Config = require("codegpt.config")
65
local errors = require("codegpt.errors")
6+
local Messages = require("codegpt.providers.messages")
77

88
local M = {}
99

10-
local function generate_messages(command, cmd_opts, command_args, text_selection)
11-
local system_message =
12-
Render.render(command, cmd_opts.system_message_template, command_args, text_selection, cmd_opts)
13-
local user_message = Render.render(command, cmd_opts.user_message_template, command_args, text_selection, cmd_opts)
14-
15-
local messages = {}
16-
if system_message ~= nil and system_message ~= "" then
17-
table.insert(messages, { role = "system", content = system_message })
18-
end
19-
20-
if user_message ~= nil and user_message ~= "" then
21-
table.insert(messages, { role = "user", content = user_message })
22-
end
23-
24-
return messages
25-
end
26-
2710
local function get_max_tokens(max_tokens, messages)
2811
local ok, total_length = Utils.get_accurate_tokens(vim.fn.json_encode(messages))
2912

@@ -42,7 +25,7 @@ local function get_max_tokens(max_tokens, messages)
4225
end
4326

4427
function M.make_request(command, cmd_opts, command_args, text_selection)
45-
local messages = generate_messages(command, cmd_opts, command_args, text_selection)
28+
local messages = Messages.generate_messages(command, cmd_opts, command_args, text_selection)
4629

4730
local max_tokens = cmd_opts.max_tokens
4831
if cmd_opts.max_output_tokens ~= nil then

lua/codegpt/providers/messages.lua

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
local Render = require("codegpt.template_render")
2+
local M = {}
3+
4+
---@param command string
5+
---@param cmd_opts codegpt.CommandOpts
6+
---@param command_args string
7+
function M.generate_messages(command, cmd_opts, command_args, text_selection)
8+
local system_message =
9+
Render.render(command, cmd_opts.system_message_template, command_args, text_selection, cmd_opts)
10+
local user_message = Render.render(command, cmd_opts.user_message_template, command_args, text_selection, cmd_opts)
11+
if cmd_opts.append_string then
12+
user_message = user_message .. " " .. cmd_opts.append_string
13+
end
14+
15+
local messages = {}
16+
if system_message ~= nil and system_message ~= "" then
17+
table.insert(messages, { role = "system", content = system_message })
18+
end
19+
20+
if user_message ~= nil and user_message ~= "" then
21+
table.insert(messages, { role = "user", content = user_message })
22+
end
23+
24+
return messages
25+
end
26+
27+
return M

lua/codegpt/providers/ollama.lua

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,13 @@
11
local curl = require("plenary.curl")
2-
local Render = require("codegpt.template_render")
32
local Utils = require("codegpt.utils")
43
local Api = require("codegpt.api")
54
local Config = require("codegpt.config")
65
local tokens = require("codegpt.tokens")
76
local errors = require("codegpt.errors")
7+
local Messages = require("codegpt.providers.messages")
88

99
local M = {}
1010

11-
---@param command string
12-
---@param cmd_opts codegpt.CommandOpts
13-
---@param command_args string[]
14-
local function generate_messages(command, cmd_opts, command_args, text_selection)
15-
local system_message =
16-
Render.render(command, cmd_opts.system_message_template, command_args, text_selection, cmd_opts)
17-
local user_message = Render.render(command, cmd_opts.user_message_template, command_args, text_selection, cmd_opts)
18-
if cmd_opts.append_string then
19-
user_message = user_message .. " " .. cmd_opts.append_string
20-
end
21-
22-
local messages = {}
23-
if system_message ~= nil and system_message ~= "" then
24-
table.insert(messages, { role = "system", content = system_message })
25-
end
26-
27-
if user_message ~= nil and user_message ~= "" then
28-
table.insert(messages, { role = "user", content = user_message })
29-
end
30-
31-
return messages
32-
end
33-
3411
local function get_max_tokens(max_tokens, prompt)
3512
local total_length = tokens.get_tokens(prompt)
3613

@@ -48,7 +25,7 @@ end
4825
---@param is_stream? boolean
4926
function M.make_request(command, cmd_opts, command_args, text_selection, is_stream)
5027
local models = require("codegpt.models")
51-
local messages = generate_messages(command, cmd_opts, command_args, text_selection)
28+
local messages = Messages.generate_messages(command, cmd_opts, command_args, text_selection)
5229

5330
-- max # of tokens to generate
5431
local max_tokens = get_max_tokens(cmd_opts.max_tokens, messages)

lua/codegpt/providers/openai.lua

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,15 @@
11
local curl = require("plenary.curl")
2-
local Render = require("codegpt.template_render")
32
local Utils = require("codegpt.utils")
43
local Api = require("codegpt.api")
54
local Config = require("codegpt.config")
65
local tokens = require("codegpt.tokens")
76
local errors = require("codegpt.errors")
7+
local Messages = require("codegpt.providers.messages")
88

99
-- TODO: handle streaming mode
1010

1111
local M = {}
1212

13-
---@param cmd_opts codegpt.CommandOpts
14-
local function generate_messages(command, cmd_opts, command_args, text_selection)
15-
local system_message =
16-
Render.render(command, cmd_opts.system_message_template, command_args, text_selection, cmd_opts)
17-
local user_message = Render.render(command, cmd_opts.user_message_template, command_args, text_selection, cmd_opts)
18-
if cmd_opts.append_string then
19-
user_message = user_message .. " " .. cmd_opts.append_string
20-
end
21-
22-
local messages = {}
23-
if system_message ~= nil and system_message ~= "" then
24-
table.insert(messages, { role = "system", content = system_message })
25-
end
26-
27-
if user_message ~= nil and user_message ~= "" then
28-
table.insert(messages, { role = "user", content = user_message })
29-
end
30-
31-
return messages
32-
end
33-
3413
local function get_max_tokens(max_tokens, messages)
3514
local total_length = tokens.get_tokens(messages)
3615

@@ -48,7 +27,7 @@ end
4827
---@param is_stream? boolean
4928
function M.make_request(command, cmd_opts, command_args, text_selection, is_stream)
5029
local models = require("codegpt.models")
51-
local messages = generate_messages(command, cmd_opts, command_args, text_selection)
30+
local messages = Messages.generate_messages(command, cmd_opts, command_args, text_selection)
5231

5332
local max_tokens = cmd_opts.max_tokens
5433
if cmd_opts.max_output_tokens ~= nil then

lua/codegpt/template_render.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ end
4040

4141
---@param cmd string
4242
---@param template string
43-
---@param command_args string[]
43+
---@param command_args string
4444
---@param cmd_opts table
4545
function Render.render(cmd, template, command_args, text_selection, cmd_opts)
4646
local language = get_language()

0 commit comments

Comments
 (0)