Skip to content

Commit a07c686

Browse files
committed
ENH: Allow keep='all' for nlargest/nsmallest
Closes gh-16818. Closes gh-18656.
1 parent a620e72 commit a07c686

File tree

7 files changed

+116
-23
lines changed

7 files changed

+116
-23
lines changed

asv_bench/benchmarks/frame_methods.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def time_info(self):
501501
class NSort(object):
502502

503503
goal_time = 0.2
504-
params = ['first', 'last']
504+
params = ['first', 'last', 'all']
505505
param_names = ['keep']
506506

507507
def setup(self, keep):

asv_bench/benchmarks/series_methods.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def time_isin(self, dtypes):
4141
class NSort(object):
4242

4343
goal_time = 0.2
44-
params = ['last', 'first']
44+
params = ['first', 'last', 'all']
4545
param_names = ['keep']
4646

4747
def setup(self, keep):

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ Other Enhancements
2424
<https://pandas-gbq.readthedocs.io/en/latest/changelog.html#changelog-0-5-0>`__.
2525
(:issue:`21627`)
2626
- New method :meth:`HDFStore.walk` will recursively walk the group hierarchy of an HDF5 file (:issue:`10932`)
27+
- :meth:`Series.nlargest`, :meth:`Series.nsmallest`, :meth:`DataFrame.nlargest`, and :meth:`DataFrame.nsmallest` now accept the value ``"all"`` for the ``keep` argument. This keeps all ties for the nth largest/smallest value (:issue:`16818`)
2728
-
2829

2930
.. _whatsnew_0240.api_breaking:

pandas/core/algorithms.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1076,8 +1076,8 @@ def __init__(self, obj, n, keep):
10761076
self.n = n
10771077
self.keep = keep
10781078

1079-
if self.keep not in ('first', 'last'):
1080-
raise ValueError('keep must be either "first", "last"')
1079+
if self.keep not in ('first', 'last', 'all'):
1080+
raise ValueError('keep must be either "first", "last" or "all"')
10811081

10821082
def nlargest(self):
10831083
return self.compute('nlargest')
@@ -1148,7 +1148,11 @@ def compute(self, method):
11481148

11491149
kth_val = algos.kth_smallest(arr.copy(), n - 1)
11501150
ns, = np.nonzero(arr <= kth_val)
1151-
inds = ns[arr[ns].argsort(kind='mergesort')][:n]
1151+
inds = ns[arr[ns].argsort(kind='mergesort')]
1152+
1153+
if self.keep != 'all':
1154+
inds = inds[:n]
1155+
11521156
if self.keep == 'last':
11531157
# reverse indices
11541158
inds = narr - 1 - inds

pandas/core/frame.py

+79-18
Original file line numberDiff line numberDiff line change
@@ -4559,11 +4559,15 @@ def nlargest(self, n, columns, keep='first'):
45594559
Number of rows to return.
45604560
columns : label or list of labels
45614561
Column label(s) to order by.
4562-
keep : {'first', 'last'}, default 'first'
4562+
keep : {'first', 'last', 'all'}, default 'first'
45634563
Where there are duplicate values:
45644564
45654565
- `first` : prioritize the first occurrence(s)
45664566
- `last` : prioritize the last occurrence(s)
4567+
- ``all`` : do not drop any duplicates, even it means
4568+
selecting more than `n` items.
4569+
4570+
.. versionadded:: 0.24.0
45674571
45684572
Returns
45694573
-------
@@ -4586,42 +4590,51 @@ def nlargest(self, n, columns, keep='first'):
45864590
45874591
Examples
45884592
--------
4589-
>>> df = pd.DataFrame({'a': [1, 10, 8, 10, -1],
4590-
... 'b': list('abdce'),
4591-
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0]})
4593+
>>> df = pd.DataFrame({'a': [1, 10, 8, 11, 8, 2],
4594+
... 'b': list('abdcef'),
4595+
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0, 9.0]})
45924596
>>> df
45934597
a b c
45944598
0 1 a 1.0
45954599
1 10 b 2.0
45964600
2 8 d NaN
4597-
3 10 c 3.0
4598-
4 -1 e 4.0
4601+
3 11 c 3.0
4602+
4 8 e 4.0
4603+
5 2 f 9.0
45994604
46004605
In the following example, we will use ``nlargest`` to select the three
46014606
rows having the largest values in column "a".
46024607
46034608
>>> df.nlargest(3, 'a')
46044609
a b c
4610+
3 11 c 3.0
46054611
1 10 b 2.0
4606-
3 10 c 3.0
46074612
2 8 d NaN
46084613
46094614
When using ``keep='last'``, ties are resolved in reverse order:
46104615
46114616
>>> df.nlargest(3, 'a', keep='last')
46124617
a b c
4613-
3 10 c 3.0
4618+
3 11 c 3.0
4619+
1 10 b 2.0
4620+
4 8 e 4.0
4621+
4622+
When using ``keep='all'``, all duplicate items are maintained
4623+
>>> df.nlargest(3, 'a', keep='all')
4624+
a b c
4625+
3 11 c 3.0
46144626
1 10 b 2.0
46154627
2 8 d NaN
4628+
4 8 e 4.0
46164629
46174630
To order by the largest values in column "a" and then "c", we can
46184631
specify multiple columns like in the next example.
46194632
46204633
>>> df.nlargest(3, ['a', 'c'])
46214634
a b c
4622-
3 10 c 3.0
4635+
4 8 e 4.0
4636+
3 11 c 3.0
46234637
1 10 b 2.0
4624-
2 8 d NaN
46254638
46264639
Attempting to use ``nlargest`` on non-numeric dtypes will raise a
46274640
``TypeError``:
@@ -4645,25 +4658,73 @@ def nsmallest(self, n, columns, keep='first'):
46454658
Number of items to retrieve
46464659
columns : list or str
46474660
Column name or names to order by
4648-
keep : {'first', 'last'}, default 'first'
4661+
keep : {'first', 'last', 'all'}, default 'first'
46494662
Where there are duplicate values:
46504663
- ``first`` : take the first occurrence.
46514664
- ``last`` : take the last occurrence.
4665+
- ``all`` : do not drop any duplicates, even it means
4666+
selecting more than `n` items.
4667+
4668+
.. versionadded:: 0.24.0
46524669
46534670
Returns
46544671
-------
46554672
DataFrame
46564673
46574674
Examples
46584675
--------
4659-
>>> df = pd.DataFrame({'a': [1, 10, 8, 11, -1],
4660-
... 'b': list('abdce'),
4661-
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0]})
4676+
>>> df = pd.DataFrame({'a': [1, 10, 8, 11, 8, 2],
4677+
... 'b': list('abdcef'),
4678+
... 'c': [1.0, 2.0, np.nan, 3.0, 4.0, 9.0]})
4679+
>>> df
4680+
a b c
4681+
0 1 a 1.0
4682+
1 10 b 2.0
4683+
2 8 d NaN
4684+
3 11 c 3.0
4685+
4 8 e 4.0
4686+
5 2 f 9.0
4687+
4688+
In the following example, we will use ``nsmallest`` to select the
4689+
three rows having the smallest values in column "a".
4690+
46624691
>>> df.nsmallest(3, 'a')
4663-
a b c
4664-
4 -1 e 4
4665-
0 1 a 1
4666-
2 8 d NaN
4692+
a b c
4693+
0 1 a 1.0
4694+
5 2 f 9.0
4695+
2 8 d NaN
4696+
4697+
When using ``keep='last'``, ties are resolved in reverse order:
4698+
4699+
>>> df.nsmallest(3, 'a', keep='last')
4700+
a b c
4701+
0 1 a 1.0
4702+
5 2 f 9.0
4703+
4 8 e 4.0
4704+
4705+
When using ``keep='all'``, all duplicate items are maintained
4706+
>>> df.nsmallest(3, 'a', keep='all')
4707+
a b c
4708+
0 1 a 1.0
4709+
5 2 f 9.0
4710+
2 8 d NaN
4711+
4 8 e 4.0
4712+
4713+
To order by the largest values in column "a" and then "c", we can
4714+
specify multiple columns like in the next example.
4715+
4716+
>>> df.nsmallest(3, ['a', 'c'])
4717+
a b c
4718+
0 1 a 1.0
4719+
5 2 f 9.0
4720+
4 8 e 4.0
4721+
4722+
Attempting to use ``nsmallest`` on non-numeric dtypes will raise a
4723+
``TypeError``:
4724+
4725+
>>> df.nsmallest(3, 'b')
4726+
Traceback (most recent call last):
4727+
TypeError: Column 'b' has dtype object, cannot use method 'nsmallest'
46674728
"""
46684729
return algorithms.SelectNFrame(self,
46694730
n=n,

pandas/tests/frame/test_analytics.py

+16
Original file line numberDiff line numberDiff line change
@@ -2461,6 +2461,22 @@ def test_n_duplicate_index(self, df_duplicates, n, order):
24612461
expected = df.sort_values(order, ascending=False).head(n)
24622462
tm.assert_frame_equal(result, expected)
24632463

2464+
def test_keep_all_ties(self):
2465+
# see gh-16818
2466+
df = pd.DataFrame({'a': [5, 4, 4, 2, 3, 3, 3, 3],
2467+
'b': [10, 9, 8, 7, 5, 50, 10, 20]})
2468+
result = df.nlargest(4, 'a', keep='all')
2469+
expected = pd.DataFrame({'a': {0: 5, 1: 4, 2: 4, 4: 3,
2470+
5: 3, 6: 3, 7: 3},
2471+
'b': {0: 10, 1: 9, 2: 8, 4: 5,
2472+
5: 50, 6: 10, 7: 20}})
2473+
tm.assert_frame_equal(result, expected)
2474+
2475+
result = df.nsmallest(2, 'a', keep='all')
2476+
expected = pd.DataFrame({'a': {3: 2, 4: 3, 5: 3, 6: 3, 7: 3},
2477+
'b': {3: 7, 4: 5, 5: 50, 6: 10, 7: 20}})
2478+
tm.assert_frame_equal(result, expected)
2479+
24642480
def test_series_broadcasting(self):
24652481
# smoke test for numpy warnings
24662482
# GH 16378, GH 16306

pandas/tests/series/test_analytics.py

+11
Original file line numberDiff line numberDiff line change
@@ -2082,6 +2082,17 @@ def test_boundary_datetimelike(self, nselect_method, dtype):
20822082
vals = [min_val + 1, min_val + 2, max_val - 1, max_val, min_val]
20832083
assert_check_nselect_boundary(vals, dtype, nselect_method)
20842084

2085+
def test_keep_all_ties(self):
2086+
# see gh-16818
2087+
s = Series([10, 9, 8, 7, 7, 7, 7, 6])
2088+
result = s.nlargest(4, keep='all')
2089+
expected = Series([10, 9, 8, 7, 7, 7, 7])
2090+
assert_series_equal(result, expected)
2091+
2092+
result = s.nsmallest(2, keep='all')
2093+
expected = Series([6, 7, 7, 7, 7], index=[7, 3, 4, 5, 6])
2094+
assert_series_equal(result, expected)
2095+
20852096

20862097
class TestCategoricalSeriesAnalytics(object):
20872098

0 commit comments

Comments
 (0)