From 3c4d782a7d15fb1526a405d943bb4610f9b1e9db Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 17 Mar 2025 11:06:28 +0100 Subject: [PATCH 1/7] API (string dtype): implement hierarchy (NA > NaN, pyarrow > python) for consistent comparisons between different string dtypes --- pandas/core/arrays/arrow/array.py | 5 +-- pandas/core/arrays/string_.py | 25 ++++++++++++- pandas/core/arrays/string_arrow.py | 8 +++++ pandas/tests/arrays/string_/test_string.py | 42 ++++++++++++++-------- 4 files changed, 61 insertions(+), 19 deletions(-) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 9295cf7873d98..e3433ffcb24e8 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -33,7 +33,6 @@ infer_dtype_from_scalar, ) from pandas.core.dtypes.common import ( - CategoricalDtype, is_array_like, is_bool_dtype, is_float_dtype, @@ -730,9 +729,7 @@ def __setstate__(self, state) -> None: def _cmp_method(self, other, op) -> ArrowExtensionArray: pc_func = ARROW_CMP_FUNCS[op.__name__] - if isinstance( - other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray) - ) or isinstance(getattr(other, "dtype", None), CategoricalDtype): + if isinstance(other, (ExtensionArray, np.ndarray, list)): try: result = pc_func(self._pa_array, self._box_pa(other)) except pa.ArrowNotImplementedError: diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 7227ea77ca433..1a0a07f3a686b 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -1018,7 +1018,30 @@ def searchsorted( return super().searchsorted(value=value, side=side, sorter=sorter) def _cmp_method(self, other, op): - from pandas.arrays import BooleanArray + from pandas.arrays import ( + ArrowExtensionArray, + BooleanArray, + ) + + if ( + isinstance(other, BaseStringArray) + and self.dtype.na_value is not libmissing.NA + and other.dtype.na_value is libmissing.NA + ): + # NA has priority of NaN semantics + return NotImplemented + + if isinstance(other, ArrowExtensionArray): + if isinstance(other, BaseStringArray): + # pyarrow storage has priority over python storage + # (except if we have NA semantics and other not) + if not ( + self.dtype.na_value is libmissing.NA + and other.dtype.na_value is not libmissing.NA + ): + return NotImplemented + else: + return NotImplemented if isinstance(other, StringArray): other = other._ndarray diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index d35083fd892a8..dc7343d0ea616 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -473,6 +473,14 @@ def value_counts(self, dropna: bool = True) -> Series: return result def _cmp_method(self, other, op): + if ( + isinstance(other, BaseStringArray) + and self.dtype.na_value is not libmissing.NA + and other.dtype.na_value is libmissing.NA + ): + # NA has priority of NaN semantics + return NotImplemented + result = super()._cmp_method(other, op) if self.dtype.na_value is np.nan: if op == operator.ne: diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 336a0fef69170..148ab4c59d5b8 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -45,6 +45,14 @@ def cls(dtype): return dtype.construct_array_type() +DTYPE_HIERARCHY = [ + pd.StringDtype("python", na_value=np.nan), + pd.StringDtype("pyarrow", na_value=np.nan), + pd.StringDtype("python", na_value=pd.NA), + pd.StringDtype("pyarrow", na_value=pd.NA), +] + + def test_dtype_constructor(): pytest.importorskip("pyarrow") @@ -319,13 +327,18 @@ def test_comparison_methods_scalar_not_string(comparison_op, dtype): tm.assert_extension_array_equal(result, expected) -def test_comparison_methods_array(comparison_op, dtype): +def test_comparison_methods_array(comparison_op, dtype, dtype2): op_name = f"__{comparison_op.__name__}__" a = pd.array(["a", None, "c"], dtype=dtype) - other = [None, None, "c"] - result = getattr(a, op_name)(other) - if dtype.na_value is np.nan: + other = pd.array([None, None, "c"], dtype=dtype2) + result = comparison_op(a, other) + + # ensure operation is commutative + result2 = comparison_op(other, a) + tm.assert_equal(result, result2) + + if dtype.na_value is np.nan and dtype2.na_value is np.nan: if operator.ne == comparison_op: expected = np.array([True, True, False]) else: @@ -333,23 +346,24 @@ def test_comparison_methods_array(comparison_op, dtype): expected[-1] = getattr(other[-1], op_name)(a[-1]) tm.assert_numpy_array_equal(result, expected) - result = getattr(a, op_name)(pd.NA) - if operator.ne == comparison_op: - expected = np.array([True, True, True]) + else: + h1 = DTYPE_HIERARCHY.index(dtype) + h2 = DTYPE_HIERARCHY.index(dtype2) + max_dtype = DTYPE_HIERARCHY[max(h1, h2)] + if max_dtype.storage == "python": + expected_dtype = "boolean" else: - expected = np.array([False, False, False]) - tm.assert_numpy_array_equal(result, expected) + expected_dtype = "bool[pyarrow]" - else: - expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean" expected = np.full(len(a), fill_value=None, dtype="object") expected[-1] = getattr(other[-1], op_name)(a[-1]) expected = pd.array(expected, dtype=expected_dtype) tm.assert_extension_array_equal(result, expected) - result = getattr(a, op_name)(pd.NA) - expected = pd.array([None, None, None], dtype=expected_dtype) - tm.assert_extension_array_equal(result, expected) + # # with list + # other = [None, None, "c"] + # result3 = getattr(a, op_name)(other) + # tm.assert_equal(result, result3) def test_constructor_raises(cls): From 7ffb08f012893af38e2546ffd3600dca90218647 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 19 Mar 2025 10:16:30 +0100 Subject: [PATCH 2/7] fix string arith tests --- pandas/tests/arrays/string_/test_string.py | 10 +++++++--- pandas/tests/extension/test_string.py | 12 ++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 148ab4c59d5b8..c519a48fc1e49 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -53,6 +53,12 @@ def cls(dtype): ] +def string_dtype_highest_priority(dtype1, dtype2): + h1 = DTYPE_HIERARCHY.index(dtype1) + h2 = DTYPE_HIERARCHY.index(dtype2) + return DTYPE_HIERARCHY[max(h1, h2)] + + def test_dtype_constructor(): pytest.importorskip("pyarrow") @@ -347,9 +353,7 @@ def test_comparison_methods_array(comparison_op, dtype, dtype2): tm.assert_numpy_array_equal(result, expected) else: - h1 = DTYPE_HIERARCHY.index(dtype) - h2 = DTYPE_HIERARCHY.index(dtype2) - max_dtype = DTYPE_HIERARCHY[max(h1, h2)] + max_dtype = string_dtype_highest_priority(dtype, dtype2) if max_dtype.storage == "python": expected_dtype = "boolean" else: diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 25129111180d6..7fc45a8439133 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -22,8 +22,6 @@ import numpy as np import pytest -from pandas.compat import HAS_PYARROW - from pandas.core.dtypes.base import StorageExtensionDtype import pandas as pd @@ -31,6 +29,7 @@ from pandas.api.types import is_string_dtype from pandas.core.arrays import ArrowStringArray from pandas.core.arrays.string_ import StringDtype +from pandas.tests.arrays.string_.test_string import string_dtype_highest_priority from pandas.tests.extension import base @@ -202,10 +201,13 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result): dtype = cast(StringDtype, tm.get_dtype(obj)) if op_name in ["__add__", "__radd__"]: cast_to = dtype + dtype_other = tm.get_dtype(other) if not isinstance(other, str) else None + if isinstance(dtype_other, StringDtype): + cast_to = string_dtype_highest_priority(dtype, dtype_other) elif dtype.na_value is np.nan: cast_to = np.bool_ # type: ignore[assignment] elif dtype.storage == "pyarrow": - cast_to = "boolean[pyarrow]" # type: ignore[assignment] + cast_to = "bool[pyarrow]" # type: ignore[assignment] else: cast_to = "boolean" # type: ignore[assignment] return pointwise_result.astype(cast_to) @@ -236,9 +238,7 @@ def test_arith_series_with_array( if ( using_infer_string and all_arithmetic_operators == "__radd__" - and ( - (dtype.na_value is pd.NA) or (dtype.storage == "python" and HAS_PYARROW) - ) + and dtype.na_value is pd.NA ): mark = pytest.mark.xfail( reason="The pointwise operation result will be inferred to " From 48907c35c588d571ab06e68b735a97ebe38ba89a Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 19 Mar 2025 10:18:55 +0100 Subject: [PATCH 3/7] fix for build without pyarrow --- pandas/tests/arrays/string_/test_string.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index c519a48fc1e49..852783640ba36 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -10,6 +10,7 @@ from pandas._config import using_string_dtype +from pandas.compat import HAS_PYARROW from pandas.compat.pyarrow import ( pa_version_under12p0, pa_version_under19p0, @@ -45,15 +46,20 @@ def cls(dtype): return dtype.construct_array_type() -DTYPE_HIERARCHY = [ - pd.StringDtype("python", na_value=np.nan), - pd.StringDtype("pyarrow", na_value=np.nan), - pd.StringDtype("python", na_value=pd.NA), - pd.StringDtype("pyarrow", na_value=pd.NA), -] - - def string_dtype_highest_priority(dtype1, dtype2): + if HAS_PYARROW: + DTYPE_HIERARCHY = [ + pd.StringDtype("python", na_value=np.nan), + pd.StringDtype("pyarrow", na_value=np.nan), + pd.StringDtype("python", na_value=pd.NA), + pd.StringDtype("pyarrow", na_value=pd.NA), + ] + else: + DTYPE_HIERARCHY = [ + pd.StringDtype("python", na_value=np.nan), + pd.StringDtype("python", na_value=pd.NA), + ] + h1 = DTYPE_HIERARCHY.index(dtype1) h2 = DTYPE_HIERARCHY.index(dtype2) return DTYPE_HIERARCHY[max(h1, h2)] From 2058120c41fc54699bae476c220b26ebb8c0b2f9 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 19 Mar 2025 14:07:39 +0100 Subject: [PATCH 4/7] fix xfail condition --- pandas/tests/extension/test_string.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 7fc45a8439133..6ea8ac59ca3e6 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -22,6 +22,8 @@ import numpy as np import pytest +from pandas.compat import HAS_PYARROW + from pandas.core.dtypes.base import StorageExtensionDtype import pandas as pd @@ -238,8 +240,12 @@ def test_arith_series_with_array( if ( using_infer_string and all_arithmetic_operators == "__radd__" - and dtype.na_value is pd.NA + and ( + dtype.na_value is pd.NA + and not (not HAS_PYARROW and dtype.storage == "python") + ) ): + # TODO(infer_string) mark = pytest.mark.xfail( reason="The pointwise operation result will be inferred to " "string[nan, pyarrow], which does not match the input dtype" From 4ebd93b5284cb3f0d1ea5957267e5aa8b8dbd06c Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 19 Mar 2025 14:10:14 +0100 Subject: [PATCH 5/7] fix type annotation --- pandas/core/ops/invalid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/ops/invalid.py b/pandas/core/ops/invalid.py index 395db1617cb63..62aa79a881717 100644 --- a/pandas/core/ops/invalid.py +++ b/pandas/core/ops/invalid.py @@ -25,7 +25,7 @@ def invalid_comparison( left: ArrayLike, - right: ArrayLike | Scalar, + right: ArrayLike | list | Scalar, op: Callable[[Any, Any], bool], ) -> npt.NDArray[np.bool_]: """ From 51340a93d596176cef32cd1bb582b4f03dc6dd47 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 26 Mar 2025 09:20:14 +0100 Subject: [PATCH 6/7] re-add test with list --- pandas/tests/arrays/string_/test_string.py | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index 852783640ba36..ff947e353c14c 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -376,6 +376,33 @@ def test_comparison_methods_array(comparison_op, dtype, dtype2): # tm.assert_equal(result, result3) +def test_comparison_methods_list(comparison_op, dtype): + op_name = f"__{comparison_op.__name__}__" + + a = pd.array(["a", None, "c"], dtype=dtype) + other = [None, None, "c"] + result = comparison_op(a, other) + + # ensure operation is commutative + result2 = comparison_op(other, a) + tm.assert_equal(result, result2) + + if dtype.na_value is np.nan: + if operator.ne == comparison_op: + expected = np.array([True, True, False]) + else: + expected = np.array([False, False, False]) + expected[-1] = getattr(other[-1], op_name)(a[-1]) + tm.assert_numpy_array_equal(result, expected) + + else: + expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean" + expected = np.full(len(a), fill_value=None, dtype="object") + expected[-1] = getattr(other[-1], op_name)(a[-1]) + expected = pd.array(expected, dtype=expected_dtype) + tm.assert_extension_array_equal(result, expected) + + def test_constructor_raises(cls): if cls is pd.arrays.StringArray: msg = "StringArray requires a sequence of strings or pandas.NA" From e2bfe18e8e1bee0e5e9744d7f9eb83b598f0f477 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 26 Mar 2025 09:20:54 +0100 Subject: [PATCH 7/7] cleanup --- pandas/tests/arrays/string_/test_string.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index ff947e353c14c..975a539a79724 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -370,11 +370,6 @@ def test_comparison_methods_array(comparison_op, dtype, dtype2): expected = pd.array(expected, dtype=expected_dtype) tm.assert_extension_array_equal(result, expected) - # # with list - # other = [None, None, "c"] - # result3 = getattr(a, op_name)(other) - # tm.assert_equal(result, result3) - def test_comparison_methods_list(comparison_op, dtype): op_name = f"__{comparison_op.__name__}__"