Skip to content

Commit fedab7f

Browse files
committed
Refactor groupby helper from tempita to fused types
1 parent 95f8dca commit fedab7f

File tree

3 files changed

+58
-49
lines changed

3 files changed

+58
-49
lines changed

pandas/_libs/groupby.pyx

+56
Original file line numberDiff line numberDiff line change
@@ -382,5 +382,61 @@ def group_any_all(uint8_t[:] out,
382382
out[lab] = flag_val
383383

384384

385+
ctypedef fused floating:
386+
float32_t
387+
float64_t
388+
389+
390+
@cython.wraparound(False)
391+
@cython.boundscheck(False)
392+
def group_add_floating(floating[:, :] out,
393+
int64_t[:] counts,
394+
floating[:, :] values,
395+
const int64_t[:] labels,
396+
Py_ssize_t min_count=0):
397+
"""
398+
Only aggregates on axis=0
399+
"""
400+
cdef:
401+
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
402+
floating val, count
403+
ndarray[floating, ndim=2] sumx, nobs
404+
405+
if not len(values) == len(labels):
406+
raise AssertionError("len(index) != len(labels)")
407+
408+
nobs = np.zeros_like(out)
409+
sumx = np.zeros_like(out)
410+
411+
N, K = (<object>values).shape
412+
413+
with nogil:
414+
415+
for i in range(N):
416+
lab = labels[i]
417+
if lab < 0:
418+
continue
419+
420+
counts[lab] += 1
421+
for j in range(K):
422+
val = values[i, j]
423+
424+
# not nan
425+
if val == val:
426+
nobs[lab, j] += 1
427+
sumx[lab, j] += val
428+
429+
for i in range(ncounts):
430+
for j in range(K):
431+
if nobs[i, j] < min_count:
432+
out[i, j] = NAN
433+
else:
434+
out[i, j] = sumx[i, j]
435+
436+
437+
group_add_float32 = group_add_floating
438+
group_add_float64 = group_add_floating
439+
440+
385441
# generated from template
386442
include "groupby_helper.pxi"

pandas/_libs/groupby_helper.pxi.in

+1-48
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ cdef extern from "numpy/npy_math.h":
99
_int64_max = np.iinfo(np.int64).max
1010

1111
# ----------------------------------------------------------------------
12-
# group_add, group_prod, group_var, group_mean, group_ohlc
12+
# group_prod, group_var, group_mean, group_ohlc
1313
# ----------------------------------------------------------------------
1414

1515
{{py:
@@ -27,53 +27,6 @@ def get_dispatch(dtypes):
2727
{{for name, c_type in get_dispatch(dtypes)}}
2828

2929

30-
@cython.wraparound(False)
31-
@cython.boundscheck(False)
32-
def group_add_{{name}}({{c_type}}[:, :] out,
33-
int64_t[:] counts,
34-
{{c_type}}[:, :] values,
35-
const int64_t[:] labels,
36-
Py_ssize_t min_count=0):
37-
"""
38-
Only aggregates on axis=0
39-
"""
40-
cdef:
41-
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
42-
{{c_type}} val, count
43-
ndarray[{{c_type}}, ndim=2] sumx, nobs
44-
45-
if not len(values) == len(labels):
46-
raise AssertionError("len(index) != len(labels)")
47-
48-
nobs = np.zeros_like(out)
49-
sumx = np.zeros_like(out)
50-
51-
N, K = (<object>values).shape
52-
53-
with nogil:
54-
55-
for i in range(N):
56-
lab = labels[i]
57-
if lab < 0:
58-
continue
59-
60-
counts[lab] += 1
61-
for j in range(K):
62-
val = values[i, j]
63-
64-
# not nan
65-
if val == val:
66-
nobs[lab, j] += 1
67-
sumx[lab, j] += val
68-
69-
for i in range(ncounts):
70-
for j in range(K):
71-
if nobs[i, j] < min_count:
72-
out[i, j] = NAN
73-
else:
74-
out[i, j] = sumx[i, j]
75-
76-
7730
@cython.wraparound(False)
7831
@cython.boundscheck(False)
7932
def group_prod_{{name}}({{c_type}}[:, :] out,

pandas/core/groupby/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def get_func(fname):
380380
# otherwise find dtype-specific version, falling back to object
381381
for dt in [dtype_str, 'object']:
382382
f = getattr(libgroupby, "{fname}_{dtype_str}".format(
383-
fname=fname, dtype_str=dtype_str), None)
383+
fname=fname, dtype_str=dt), None)
384384
if f is not None:
385385
return f
386386

0 commit comments

Comments
 (0)