forked from LAION-AI/Open-Assistant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathqueueing.py
92 lines (74 loc) · 3.2 KB
/
queueing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import redis.asyncio as redis
from oasst_inference_server.settings import settings
class QueueFullException(Exception):
pass
class RedisQueue:
def __init__(
self,
redis_client: redis.Redis,
queue_id: str,
expire: int | None = None,
with_counter: bool = False,
counter_pos_expire: int = 1,
max_size: int | None = None,
) -> None:
self.redis_client = redis_client
self.queue_id = queue_id
self.expire = expire
self.with_counter = with_counter
self.counter_pos_expire = counter_pos_expire
self.max_size = max_size or 0
async def enqueue(self, value: str, enforce_max_size: bool = True) -> int | None:
if enforce_max_size and self.max_size > 0:
if await self.get_length() >= self.max_size:
raise QueueFullException()
await self.redis_client.rpush(self.queue_id, value)
if self.expire is not None:
await self.set_expire(self.expire)
if self.with_counter:
ctr = await self.redis_client.incr(f"ctr_enq:{self.queue_id}")
await self.redis_client.set(f"pos:{value}", ctr, ex=self.counter_pos_expire)
else:
ctr = None
return ctr
async def dequeue(self, timeout: int = 1) -> str | None:
val = await self.redis_client.blpop(self.queue_id, timeout=timeout)
if val is not None and self.with_counter:
await self.redis_client.incr(f"ctr_deq:{self.queue_id}")
return val
async def set_expire(self, timeout: int) -> None:
return await self.redis_client.expire(self.queue_id, timeout)
async def get_enq_counter(self) -> int:
if not self.with_counter:
return 0
enq = await self.redis_client.get(f"ctr_enq:{self.queue_id}")
enq = int(enq) if enq is not None else 0
return enq
async def get_deq_counter(self) -> int:
if not self.with_counter:
return 0
deq = await self.redis_client.get(f"ctr_deq:{self.queue_id}")
deq = int(deq) if deq is not None else 0
return deq
async def get_length(self) -> int:
return await self.redis_client.llen(self.queue_id)
async def get_pos_value(redis_client: redis.Redis, message_id: str) -> int:
val = await redis_client.get(f"pos:{message_id}")
if val is None:
return 0
return int(val)
def message_queue(redis_client: redis.Redis, message_id: str) -> RedisQueue:
return RedisQueue(redis_client, f"message:{message_id}", expire=settings.message_queue_expire)
def work_queue(redis_client: redis.Redis, worker_compat_hash: str) -> RedisQueue:
if settings.allowed_worker_compat_hashes != "*":
if worker_compat_hash not in settings.allowed_worker_compat_hashes_list:
raise ValueError(f"Worker compat hash {worker_compat_hash} not allowed")
return RedisQueue(
redis_client,
f"work:{worker_compat_hash}",
with_counter=True,
counter_pos_expire=settings.message_queue_expire,
max_size=settings.work_queue_max_size,
)
def compliance_queue(redis_client: redis.Redis, worker_id: str) -> RedisQueue:
return RedisQueue(redis_client, f"compliance:{worker_id}")