forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
139 lines (119 loc) · 4.06 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import collections
import random
import threading
import time
from typing import Iterable, Literal
import interface
import lorem
import pydantic
import requests
import websocket
from loguru import logger
from oasst_shared.schemas import inference
class TokenBuffer:
def __init__(self, stop_sequences: list[str]) -> None:
self.stop_sequences = stop_sequences
self.longest_stop_len = max((len(stop) for stop in stop_sequences), default=1)
self.tokens = collections.deque()
self.token_lens = collections.deque()
self.total_len = 0
def add(self, token: interface.Token):
self.tokens.append(token)
self.token_lens.append(len(token))
self.total_len += len(token)
while True:
if not self.tokens:
break
head_len = self.token_lens[0]
if self.total_len - head_len >= self.longest_stop_len:
token = self.tokens.popleft()
self.token_lens.popleft()
self.total_len -= head_len
yield token
else:
break
def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]) -> Iterable[interface.Token]:
if reason == "stop_sequence":
end_sequence = ""
end_tokens = []
while self.tokens:
token = self.tokens.pop()
end_tokens.append(token)
end_sequence = token.text + end_sequence
if end_sequence in self.stop_sequences:
break
else:
self.tokens.extend(reversed(end_tokens))
yield from self.tokens
elif reason == "eos_token":
if self.tokens:
self.tokens.pop()
yield from self.tokens
else:
yield from self.tokens
def wait_for_inference_server(http: "HttpClient", timeout: int = 600):
time_limit = time.time() + timeout
while True:
try:
response = http.get("/health")
response.raise_for_status()
except (requests.HTTPError, requests.ConnectionError):
if time.time() > time_limit:
raise
sleep_duration = random.uniform(0, 10)
logger.warning(f"Inference server not ready. Retrying in {sleep_duration:.2f} seconds")
time.sleep(sleep_duration)
else:
logger.info("Inference server is ready")
break
def text_to_events(text: str, seed: int | None = None, pause: float = 0.0):
tokens = text.split()
for token in tokens[:-1]:
yield interface.GenerateStreamResponse(
token=interface.Token(
text=token + " ",
logprob=0.1,
id=0,
),
)
if pause > 0:
time.sleep(pause)
yield interface.GenerateStreamResponse(
token=interface.Token(
text=tokens[-1],
logprob=0.1,
id=0,
),
generated_text=text,
details=interface.StreamDetails(
finish_reason="length",
generated_tokens=len(tokens),
seed=seed,
),
)
def lorem_events(seed):
sentence = lorem.sentence()
time.sleep(1)
yield from text_to_events(sentence, seed=seed, pause=0.5)
ws_lock = threading.Lock()
def send_response(
ws: websocket.WebSocket,
repsonse: inference.WorkerResponse | inference.WorkerInfo,
):
msg = repsonse.json()
with ws_lock:
ws.send(msg)
class HttpClient(pydantic.BaseModel):
base_url: str
basic_auth_username: str | None = None
basic_auth_password: str | None = None
@property
def auth(self):
if self.basic_auth_username and self.basic_auth_password:
return (self.basic_auth_username, self.basic_auth_password)
else:
return None
def get(self, path: str, **kwargs):
return requests.get(self.base_url + path, auth=self.auth, **kwargs)
def post(self, path: str, **kwargs):
return requests.post(self.base_url + path, auth=self.auth, **kwargs)