Skip to content

Commit cad3056

Browse files
authored
Computing message queue positions (LAION-AI#2235)
Introduces counters for work queues that allows us to track the positions of enqueued work requests without having to iterate through the queues
1 parent df50632 commit cad3056

File tree

6 files changed

+87
-26
lines changed

6 files changed

+87
-26
lines changed

inference/server/oasst_inference_server/chat_repository.py

-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ async def start_work(
4949
message.state = inference.MessageState.in_progress
5050
message.work_begin_at = datetime.datetime.utcnow()
5151
message.worker_id = worker_id
52-
message.worker_compat_hash = worker_config.compat_hash
5352
message.worker_config = worker_config
5453
await self.session.commit()
5554
logger.debug(f"Started work on message {message_id}")

inference/server/oasst_inference_server/queueing.py

+52-12
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,76 @@
33

44

55
class RedisQueue:
6-
def __init__(self, redis_client: redis.Redis, queue_id: str) -> None:
6+
def __init__(
7+
self,
8+
redis_client: redis.Redis,
9+
queue_id: str,
10+
expire: int | None = None,
11+
with_counter: bool = False,
12+
counter_pos_expire: int = 1,
13+
) -> None:
714
self.redis_client = redis_client
815
self.queue_id = queue_id
16+
self.expire = expire
17+
self.with_counter = with_counter
18+
self.counter_pos_expire = counter_pos_expire
919

10-
async def enqueue(self, value: str, expire: int | None = None) -> None:
11-
pushed = await self.redis_client.rpush(self.queue_id, value)
12-
if expire is not None:
13-
await self.set_expire(expire)
14-
return pushed
20+
async def enqueue(self, value: str) -> int | None:
21+
await self.redis_client.rpush(self.queue_id, value)
22+
if self.expire is not None:
23+
await self.set_expire(self.expire)
24+
if self.with_counter:
25+
ctr = await self.redis_client.incr(f"ctr_enq:{self.queue_id}")
26+
await self.redis_client.set(f"pos:{value}", ctr, ex=self.counter_pos_expire)
27+
else:
28+
ctr = None
29+
return ctr
1530

16-
async def dequeue(self, timeout: int = 1) -> str:
17-
return await self.redis_client.blpop(self.queue_id, timeout=timeout)
31+
async def dequeue(self, timeout: int = 1) -> str | None:
32+
val = await self.redis_client.blpop(self.queue_id, timeout=timeout)
33+
if val is not None and self.with_counter:
34+
await self.redis_client.incr(f"ctr_deq:{self.queue_id}")
35+
return val
1836

1937
async def set_expire(self, timeout: int) -> None:
2038
return await self.redis_client.expire(self.queue_id, timeout)
2139

40+
async def get_enq_counter(self) -> int:
41+
if not self.with_counter:
42+
return 0
43+
enq = await self.redis_client.get(f"ctr_enq:{self.queue_id}")
44+
enq = int(enq) if enq is not None else 0
45+
return enq
2246

23-
def chat_queue(redis_client: redis.Redis, chat_id: str) -> RedisQueue:
24-
return RedisQueue(redis_client, f"chat:{chat_id}")
47+
async def get_deq_counter(self) -> int:
48+
if not self.with_counter:
49+
return 0
50+
deq = await self.redis_client.get(f"ctr_deq:{self.queue_id}")
51+
deq = int(deq) if deq is not None else 0
52+
return deq
53+
54+
async def get_length(self) -> int:
55+
return await self.redis_client.llen(self.queue_id)
56+
57+
58+
async def get_pos_value(redis_client: redis.Redis, message_id: str) -> int:
59+
val = await redis_client.get(f"pos:{message_id}")
60+
if val is None:
61+
return 0
62+
return int(val)
2563

2664

2765
def message_queue(redis_client: redis.Redis, message_id: str) -> RedisQueue:
28-
return RedisQueue(redis_client, f"message:{message_id}")
66+
return RedisQueue(redis_client, f"message:{message_id}", expire=settings.message_queue_expire)
2967

3068

3169
def work_queue(redis_client: redis.Redis, worker_compat_hash: str) -> RedisQueue:
3270
if settings.allowed_worker_compat_hashes != "*":
3371
if worker_compat_hash not in settings.allowed_worker_compat_hashes_list:
3472
raise ValueError(f"Worker compat hash {worker_compat_hash} not allowed")
35-
return RedisQueue(redis_client, f"work:{worker_compat_hash}")
73+
return RedisQueue(
74+
redis_client, f"work:{worker_compat_hash}", with_counter=True, counter_pos_expire=settings.message_queue_expire
75+
)
3676

3777

3878
def compliance_queue(redis_client: redis.Redis, worker_id: str) -> RedisQueue:

inference/server/oasst_inference_server/routes/chats.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
import fastapi
24
import pydantic
35
from fastapi import Depends
@@ -94,6 +96,7 @@ async def create_assistant_message(
9496
assistant_message = await ucr.initiate_assistant_message(
9597
parent_id=request.parent_id,
9698
work_parameters=work_parameters,
99+
worker_compat_hash=model_config.compat_hash,
97100
)
98101
queue = queueing.work_queue(deps.redis_client, model_config.compat_hash)
99102
logger.debug(f"Adding {assistant_message.id=} to {queue.queue_id} for {chat_id}")
@@ -133,19 +136,35 @@ async def message_events(
133136
if message.has_finished:
134137
raise fastapi.HTTPException(status_code=204, detail=message.state)
135138

136-
async def event_generator(chat_id: str, message_id: str):
139+
async def event_generator(chat_id: str, message_id: str, worker_compat_hash: str | None):
137140
redis_client = deps.make_redis_client()
138-
queue = queueing.message_queue(redis_client, message_id=message_id)
141+
message_queue = queueing.message_queue(redis_client, message_id=message_id)
142+
work_queue = (
143+
queueing.work_queue(redis_client, worker_compat_hash=worker_compat_hash)
144+
if worker_compat_hash is not None
145+
else None
146+
)
139147
has_started = False
140148
try:
141149
while True:
142-
item = await queue.dequeue(timeout=settings.pending_event_interval)
150+
item = await message_queue.dequeue(timeout=settings.pending_event_interval)
143151
if item is None:
144152
if not has_started:
153+
if work_queue is None:
154+
qpos, qlen = 0, 1
155+
else:
156+
# TODO: make more efficient, e.g. pipeline
157+
[qdeq, qenq, mpos] = await asyncio.gather(
158+
work_queue.get_deq_counter(),
159+
work_queue.get_enq_counter(),
160+
queueing.get_pos_value(redis_client, message_id),
161+
)
162+
qpos = max(mpos - qdeq, 0)
163+
qlen = max(qenq - qdeq, qpos)
145164
yield {
146165
"data": chat_schema.PendingResponseEvent(
147-
queue_position=0,
148-
queue_size=1,
166+
queue_position=qpos,
167+
queue_size=qlen,
149168
).json()
150169
}
151170
continue
@@ -188,7 +207,9 @@ async def event_generator(chat_id: str, message_id: str):
188207
finally:
189208
await redis_client.close()
190209

191-
return EventSourceResponse(event_generator(chat_id=chat_id, message_id=message_id))
210+
return EventSourceResponse(
211+
event_generator(chat_id=chat_id, message_id=message_id, worker_compat_hash=message.worker_compat_hash)
212+
)
192213

193214

194215
@router.post("/{chat_id}/messages/{message_id}/votes")

inference/server/oasst_inference_server/routes/workers.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ async def handle_token_response(
327327
deps.redis_client,
328328
message_id=work_response_container.message_id,
329329
)
330-
await message_queue.enqueue(response.json(), expire=settings.message_queue_expire)
330+
await message_queue.enqueue(response.json())
331331
work_response_container.num_responses += 1
332332

333333

@@ -352,7 +352,7 @@ async def handle_generated_text_response(
352352
deps.redis_client,
353353
message_id=message_id,
354354
)
355-
await message_queue.enqueue(message_packet.json(), expire=settings.message_queue_expire)
355+
await message_queue.enqueue(message_packet.json())
356356
finally:
357357
del work_request_map[response.request_id]
358358

@@ -365,7 +365,7 @@ async def abort_message(message_id: str, error: str):
365365
deps.redis_client,
366366
message_id=message_id,
367367
)
368-
await message_queue.enqueue(response.json(), expire=settings.message_queue_expire)
368+
await message_queue.enqueue(response.json())
369369

370370

371371
async def handle_error_response(
@@ -396,4 +396,4 @@ async def handle_timeout(message: inference.MessageRead):
396396
deps.redis_client,
397397
message_id=message.id,
398398
)
399-
await message_queue.enqueue(response.json(), expire=settings.message_queue_expire)
399+
await message_queue.enqueue(response.json())

inference/server/oasst_inference_server/user_chat_repository.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async def add_prompter_message(self, chat_id: str, parent_id: str | None, conten
110110
return message
111111

112112
async def initiate_assistant_message(
113-
self, parent_id: str, work_parameters: inference.WorkParameters
113+
self, parent_id: str, work_parameters: inference.WorkParameters, worker_compat_hash: str
114114
) -> models.DbMessage:
115115
logger.info(f"Adding stub assistant message to {parent_id=}")
116116

@@ -154,6 +154,7 @@ async def initiate_assistant_message(
154154
parent_id=parent_id,
155155
state=inference.MessageState.pending,
156156
work_parameters=work_parameters,
157+
worker_compat_hash=worker_compat_hash,
157158
)
158159
self.session.add(message)
159160
await self.session.commit()

inference/text-client/__main__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111

1212
@app.command()
13-
def main(backend_url: str = "http://127.0.0.1:8000", model_config_name="distilgpt2"):
13+
def main(backend_url: str = "http://127.0.0.1:8000", model_config_name="distilgpt2", username="test1"):
1414
"""Simple REPL client."""
1515
while True:
1616
try:
1717
# login
1818
client = utils.DebugClient(backend_url)
19-
client.login("test1")
19+
client.login(username)
2020
chat_id = client.create_chat()
2121
typer.echo(f"Chat ID: {chat_id}")
2222
while True:

0 commit comments

Comments
 (0)