Skip to content

Commit f7064fc

Browse files
authored
Use WorkQueue everywhere (dask#218)
Previously we used `UniqueQueue` for managing reconcilation queues. When writing the Kubernetes backend, I added a more featureful task queue in ``dask_gateway_server.workqueue.WorkQueue``. This supports concurrent workers without hashing (any worker is free to work on any task, and no task will be worked on twice concurrently), retries with backoffs, and cleaner shutdowns. Here we drop the ``UniqueQueue`` abstraction entirely, and use the `WorkQueue` instead. Also cleans up the test suite a bit to silence some warnings.
1 parent b656eaf commit f7064fc

File tree

5 files changed

+71
-82
lines changed

5 files changed

+71
-82
lines changed

dask-gateway-server/dask_gateway_server/backends/db_base.py

+66-44
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,8 @@
1515
from .. import models
1616
from ..proxy import Proxy
1717
from ..tls import new_keypair
18-
from ..utils import (
19-
FrozenAttrDict,
20-
TaskPool,
21-
Flag,
22-
normalize_address,
23-
UniqueQueue,
24-
CancelGroup,
25-
timestamp,
26-
)
18+
from ..workqueue import WorkQueue, Backoff, WorkQueueClosed
19+
from ..utils import FrozenAttrDict, TaskPool, Flag, normalize_address, timestamp
2720

2821

2922
__all__ = ("DBBackendBase", "Cluster", "Worker")
@@ -778,6 +771,28 @@ def _default_check_timeouts_period(self):
778771
config=True,
779772
)
780773

774+
backoff_base_delay = Float(
775+
0.1,
776+
help="""
777+
Base delay (in seconds) for backoff when retrying after failures.
778+
779+
If an operation fails, it is retried after a backoff computed as:
780+
781+
```
782+
min(backoff_max_delay, backoff_base_delay * 2 ** num_failures)
783+
```
784+
""",
785+
config=True,
786+
)
787+
788+
backoff_max_delay = Float(
789+
300,
790+
help="""
791+
Max delay (in seconds) for backoff policy when retrying after failures.
792+
""",
793+
config=True,
794+
)
795+
781796
api_url = Unicode(
782797
help="""
783798
The address that internal components (e.g. dask clusters)
@@ -800,11 +815,14 @@ async def setup(self, app):
800815
await super().setup(app)
801816

802817
# Setup reconcilation queues
803-
self.cg = CancelGroup()
804-
805-
self.queues = [UniqueQueue() for _ in range(self.parallelism)]
818+
self.queue = WorkQueue(
819+
backoff=Backoff(
820+
base_delay=self.backoff_base_delay, max_delay=self.backoff_max_delay
821+
)
822+
)
806823
self.reconcilers = [
807-
asyncio.ensure_future(self.reconciler_loop(q)) for q in self.queues
824+
asyncio.ensure_future(self.reconciler_loop())
825+
for _ in range(self.parallelism)
808826
]
809827

810828
# Start the proxy
@@ -826,10 +844,10 @@ async def setup(self, app):
826844
# Load all active clusters/workers into reconcilation queues
827845
for cluster in self.db.name_to_cluster.values():
828846
if cluster.status < JobStatus.STOPPED:
829-
await self.enqueue(cluster)
847+
self.queue.put(cluster)
830848
for worker in cluster.workers.values():
831849
if worker.status < JobStatus.STOPPED:
832-
await self.enqueue(worker)
850+
self.queue.put(worker)
833851

834852
# Further backend-specific setup
835853
await self.do_setup()
@@ -853,7 +871,7 @@ async def cleanup(self):
853871
[(c, {"target": JobStatus.FAILED}) for c in active]
854872
)
855873
for c in active:
856-
await self.enqueue(c)
874+
self.queue.put(c)
857875

858876
# Wait until all clusters are shutdown
859877
pending_shutdown = [
@@ -864,9 +882,9 @@ async def cleanup(self):
864882
if pending_shutdown:
865883
await asyncio.wait([c.shutdown for c in pending_shutdown])
866884

867-
if hasattr(self, "cg"):
868-
# Stop reconcilation queues
869-
await self.cg.cancel()
885+
# Stop reconcilation queues
886+
if hasattr(self, "reconcilers"):
887+
self.queue.close()
870888
await asyncio.gather(*self.reconcilers, return_exceptions=True)
871889

872890
await self.do_cleanup()
@@ -895,7 +913,7 @@ async def start_cluster(self, user, cluster_options):
895913
options, config = await self.process_cluster_options(user, cluster_options)
896914
cluster = self.db.create_cluster(user.name, options, config.to_dict())
897915
self.log.info("Created cluster %s for user %s", cluster.name, user.name)
898-
await self.enqueue(cluster)
916+
self.queue.put(cluster)
899917
return cluster.name
900918

901919
async def stop_cluster(self, cluster_name, failed=False):
@@ -906,7 +924,7 @@ async def stop_cluster(self, cluster_name, failed=False):
906924
self.log.info("Stopping cluster %s", cluster.name)
907925
target = JobStatus.FAILED if failed else JobStatus.STOPPED
908926
self.db.update_cluster(cluster, target=target)
909-
await self.enqueue(cluster)
927+
self.queue.put(cluster)
910928

911929
async def on_cluster_heartbeat(self, cluster_name, msg):
912930
cluster = self.db.get_cluster(cluster_name)
@@ -976,11 +994,11 @@ async def on_cluster_heartbeat(self, cluster_name, msg):
976994

977995
if cluster_update:
978996
self.db.update_cluster(cluster, **cluster_update)
979-
await self.enqueue(cluster)
997+
self.queue.put(cluster)
980998

981999
self.db.update_workers(target_updates)
9821000
for w, u in target_updates:
983-
await self.enqueue(w)
1001+
self.queue.put(w)
9841002

9851003
if newly_running:
9861004
# At least one worker successfully started, reset failure count
@@ -1037,10 +1055,10 @@ async def _check_timeouts(self):
10371055
worker_updates.append((w, {"target": JobStatus.FAILED}))
10381056
self.db.update_clusters(cluster_updates)
10391057
for c, _ in cluster_updates:
1040-
await self.enqueue(c)
1058+
self.queue.put(c)
10411059
self.db.update_workers(worker_updates)
10421060
for w, _ in worker_updates:
1043-
await self.enqueue(w)
1061+
self.queue.put(w)
10441062

10451063
async def check_clusters_loop(self):
10461064
while True:
@@ -1061,7 +1079,7 @@ async def check_clusters_loop(self):
10611079
self.db.update_clusters(updates)
10621080
for c, _ in updates:
10631081
self.log.info("Cluster %s failed during startup", c.name)
1064-
await self.enqueue(c)
1082+
self.queue.put(c)
10651083
except asyncio.CancelledError:
10661084
raise
10671085
except Exception as exc:
@@ -1095,7 +1113,7 @@ async def check_workers_loop(self):
10951113
for w, _ in updates:
10961114
self.log.info("Worker %s failed during startup", w.name)
10971115
w.cluster.worker_start_failure_count += 1
1098-
await self.enqueue(w)
1116+
self.queue.put(w)
10991117
except asyncio.CancelledError:
11001118
raise
11011119
except Exception as exc:
@@ -1115,14 +1133,13 @@ async def cleanup_db_loop(self):
11151133
self.log.debug("Removed %d expired clusters from the database", n)
11161134
await asyncio.sleep(self.db_cleanup_period)
11171135

1118-
async def enqueue(self, obj):
1119-
ind = hash(obj) % self.parallelism
1120-
await self.queues[ind].put(obj)
1121-
1122-
async def reconciler_loop(self, queue):
1136+
async def reconciler_loop(self):
11231137
while True:
1124-
async with self.cg.cancellable():
1125-
obj = await queue.get()
1138+
try:
1139+
obj = await self.queue.get()
1140+
except WorkQueueClosed:
1141+
return
1142+
11261143
if isinstance(obj, Cluster):
11271144
method = self.reconcile_cluster
11281145
kind = "cluster"
@@ -1144,6 +1161,11 @@ async def reconciler_loop(self, queue):
11441161
self.log.warning(
11451162
"Error while reconciling %s %s", kind, obj.name, exc_info=True
11461163
)
1164+
self.queue.put_backoff(obj)
1165+
else:
1166+
self.queue.reset_backoff(obj)
1167+
finally:
1168+
self.queue.task_done(obj)
11471169

11481170
async def reconcile_cluster(self, cluster):
11491171
if cluster.status >= JobStatus.STOPPED:
@@ -1177,17 +1199,17 @@ async def reconcile_worker(self, worker):
11771199
if worker.status != JobStatus.CLOSING:
11781200
self.db.update_worker(worker, status=JobStatus.CLOSING)
11791201
if self.is_cluster_ready_to_close(worker.cluster):
1180-
await self.enqueue(worker.cluster)
1202+
self.queue.put(worker.cluster)
11811203
return
11821204

11831205
if worker.target in (JobStatus.STOPPED, JobStatus.FAILED):
11841206
await self._worker_to_stopped(worker)
11851207
if self.is_cluster_ready_to_close(worker.cluster):
1186-
await self.enqueue(worker.cluster)
1208+
self.queue.put(worker.cluster)
11871209
elif (
11881210
worker.cluster.target == JobStatus.RUNNING and not worker.close_expected
11891211
):
1190-
await self.enqueue(worker.cluster)
1212+
self.queue.put(worker.cluster)
11911213
return
11921214

11931215
if worker.status == JobStatus.CREATED and worker.target == JobStatus.RUNNING:
@@ -1225,20 +1247,20 @@ async def _cluster_to_submitted(self, cluster):
12251247
self.db.update_cluster(
12261248
cluster, status=JobStatus.SUBMITTED, target=JobStatus.FAILED
12271249
)
1228-
await self.enqueue(cluster)
1250+
self.queue.put(cluster)
12291251

12301252
async def _cluster_to_closing(self, cluster):
12311253
self.log.debug("Preparing to stop cluster %s", cluster.name)
12321254
target = JobStatus.CLOSING if self.supports_bulk_shutdown else JobStatus.STOPPED
12331255
workers = [w for w in cluster.workers.values() if w.target < target]
12341256
self.db.update_workers([(w, {"target": target}) for w in workers])
12351257
for w in workers:
1236-
await self.enqueue(w)
1258+
self.queue.put(w)
12371259
self.db.update_cluster(cluster, status=JobStatus.CLOSING)
12381260
if not workers:
12391261
# If there are workers, the cluster will be enqueued after the last one closed
1240-
# If there are no workers, re-enqueue now
1241-
await self.enqueue(cluster)
1262+
# If there are no workers, requeue now
1263+
self.queue.put(cluster)
12421264
cluster.ready.set()
12431265

12441266
async def _cluster_to_stopped(self, cluster):
@@ -1291,7 +1313,7 @@ async def _check_cluster_scale(self, cluster):
12911313
cluster.worker_start_failure_count,
12921314
)
12931315
self.db.update_cluster(cluster, target=JobStatus.FAILED)
1294-
await self.enqueue(cluster)
1316+
self.queue.put(cluster)
12951317
return
12961318

12971319
active = cluster.active_workers()
@@ -1301,7 +1323,7 @@ async def _check_cluster_scale(self, cluster):
13011323
self.log.info(
13021324
"Created worker %s for cluster %s", worker.name, cluster.name
13031325
)
1304-
await self.enqueue(worker)
1326+
self.queue.put(worker)
13051327

13061328
async def _worker_to_submitted(self, worker):
13071329
self.log.info("Submitting worker %s...", worker.name)
@@ -1325,7 +1347,7 @@ async def _worker_to_submitted(self, worker):
13251347
worker, status=JobStatus.SUBMITTED, target=JobStatus.FAILED
13261348
)
13271349
worker.cluster.worker_start_failure_count += 1
1328-
await self.enqueue(worker)
1350+
self.queue.put(worker)
13291351

13301352
async def _worker_to_stopped(self, worker):
13311353
self.log.info("Stopping worker %s...", worker.name)

dask-gateway-server/dask_gateway_server/utils.py

-18
Original file line numberDiff line numberDiff line change
@@ -267,24 +267,6 @@ def discard(self, key):
267267
pass
268268

269269

270-
class UniqueQueue(asyncio.Queue):
271-
"""A queue that may only contain each item once."""
272-
273-
def __init__(self, maxsize=0, *, loop=None):
274-
super().__init__(maxsize=maxsize, loop=loop)
275-
self._items = set()
276-
277-
def _put(self, item):
278-
if item not in self._items:
279-
self._items.add(item)
280-
super()._put(item)
281-
282-
def _get(self):
283-
item = super()._get()
284-
self._items.discard(item)
285-
return item
286-
287-
288270
class Flag(object):
289271
"""A simpler version of asyncio.Event"""
290272

tests/test_client.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ async def test_create_cluster_with_GatewayCluster_constructor():
304304

305305
await cluster.scale(1)
306306

307-
with cluster.get_client(set_as_default=False) as client:
307+
async with cluster.get_client(set_as_default=False) as client:
308308
res = await client.submit(lambda x: x + 1, 1)
309309
assert res == 2
310310

tests/test_db_backend.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -615,15 +615,15 @@ async def test_successful_cluster():
615615
# Scale up, connect, and compute
616616
await cluster.scale(2)
617617

618-
with cluster.get_client(set_as_default=False) as client:
618+
async with cluster.get_client(set_as_default=False) as client:
619619
res = await client.submit(lambda x: x + 1, 1)
620620
assert res == 2
621621

622622
# Scale down
623623
await cluster.scale(1)
624624

625625
# Can still compute
626-
with cluster.get_client(set_as_default=False) as client:
626+
async with cluster.get_client(set_as_default=False) as client:
627627
res = await client.submit(lambda x: x + 1, 1)
628628
assert res == 2
629629

@@ -801,7 +801,7 @@ async def test_gateway_resume_clusters_after_shutdown(tmpdir):
801801
async with gateway.connect(
802802
cluster1_name, shutdown_on_close=True
803803
) as cluster:
804-
with cluster.get_client(set_as_default=False) as client:
804+
async with cluster.get_client(set_as_default=False) as client:
805805
res = await client.submit(lambda x: x + 1, 1)
806806
assert res == 2
807807

@@ -833,7 +833,7 @@ async def test_adaptive_scaling():
833833
await cluster.adapt()
834834

835835
# Worker is automatically requested
836-
with cluster.get_client(set_as_default=False) as client:
836+
async with cluster.get_client(set_as_default=False) as client:
837837
res = await client.submit(lambda x: x + 1, 1)
838838
assert res == 2
839839

tests/test_utils.py

-15
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
cancel_task,
1313
TaskPool,
1414
LRUCache,
15-
UniqueQueue,
1615
Flag,
1716
FrozenAttrDict,
1817
CancelGroup,
@@ -120,20 +119,6 @@ def test_lru_cache():
120119
assert cache.get(7) == 8
121120

122121

123-
@pytest.mark.asyncio
124-
async def test_unique_queue():
125-
queue = UniqueQueue()
126-
127-
for data in [1, 3, 1, 2, 1, 2, 1, 3]:
128-
await queue.put(data)
129-
130-
out = []
131-
while not queue.empty():
132-
out.append(await queue.get())
133-
134-
assert out == [1, 3, 2]
135-
136-
137122
@pytest.mark.asyncio
138123
@pytest.mark.parametrize("use_wait", [False, True])
139124
async def test_flag(use_wait):

0 commit comments

Comments
 (0)