Skip to content

Commit fcc91de

Browse files
buntwoKrzysztof Chomski
authored and
Krzysztof Chomski
committed
ENH: tolerance now takes list-like argument for reindex and get_indexer. (pandas-dev#17367)
1 parent c0aacb7 commit fcc91de

File tree

17 files changed

+222
-56
lines changed

17 files changed

+222
-56
lines changed

doc/source/whatsnew/v0.21.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ Other Enhancements
234234
- :meth:`DataFrame.assign` will preserve the original order of ``**kwargs`` for Python 3.6+ users instead of sorting the column names. (:issue:`14207`)
235235
- Improved the import time of pandas by about 2.25x. (:issue:`16764`)
236236
- :func:`read_json` and :func:`to_json` now accept a ``compression`` argument which allows them to transparently handle compressed files. (:issue:`17798`)
237+
- :func:`Series.reindex`, :func:`DataFrame.reindex`, :func:`Index.get_indexer` now support list-like argument for ``tolerance``. (:issue:`17367`)
237238

238239
.. _whatsnew_0210.api_breaking:
239240

pandas/core/generic.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -2470,9 +2470,10 @@ def reindex_like(self, other, method=None, copy=True, limit=None,
24702470
Maximum number of consecutive labels to fill for inexact matches.
24712471
tolerance : optional
24722472
Maximum distance between labels of the other object and this
2473-
object for inexact matches.
2473+
object for inexact matches. Can be list-like.
24742474
24752475
.. versionadded:: 0.17.0
2476+
.. versionadded:: 0.21.0 (list-like tolerance)
24762477
24772478
Notes
24782479
-----
@@ -2860,7 +2861,14 @@ def sort_index(self, axis=0, level=None, ascending=True, inplace=False,
28602861
matches. The values of the index at the matching locations most
28612862
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
28622863
2864+
Tolerance may be a scalar value, which applies the same tolerance
2865+
to all values, or list-like, which applies variable tolerance per
2866+
element. List-like includes list, tuple, array, Series, and must be
2867+
the same size as the index and its dtype must exactly match the
2868+
index's type.
2869+
28632870
.. versionadded:: 0.17.0
2871+
.. versionadded:: 0.21.0 (list-like tolerance)
28642872
28652873
Examples
28662874
--------
@@ -3120,7 +3128,14 @@ def _reindex_multi(self, axes, copy, fill_value):
31203128
matches. The values of the index at the matching locations most
31213129
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
31223130
3131+
Tolerance may be a scalar value, which applies the same tolerance
3132+
to all values, or list-like, which applies variable tolerance per
3133+
element. List-like includes list, tuple, array, Series, and must be
3134+
the same size as the index and its dtype must exactly match the
3135+
index's type.
3136+
31233137
.. versionadded:: 0.17.0
3138+
.. versionadded:: 0.21.0 (list-like tolerance)
31243139
31253140
Examples
31263141
--------

pandas/core/indexes/base.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -2484,7 +2484,14 @@ def _get_unique_index(self, dropna=False):
24842484
the index at the matching location most satisfy the equation
24852485
``abs(index[loc] - key) <= tolerance``.
24862486
2487+
Tolerance may be a scalar
2488+
value, which applies the same tolerance to all values, or
2489+
list-like, which applies variable tolerance per element. List-like
2490+
includes list, tuple, array, Series, and must be the same size as
2491+
the index and its dtype must exactly match the index's type.
2492+
24872493
.. versionadded:: 0.17.0
2494+
.. versionadded:: 0.21.0 (list-like tolerance)
24882495
24892496
Returns
24902497
-------
@@ -2627,7 +2634,14 @@ def _get_level_values(self, level):
26272634
matches. The values of the index at the matching locations most
26282635
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
26292636
2637+
Tolerance may be a scalar value, which applies the same tolerance
2638+
to all values, or list-like, which applies variable tolerance per
2639+
element. List-like includes list, tuple, array, Series, and must be
2640+
the same size as the index and its dtype must exactly match the
2641+
index's type.
2642+
26302643
.. versionadded:: 0.17.0
2644+
.. versionadded:: 0.21.0 (list-like tolerance)
26312645
26322646
Examples
26332647
--------
@@ -2647,7 +2661,7 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None):
26472661
method = missing.clean_reindex_fill_method(method)
26482662
target = _ensure_index(target)
26492663
if tolerance is not None:
2650-
tolerance = self._convert_tolerance(tolerance)
2664+
tolerance = self._convert_tolerance(tolerance, target)
26512665

26522666
# Treat boolean labels passed to a numeric index as not found. Without
26532667
# this fix False and True would be treated as 0 and 1 respectively.
@@ -2683,10 +2697,15 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None):
26832697
'backfill or nearest reindexing')
26842698

26852699
indexer = self._engine.get_indexer(target._values)
2700+
26862701
return _ensure_platform_int(indexer)
26872702

2688-
def _convert_tolerance(self, tolerance):
2703+
def _convert_tolerance(self, tolerance, target):
26892704
# override this method on subclasses
2705+
tolerance = np.asarray(tolerance)
2706+
if target.size != tolerance.size and tolerance.size > 1:
2707+
raise ValueError('list-like tolerance size must match '
2708+
'target index size')
26902709
return tolerance
26912710

26922711
def _get_fill_indexer(self, target, method, limit=None, tolerance=None):

pandas/core/indexes/datetimelike.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from pandas import compat
99
from pandas.compat.numpy import function as nv
10+
from pandas.core.tools.timedeltas import to_timedelta
1011

1112
import numpy as np
1213
from pandas.core.dtypes.common import (
@@ -431,13 +432,12 @@ def asobject(self):
431432
from pandas.core.index import Index
432433
return Index(self._box_values(self.asi8), name=self.name, dtype=object)
433434

434-
def _convert_tolerance(self, tolerance):
435-
try:
436-
return Timedelta(tolerance).to_timedelta64()
437-
except ValueError:
438-
raise ValueError('tolerance argument for %s must be convertible '
439-
'to Timedelta: %r'
440-
% (type(self).__name__, tolerance))
435+
def _convert_tolerance(self, tolerance, target):
436+
tolerance = np.asarray(to_timedelta(tolerance, box=False))
437+
if target.size != tolerance.size and tolerance.size > 1:
438+
raise ValueError('list-like tolerance size must match '
439+
'target index size')
440+
return tolerance
441441

442442
def _maybe_mask_results(self, result, fill_value=None, convert=None):
443443
"""

pandas/core/indexes/datetimes.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1423,7 +1423,7 @@ def get_loc(self, key, method=None, tolerance=None):
14231423
if tolerance is not None:
14241424
# try converting tolerance now, so errors don't get swallowed by
14251425
# the try/except clauses below
1426-
tolerance = self._convert_tolerance(tolerance)
1426+
tolerance = self._convert_tolerance(tolerance, np.asarray(key))
14271427

14281428
if isinstance(key, datetime):
14291429
# needed to localize naive datetimes
@@ -1447,7 +1447,12 @@ def get_loc(self, key, method=None, tolerance=None):
14471447
try:
14481448
stamp = Timestamp(key, tz=self.tz)
14491449
return Index.get_loc(self, stamp, method, tolerance)
1450-
except (KeyError, ValueError):
1450+
except KeyError:
1451+
raise KeyError(key)
1452+
except ValueError as e:
1453+
# list-like tolerance size must match target index size
1454+
if 'list-like' in str(e):
1455+
raise e
14511456
raise KeyError(key)
14521457

14531458
def _maybe_cast_slice_bound(self, label, side, kind):

pandas/core/indexes/numeric.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,21 @@ def _convert_for_op(self, value):
7171

7272
return value
7373

74-
def _convert_tolerance(self, tolerance):
75-
try:
76-
return float(tolerance)
77-
except ValueError:
78-
raise ValueError('tolerance argument for %s must be numeric: %r' %
79-
(type(self).__name__, tolerance))
74+
def _convert_tolerance(self, tolerance, target):
75+
tolerance = np.asarray(tolerance)
76+
if target.size != tolerance.size and tolerance.size > 1:
77+
raise ValueError('list-like tolerance size must match '
78+
'target index size')
79+
if not np.issubdtype(tolerance.dtype, np.number):
80+
if tolerance.ndim > 0:
81+
raise ValueError(('tolerance argument for %s must contain '
82+
'numeric elements if it is list type') %
83+
(type(self).__name__,))
84+
else:
85+
raise ValueError(('tolerance argument for %s must be numeric '
86+
'if it is a scalar: %r') %
87+
(type(self).__name__, tolerance))
88+
return tolerance
8089

8190
@classmethod
8291
def _assert_safe_casting(cls, data, subarr):

pandas/core/indexes/period.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -641,12 +641,17 @@ def to_timestamp(self, freq=None, how='start'):
641641
return DatetimeIndex(new_data, freq='infer', name=self.name)
642642

643643
def _maybe_convert_timedelta(self, other):
644-
if isinstance(other, (timedelta, np.timedelta64, offsets.Tick)):
644+
if isinstance(
645+
other, (timedelta, np.timedelta64, offsets.Tick, np.ndarray)):
645646
offset = frequencies.to_offset(self.freq.rule_code)
646647
if isinstance(offset, offsets.Tick):
647-
nanos = tslib._delta_to_nanoseconds(other)
648+
if isinstance(other, np.ndarray):
649+
nanos = np.vectorize(tslib._delta_to_nanoseconds)(other)
650+
else:
651+
nanos = tslib._delta_to_nanoseconds(other)
648652
offset_nanos = tslib._delta_to_nanoseconds(offset)
649-
if nanos % offset_nanos == 0:
653+
check = np.all(nanos % offset_nanos == 0)
654+
if check:
650655
return nanos // offset_nanos
651656
elif isinstance(other, offsets.DateOffset):
652657
freqstr = other.rule_code
@@ -782,7 +787,7 @@ def get_indexer(self, target, method=None, limit=None, tolerance=None):
782787
target = target.asi8
783788

784789
if tolerance is not None:
785-
tolerance = self._convert_tolerance(tolerance)
790+
tolerance = self._convert_tolerance(tolerance, target)
786791
return Index.get_indexer(self._int64index, target, method,
787792
limit, tolerance)
788793

@@ -825,7 +830,8 @@ def get_loc(self, key, method=None, tolerance=None):
825830
try:
826831
ordinal = tslib.iNaT if key is tslib.NaT else key.ordinal
827832
if tolerance is not None:
828-
tolerance = self._convert_tolerance(tolerance)
833+
tolerance = self._convert_tolerance(tolerance,
834+
np.asarray(key))
829835
return self._int64index.get_loc(ordinal, method, tolerance)
830836

831837
except KeyError:
@@ -908,8 +914,12 @@ def _get_string_slice(self, key):
908914
return slice(self.searchsorted(t1.ordinal, side='left'),
909915
self.searchsorted(t2.ordinal, side='right'))
910916

911-
def _convert_tolerance(self, tolerance):
912-
tolerance = DatetimeIndexOpsMixin._convert_tolerance(self, tolerance)
917+
def _convert_tolerance(self, tolerance, target):
918+
tolerance = DatetimeIndexOpsMixin._convert_tolerance(self, tolerance,
919+
target)
920+
if target.size != tolerance.size and tolerance.size > 1:
921+
raise ValueError('list-like tolerance size must match '
922+
'target index size')
913923
return self._maybe_convert_timedelta(tolerance)
914924

915925
def insert(self, loc, item):

pandas/core/indexes/timedeltas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ def get_loc(self, key, method=None, tolerance=None):
699699
if tolerance is not None:
700700
# try converting tolerance now, so errors don't get swallowed by
701701
# the try/except clauses below
702-
tolerance = self._convert_tolerance(tolerance)
702+
tolerance = self._convert_tolerance(tolerance, np.asarray(key))
703703

704704
if _is_convertible_to_td(key):
705705
key = Timedelta(key)

pandas/core/tools/timedeltas.py

+3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def to_timedelta(arg, unit='ns', box=True, errors='raise'):
8383
elif isinstance(arg, ABCIndexClass):
8484
return _convert_listlike(arg, unit=unit, box=box,
8585
errors=errors, name=arg.name)
86+
elif is_list_like(arg) and getattr(arg, 'ndim', 1) == 0:
87+
# extract array scalar and process below
88+
arg = arg.item()
8689
elif is_list_like(arg) and getattr(arg, 'ndim', 1) == 1:
8790
return _convert_listlike(arg, unit=unit, box=box, errors=errors)
8891
elif getattr(arg, 'ndim', 1) > 1:

pandas/tests/frame/test_indexing.py

+9
Original file line numberDiff line numberDiff line change
@@ -1935,9 +1935,13 @@ def test_reindex_methods(self):
19351935

19361936
actual = df.reindex_like(df, method=method, tolerance=0)
19371937
assert_frame_equal(df, actual)
1938+
actual = df.reindex_like(df, method=method, tolerance=[0, 0, 0, 0])
1939+
assert_frame_equal(df, actual)
19381940

19391941
actual = df.reindex(target, method=method, tolerance=1)
19401942
assert_frame_equal(expected, actual)
1943+
actual = df.reindex(target, method=method, tolerance=[1, 1, 1, 1])
1944+
assert_frame_equal(expected, actual)
19411945

19421946
e2 = expected[::-1]
19431947
actual = df.reindex(target[::-1], method=method)
@@ -1958,6 +1962,11 @@ def test_reindex_methods(self):
19581962
actual = df.reindex(target, method='nearest', tolerance=0.2)
19591963
assert_frame_equal(expected, actual)
19601964

1965+
expected = pd.DataFrame({'x': [0, np.nan, 1, np.nan]}, index=target)
1966+
actual = df.reindex(target, method='nearest',
1967+
tolerance=[0.5, 0.01, 0.4, 0.1])
1968+
assert_frame_equal(expected, actual)
1969+
19611970
def test_reindex_frame_add_nat(self):
19621971
rng = date_range('1/1/2000 00:00:00', periods=10, freq='10s')
19631972
df = DataFrame({'A': np.random.randn(len(rng)), 'B': rng})

pandas/tests/indexes/datetimes/test_datetime.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,17 @@ def test_get_loc(self):
4141
tolerance=np.timedelta64(1, 'D')) == 1
4242
assert idx.get_loc('2000-01-01T12', method='nearest',
4343
tolerance=timedelta(1)) == 1
44-
with tm.assert_raises_regex(ValueError, 'must be convertible'):
44+
with tm.assert_raises_regex(ValueError,
45+
'unit abbreviation w/o a number'):
4546
idx.get_loc('2000-01-01T12', method='nearest', tolerance='foo')
4647
with pytest.raises(KeyError):
4748
idx.get_loc('2000-01-01T03', method='nearest', tolerance='2 hours')
49+
with pytest.raises(
50+
ValueError,
51+
match='tolerance size must match target index size'):
52+
idx.get_loc('2000-01-01', method='nearest',
53+
tolerance=[pd.Timedelta('1day').to_timedelta64(),
54+
pd.Timedelta('1day').to_timedelta64()])
4855

4956
assert idx.get_loc('2000', method='nearest') == slice(0, 3)
5057
assert idx.get_loc('2000-01', method='nearest') == slice(0, 3)
@@ -93,6 +100,19 @@ def test_get_indexer(self):
93100
idx.get_indexer(target, 'nearest',
94101
tolerance=pd.Timedelta('1 hour')),
95102
np.array([0, -1, 1], dtype=np.intp))
103+
tol_raw = [pd.Timedelta('1 hour'),
104+
pd.Timedelta('1 hour'),
105+
pd.Timedelta('1 hour').to_timedelta64(), ]
106+
tm.assert_numpy_array_equal(
107+
idx.get_indexer(target, 'nearest',
108+
tolerance=[np.timedelta64(x) for x in tol_raw]),
109+
np.array([0, -1, 1], dtype=np.intp))
110+
tol_bad = [pd.Timedelta('2 hour').to_timedelta64(),
111+
pd.Timedelta('1 hour').to_timedelta64(),
112+
'foo', ]
113+
with pytest.raises(
114+
ValueError, match='abbreviation w/o a number'):
115+
idx.get_indexer(target, 'nearest', tolerance=tol_bad)
96116
with pytest.raises(ValueError):
97117
idx.get_indexer(idx[[0]], method='nearest', tolerance='foo')
98118

pandas/tests/indexes/period/test_period.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pandas import (PeriodIndex, period_range, notna, DatetimeIndex, NaT,
1010
Index, Period, Int64Index, Series, DataFrame, date_range,
1111
offsets, compat)
12+
from pandas.core.indexes.period import IncompatibleFrequency
1213

1314
from ..datetimelike import DatetimeLike
1415

@@ -83,14 +84,21 @@ def test_get_loc(self):
8384
tolerance=np.timedelta64(1, 'D')) == 1
8485
assert idx.get_loc('2000-01-02T12', method='nearest',
8586
tolerance=timedelta(1)) == 1
86-
with tm.assert_raises_regex(ValueError, 'must be convertible'):
87+
with tm.assert_raises_regex(ValueError,
88+
'unit abbreviation w/o a number'):
8789
idx.get_loc('2000-01-10', method='nearest', tolerance='foo')
8890

8991
msg = 'Input has different freq from PeriodIndex\\(freq=D\\)'
9092
with tm.assert_raises_regex(ValueError, msg):
9193
idx.get_loc('2000-01-10', method='nearest', tolerance='1 hour')
9294
with pytest.raises(KeyError):
9395
idx.get_loc('2000-01-10', method='nearest', tolerance='1 day')
96+
with pytest.raises(
97+
ValueError,
98+
match='list-like tolerance size must match target index size'):
99+
idx.get_loc('2000-01-10', method='nearest',
100+
tolerance=[pd.Timedelta('1 day').to_timedelta64(),
101+
pd.Timedelta('1 day').to_timedelta64()])
94102

95103
def test_where(self):
96104
i = self.create_index()
@@ -158,6 +166,20 @@ def test_get_indexer(self):
158166
tm.assert_numpy_array_equal(idx.get_indexer(target, 'nearest',
159167
tolerance='1 day'),
160168
np.array([0, 1, 1], dtype=np.intp))
169+
tol_raw = [pd.Timedelta('1 hour'),
170+
pd.Timedelta('1 hour'),
171+
np.timedelta64(1, 'D'), ]
172+
tm.assert_numpy_array_equal(
173+
idx.get_indexer(target, 'nearest',
174+
tolerance=[np.timedelta64(x) for x in tol_raw]),
175+
np.array([0, -1, 1], dtype=np.intp))
176+
tol_bad = [pd.Timedelta('2 hour').to_timedelta64(),
177+
pd.Timedelta('1 hour').to_timedelta64(),
178+
np.timedelta64(1, 'M'), ]
179+
with pytest.raises(
180+
IncompatibleFrequency,
181+
match='Input has different freq from'):
182+
idx.get_indexer(target, 'nearest', tolerance=tol_bad)
161183

162184
def test_repeat(self):
163185
# GH10183

0 commit comments

Comments
 (0)