forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
112 lines (99 loc) · 3.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import collections
import random
import threading
import time
from typing import Iterable, Literal
import interface
import lorem
import requests
import websocket
from loguru import logger
from oasst_shared.schemas import inference
class TokenBuffer:
def __init__(self, stop_sequences: list[str]) -> None:
self.stop_sequences = stop_sequences
self.longest_stop_len = max((len(stop) for stop in stop_sequences), default=0)
self.tokens = collections.deque()
self.token_lens = collections.deque()
self.total_len = 0
def add(self, token: interface.Token):
self.tokens.append(token)
self.token_lens.append(len(token))
self.total_len += len(token)
while True:
if not self.tokens:
break
head_len = self.token_lens[0]
if self.total_len - head_len >= self.longest_stop_len:
token = self.tokens.popleft()
self.token_lens.popleft()
self.total_len -= head_len
yield token
else:
break
def finish(self, reason: Literal["length", "eos_token", "stop_sequence"]) -> Iterable[interface.Token]:
if reason == "stop_sequence":
end_sequence = ""
end_tokens = []
while self.tokens:
token = self.tokens.pop()
end_tokens.append(token)
end_sequence = token.text + end_sequence
if end_sequence in self.stop_sequences:
break
else:
self.tokens.extend(reversed(end_tokens))
yield from self.tokens
elif reason == "eos_token":
self.tokens.pop()
yield from self.tokens
else:
yield from self.tokens
def wait_for_inference_server(inference_server_url: str, timeout: int = 600):
health_url = f"{inference_server_url}/health"
time_limit = time.time() + timeout
while True:
try:
response = requests.get(health_url)
response.raise_for_status()
except (requests.HTTPError, requests.ConnectionError):
if time.time() > time_limit:
raise
sleep_duration = random.uniform(0, 10)
logger.warning(f"Inference server not ready. Retrying in {sleep_duration:.2f} seconds")
time.sleep(sleep_duration)
else:
logger.info("Inference server is ready")
break
def lorem_events(seed):
sentence = lorem.sentence()
tokens = sentence.split()
for token in tokens[:-1]:
yield interface.GenerateStreamResponse(
token=interface.Token(
text=token + " ",
logprob=0.1,
id=0,
),
)
yield interface.GenerateStreamResponse(
token=interface.Token(
text=tokens[-1],
logprob=0.1,
id=0,
),
generated_text=sentence,
details=interface.StreamDetails(
finish_reason="length",
generated_tokens=len(tokens),
seed=seed,
),
)
ws_lock = threading.Lock()
def send_response(
ws: websocket.WebSocket,
repsonse: inference.WorkerResponse | inference.WorkerConfig,
):
msg = repsonse.json()
with ws_lock:
ws.send(msg)