forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext_client_utils.py
111 lines (102 loc) · 4.1 KB
/
text_client_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import requests
import sseclient
from loguru import logger
class DebugClient:
def __init__(self, backend_url, http_client=requests):
self.backend_url = backend_url
self.http_client = http_client
self.auth_headers = None
self.available_models = self.get_available_models()
def login(self, username):
auth_data = self.http_client.get(f"{self.backend_url}/auth/callback/debug", params={"code": username}).json()
assert auth_data["access_token"]["token_type"] == "bearer"
bearer_token = auth_data["access_token"]["access_token"]
logger.debug(f"Logged in as {username} with token {bearer_token}")
self.auth_headers = {"Authorization": f"Bearer {bearer_token}"}
def create_chat(self):
response = self.http_client.post(
f"{self.backend_url}/chats",
json={},
headers=self.auth_headers,
)
response.raise_for_status()
self.chat_id = response.json()["id"]
self.message_id = None
return self.chat_id
def get_available_models(self):
response = self.http_client.get(
f"{self.backend_url}/configs/model_configs",
headers=self.auth_headers,
)
response.raise_for_status()
return [model["name"] for model in response.json()]
def send_message(self, message, model_config_name):
available_models = self.get_available_models()
if model_config_name not in available_models:
raise ValueError(f"Invalid model config name: {model_config_name}")
response = self.http_client.post(
f"{self.backend_url}/chats/{self.chat_id}/prompter_message",
json={
"parent_id": self.message_id,
"content": message,
},
headers=self.auth_headers,
)
response.raise_for_status()
prompter_message_id = response.json()["id"]
response = self.http_client.post(
f"{self.backend_url}/chats/{self.chat_id}/assistant_message",
json={
"parent_id": prompter_message_id,
"model_config_name": model_config_name,
"sampling_parameters": {
"top_p": 0.95,
"top_k": 50,
"repetition_penalty": 1.2,
"temperature": 1.0,
},
},
headers=self.auth_headers,
)
response.raise_for_status()
self.message_id = response.json()["id"]
response = self.http_client.get(
f"{self.backend_url}/chats/{self.chat_id}/messages/{self.message_id}/events",
stream=True,
headers={
"Accept": "text/event-stream",
**self.auth_headers,
},
)
response.raise_for_status()
if response.status_code == 204:
response = self.http_client.get(
f"{self.backend_url}/chats/{self.chat_id}/messages/{self.message_id}",
headers=self.auth_headers,
)
response.raise_for_status()
data = response.json()
yield data["content"]
else:
client = sseclient.SSEClient(response)
events = iter(client.events())
for event in events:
if event.event == "error":
raise RuntimeError(event.data)
if event.event == "ping":
continue
try:
data = json.loads(event.data)
except json.JSONDecodeError:
raise RuntimeError(f"Failed to decode {event.data=}")
event_type = data["event_type"]
if event_type == "token":
yield data["text"]
elif event_type == "message":
# full message content, can be ignored here
break
elif event_type == "error":
raise RuntimeError(data["error"])
elif event_type == "pending":
logger.debug(f"Message pending. {data=}")