Skip to content

Commit e69ecce

Browse files
noamherPingviinituutti
authored andcommitted
Refactor groupby group_add from tempita to fused types (pandas-dev#24954)
1 parent aaa94d0 commit e69ecce

File tree

3 files changed

+53
-49
lines changed

3 files changed

+53
-49
lines changed

pandas/_libs/groupby.pyx

+51
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import cython
44
from cython import Py_ssize_t
5+
from cython cimport floating
56

67
from libc.stdlib cimport malloc, free
78

@@ -382,5 +383,55 @@ def group_any_all(uint8_t[:] out,
382383
out[lab] = flag_val
383384

384385

386+
@cython.wraparound(False)
387+
@cython.boundscheck(False)
388+
def _group_add(floating[:, :] out,
389+
int64_t[:] counts,
390+
floating[:, :] values,
391+
const int64_t[:] labels,
392+
Py_ssize_t min_count=0):
393+
"""
394+
Only aggregates on axis=0
395+
"""
396+
cdef:
397+
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
398+
floating val, count
399+
ndarray[floating, ndim=2] sumx, nobs
400+
401+
if not len(values) == len(labels):
402+
raise AssertionError("len(index) != len(labels)")
403+
404+
nobs = np.zeros_like(out)
405+
sumx = np.zeros_like(out)
406+
407+
N, K = (<object>values).shape
408+
409+
with nogil:
410+
411+
for i in range(N):
412+
lab = labels[i]
413+
if lab < 0:
414+
continue
415+
416+
counts[lab] += 1
417+
for j in range(K):
418+
val = values[i, j]
419+
420+
# not nan
421+
if val == val:
422+
nobs[lab, j] += 1
423+
sumx[lab, j] += val
424+
425+
for i in range(ncounts):
426+
for j in range(K):
427+
if nobs[i, j] < min_count:
428+
out[i, j] = NAN
429+
else:
430+
out[i, j] = sumx[i, j]
431+
432+
433+
group_add_float32 = _group_add['float']
434+
group_add_float64 = _group_add['double']
435+
385436
# generated from template
386437
include "groupby_helper.pxi"

pandas/_libs/groupby_helper.pxi.in

+1-48
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ cdef extern from "numpy/npy_math.h":
99
_int64_max = np.iinfo(np.int64).max
1010

1111
# ----------------------------------------------------------------------
12-
# group_add, group_prod, group_var, group_mean, group_ohlc
12+
# group_prod, group_var, group_mean, group_ohlc
1313
# ----------------------------------------------------------------------
1414

1515
{{py:
@@ -27,53 +27,6 @@ def get_dispatch(dtypes):
2727
{{for name, c_type in get_dispatch(dtypes)}}
2828

2929

30-
@cython.wraparound(False)
31-
@cython.boundscheck(False)
32-
def group_add_{{name}}({{c_type}}[:, :] out,
33-
int64_t[:] counts,
34-
{{c_type}}[:, :] values,
35-
const int64_t[:] labels,
36-
Py_ssize_t min_count=0):
37-
"""
38-
Only aggregates on axis=0
39-
"""
40-
cdef:
41-
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
42-
{{c_type}} val, count
43-
ndarray[{{c_type}}, ndim=2] sumx, nobs
44-
45-
if not len(values) == len(labels):
46-
raise AssertionError("len(index) != len(labels)")
47-
48-
nobs = np.zeros_like(out)
49-
sumx = np.zeros_like(out)
50-
51-
N, K = (<object>values).shape
52-
53-
with nogil:
54-
55-
for i in range(N):
56-
lab = labels[i]
57-
if lab < 0:
58-
continue
59-
60-
counts[lab] += 1
61-
for j in range(K):
62-
val = values[i, j]
63-
64-
# not nan
65-
if val == val:
66-
nobs[lab, j] += 1
67-
sumx[lab, j] += val
68-
69-
for i in range(ncounts):
70-
for j in range(K):
71-
if nobs[i, j] < min_count:
72-
out[i, j] = NAN
73-
else:
74-
out[i, j] = sumx[i, j]
75-
76-
7730
@cython.wraparound(False)
7831
@cython.boundscheck(False)
7932
def group_prod_{{name}}({{c_type}}[:, :] out,

pandas/core/groupby/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def get_func(fname):
380380
# otherwise find dtype-specific version, falling back to object
381381
for dt in [dtype_str, 'object']:
382382
f = getattr(libgroupby, "{fname}_{dtype_str}".format(
383-
fname=fname, dtype_str=dtype_str), None)
383+
fname=fname, dtype_str=dt), None)
384384
if f is not None:
385385
return f
386386

0 commit comments

Comments
 (0)