Skip to content

Commit 2224cf3

Browse files
authored
REF: share ExtensionIndex astype, __getitem__ with Index (#44059)
1 parent 8a51e68 commit 2224cf3

File tree

2 files changed

+49
-70
lines changed

2 files changed

+49
-70
lines changed

pandas/core/indexes/base.py

+49-21
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@
5959
deprecate_nonkeyword_arguments,
6060
doc,
6161
)
62-
from pandas.util._exceptions import find_stack_level
62+
from pandas.util._exceptions import (
63+
find_stack_level,
64+
rewrite_exception,
65+
)
6366

6467
from pandas.core.dtypes.cast import (
6568
can_hold_element,
@@ -985,20 +988,40 @@ def astype(self, dtype, copy=True):
985988
dtype = pandas_dtype(dtype)
986989

987990
if is_dtype_equal(self.dtype, dtype):
991+
# Ensure that self.astype(self.dtype) is self
988992
return self.copy() if copy else self
989993

994+
if (
995+
self.dtype == np.dtype("M8[ns]")
996+
and isinstance(dtype, np.dtype)
997+
and dtype.kind == "M"
998+
and dtype != np.dtype("M8[ns]")
999+
):
1000+
# For now DatetimeArray supports this by unwrapping ndarray,
1001+
# but DatetimeIndex doesn't
1002+
raise TypeError(f"Cannot cast {type(self).__name__} to dtype")
1003+
1004+
values = self._data
1005+
if isinstance(values, ExtensionArray):
1006+
with rewrite_exception(type(values).__name__, type(self).__name__):
1007+
new_values = values.astype(dtype, copy=copy)
1008+
9901009
elif isinstance(dtype, ExtensionDtype):
9911010
cls = dtype.construct_array_type()
992-
new_values = cls._from_sequence(self, dtype=dtype, copy=False)
993-
return Index(new_values, dtype=dtype, copy=copy, name=self.name)
1011+
# Note: for RangeIndex and CategoricalDtype self vs self._values
1012+
# behaves differently here.
1013+
new_values = cls._from_sequence(self, dtype=dtype, copy=copy)
9941014

995-
try:
996-
casted = self._values.astype(dtype, copy=copy)
997-
except (TypeError, ValueError) as err:
998-
raise TypeError(
999-
f"Cannot cast {type(self).__name__} to dtype {dtype}"
1000-
) from err
1001-
return Index(casted, name=self.name, dtype=dtype)
1015+
else:
1016+
try:
1017+
new_values = values.astype(dtype, copy=copy)
1018+
except (TypeError, ValueError) as err:
1019+
raise TypeError(
1020+
f"Cannot cast {type(self).__name__} to dtype {dtype}"
1021+
) from err
1022+
1023+
# pass copy=False because any copying will be done in the astype above
1024+
return Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)
10021025

10031026
_index_shared_docs[
10041027
"take"
@@ -4875,8 +4898,6 @@ def __getitem__(self, key):
48754898
corresponding `Index` subclass.
48764899
48774900
"""
4878-
# There's no custom logic to be implemented in __getslice__, so it's
4879-
# not overloaded intentionally.
48804901
getitem = self._data.__getitem__
48814902

48824903
if is_scalar(key):
@@ -4885,25 +4906,32 @@ def __getitem__(self, key):
48854906

48864907
if isinstance(key, slice):
48874908
# This case is separated from the conditional above to avoid
4888-
# pessimization of basic indexing.
4909+
# pessimization com.is_bool_indexer and ndim checks.
48894910
result = getitem(key)
48904911
# Going through simple_new for performance.
48914912
return type(self)._simple_new(result, name=self._name)
48924913

48934914
if com.is_bool_indexer(key):
4915+
# if we have list[bools, length=1e5] then doing this check+convert
4916+
# takes 166 µs + 2.1 ms and cuts the ndarray.__getitem__
4917+
# time below from 3.8 ms to 496 µs
4918+
# if we already have ndarray[bool], the overhead is 1.4 µs or .25%
48944919
key = np.asarray(key, dtype=bool)
48954920

48964921
result = getitem(key)
4897-
if not is_scalar(result):
4898-
if np.ndim(result) > 1:
4899-
deprecate_ndim_indexing(result)
4900-
return result
4901-
# NB: Using _constructor._simple_new would break if MultiIndex
4902-
# didn't override __getitem__
4903-
return self._constructor._simple_new(result, name=self._name)
4904-
else:
4922+
# Because we ruled out integer above, we always get an arraylike here
4923+
if result.ndim > 1:
4924+
deprecate_ndim_indexing(result)
4925+
if hasattr(result, "_ndarray"):
4926+
# i.e. NDArrayBackedExtensionArray
4927+
# Unpack to ndarray for MPL compat
4928+
return result._ndarray
49054929
return result
49064930

4931+
# NB: Using _constructor._simple_new would break if MultiIndex
4932+
# didn't override __getitem__
4933+
return self._constructor._simple_new(result, name=self._name)
4934+
49074935
def _getitem_slice(self: _IndexT, slobj: slice) -> _IndexT:
49084936
"""
49094937
Fastpath for __getitem__ when we know we have a slice.

pandas/core/indexes/extension.py

-49
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,11 @@
1515
cache_readonly,
1616
doc,
1717
)
18-
from pandas.util._exceptions import rewrite_exception
1918

20-
from pandas.core.dtypes.common import (
21-
is_dtype_equal,
22-
pandas_dtype,
23-
)
2419
from pandas.core.dtypes.generic import ABCDataFrame
2520

2621
from pandas.core.arrays import IntervalArray
2722
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
28-
from pandas.core.indexers import deprecate_ndim_indexing
2923
from pandas.core.indexes.base import Index
3024

3125
_T = TypeVar("_T", bound="NDArrayBackedExtensionIndex")
@@ -138,22 +132,6 @@ class ExtensionIndex(Index):
138132

139133
_data: IntervalArray | NDArrayBackedExtensionArray
140134

141-
# ---------------------------------------------------------------------
142-
# NDarray-Like Methods
143-
144-
def __getitem__(self, key):
145-
result = self._data[key]
146-
if isinstance(result, type(self._data)):
147-
if result.ndim == 1:
148-
return type(self)(result, name=self._name)
149-
# Unpack to ndarray for MPL compat
150-
151-
result = result._ndarray
152-
153-
# Includes cases where we get a 2D ndarray back for MPL compat
154-
deprecate_ndim_indexing(result)
155-
return result
156-
157135
# ---------------------------------------------------------------------
158136

159137
def insert(self, loc: int, item) -> Index:
@@ -204,33 +182,6 @@ def map(self, mapper, na_action=None):
204182
except Exception:
205183
return self.astype(object).map(mapper)
206184

207-
@doc(Index.astype)
208-
def astype(self, dtype, copy: bool = True) -> Index:
209-
dtype = pandas_dtype(dtype)
210-
if is_dtype_equal(self.dtype, dtype):
211-
if not copy:
212-
# Ensure that self.astype(self.dtype) is self
213-
return self
214-
return self.copy()
215-
216-
# error: Non-overlapping equality check (left operand type: "dtype[Any]", right
217-
# operand type: "Literal['M8[ns]']")
218-
if (
219-
isinstance(self.dtype, np.dtype)
220-
and isinstance(dtype, np.dtype)
221-
and dtype.kind == "M"
222-
and dtype != "M8[ns]" # type: ignore[comparison-overlap]
223-
):
224-
# For now Datetime supports this by unwrapping ndarray, but DTI doesn't
225-
raise TypeError(f"Cannot cast {type(self).__name__} to dtype")
226-
227-
with rewrite_exception(type(self._data).__name__, type(self).__name__):
228-
new_values = self._data.astype(dtype, copy=copy)
229-
230-
# pass copy=False because any copying will be done in the
231-
# _data.astype call above
232-
return Index(new_values, dtype=new_values.dtype, name=self.name, copy=False)
233-
234185
@cache_readonly
235186
def _isnan(self) -> npt.NDArray[np.bool_]:
236187
# error: Incompatible return value type (got "ExtensionArray", expected

0 commit comments

Comments
 (0)