-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathhelper_utils.py
89 lines (81 loc) · 2.99 KB
/
helper_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# helper_utils.py
import json
import uuid
from datetime import datetime
import streamlit as st
from db_utils import conn, get_cursor
def save_session():
"""保存当前会话到数据库"""
if st.session_state.get("valid_key") and "current_session_id" in st.session_state:
try:
session_data = json.dumps(st.session_state.messages)
with get_cursor() as c:
username = c.execute(
"SELECT username FROM api_keys WHERE key = ?",
(st.session_state.used_key,)
).fetchone()[0]
c.execute("""
INSERT INTO history (
username,
session_id,
session_name,
session_data
) VALUES (?, ?, ?, ?)
ON CONFLICT(session_id) DO UPDATE SET
session_data = excluded.session_data,
updated_at = CURRENT_TIMESTAMP
""", (
username,
st.session_state.current_session_id,
f"会话-{datetime.now().strftime('%m-%d %H:%M')}",
session_data
))
# 清理旧记录
c.execute("""
DELETE FROM history
WHERE id NOT IN (
SELECT id FROM history
WHERE username = ?
ORDER BY updated_at DESC
LIMIT 10
)
""", (username,))
except Exception as e:
st.error(f"保存会话失败: {str(e)}")
def load_session(session_id):
"""从数据库加载指定会话"""
try:
with get_cursor() as c:
c.execute("""
SELECT session_data
FROM history
WHERE session_id = ?
""", (session_id,))
if data := c.fetchone():
st.session_state.messages = json.loads(data[0])
st.session_state.current_session_id = session_id
st.rerun()
except Exception as e:
st.error(f"加载会话失败: {str(e)}")
def display_message(message):
"""显示聊天消息"""
role = message["role"]
with st.chat_message(role):
if role == "assistant":
_display_assistant_message(message["content"])
else:
st.markdown(message["content"])
def _display_assistant_message(content):
"""解析并显示助理消息"""
if "<THINKING>" in content:
parts = content.split("</THINKING>")
with st.expander("查看思考过程"):
st.markdown(f"```\n{parts[0][10:]}\n```")
st.markdown(parts[1])
else:
st.markdown(content)
def display_chat_history():
"""显示完整的聊天记录"""
for message in st.session_state.messages:
if message["role"] != "system":
display_message(message)