Skip to content

Commit 09b73fc

Browse files
committed
Force mask usage
1 parent 293dc6e commit 09b73fc

File tree

2 files changed

+26
-39
lines changed

2 files changed

+26
-39
lines changed

pandas/_libs/groupby.pyx

+4-28
Original file line numberDiff line numberDiff line change
@@ -1256,8 +1256,7 @@ cdef group_cummin_max(groupby_t[:, ::1] out,
12561256
values : np.ndarray[groupby_t, ndim=2]
12571257
Values to take cummin/max of.
12581258
mask : array[uint8_t] or None
1259-
If not None, indices represent missing values,
1260-
otherwise the mask will not be used
1259+
Indices representing missing values,
12611260
labels : np.ndarray[np.intp]
12621261
Labels to group by.
12631262
ngroups : int
@@ -1277,9 +1276,7 @@ cdef group_cummin_max(groupby_t[:, ::1] out,
12771276
groupby_t val, mval
12781277
groupby_t[:, ::1] accum
12791278
intp_t lab
1280-
bint val_is_nan, use_mask
1281-
1282-
use_mask = mask is not None
1279+
bint val_is_nan
12831280

12841281
N, K = (<object>values).shape
12851282
accum = np.empty((ngroups, K), dtype=np.asarray(values).dtype, order='C')
@@ -1293,33 +1290,12 @@ cdef group_cummin_max(groupby_t[:, ::1] out,
12931290
with nogil:
12941291
for i in range(N):
12951292
lab = labels[i]
1296-
12971293
if lab < 0:
12981294
continue
1299-
for j in range(K):
1300-
val_is_nan = False
1301-
1302-
if use_mask:
1303-
if mask[i, j]:
1304-
1305-
# `out` does not need to be set since it
1306-
# will be masked anyway
1307-
val_is_nan = True
1308-
else:
1309-
1310-
# If using the mask, we can avoid grabbing the
1311-
# value unless necessary
1312-
val = values[i, j]
13131295

1314-
# Otherwise, `out` must be set accordingly if the
1315-
# value is missing
1316-
else:
1296+
for j in range(K):
1297+
if not mask[i, j]:
13171298
val = values[i, j]
1318-
if _treat_as_na(val, is_datetimelike):
1319-
val_is_nan = True
1320-
out[i, j] = val
1321-
1322-
if not val_is_nan:
13231299
mval = accum[lab, j]
13241300
if compute_max:
13251301
if val > mval:

pandas/core/groupby/ops.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,14 @@ def _is_builtin_func(self, arg):
575575

576576
@final
577577
def _ea_wrap_cython_operation(
578-
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
578+
self,
579+
kind: str,
580+
values,
581+
how: str,
582+
axis: int,
583+
min_count: int = -1,
584+
mask: np.ndarray | None = None,
585+
**kwargs,
579586
) -> ArrayLike:
580587
"""
581588
If we have an ExtensionArray, unwrap, call _cython_operation, and
@@ -589,7 +596,7 @@ def _ea_wrap_cython_operation(
589596
# operate on the tz-naive equivalents
590597
values = values.view("M8[ns]")
591598
res_values = self._cython_operation(
592-
kind, values, how, axis, min_count, **kwargs
599+
kind, values, how, axis, min_count, mask=mask, **kwargs
593600
)
594601
if how in ["rank"]:
595602
# preserve float64 dtype
@@ -603,7 +610,7 @@ def _ea_wrap_cython_operation(
603610
# IntegerArray or BooleanArray
604611
values = ensure_int_or_float(values)
605612
res_values = self._cython_operation(
606-
kind, values, how, axis, min_count, **kwargs
613+
kind, values, how, axis, min_count, mask=mask, **kwargs
607614
)
608615
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
609616
if isinstance(dtype, ExtensionDtype):
@@ -616,7 +623,7 @@ def _ea_wrap_cython_operation(
616623
# FloatingArray
617624
values = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan)
618625
res_values = self._cython_operation(
619-
kind, values, how, axis, min_count, **kwargs
626+
kind, values, how, axis, min_count, mask=mask, **kwargs
620627
)
621628
result = type(orig_values)._from_sequence(res_values)
622629
return result
@@ -632,7 +639,8 @@ def _masked_ea_wrap_cython_operation(
632639
values: BaseMaskedArray,
633640
how: str,
634641
axis: int,
635-
min_count: int = -1,
642+
min_count: int,
643+
mask: np.ndarray,
636644
**kwargs,
637645
) -> BaseMaskedArray:
638646
"""
@@ -641,9 +649,6 @@ def _masked_ea_wrap_cython_operation(
641649
"""
642650
orig_values = values
643651

644-
# isna just directly returns self._mask, so copy here to prevent
645-
# modifying the original
646-
mask = isna(values).copy()
647652
arr = values._data
648653

649654
if is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype):
@@ -658,7 +663,7 @@ def _masked_ea_wrap_cython_operation(
658663
cls = dtype.construct_array_type()
659664

660665
return cls(
661-
res_values.astype(dtype.type, copy=False), mask.astype(bool, copy=False)
666+
res_values.astype(dtype.type, copy=False), mask.astype(bool, copy=True)
662667
)
663668

664669
@final
@@ -695,14 +700,20 @@ def _cython_operation(
695700
cy_op.disallow_invalid_ops(dtype, is_numeric)
696701

697702
func_uses_mask = cy_op.uses_mask()
703+
704+
# Only compute the mask if we haven't yet
705+
if func_uses_mask and mask is None:
706+
mask = isna(values)
707+
698708
if is_extension_array_dtype(dtype):
699709
if isinstance(values, BaseMaskedArray) and func_uses_mask:
710+
assert mask is not None
700711
return self._masked_ea_wrap_cython_operation(
701-
kind, values, how, axis, min_count, **kwargs
712+
kind, values, how, axis, min_count, mask=mask, **kwargs
702713
)
703714
else:
704715
return self._ea_wrap_cython_operation(
705-
kind, values, how, axis, min_count, **kwargs
716+
kind, values, how, axis, min_count, mask=mask, **kwargs
706717
)
707718

708719
elif values.ndim == 1:

0 commit comments

Comments
 (0)