forked from dask/dask-gateway
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_slurm_backend.py
133 lines (95 loc) · 4.06 KB
/
test_slurm_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import subprocess
import pytest
from traitlets.config import Config
if not os.environ.get("TEST_DASK_GATEWAY_SLURM"):
pytest.skip("Not running Slurm tests", allow_module_level=True)
from dask_gateway.auth import BasicAuth
from dask_gateway_server.backends.jobqueue.slurm import (
SlurmBackend,
slurm_format_memory,
)
from .utils_test import temp_gateway, wait_for_workers, with_retries
pytestmark = pytest.mark.usefixtures("cleanup_jobs")
JOBIDS = set()
def kill_job(job_id):
try:
subprocess.check_output(
["/usr/local/bin/scancel", job_id], stderr=subprocess.STDOUT
)
except subprocess.CalledProcessError as exc:
if b"Job has finished" not in exc.output:
print("Failed to stop %s, output: %s" % (job_id, exc.output.decode()))
def is_job_running(job_id):
stdout = subprocess.check_output(
["/usr/local/bin/squeue", "-h", "-j", job_id, "-o", "%t"]
)
state = stdout.decode().strip()
return state in ("PD", "CF", "R", "CG")
@pytest.fixture(scope="module")
def cleanup_jobs():
yield
if not JOBIDS:
return
for job in JOBIDS:
kill_job(job)
print("-- Stopped %d lost clusters --" % len(JOBIDS))
class SlurmTestingBackend(SlurmBackend):
async def do_start_cluster(self, cluster):
async for state in super().do_start_cluster(cluster):
JOBIDS.add(state["job_id"])
yield state
async def do_stop_cluster(self, cluster):
job_id = cluster.state.get("job_id")
await super().do_stop_cluster(cluster)
JOBIDS.discard(job_id)
@pytest.mark.asyncio
async def test_slurm_backend():
c = Config()
c.SlurmClusterConfig.scheduler_cmd = "/opt/miniconda/bin/dask-scheduler"
c.SlurmClusterConfig.worker_cmd = "/opt/miniconda/bin/dask-worker"
c.SlurmClusterConfig.scheduler_memory = "256M"
c.SlurmClusterConfig.worker_memory = "256M"
c.SlurmClusterConfig.scheduler_cores = 1
c.SlurmClusterConfig.worker_cores = 1
c.DaskGateway.backend_class = SlurmTestingBackend
async with temp_gateway(config=c) as g:
auth = BasicAuth(username="alice")
async with g.gateway_client(auth=auth) as gateway:
async with gateway.new_cluster() as cluster:
db_cluster = g.gateway.backend.db.get_cluster(cluster.name)
res = await g.gateway.backend.do_check_clusters([db_cluster])
assert res == [True]
await cluster.scale(2)
await wait_for_workers(cluster, atleast=1)
await cluster.scale(1)
await wait_for_workers(cluster, exact=1)
db_workers = list(db_cluster.workers.values())
async def test():
res = await g.gateway.backend.do_check_workers(db_workers)
assert sum(res) == 1
await with_retries(test, 30, 0.25)
async with cluster.get_client(set_as_default=False) as client:
res = await client.submit(lambda x: x + 1, 1)
assert res == 2
await cluster.scale(0)
await wait_for_workers(cluster, exact=0)
async def test():
res = await g.gateway.backend.do_check_workers(db_workers)
assert sum(res) == 0
await with_retries(test, 30, 0.25)
# No-op for shutdown of already shutdown worker
async def test():
res = await g.gateway.backend.do_check_clusters([db_cluster])
assert res == [False]
await with_retries(test, 30, 0.25)
def test_slurm_format_memory():
# Minimum is 1 K
assert slurm_format_memory(2) == "1K"
assert slurm_format_memory(2 ** 10) == "1K"
assert slurm_format_memory(2 ** 20) == "1024K"
assert slurm_format_memory(2 ** 20 + 1) == "1025K"
assert slurm_format_memory(2 ** 30) == "1024M"
assert slurm_format_memory(2 ** 30 + 1) == "1025M"
assert slurm_format_memory(2 ** 40) == "1024G"
assert slurm_format_memory(2 ** 40 + 1) == "1025G"