Skip to content

Commit 8020aae

Browse files
committed
fix chat_history and add tests
1 parent 0d65515 commit 8020aae

File tree

4 files changed

+96
-1
lines changed

4 files changed

+96
-1
lines changed

lua/codegpt/config.lua

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
---@class codegpt.Config
22
local M = {}
33

4+
---@class codegpt.Chatmsg
5+
---@field role "system"|"user"|"assistant"
6+
---@field content string
7+
48
---@class codegpt.CommandOpts
59
---@field user_message_template? string
610
---@field language_instructions? table<string, string> language instruction in the form lang = instruction
@@ -10,6 +14,7 @@ local M = {}
1014
---@field max_tokens? number Custom max_tokens for this command
1115
---@field append_string? string String to append to prompt -- ex: /no_think
1216
---@field model? string Model to always use with this command
17+
---@field chat_history? codegpt.Chatmsg[]
1318
---@field [string] any -- merged command parameters
1419

1520
---@type { [string]: codegpt.CommandOpts }

lua/codegpt/providers/messages.lua

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ local M = {}
44
---@param command string
55
---@param cmd_opts codegpt.CommandOpts
66
---@param command_args string
7+
---@param text_selection string
78
function M.generate_messages(command, cmd_opts, command_args, text_selection)
89
local system_message =
910
Render.render(command, cmd_opts.system_message_template, command_args, text_selection, cmd_opts)
@@ -13,10 +14,17 @@ function M.generate_messages(command, cmd_opts, command_args, text_selection)
1314
end
1415

1516
local messages = {}
17+
1618
if system_message ~= nil and system_message ~= "" then
1719
table.insert(messages, { role = "system", content = system_message })
1820
end
1921

22+
if cmd_opts.chat_history then
23+
for _, msg in ipairs(cmd_opts.chat_history) do
24+
table.insert(messages, msg)
25+
end
26+
end
27+
2028
if user_message ~= nil and user_message ~= "" then
2129
table.insert(messages, { role = "user", content = user_message })
2230
end

tests/command_spec.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ describe("command parsing: ", function()
3535
end)
3636
end)
3737

38-
describe("cmd opts", function()
38+
describe("command options", function()
3939
before_each(function()
4040
codegpt.setup()
4141
end)

tests/message_render_spec.lua

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
local codegpt = require("codegpt")
2+
local config = require("codegpt.config")
3+
local Messages = require("codegpt.providers.messages")
4+
local Commands = require("codegpt.commands")
5+
6+
local function should_fail(fun)
7+
local stat = pcall(fun)
8+
assert(not stat, "Function should have errored")
9+
end
10+
11+
describe("message templates", function()
12+
before_each(function()
13+
codegpt.setup()
14+
end)
15+
16+
it("should render default system msg", function()
17+
vim.o.filetype = "lua"
18+
codegpt.setup({})
19+
local cmd_opts = Commands.get_cmd_opts("completion")
20+
21+
local messages = Messages.generate_messages("completion", cmd_opts, "", "")
22+
23+
assert(#messages == 2)
24+
local default_systpl = config.opts.global_defaults.system_message_template
25+
assert(messages[1].role == "system")
26+
assert(messages[1].content == default_systpl:gsub("{{language}}", "lua"))
27+
end)
28+
29+
it("should render user msg", function()
30+
local testcmd = {
31+
user_message_template = "Foo user message",
32+
}
33+
codegpt.setup({
34+
commands = {
35+
testcmd = testcmd,
36+
},
37+
})
38+
local cmd_opts = Commands.get_cmd_opts("testcmd")
39+
40+
vim.o.filetype = "lua"
41+
42+
local messages = Messages.generate_messages("testcmd", cmd_opts, "", "")
43+
44+
assert(#messages == 2)
45+
local default_systpl = config.opts.global_defaults.system_message_template
46+
assert(messages[1].role == "system")
47+
assert(messages[1].content == default_systpl:gsub("{{language}}", "lua"))
48+
assert(messages[2].content == testcmd.user_message_template)
49+
end)
50+
51+
it("should handle message history", function()
52+
local testcmd = {
53+
user_message_template = "Foo user message",
54+
---@type codegpt.Chatmsg[]
55+
chat_history = {
56+
{ role = "user", content = "Hist user msg" },
57+
{ role = "assistant", content = "Hist assistant response" },
58+
},
59+
}
60+
codegpt.setup({
61+
commands = {
62+
testcmd = testcmd,
63+
},
64+
global_defaults = {
65+
system_message_template = "Default system message",
66+
},
67+
})
68+
local cmd_opts = Commands.get_cmd_opts("testcmd")
69+
70+
local messages = Messages.generate_messages("testcmd", cmd_opts, "", "")
71+
72+
assert(#messages == 4)
73+
assert(messages[1].role == "system")
74+
assert(messages[1].content == "Default system message")
75+
assert(messages[2].role == "user")
76+
assert(messages[2].content == "Hist user msg")
77+
assert(messages[3].role == "assistant")
78+
assert(messages[3].content == "Hist assistant response")
79+
assert(messages[4].role == "user")
80+
assert(messages[4].content == "Foo user message")
81+
end)
82+
end)

0 commit comments

Comments
 (0)