-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
Copy pathtest_parallel.py
52 lines (45 loc) · 1.28 KB
/
test_parallel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.
import multiprocessing
import os
from contextlib import nullcontext
import pytest
from mne.parallel import parallel_func
@pytest.mark.parametrize(
"n_jobs",
[
None,
1,
-1,
"loky 2",
"threading 3",
"multiprocessing 4",
],
)
def test_parallel_func(n_jobs):
"""Test Parallel wrapping."""
joblib = pytest.importorskip("joblib")
if os.getenv("MNE_FORCE_SERIAL", "").lower() in ("true", "1"):
pytest.skip("MNE_FORCE_SERIAL is set")
def fun(x):
return x * 2
if isinstance(n_jobs, str):
backend, n_jobs = n_jobs.split()
n_jobs = want_jobs = int(n_jobs)
try:
func = joblib.parallel_config
except AttributeError:
# joblib < 1.3
func = joblib.parallel_backend
ctx = func(backend, n_jobs=n_jobs)
n_jobs = None
else:
ctx = nullcontext()
if n_jobs is not None and n_jobs < 0:
want_jobs = multiprocessing.cpu_count() + 1 + n_jobs
else:
want_jobs = 1
with ctx:
parallel, p_fun, got_jobs = parallel_func(fun, n_jobs, verbose="debug")
assert got_jobs == want_jobs