Skip to content

Commit 8b858cd

Browse files
authored
ENH: Add numba engine to rolling/expanding.std/var (pandas-dev#44461)
1 parent 5e2bf77 commit 8b858cd

File tree

11 files changed

+282
-62
lines changed

11 files changed

+282
-62
lines changed

asv_bench/benchmarks/rolling.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class NumbaEngineMethods:
5353
["DataFrame", "Series"],
5454
["int", "float"],
5555
[("rolling", {"window": 10}), ("expanding", {})],
56-
["sum", "max", "min", "median", "mean"],
56+
["sum", "max", "min", "median", "mean", "var", "std"],
5757
[True, False],
5858
[None, 100],
5959
)

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ Other enhancements
214214
- :meth:`Timestamp.isoformat`, now handles the ``timespec`` argument from the base :class:``datetime`` class (:issue:`26131`)
215215
- :meth:`NaT.to_numpy` ``dtype`` argument is now respected, so ``np.timedelta64`` can be returned (:issue:`44460`)
216216
- New option ``display.max_dir_items`` customizes the number of columns added to :meth:`Dataframe.__dir__` and suggested for tab completion (:issue:`37996`)
217+
- :meth:`.Rolling.var`, :meth:`.Expanding.var`, :meth:`.Rolling.std`, :meth:`.Expanding.std` now support `Numba <http://numba.pydata.org/>`_ execution with the ``engine`` keyword (:issue:`44461`)
217218

218219

219220
.. ---------------------------------------------------------------------------

pandas/core/_numba/executor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,11 @@ def column_looper(
5151
start: np.ndarray,
5252
end: np.ndarray,
5353
min_periods: int,
54+
*args,
5455
):
5556
result = np.empty((len(start), values.shape[1]), dtype=np.float64)
5657
for i in numba.prange(values.shape[1]):
57-
result[:, i] = func(values[:, i], start, end, min_periods)
58+
result[:, i] = func(values[:, i], start, end, min_periods, *args)
5859
return result
5960

6061
return column_looper
+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pandas.core._numba.kernels.mean_ import sliding_mean
22
from pandas.core._numba.kernels.sum_ import sliding_sum
3+
from pandas.core._numba.kernels.var_ import sliding_var
34

4-
__all__ = ["sliding_mean", "sliding_sum"]
5+
__all__ = ["sliding_mean", "sliding_sum", "sliding_var"]

pandas/core/_numba/kernels/var_.py

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
Numba 1D var kernels that can be shared by
3+
* Dataframe / Series
4+
* groupby
5+
* rolling / expanding
6+
7+
Mirrors pandas/_libs/window/aggregation.pyx
8+
"""
9+
from __future__ import annotations
10+
11+
import numba
12+
import numpy as np
13+
14+
from pandas.core._numba.kernels.shared import is_monotonic_increasing
15+
16+
17+
@numba.jit(nopython=True, nogil=True, parallel=False)
18+
def add_var(
19+
val: float, nobs: int, mean_x: float, ssqdm_x: float, compensation: float
20+
) -> tuple[int, float, float, float]:
21+
if not np.isnan(val):
22+
nobs += 1
23+
prev_mean = mean_x - compensation
24+
y = val - compensation
25+
t = y - mean_x
26+
compensation = t + mean_x - y
27+
delta = t
28+
if nobs:
29+
mean_x += delta / nobs
30+
else:
31+
mean_x = 0
32+
ssqdm_x += (val - prev_mean) * (val - mean_x)
33+
return nobs, mean_x, ssqdm_x, compensation
34+
35+
36+
@numba.jit(nopython=True, nogil=True, parallel=False)
37+
def remove_var(
38+
val: float, nobs: int, mean_x: float, ssqdm_x: float, compensation: float
39+
) -> tuple[int, float, float, float]:
40+
if not np.isnan(val):
41+
nobs -= 1
42+
if nobs:
43+
prev_mean = mean_x - compensation
44+
y = val - compensation
45+
t = y - mean_x
46+
compensation = t + mean_x - y
47+
delta = t
48+
mean_x -= delta / nobs
49+
ssqdm_x -= (val - prev_mean) * (val - mean_x)
50+
else:
51+
mean_x = 0
52+
ssqdm_x = 0
53+
return nobs, mean_x, ssqdm_x, compensation
54+
55+
56+
@numba.jit(nopython=True, nogil=True, parallel=False)
57+
def sliding_var(
58+
values: np.ndarray,
59+
start: np.ndarray,
60+
end: np.ndarray,
61+
min_periods: int,
62+
ddof: int = 1,
63+
) -> np.ndarray:
64+
N = len(start)
65+
nobs = 0
66+
mean_x = 0.0
67+
ssqdm_x = 0.0
68+
compensation_add = 0.0
69+
compensation_remove = 0.0
70+
71+
min_periods = max(min_periods, 1)
72+
is_monotonic_increasing_bounds = is_monotonic_increasing(
73+
start
74+
) and is_monotonic_increasing(end)
75+
76+
output = np.empty(N, dtype=np.float64)
77+
78+
for i in range(N):
79+
s = start[i]
80+
e = end[i]
81+
if i == 0 or not is_monotonic_increasing_bounds:
82+
for j in range(s, e):
83+
val = values[j]
84+
nobs, mean_x, ssqdm_x, compensation_add = add_var(
85+
val, nobs, mean_x, ssqdm_x, compensation_add
86+
)
87+
else:
88+
for j in range(start[i - 1], s):
89+
val = values[j]
90+
nobs, mean_x, ssqdm_x, compensation_remove = remove_var(
91+
val, nobs, mean_x, ssqdm_x, compensation_remove
92+
)
93+
94+
for j in range(end[i - 1], e):
95+
val = values[j]
96+
nobs, mean_x, ssqdm_x, compensation_add = add_var(
97+
val, nobs, mean_x, ssqdm_x, compensation_add
98+
)
99+
100+
if nobs >= min_periods and nobs > ddof:
101+
if nobs == 1:
102+
result = 0.0
103+
else:
104+
result = ssqdm_x / (nobs - ddof)
105+
else:
106+
result = np.nan
107+
108+
output[i] = result
109+
110+
if not is_monotonic_increasing_bounds:
111+
nobs = 0
112+
mean_x = 0.0
113+
ssqdm_x = 0.0
114+
compensation_remove = 0.0
115+
116+
return output

pandas/core/window/doc.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,17 @@ def create_section_header(header: str) -> str:
9898
"extended documentation and performance considerations for the Numba engine.\n\n"
9999
)
100100

101-
window_agg_numba_parameters = dedent(
102-
"""
101+
102+
def window_agg_numba_parameters(version: str = "1.3") -> str:
103+
return (
104+
dedent(
105+
"""
103106
engine : str, default None
104107
* ``'cython'`` : Runs the operation through C-extensions from cython.
105108
* ``'numba'`` : Runs the operation through JIT compiled code from numba.
106109
* ``None`` : Defaults to ``'cython'`` or globally setting ``compute.use_numba``
107110
108-
.. versionadded:: 1.3.0
111+
.. versionadded:: {version}.0
109112
110113
engine_kwargs : dict, default None
111114
* For ``'cython'`` engine, there are no accepted ``engine_kwargs``
@@ -114,6 +117,9 @@ def create_section_header(header: str) -> str:
114117
``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is
115118
``{{'nopython': True, 'nogil': False, 'parallel': False}}``
116119
117-
.. versionadded:: 1.3.0\n
120+
.. versionadded:: {version}.0\n
118121
"""
119-
).replace("\n", "", 1)
122+
)
123+
.replace("\n", "", 1)
124+
.replace("{version}", version)
125+
)

pandas/core/window/ewm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ def aggregate(self, func, *args, **kwargs):
511511
template_header,
512512
create_section_header("Parameters"),
513513
args_compat,
514-
window_agg_numba_parameters,
514+
window_agg_numba_parameters(),
515515
kwargs_compat,
516516
create_section_header("Returns"),
517517
template_returns,
@@ -565,7 +565,7 @@ def mean(self, *args, engine=None, engine_kwargs=None, **kwargs):
565565
template_header,
566566
create_section_header("Parameters"),
567567
args_compat,
568-
window_agg_numba_parameters,
568+
window_agg_numba_parameters(),
569569
kwargs_compat,
570570
create_section_header("Returns"),
571571
template_returns,

pandas/core/window/expanding.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def apply(
227227
template_header,
228228
create_section_header("Parameters"),
229229
args_compat,
230-
window_agg_numba_parameters,
230+
window_agg_numba_parameters(),
231231
kwargs_compat,
232232
create_section_header("Returns"),
233233
template_returns,
@@ -253,7 +253,7 @@ def sum(
253253
template_header,
254254
create_section_header("Parameters"),
255255
args_compat,
256-
window_agg_numba_parameters,
256+
window_agg_numba_parameters(),
257257
kwargs_compat,
258258
create_section_header("Returns"),
259259
template_returns,
@@ -279,7 +279,7 @@ def max(
279279
template_header,
280280
create_section_header("Parameters"),
281281
args_compat,
282-
window_agg_numba_parameters,
282+
window_agg_numba_parameters(),
283283
kwargs_compat,
284284
create_section_header("Returns"),
285285
template_returns,
@@ -305,7 +305,7 @@ def min(
305305
template_header,
306306
create_section_header("Parameters"),
307307
args_compat,
308-
window_agg_numba_parameters,
308+
window_agg_numba_parameters(),
309309
kwargs_compat,
310310
create_section_header("Returns"),
311311
template_returns,
@@ -330,7 +330,7 @@ def mean(
330330
@doc(
331331
template_header,
332332
create_section_header("Parameters"),
333-
window_agg_numba_parameters,
333+
window_agg_numba_parameters(),
334334
kwargs_compat,
335335
create_section_header("Returns"),
336336
template_returns,
@@ -361,6 +361,7 @@ def median(
361361
"""
362362
).replace("\n", "", 1),
363363
args_compat,
364+
window_agg_numba_parameters("1.4"),
364365
kwargs_compat,
365366
create_section_header("Returns"),
366367
template_returns,
@@ -396,9 +397,18 @@ def median(
396397
aggregation_description="standard deviation",
397398
agg_method="std",
398399
)
399-
def std(self, ddof: int = 1, *args, **kwargs):
400+
def std(
401+
self,
402+
ddof: int = 1,
403+
*args,
404+
engine: str | None = None,
405+
engine_kwargs: dict[str, bool] | None = None,
406+
**kwargs,
407+
):
400408
nv.validate_expanding_func("std", args, kwargs)
401-
return super().std(ddof=ddof, **kwargs)
409+
return super().std(
410+
ddof=ddof, engine=engine, engine_kwargs=engine_kwargs, **kwargs
411+
)
402412

403413
@doc(
404414
template_header,
@@ -411,6 +421,7 @@ def std(self, ddof: int = 1, *args, **kwargs):
411421
"""
412422
).replace("\n", "", 1),
413423
args_compat,
424+
window_agg_numba_parameters("1.4"),
414425
kwargs_compat,
415426
create_section_header("Returns"),
416427
template_returns,
@@ -446,9 +457,18 @@ def std(self, ddof: int = 1, *args, **kwargs):
446457
aggregation_description="variance",
447458
agg_method="var",
448459
)
449-
def var(self, ddof: int = 1, *args, **kwargs):
460+
def var(
461+
self,
462+
ddof: int = 1,
463+
*args,
464+
engine: str | None = None,
465+
engine_kwargs: dict[str, bool] | None = None,
466+
**kwargs,
467+
):
450468
nv.validate_expanding_func("var", args, kwargs)
451-
return super().var(ddof=ddof, **kwargs)
469+
return super().var(
470+
ddof=ddof, engine=engine, engine_kwargs=engine_kwargs, **kwargs
471+
)
452472

453473
@doc(
454474
template_header,

0 commit comments

Comments
 (0)