Skip to content

Commit 913dd30

Browse files
committed
Refactor
1 parent 23438cf commit 913dd30

File tree

1 file changed

+49
-34
lines changed

1 file changed

+49
-34
lines changed

pandas/core/nanops.py

+49-34
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,34 @@ def _get_fill_value(dtype, fill_value=None, fill_value_typ=None):
199199
else:
200200
return tslibs.iNaT
201201

202+
def _maybe_get_mask(values, skipna, mask):
203+
""" This function will return a mask iff it is necessary. Otherwise, return None
204+
when a mask is not needed.
205+
"""
206+
if (hasattr(values, 'dtype') and is_bool_dtype(values.dtype)):
207+
# Boolean data cannot contain nulls, so signal this via the lack of a mask
208+
return None
209+
210+
mask = _get_mask(values, mask, skipna=skipna)
211+
212+
return mask
213+
214+
215+
def _get_mask(values, mask, skipna=True):
216+
if mask is None and skipna:
217+
mask = isna(values)
218+
return mask
219+
202220

203-
def _get_values(values, skipna, fill_value=None, fill_value_typ=None,
204-
isfinite=False, copy=True, mask=None, compute_mask=True):
221+
def _get_values(values, skipna, mask, fill_value=None, fill_value_typ=None,
222+
copy=True):
205223
""" utility to get the values view, mask, dtype
206224
if necessary copy and mask using the specified fill_value
207225
copy = True will force the copy
208226
"""
209-
if skipna:
210-
compute_mask = True
227+
if mask is None:
228+
# We're relying on _maybe_get_mask to determine if mask should be not None
229+
skipna = False
211230

212231
if is_datetime64tz_dtype(values):
213232
# com.values_from_object returns M8[ns] dtype instead of tz-aware,
@@ -218,12 +237,6 @@ def _get_values(values, skipna, fill_value=None, fill_value_typ=None,
218237
values = com.values_from_object(values)
219238
dtype = values.dtype
220239

221-
if mask is None and compute_mask:
222-
if isfinite:
223-
mask = _isfinite(values)
224-
else:
225-
mask = isna(values)
226-
227240
if is_datetime_or_timedelta_dtype(values) or is_datetime64tz_dtype(values):
228241
# changing timedelta64/datetime64 to int64 needs to happen after
229242
# finding `mask` above
@@ -257,7 +270,7 @@ def _get_values(values, skipna, fill_value=None, fill_value_typ=None,
257270
elif is_float_dtype(dtype):
258271
dtype_max = np.float64
259272

260-
return values, mask, dtype, dtype_max, fill_value
273+
return values, dtype, dtype_max, fill_value
261274

262275

263276
def _isfinite(values):
@@ -364,12 +377,8 @@ def nanany(values, axis=None, skipna=True, mask=None):
364377
>>> nanops.nanany(s)
365378
False
366379
"""
367-
if (hasattr(values, 'dtype') and is_bool_dtype(values.dtype) and
368-
mask is None):
369-
# Assume np.bool cannot store NaNs
370-
skipna = False
371-
values, _, _, _, _ = _get_values(values, skipna, False, copy=skipna,
372-
mask=mask, compute_mask=False)
380+
mask = _maybe_get_mask(values, skipna, mask)
381+
values, _, _, _ = _get_values(values, skipna, mask, fill_value=False, copy=skipna)
373382
return values.any(axis)
374383

375384

@@ -401,12 +410,8 @@ def nanall(values, axis=None, skipna=True, mask=None):
401410
>>> nanops.nanall(s)
402411
False
403412
"""
404-
if (hasattr(values, 'dtype') and is_bool_dtype(values.dtype) and
405-
mask is None):
406-
# Assume np.bool cannot store NaNs
407-
skipna = False
408-
values, _, _, _, _ = _get_values(values, skipna, True, copy=skipna,
409-
mask=mask, compute_mask=False)
413+
mask = _maybe_get_mask(values, skipna, mask)
414+
values, _, _, _ = _get_values(values, skipna, mask, fill_value=True, copy=skipna)
410415
return values.all(axis)
411416

412417

@@ -435,8 +440,8 @@ def nansum(values, axis=None, skipna=True, min_count=0, mask=None):
435440
>>> nanops.nansum(s)
436441
3.0
437442
"""
438-
values, mask, dtype, dtype_max, _ = _get_values(values,
439-
skipna, 0, mask=mask)
443+
mask = _maybe_get_mask(values, skipna, mask)
444+
values, dtype, dtype_max, _ = _get_values(values, skipna, mask, fill_value=0)
440445
dtype_sum = dtype_max
441446
if is_float_dtype(dtype):
442447
dtype_sum = dtype
@@ -475,8 +480,8 @@ def nanmean(values, axis=None, skipna=True, mask=None):
475480
>>> nanops.nanmean(s)
476481
1.5
477482
"""
478-
values, mask, dtype, dtype_max, _ = _get_values(
479-
values, skipna, 0, mask=mask)
483+
mask = _get_mask(values, mask)
484+
values, dtype, dtype_max, _ = _get_values(values, skipna, mask, fill_value=0)
480485
dtype_sum = dtype_max
481486
dtype_count = np.float64
482487
if (is_integer_dtype(dtype) or is_timedelta64_dtype(dtype) or
@@ -532,7 +537,8 @@ def get_median(x):
532537
return np.nan
533538
return np.nanmedian(x[mask])
534539

535-
values, mask, dtype, dtype_max, _ = _get_values(values, skipna, mask=mask)
540+
mask = _get_mask(values, mask)
541+
values, dtype, dtype_max, _ = _get_values(values, skipna, mask)
536542
if not is_float_dtype(values):
537543
values = values.astype('f8')
538544
values[mask] = np.nan
@@ -737,8 +743,9 @@ def _nanminmax(meth, fill_value_typ):
737743
@bottleneck_switch()
738744
def reduction(values, axis=None, skipna=True, mask=None):
739745

740-
values, mask, dtype, dtype_max, fill_value = _get_values(
741-
values, skipna, fill_value_typ=fill_value_typ, mask=mask)
746+
mask = _maybe_get_mask(values, skipna, mask)
747+
values, dtype, dtype_max, fill_value = _get_values(
748+
values, skipna, mask, fill_value_typ=fill_value_typ)
742749

743750
if ((axis is not None and values.shape[axis] == 0) or
744751
values.size == 0):
@@ -785,8 +792,9 @@ def nanargmax(values, axis=None, skipna=True, mask=None):
785792
>>> nanops.nanargmax(s)
786793
4
787794
"""
788-
values, mask, dtype, _, _ = _get_values(
789-
values, skipna, fill_value_typ='-inf', mask=mask)
795+
mask = _get_mask(values, mask)
796+
values, dtype, _, _ = _get_values(
797+
values, skipna, mask, fill_value_typ='-inf')
790798
result = values.argmax(axis)
791799
result = _maybe_arg_null_out(result, axis, mask, skipna)
792800
return result
@@ -815,8 +823,9 @@ def nanargmin(values, axis=None, skipna=True, mask=None):
815823
>>> nanops.nanargmin(s)
816824
0
817825
"""
818-
values, mask, dtype, _, _ = _get_values(
819-
values, skipna, fill_value_typ='+inf', mask=mask)
826+
mask = _get_mask(values, mask)
827+
values, dtype, _, _ = _get_values(
828+
values, skipna, mask, fill_value_typ='+inf')
820829
result = values.argmin(axis)
821830
result = _maybe_arg_null_out(result, axis, mask, skipna)
822831
return result
@@ -1028,6 +1037,9 @@ def nanprod(values, axis=None, skipna=True, min_count=0, mask=None):
10281037

10291038
def _maybe_arg_null_out(result, axis, mask, skipna):
10301039
# helper function for nanargmin/nanargmax
1040+
if mask is None:
1041+
return result
1042+
10311043
if axis is None or not getattr(result, 'ndim', False):
10321044
if skipna:
10331045
if mask.all():
@@ -1060,6 +1072,9 @@ def _get_counts(mask, axis, dtype=float):
10601072

10611073

10621074
def _maybe_null_out(result, axis, mask, min_count=1):
1075+
if mask is None:
1076+
return result
1077+
10631078
if axis is not None and getattr(result, 'ndim', False):
10641079
null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0
10651080
if np.any(null_mask):

0 commit comments

Comments
 (0)