Skip to content

Commit 94fbe52

Browse files
committed
Refactor
1 parent 7dc9c30 commit 94fbe52

File tree

1 file changed

+54
-34
lines changed

1 file changed

+54
-34
lines changed

pandas/core/nanops.py

+54-34
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,34 @@ def _get_fill_value(dtype, fill_value=None, fill_value_typ=None):
200200
return tslibs.iNaT
201201

202202

203-
def _get_values(values, skipna, fill_value=None, fill_value_typ=None,
204-
isfinite=False, copy=True, mask=None, compute_mask=True):
203+
def _maybe_get_mask(values, skipna, mask):
204+
""" This function will return a mask iff it is necessary. Otherwise, return
205+
None when a mask is not needed.
206+
"""
207+
if (hasattr(values, 'dtype') and is_bool_dtype(values.dtype)):
208+
# Boolean data cannot contain nulls, so signal via mask being None
209+
return None
210+
211+
mask = _get_mask(values, mask, skipna=skipna)
212+
213+
return mask
214+
215+
216+
def _get_mask(values, mask, skipna=True):
217+
if mask is None and skipna:
218+
mask = isna(values)
219+
return mask
220+
221+
222+
def _get_values(values, skipna, mask, fill_value=None, fill_value_typ=None,
223+
copy=True):
205224
""" utility to get the values view, mask, dtype
206225
if necessary copy and mask using the specified fill_value
207226
copy = True will force the copy
208227
"""
209-
if skipna:
210-
compute_mask = True
228+
if mask is None:
229+
# We're relying on _maybe_get_mask to determine if a mask is necessary
230+
skipna = False
211231

212232
if is_datetime64tz_dtype(values):
213233
# com.values_from_object returns M8[ns] dtype instead of tz-aware,
@@ -218,12 +238,6 @@ def _get_values(values, skipna, fill_value=None, fill_value_typ=None,
218238
values = com.values_from_object(values)
219239
dtype = values.dtype
220240

221-
if mask is None and compute_mask:
222-
if isfinite:
223-
mask = _isfinite(values)
224-
else:
225-
mask = isna(values)
226-
227241
if is_datetime_or_timedelta_dtype(values) or is_datetime64tz_dtype(values):
228242
# changing timedelta64/datetime64 to int64 needs to happen after
229243
# finding `mask` above
@@ -257,7 +271,7 @@ def _get_values(values, skipna, fill_value=None, fill_value_typ=None,
257271
elif is_float_dtype(dtype):
258272
dtype_max = np.float64
259273

260-
return values, mask, dtype, dtype_max, fill_value
274+
return values, dtype, dtype_max, fill_value
261275

262276

263277
def _isfinite(values):
@@ -364,12 +378,9 @@ def nanany(values, axis=None, skipna=True, mask=None):
364378
>>> nanops.nanany(s)
365379
False
366380
"""
367-
if (hasattr(values, 'dtype') and is_bool_dtype(values.dtype) and
368-
mask is None):
369-
# Assume np.bool cannot store NaNs
370-
skipna = False
371-
values, _, _, _, _ = _get_values(values, skipna, False, copy=skipna,
372-
mask=mask, compute_mask=False)
381+
mask = _maybe_get_mask(values, skipna, mask)
382+
values, _, _, _ = _get_values(values, skipna, mask, fill_value=False,
383+
copy=skipna)
373384
return values.any(axis)
374385

375386

@@ -401,12 +412,9 @@ def nanall(values, axis=None, skipna=True, mask=None):
401412
>>> nanops.nanall(s)
402413
False
403414
"""
404-
if (hasattr(values, 'dtype') and is_bool_dtype(values.dtype) and
405-
mask is None):
406-
# Assume np.bool cannot store NaNs
407-
skipna = False
408-
values, _, _, _, _ = _get_values(values, skipna, True, copy=skipna,
409-
mask=mask, compute_mask=False)
415+
mask = _maybe_get_mask(values, skipna, mask)
416+
values, _, _, _ = _get_values(values, skipna, mask, fill_value=True,
417+
copy=skipna)
410418
return values.all(axis)
411419

412420

@@ -435,8 +443,9 @@ def nansum(values, axis=None, skipna=True, min_count=0, mask=None):
435443
>>> nanops.nansum(s)
436444
3.0
437445
"""
438-
values, mask, dtype, dtype_max, _ = _get_values(values,
439-
skipna, 0, mask=mask)
446+
mask = _maybe_get_mask(values, skipna, mask)
447+
values, dtype, dtype_max, _ = _get_values(values, skipna, mask,
448+
fill_value=0)
440449
dtype_sum = dtype_max
441450
if is_float_dtype(dtype):
442451
dtype_sum = dtype
@@ -475,8 +484,9 @@ def nanmean(values, axis=None, skipna=True, mask=None):
475484
>>> nanops.nanmean(s)
476485
1.5
477486
"""
478-
values, mask, dtype, dtype_max, _ = _get_values(
479-
values, skipna, 0, mask=mask)
487+
mask = _get_mask(values, mask)
488+
values, dtype, dtype_max, _ = _get_values(values, skipna, mask,
489+
fill_value=0)
480490
dtype_sum = dtype_max
481491
dtype_count = np.float64
482492
if (is_integer_dtype(dtype) or is_timedelta64_dtype(dtype) or
@@ -532,7 +542,8 @@ def get_median(x):
532542
return np.nan
533543
return np.nanmedian(x[mask])
534544

535-
values, mask, dtype, dtype_max, _ = _get_values(values, skipna, mask=mask)
545+
mask = _get_mask(values, mask)
546+
values, dtype, dtype_max, _ = _get_values(values, skipna, mask)
536547
if not is_float_dtype(values):
537548
values = values.astype('f8')
538549
values[mask] = np.nan
@@ -737,8 +748,9 @@ def _nanminmax(meth, fill_value_typ):
737748
@bottleneck_switch()
738749
def reduction(values, axis=None, skipna=True, mask=None):
739750

740-
values, mask, dtype, dtype_max, fill_value = _get_values(
741-
values, skipna, fill_value_typ=fill_value_typ, mask=mask)
751+
mask = _maybe_get_mask(values, skipna, mask)
752+
values, dtype, dtype_max, fill_value = _get_values(
753+
values, skipna, mask, fill_value_typ=fill_value_typ)
742754

743755
if ((axis is not None and values.shape[axis] == 0) or
744756
values.size == 0):
@@ -785,8 +797,9 @@ def nanargmax(values, axis=None, skipna=True, mask=None):
785797
>>> nanops.nanargmax(s)
786798
4
787799
"""
788-
values, mask, dtype, _, _ = _get_values(
789-
values, skipna, fill_value_typ='-inf', mask=mask)
800+
mask = _get_mask(values, mask)
801+
values, dtype, _, _ = _get_values(
802+
values, skipna, mask, fill_value_typ='-inf')
790803
result = values.argmax(axis)
791804
result = _maybe_arg_null_out(result, axis, mask, skipna)
792805
return result
@@ -815,8 +828,9 @@ def nanargmin(values, axis=None, skipna=True, mask=None):
815828
>>> nanops.nanargmin(s)
816829
0
817830
"""
818-
values, mask, dtype, _, _ = _get_values(
819-
values, skipna, fill_value_typ='+inf', mask=mask)
831+
mask = _get_mask(values, mask)
832+
values, dtype, _, _ = _get_values(
833+
values, skipna, mask, fill_value_typ='+inf')
820834
result = values.argmin(axis)
821835
result = _maybe_arg_null_out(result, axis, mask, skipna)
822836
return result
@@ -1028,6 +1042,9 @@ def nanprod(values, axis=None, skipna=True, min_count=0, mask=None):
10281042

10291043
def _maybe_arg_null_out(result, axis, mask, skipna):
10301044
# helper function for nanargmin/nanargmax
1045+
if mask is None:
1046+
return result
1047+
10311048
if axis is None or not getattr(result, 'ndim', False):
10321049
if skipna:
10331050
if mask.all():
@@ -1060,6 +1077,9 @@ def _get_counts(mask, axis, dtype=float):
10601077

10611078

10621079
def _maybe_null_out(result, axis, mask, min_count=1):
1080+
if mask is None:
1081+
return result
1082+
10631083
if axis is not None and getattr(result, 'ndim', False):
10641084
null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0
10651085
if np.any(null_mask):

0 commit comments

Comments
 (0)