From 8a431f7aa75f63632d43068d5ba4fe37447aaed7 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 21 Apr 2023 18:34:50 +0200 Subject: [PATCH] BUG: fix setitem with enlargment with pyarrow Scalar --- pandas/_libs/lib.pyx | 27 +++++++++++++++++++++++++++ pandas/core/dtypes/missing.py | 6 ++++++ pandas/core/indexing.py | 12 ++++++++++-- pandas/tests/extension/test_arrow.py | 23 +++++++++++++++++++++++ 4 files changed, 66 insertions(+), 2 deletions(-) diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index 5bf99301d9261..5b839b4f56446 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -145,6 +145,32 @@ i8max = INT64_MAX u8max = UINT64_MAX +cdef bint PYARROW_INSTALLED = False + +try: + import pyarrow as pa + + PYARROW_INSTALLED = True +except ImportError: + pa = None + + +cpdef is_pyarrow_array(obj): + if PYARROW_INSTALLED: + return isinstance(obj, (pa.Array, pa.ChunkedArray)) + return False + + +cpdef is_pyarrow_scalar(obj): + if PYARROW_INSTALLED: + return isinstance(obj, pa.Scalar) + return False + + +def is_pyarrow_installed(): + return PYARROW_INSTALLED + + @cython.wraparound(False) @cython.boundscheck(False) def memory_usage_of_objects(arr: object[:]) -> int64_t: @@ -238,6 +264,7 @@ def is_scalar(val: object) -> bool: # Note: PyNumber_Check check includes Decimal, Fraction, numbers.Number return (PyNumber_Check(val) + or is_pyarrow_scalar(val) or is_period_object(val) or is_interval(val) or is_offset_object(val)) diff --git a/pandas/core/dtypes/missing.py b/pandas/core/dtypes/missing.py index 718404f0799e4..5fad49301f8c2 100644 --- a/pandas/core/dtypes/missing.py +++ b/pandas/core/dtypes/missing.py @@ -690,6 +690,12 @@ def is_valid_na_for_dtype(obj, dtype: DtypeObj) -> bool: """ if not lib.is_scalar(obj) or not isna(obj): return False + elif lib.is_pyarrow_scalar(obj): + return ( + obj.is_null() + and hasattr(dtype, "pyarrow_dtype") + and dtype.pyarrow_dtype == obj.type + ) elif dtype.kind == "M": if isinstance(dtype, np.dtype): # i.e. not tzaware diff --git a/pandas/core/indexing.py b/pandas/core/indexing.py index 6d5daf5025c49..850e87350ea04 100644 --- a/pandas/core/indexing.py +++ b/pandas/core/indexing.py @@ -15,6 +15,7 @@ from pandas._config import using_copy_on_write +from pandas._libs import lib from pandas._libs.indexing import NDFrameIndexerBase from pandas._libs.lib import item_from_zerodim from pandas.compat import PYPY @@ -2098,8 +2099,15 @@ def _setitem_with_indexer_missing(self, indexer, value): # We should not cast, if we have object dtype because we can # set timedeltas into object series curr_dtype = self.obj.dtype - curr_dtype = getattr(curr_dtype, "numpy_dtype", curr_dtype) - new_dtype = maybe_promote(curr_dtype, value)[0] + if lib.is_pyarrow_scalar(value) and hasattr( + curr_dtype, "pyarrow_dtype" + ): + # TODO promote arrow scalar and type + new_dtype = curr_dtype + value = value.as_py() + else: + curr_dtype = getattr(curr_dtype, "numpy_dtype", curr_dtype) + new_dtype = maybe_promote(curr_dtype, value)[0] else: new_dtype = None diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 3ecbc723be2eb..1d97a557c5db6 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2418,3 +2418,26 @@ def test_describe_numeric_data(pa_type): index=["count", "mean", "std", "min", "25%", "50%", "75%", "max"], ) tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "value, target_value, dtype", + [ + (pa.scalar(4, type="int32"), 4, "int32[pyarrow]"), + (pa.scalar(4, type="int64"), 4, "int32[pyarrow]"), + # (pa.scalar(4.5, type="float64"), 4, "int32[pyarrow]"), + (4, 4, "int32[pyarrow]"), + (pd.NA, None, "int32[pyarrow]"), + (None, None, "int32[pyarrow]"), + (pa.scalar(None, type="int32"), None, "int32[pyarrow]"), + (pa.scalar(None, type="int64"), None, "int32[pyarrow]"), + ], +) +def test_series_setitem_with_enlargement(value, target_value, dtype): + # GH-52235 + # similar to series/inedexing/test_setitem.py::test_setitem_keep_precision + # and test_setitem_enlarge_with_na, but for arrow dtypes + ser = pd.Series([1, 2, 3], dtype=dtype) + ser[3] = value + expected = pd.Series([1, 2, 3, target_value], dtype=dtype) + tm.assert_series_equal(ser, expected)