1
+ from functools import wraps
2
+ import threading
3
+
1
4
import numpy as np
2
5
3
6
from pandas import (
30
33
from pandas ._libs import algos
31
34
except ImportError :
32
35
from pandas import algos
33
- try :
34
- from pandas ._testing import test_parallel # noqa: PDF014
35
36
36
- have_real_test_parallel = True
37
- except ImportError :
38
- have_real_test_parallel = False
39
37
40
- def test_parallel (num_threads = 1 ):
41
- def wrapper (fname ):
42
- return fname
38
+ from .pandas_vb_common import BaseIO # isort:skip
43
39
44
- return wrapper
45
40
41
+ def test_parallel (num_threads = 2 , kwargs_list = None ):
42
+ """
43
+ Decorator to run the same function multiple times in parallel.
46
44
47
- from .pandas_vb_common import BaseIO # isort:skip
45
+ Parameters
46
+ ----------
47
+ num_threads : int, optional
48
+ The number of times the function is run in parallel.
49
+ kwargs_list : list of dicts, optional
50
+ The list of kwargs to update original
51
+ function kwargs on different threads.
52
+
53
+ Notes
54
+ -----
55
+ This decorator does not pass the return value of the decorated function.
56
+
57
+ Original from scikit-image:
58
+
59
+ https://github.com/scikit-image/scikit-image/pull/1519
60
+
61
+ """
62
+ assert num_threads > 0
63
+ has_kwargs_list = kwargs_list is not None
64
+ if has_kwargs_list :
65
+ assert len (kwargs_list ) == num_threads
66
+
67
+ def wrapper (func ):
68
+ @wraps (func )
69
+ def inner (* args , ** kwargs ):
70
+ if has_kwargs_list :
71
+ update_kwargs = lambda i : dict (kwargs , ** kwargs_list [i ])
72
+ else :
73
+ update_kwargs = lambda i : kwargs
74
+ threads = []
75
+ for i in range (num_threads ):
76
+ updated_kwargs = update_kwargs (i )
77
+ thread = threading .Thread (target = func , args = args , kwargs = updated_kwargs )
78
+ threads .append (thread )
79
+ for thread in threads :
80
+ thread .start ()
81
+ for thread in threads :
82
+ thread .join ()
83
+
84
+ return inner
85
+
86
+ return wrapper
48
87
49
88
50
89
class ParallelGroupbyMethods :
@@ -53,8 +92,7 @@ class ParallelGroupbyMethods:
53
92
param_names = ["threads" , "method" ]
54
93
55
94
def setup (self , threads , method ):
56
- if not have_real_test_parallel :
57
- raise NotImplementedError
95
+
58
96
N = 10 ** 6
59
97
ngroups = 10 ** 3
60
98
df = DataFrame (
@@ -86,8 +124,7 @@ class ParallelGroups:
86
124
param_names = ["threads" ]
87
125
88
126
def setup (self , threads ):
89
- if not have_real_test_parallel :
90
- raise NotImplementedError
127
+
91
128
size = 2 ** 22
92
129
ngroups = 10 ** 3
93
130
data = Series (np .random .randint (0 , ngroups , size = size ))
@@ -108,8 +145,7 @@ class ParallelTake1D:
108
145
param_names = ["dtype" ]
109
146
110
147
def setup (self , dtype ):
111
- if not have_real_test_parallel :
112
- raise NotImplementedError
148
+
113
149
N = 10 ** 6
114
150
df = DataFrame ({"col" : np .arange (N , dtype = dtype )})
115
151
indexer = np .arange (100 , len (df ) - 100 )
@@ -131,8 +167,7 @@ class ParallelKth:
131
167
repeat = 5
132
168
133
169
def setup (self ):
134
- if not have_real_test_parallel :
135
- raise NotImplementedError
170
+
136
171
N = 10 ** 7
137
172
k = 5 * 10 ** 5
138
173
kwargs_list = [{"arr" : np .random .randn (N )}, {"arr" : np .random .randn (N )}]
@@ -149,8 +184,7 @@ def time_kth_smallest(self):
149
184
150
185
class ParallelDatetimeFields :
151
186
def setup (self ):
152
- if not have_real_test_parallel :
153
- raise NotImplementedError
187
+
154
188
N = 10 ** 6
155
189
self .dti = date_range ("1900-01-01" , periods = N , freq = "T" )
156
190
self .period = self .dti .to_period ("D" )
@@ -204,8 +238,7 @@ class ParallelRolling:
204
238
param_names = ["method" ]
205
239
206
240
def setup (self , method ):
207
- if not have_real_test_parallel :
208
- raise NotImplementedError
241
+
209
242
win = 100
210
243
arr = np .random .rand (100000 )
211
244
if hasattr (DataFrame , "rolling" ):
@@ -248,8 +281,7 @@ class ParallelReadCSV(BaseIO):
248
281
param_names = ["dtype" ]
249
282
250
283
def setup (self , dtype ):
251
- if not have_real_test_parallel :
252
- raise NotImplementedError
284
+
253
285
rows = 10000
254
286
cols = 50
255
287
data = {
@@ -284,8 +316,6 @@ class ParallelFactorize:
284
316
param_names = ["threads" ]
285
317
286
318
def setup (self , threads ):
287
- if not have_real_test_parallel :
288
- raise NotImplementedError
289
319
290
320
strings = tm .makeStringIndex (100000 )
291
321
0 commit comments