Skip to content

Commit 1be9d38

Browse files
authoredMay 20, 2022
CLN: Move test_parallel to gil.py (pandas-dev#47068)
1 parent afcd780 commit 1be9d38

File tree

2 files changed

+56
-76
lines changed

2 files changed

+56
-76
lines changed
 

‎asv_bench/benchmarks/gil.py

+56-26
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from functools import wraps
2+
import threading
3+
14
import numpy as np
25

36
from pandas import (
@@ -30,21 +33,57 @@
3033
from pandas._libs import algos
3134
except ImportError:
3235
from pandas import algos
33-
try:
34-
from pandas._testing import test_parallel # noqa: PDF014
3536

36-
have_real_test_parallel = True
37-
except ImportError:
38-
have_real_test_parallel = False
3937

40-
def test_parallel(num_threads=1):
41-
def wrapper(fname):
42-
return fname
38+
from .pandas_vb_common import BaseIO # isort:skip
4339

44-
return wrapper
4540

41+
def test_parallel(num_threads=2, kwargs_list=None):
42+
"""
43+
Decorator to run the same function multiple times in parallel.
4644
47-
from .pandas_vb_common import BaseIO # isort:skip
45+
Parameters
46+
----------
47+
num_threads : int, optional
48+
The number of times the function is run in parallel.
49+
kwargs_list : list of dicts, optional
50+
The list of kwargs to update original
51+
function kwargs on different threads.
52+
53+
Notes
54+
-----
55+
This decorator does not pass the return value of the decorated function.
56+
57+
Original from scikit-image:
58+
59+
https://github.com/scikit-image/scikit-image/pull/1519
60+
61+
"""
62+
assert num_threads > 0
63+
has_kwargs_list = kwargs_list is not None
64+
if has_kwargs_list:
65+
assert len(kwargs_list) == num_threads
66+
67+
def wrapper(func):
68+
@wraps(func)
69+
def inner(*args, **kwargs):
70+
if has_kwargs_list:
71+
update_kwargs = lambda i: dict(kwargs, **kwargs_list[i])
72+
else:
73+
update_kwargs = lambda i: kwargs
74+
threads = []
75+
for i in range(num_threads):
76+
updated_kwargs = update_kwargs(i)
77+
thread = threading.Thread(target=func, args=args, kwargs=updated_kwargs)
78+
threads.append(thread)
79+
for thread in threads:
80+
thread.start()
81+
for thread in threads:
82+
thread.join()
83+
84+
return inner
85+
86+
return wrapper
4887

4988

5089
class ParallelGroupbyMethods:
@@ -53,8 +92,7 @@ class ParallelGroupbyMethods:
5392
param_names = ["threads", "method"]
5493

5594
def setup(self, threads, method):
56-
if not have_real_test_parallel:
57-
raise NotImplementedError
95+
5896
N = 10**6
5997
ngroups = 10**3
6098
df = DataFrame(
@@ -86,8 +124,7 @@ class ParallelGroups:
86124
param_names = ["threads"]
87125

88126
def setup(self, threads):
89-
if not have_real_test_parallel:
90-
raise NotImplementedError
127+
91128
size = 2**22
92129
ngroups = 10**3
93130
data = Series(np.random.randint(0, ngroups, size=size))
@@ -108,8 +145,7 @@ class ParallelTake1D:
108145
param_names = ["dtype"]
109146

110147
def setup(self, dtype):
111-
if not have_real_test_parallel:
112-
raise NotImplementedError
148+
113149
N = 10**6
114150
df = DataFrame({"col": np.arange(N, dtype=dtype)})
115151
indexer = np.arange(100, len(df) - 100)
@@ -131,8 +167,7 @@ class ParallelKth:
131167
repeat = 5
132168

133169
def setup(self):
134-
if not have_real_test_parallel:
135-
raise NotImplementedError
170+
136171
N = 10**7
137172
k = 5 * 10**5
138173
kwargs_list = [{"arr": np.random.randn(N)}, {"arr": np.random.randn(N)}]
@@ -149,8 +184,7 @@ def time_kth_smallest(self):
149184

150185
class ParallelDatetimeFields:
151186
def setup(self):
152-
if not have_real_test_parallel:
153-
raise NotImplementedError
187+
154188
N = 10**6
155189
self.dti = date_range("1900-01-01", periods=N, freq="T")
156190
self.period = self.dti.to_period("D")
@@ -204,8 +238,7 @@ class ParallelRolling:
204238
param_names = ["method"]
205239

206240
def setup(self, method):
207-
if not have_real_test_parallel:
208-
raise NotImplementedError
241+
209242
win = 100
210243
arr = np.random.rand(100000)
211244
if hasattr(DataFrame, "rolling"):
@@ -248,8 +281,7 @@ class ParallelReadCSV(BaseIO):
248281
param_names = ["dtype"]
249282

250283
def setup(self, dtype):
251-
if not have_real_test_parallel:
252-
raise NotImplementedError
284+
253285
rows = 10000
254286
cols = 50
255287
data = {
@@ -284,8 +316,6 @@ class ParallelFactorize:
284316
param_names = ["threads"]
285317

286318
def setup(self, threads):
287-
if not have_real_test_parallel:
288-
raise NotImplementedError
289319

290320
strings = tm.makeStringIndex(100000)
291321

‎pandas/_testing/__init__.py

-50
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import collections
44
from datetime import datetime
55
from decimal import Decimal
6-
from functools import wraps
76
import operator
87
import os
98
import re
@@ -749,55 +748,6 @@ def makeMissingDataframe(density=0.9, random_state=None):
749748
return df
750749

751750

752-
def test_parallel(num_threads=2, kwargs_list=None):
753-
"""
754-
Decorator to run the same function multiple times in parallel.
755-
756-
Parameters
757-
----------
758-
num_threads : int, optional
759-
The number of times the function is run in parallel.
760-
kwargs_list : list of dicts, optional
761-
The list of kwargs to update original
762-
function kwargs on different threads.
763-
764-
Notes
765-
-----
766-
This decorator does not pass the return value of the decorated function.
767-
768-
Original from scikit-image:
769-
770-
https://github.com/scikit-image/scikit-image/pull/1519
771-
772-
"""
773-
assert num_threads > 0
774-
has_kwargs_list = kwargs_list is not None
775-
if has_kwargs_list:
776-
assert len(kwargs_list) == num_threads
777-
import threading
778-
779-
def wrapper(func):
780-
@wraps(func)
781-
def inner(*args, **kwargs):
782-
if has_kwargs_list:
783-
update_kwargs = lambda i: dict(kwargs, **kwargs_list[i])
784-
else:
785-
update_kwargs = lambda i: kwargs
786-
threads = []
787-
for i in range(num_threads):
788-
updated_kwargs = update_kwargs(i)
789-
thread = threading.Thread(target=func, args=args, kwargs=updated_kwargs)
790-
threads.append(thread)
791-
for thread in threads:
792-
thread.start()
793-
for thread in threads:
794-
thread.join()
795-
796-
return inner
797-
798-
return wrapper
799-
800-
801751
class SubclassedSeries(Series):
802752
_metadata = ["testattr", "name"]
803753

0 commit comments

Comments
 (0)