forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat_stream.ts
83 lines (78 loc) · 2.37 KB
/
chat_stream.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import { InferenceEvent, InferenceMessage } from "src/types/Chat";
export interface QueueInfo {
queuePosition: number;
queueSize: number;
}
export interface ChatStreamHandlerOptions {
stream: ReadableStream<Uint8Array>;
onError: (err: unknown) => unknown;
onPending: (info: QueueInfo) => unknown;
onToken: (partialMessage: string) => unknown;
}
export async function handleChatEventStream({
stream,
onError,
onPending,
onToken,
}: ChatStreamHandlerOptions): Promise<InferenceMessage | null> {
let tokens = "";
for await (const { event, data } of iteratorSSE(stream)) {
if (event === "error") {
await onError(data);
} else if (event === "ping") {
continue;
}
try {
const chunk: InferenceEvent = JSON.parse(data);
if (chunk.event_type === "pending") {
await onPending({ queuePosition: chunk.queue_position, queueSize: chunk.queue_size });
} else if (chunk.event_type === "token") {
tokens += chunk.text;
await onToken(tokens);
} else if (chunk.event_type === "message") {
// final message
return chunk.message;
} else if (chunk.event_type === "error") {
// handle error
await onError(chunk.error);
return chunk.message;
} else {
console.error("Unexpected event", chunk);
}
} catch (e) {
console.error(`Error parsing data: ${data}, error: ${e}`);
}
}
return null;
}
export async function* iteratorSSE(stream: ReadableStream<Uint8Array>) {
const reader = stream.pipeThrough(new TextDecoderStream()).getReader();
let done = false,
value: string | undefined = "";
let unfinished_line = "";
while (!done) {
({ value, done } = await reader.read());
if (done) {
break;
}
if (!value) {
continue;
}
const full_value = unfinished_line + value;
const lines = full_value.split(/\r?\n/).filter(Boolean);
// do line buffering - otherwise leads to parsing errors
if (full_value[full_value.length - 1] !== "\n") {
unfinished_line = lines.pop();
} else {
unfinished_line = "";
}
const fields = lines.map((line) => {
const colonIdx = line.indexOf(":");
return [line.slice(0, colonIdx), line.slice(colonIdx + 1).trimStart()];
});
// yield multiple messages distinctly
for (const field of fields) {
yield Object.fromEntries([field]);
}
}
}