diff --git a/asv_bench/benchmarks/frame_methods.py b/asv_bench/benchmarks/frame_methods.py index 12e4824b2dd2a..1819cfa2725db 100644 --- a/asv_bench/benchmarks/frame_methods.py +++ b/asv_bench/benchmarks/frame_methods.py @@ -501,7 +501,7 @@ def time_info(self): class NSort(object): goal_time = 0.2 - params = ['first', 'last'] + params = ['first', 'last', 'all'] param_names = ['keep'] def setup(self, keep): diff --git a/asv_bench/benchmarks/series_methods.py b/asv_bench/benchmarks/series_methods.py index 3f6522c3403d9..a5ccf5c32b876 100644 --- a/asv_bench/benchmarks/series_methods.py +++ b/asv_bench/benchmarks/series_methods.py @@ -41,7 +41,7 @@ def time_isin(self, dtypes): class NSort(object): goal_time = 0.2 - params = ['last', 'first'] + params = ['first', 'last', 'all'] param_names = ['keep'] def setup(self, keep): diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index 1105acda067d3..1ab67bd80a5e8 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -24,6 +24,7 @@ Other Enhancements `__. (:issue:`21627`) - New method :meth:`HDFStore.walk` will recursively walk the group hierarchy of an HDF5 file (:issue:`10932`) +- :meth:`Series.nlargest`, :meth:`Series.nsmallest`, :meth:`DataFrame.nlargest`, and :meth:`DataFrame.nsmallest` now accept the value ``"all"`` for the ``keep` argument. This keeps all ties for the nth largest/smallest value (:issue:`16818`) - .. _whatsnew_0240.api_breaking: diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 9e34b8eb55ccb..dc726a736d34f 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -1076,8 +1076,8 @@ def __init__(self, obj, n, keep): self.n = n self.keep = keep - if self.keep not in ('first', 'last'): - raise ValueError('keep must be either "first", "last"') + if self.keep not in ('first', 'last', 'all'): + raise ValueError('keep must be either "first", "last" or "all"') def nlargest(self): return self.compute('nlargest') @@ -1148,7 +1148,11 @@ def compute(self, method): kth_val = algos.kth_smallest(arr.copy(), n - 1) ns, = np.nonzero(arr <= kth_val) - inds = ns[arr[ns].argsort(kind='mergesort')][:n] + inds = ns[arr[ns].argsort(kind='mergesort')] + + if self.keep != 'all': + inds = inds[:n] + if self.keep == 'last': # reverse indices inds = narr - 1 - inds diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 42a68de52a3c4..a420266561c5a 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -4559,11 +4559,15 @@ def nlargest(self, n, columns, keep='first'): Number of rows to return. columns : label or list of labels Column label(s) to order by. - keep : {'first', 'last'}, default 'first' + keep : {'first', 'last', 'all'}, default 'first' Where there are duplicate values: - `first` : prioritize the first occurrence(s) - `last` : prioritize the last occurrence(s) + - ``all`` : do not drop any duplicates, even it means + selecting more than `n` items. + + .. versionadded:: 0.24.0 Returns ------- @@ -4586,47 +4590,58 @@ def nlargest(self, n, columns, keep='first'): Examples -------- - >>> df = pd.DataFrame({'a': [1, 10, 8, 10, -1], - ... 'b': list('abdce'), - ... 'c': [1.0, 2.0, np.nan, 3.0, 4.0]}) + >>> df = pd.DataFrame({'a': [1, 10, 8, 11, 8, 2], + ... 'b': list('abdcef'), + ... 'c': [1.0, 2.0, np.nan, 3.0, 4.0, 9.0]}) >>> df a b c 0 1 a 1.0 1 10 b 2.0 2 8 d NaN - 3 10 c 3.0 - 4 -1 e 4.0 + 3 11 c 3.0 + 4 8 e 4.0 + 5 2 f 9.0 In the following example, we will use ``nlargest`` to select the three rows having the largest values in column "a". >>> df.nlargest(3, 'a') a b c + 3 11 c 3.0 1 10 b 2.0 - 3 10 c 3.0 2 8 d NaN When using ``keep='last'``, ties are resolved in reverse order: >>> df.nlargest(3, 'a', keep='last') a b c - 3 10 c 3.0 + 3 11 c 3.0 + 1 10 b 2.0 + 4 8 e 4.0 + + When using ``keep='all'``, all duplicate items are maintained: + + >>> df.nlargest(3, 'a', keep='all') + a b c + 3 11 c 3.0 1 10 b 2.0 2 8 d NaN + 4 8 e 4.0 To order by the largest values in column "a" and then "c", we can specify multiple columns like in the next example. >>> df.nlargest(3, ['a', 'c']) a b c - 3 10 c 3.0 + 4 8 e 4.0 + 3 11 c 3.0 1 10 b 2.0 - 2 8 d NaN Attempting to use ``nlargest`` on non-numeric dtypes will raise a ``TypeError``: >>> df.nlargest(3, 'b') + Traceback (most recent call last): TypeError: Column 'b' has dtype object, cannot use method 'nlargest' """ @@ -4645,10 +4660,14 @@ def nsmallest(self, n, columns, keep='first'): Number of items to retrieve columns : list or str Column name or names to order by - keep : {'first', 'last'}, default 'first' + keep : {'first', 'last', 'all'}, default 'first' Where there are duplicate values: - ``first`` : take the first occurrence. - ``last`` : take the last occurrence. + - ``all`` : do not drop any duplicates, even it means + selecting more than `n` items. + + .. versionadded:: 0.24.0 Returns ------- @@ -4656,14 +4675,60 @@ def nsmallest(self, n, columns, keep='first'): Examples -------- - >>> df = pd.DataFrame({'a': [1, 10, 8, 11, -1], - ... 'b': list('abdce'), - ... 'c': [1.0, 2.0, np.nan, 3.0, 4.0]}) + >>> df = pd.DataFrame({'a': [1, 10, 8, 11, 8, 2], + ... 'b': list('abdcef'), + ... 'c': [1.0, 2.0, np.nan, 3.0, 4.0, 9.0]}) + >>> df + a b c + 0 1 a 1.0 + 1 10 b 2.0 + 2 8 d NaN + 3 11 c 3.0 + 4 8 e 4.0 + 5 2 f 9.0 + + In the following example, we will use ``nsmallest`` to select the + three rows having the smallest values in column "a". + >>> df.nsmallest(3, 'a') - a b c - 4 -1 e 4 - 0 1 a 1 - 2 8 d NaN + a b c + 0 1 a 1.0 + 5 2 f 9.0 + 2 8 d NaN + + When using ``keep='last'``, ties are resolved in reverse order: + + >>> df.nsmallest(3, 'a', keep='last') + a b c + 0 1 a 1.0 + 5 2 f 9.0 + 4 8 e 4.0 + + When using ``keep='all'``, all duplicate items are maintained: + + >>> df.nsmallest(3, 'a', keep='all') + a b c + 0 1 a 1.0 + 5 2 f 9.0 + 2 8 d NaN + 4 8 e 4.0 + + To order by the largest values in column "a" and then "c", we can + specify multiple columns like in the next example. + + >>> df.nsmallest(3, ['a', 'c']) + a b c + 0 1 a 1.0 + 5 2 f 9.0 + 4 8 e 4.0 + + Attempting to use ``nsmallest`` on non-numeric dtypes will raise a + ``TypeError``: + + >>> df.nsmallest(3, 'b') + + Traceback (most recent call last): + TypeError: Column 'b' has dtype object, cannot use method 'nsmallest' """ return algorithms.SelectNFrame(self, n=n, diff --git a/pandas/tests/frame/test_analytics.py b/pandas/tests/frame/test_analytics.py index 84873659ac931..d357208813dd8 100644 --- a/pandas/tests/frame/test_analytics.py +++ b/pandas/tests/frame/test_analytics.py @@ -2461,6 +2461,22 @@ def test_n_duplicate_index(self, df_duplicates, n, order): expected = df.sort_values(order, ascending=False).head(n) tm.assert_frame_equal(result, expected) + def test_duplicate_keep_all_ties(self): + # see gh-16818 + df = pd.DataFrame({'a': [5, 4, 4, 2, 3, 3, 3, 3], + 'b': [10, 9, 8, 7, 5, 50, 10, 20]}) + result = df.nlargest(4, 'a', keep='all') + expected = pd.DataFrame({'a': {0: 5, 1: 4, 2: 4, 4: 3, + 5: 3, 6: 3, 7: 3}, + 'b': {0: 10, 1: 9, 2: 8, 4: 5, + 5: 50, 6: 10, 7: 20}}) + tm.assert_frame_equal(result, expected) + + result = df.nsmallest(2, 'a', keep='all') + expected = pd.DataFrame({'a': {3: 2, 4: 3, 5: 3, 6: 3, 7: 3}, + 'b': {3: 7, 4: 5, 5: 50, 6: 10, 7: 20}}) + tm.assert_frame_equal(result, expected) + def test_series_broadcasting(self): # smoke test for numpy warnings # GH 16378, GH 16306 diff --git a/pandas/tests/series/test_analytics.py b/pandas/tests/series/test_analytics.py index 36342b5ba4ee1..fcfaff9b11002 100644 --- a/pandas/tests/series/test_analytics.py +++ b/pandas/tests/series/test_analytics.py @@ -2082,6 +2082,17 @@ def test_boundary_datetimelike(self, nselect_method, dtype): vals = [min_val + 1, min_val + 2, max_val - 1, max_val, min_val] assert_check_nselect_boundary(vals, dtype, nselect_method) + def test_duplicate_keep_all_ties(self): + # see gh-16818 + s = Series([10, 9, 8, 7, 7, 7, 7, 6]) + result = s.nlargest(4, keep='all') + expected = Series([10, 9, 8, 7, 7, 7, 7]) + assert_series_equal(result, expected) + + result = s.nsmallest(2, keep='all') + expected = Series([6, 7, 7, 7, 7], index=[7, 3, 4, 5, 6]) + assert_series_equal(result, expected) + class TestCategoricalSeriesAnalytics(object):