Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: groupby.agg with UDF changing pyarrow dtypes #59601

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
9faa460
Set preserve_dtype flag for bool type only when result is also bool
Apr 1, 2024
969d5b1
Update implementation to change type to pyarrow only
Apr 2, 2024
66114f3
Change import order
Apr 2, 2024
b0290ed
Convert numpy array to pandas representation of pyarrow array
Apr 3, 2024
20c8fa0
Add tests
Apr 3, 2024
97b3d54
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
Apr 3, 2024
932d737
Change pyarrow to optional import in agg_series() method
Apr 5, 2024
82ddeb5
Seperate tests
Apr 5, 2024
d510052
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
Apr 5, 2024
62a31d9
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
Apr 8, 2024
a54bf58
Revert to old implementation
Apr 8, 2024
64330f0
Update implementation to use pyarrow array method
Apr 8, 2024
0647711
Update test_aggregate tests
Apr 8, 2024
affde38
Move pyarrow import to top of method
Apr 8, 2024
842f561
Update according to pr comments
Apr 12, 2024
93b5bf3
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
Apr 20, 2024
6f35c0e
Fallback convert to input dtype is output is all nan or empty array
Apr 20, 2024
abd0adf
Strip na values when inferring pyarrow dtype
Apr 20, 2024
bebc442
Update tests to check expected inferred dtype instead of inputy dtype
Apr 20, 2024
bb6343b
Override test case for test_arrow.py
Apr 21, 2024
3a3f2a2
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
Apr 21, 2024
6dc40f5
Empty commit to trigger build run
Apr 21, 2024
4ef96f7
In agg series, convert to np values, then cast to pyarrow dtype, acco…
Apr 23, 2024
c6a98c0
Update tests
Apr 23, 2024
9181eaf
Update rst docs
Apr 25, 2024
612d7d0
Update impl to fix tests
Apr 25, 2024
3b6696b
Declare variable in outer scope
Apr 25, 2024
680e238
Update impl to use maybe_cast_pointwise_result instead of maybe_cast…
Apr 29, 2024
3a8597e
Fix tests with nested array
Apr 29, 2024
6496b15
Update according to pr comments
May 2, 2024
712c36a
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
May 2, 2024
e1ccef6
Preserve_dtype if argument is passed in, else don't preserve
May 7, 2024
0ce083d
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
undermyumbrella1 May 7, 2024
a1d73f5
Update tests
May 7, 2024
57845a8
Merge branch 'fix/group_by_agg_pyarrow_bool_numpy_same_type' of githu…
May 7, 2024
fa257b0
Remove redundant tests
undermyumbrella1 May 12, 2024
0a9b83f
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
undermyumbrella1 May 12, 2024
139319a
retrigger pipeline
undermyumbrella1 May 12, 2024
9c2f9f2
Merge main
rhshadrach Aug 25, 2024
fef315d
Merge branch 'main' into fix/group_by_agg_pyarrow_bool_numpy_same_type
rhshadrach Oct 6, 2024
f758eb1
Merge branch 'main' of https://github.com/pandas-dev/pandas into fix/…
rhshadrach Mar 22, 2025
283eda9
Rework
rhshadrach Mar 22, 2025
d6edeff
Cleanup
rhshadrach Mar 22, 2025
b2e34fb
Fixup
rhshadrach Mar 22, 2025
9cbf339
More skips
rhshadrach Mar 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,7 @@ Groupby/resample/rolling
- Bug in :meth:`.DataFrameGroupBy.quantile` when ``interpolation="nearest"`` is inconsistent with :meth:`DataFrame.quantile` (:issue:`47942`)
- Bug in :meth:`.Resampler.interpolate` on a :class:`DataFrame` with non-uniform sampling and/or indices not aligning with the resulting resampled index would result in wrong interpolation (:issue:`21351`)
- Bug in :meth:`DataFrame.ewm` and :meth:`Series.ewm` when passed ``times`` and aggregation functions other than mean (:issue:`51695`)
- Bug in :meth:`DataFrameGroupBy.agg` and :meth:`SeriesGroupBy.agg` that was returning numpy dtype values when input values are pyarrow dtype values, instead of returning pyarrow dtype values. (:issue:`53030`)
- Bug in :meth:`DataFrameGroupBy.agg` that raises ``AttributeError`` when there is dictionary input and duplicated columns, instead of returning a DataFrame with the aggregation of all duplicate columns. (:issue:`55041`)
- Bug in :meth:`DataFrameGroupBy.apply` and :meth:`SeriesGroupBy.apply` for empty data frame with ``group_keys=False`` still creating output index using group keys. (:issue:`60471`)
- Bug in :meth:`DataFrameGroupBy.apply` that was returning a completely empty DataFrame when all return values of ``func`` were ``None`` instead of returning an empty DataFrame with the original columns and dtypes. (:issue:`57775`)
Expand Down
25 changes: 18 additions & 7 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
)

from pandas.core.arrays import Categorical
from pandas.core.arrays.arrow.array import ArrowExtensionArray
from pandas.core.frame import DataFrame
from pandas.core.groupby import grouper
from pandas.core.indexes.api import (
Expand Down Expand Up @@ -954,18 +955,28 @@ def agg_series(
-------
np.ndarray or ExtensionArray
"""
result = self._aggregate_series_pure_python(obj, func)
npvalues = lib.maybe_convert_objects(result, try_float=False)

if isinstance(obj._values, ArrowExtensionArray):
from pandas.core.dtypes.common import is_string_dtype

if not isinstance(obj._values, np.ndarray):
# When obj.dtype is a string, any object can be cast. Only do so if the
# UDF returned strings or NA values.
if not is_string_dtype(obj.dtype) or is_string_dtype(
npvalues[~isna(npvalues)]
):
out = maybe_cast_pointwise_result(
npvalues, obj.dtype, numeric_only=True, same_dtype=preserve_dtype
)
else:
out = npvalues

elif not isinstance(obj._values, np.ndarray):
# we can preserve a little bit more aggressively with EA dtype
# because maybe_cast_pointwise_result will do a try/except
# with _from_sequence. NB we are assuming here that _from_sequence
# is sufficiently strict that it casts appropriately.
preserve_dtype = True

result = self._aggregate_series_pure_python(obj, func)

npvalues = lib.maybe_convert_objects(result, try_float=False)
if preserve_dtype:
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
else:
out = npvalues
Expand Down
98 changes: 98 additions & 0 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

from pandas.errors import SpecificationError
import pandas.util._test_decorators as td

from pandas.core.dtypes.common import is_integer_dtype

Expand All @@ -23,6 +24,7 @@
to_datetime,
)
import pandas._testing as tm
from pandas.arrays import ArrowExtensionArray
from pandas.core.groupby.grouper import Grouping


Expand Down Expand Up @@ -1807,3 +1809,99 @@ def test_groupby_aggregation_func_list_multi_index_duplicate_columns():
index=Index(["level1.1", "level1.2"]),
)
tm.assert_frame_equal(result, expected)


@td.skip_if_no("pyarrow")
@pytest.mark.parametrize(
"input_dtype, output_dtype",
[
# With NumPy arrays, the results from the UDF would be e.g. np.float32 scalars
# which we can therefore preserve. However with PyArrow arrays, the results are
# Python scalars so we have no information about size or uint vs int.
("float[pyarrow]", "double[pyarrow]"),
("int64[pyarrow]", "int64[pyarrow]"),
("uint64[pyarrow]", "int64[pyarrow]"),
("bool[pyarrow]", "bool[pyarrow]"),
],
)
def test_agg_lambda_pyarrow_dtype_conversion(input_dtype, output_dtype):
# GH#59601
# Test PyArrow dtype conversion back to PyArrow dtype
df = DataFrame(
{
"A": ["c1", "c2", "c3", "c1", "c2", "c3"],
"B": pd.array([100, 200, 255, 0, 199, 40392], dtype=input_dtype),
}
)
gb = df.groupby("A")
result = gb.agg(lambda x: x.min())

expected = DataFrame(
{"B": pd.array([0, 199, 255], dtype=output_dtype)},
index=Index(["c1", "c2", "c3"], name="A"),
)
tm.assert_frame_equal(result, expected)


@td.skip_if_no("pyarrow")
def test_agg_lambda_complex128_dtype_conversion():
# GH#59601
df = DataFrame(
{"A": ["c1", "c2", "c3"], "B": pd.array([100, 200, 255], "int64[pyarrow]")}
)
gb = df.groupby("A")
result = gb.agg(lambda x: complex(x.sum(), x.count()))

expected = DataFrame(
{
"B": pd.array(
[complex(100, 1), complex(200, 1), complex(255, 1)], dtype="complex128"
),
},
index=Index(["c1", "c2", "c3"], name="A"),
)
tm.assert_frame_equal(result, expected)


@td.skip_if_no("pyarrow")
def test_agg_lambda_numpy_uint64_to_pyarrow_dtype_conversion():
# GH#59601
df = DataFrame(
{
"A": ["c1", "c2", "c3"],
"B": pd.array([100, 200, 255], dtype="uint64[pyarrow]"),
}
)
gb = df.groupby("A")
result = gb.agg(lambda x: np.uint64(x.sum()))

expected = DataFrame(
{
"B": pd.array([100, 200, 255], dtype="uint64[pyarrow]"),
},
index=Index(["c1", "c2", "c3"], name="A"),
)
tm.assert_frame_equal(result, expected)


@td.skip_if_no("pyarrow")
def test_agg_lambda_pyarrow_struct_to_object_dtype_conversion():
# GH#59601
import pyarrow as pa

df = DataFrame(
{
"A": ["c1", "c2", "c3"],
"B": pd.array([100, 200, 255], dtype="int64[pyarrow]"),
}
)
gb = df.groupby("A")
result = gb.agg(lambda x: {"number": 1})

arr = pa.array([{"number": 1}, {"number": 1}, {"number": 1}])
expected = DataFrame(
{"B": ArrowExtensionArray(arr)},
index=Index(["c1", "c2", "c3"], name="A"),
)
Comment on lines +1899 to +1905
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the column starts as a PyArrow dtype and returns dictionaries, it seems questionable to me whether we should return the corresponding PyArrow dtype. The other option is a NumPy array of object dtype. But both seem like reasonable results and I imagine the PyArrow is likely to be more convenient for the user who is using PyArrow dtypes.


tm.assert_frame_equal(result, expected)
13 changes: 8 additions & 5 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2434,25 +2434,28 @@ def test_rolling_wrong_param_min_period():

def test_by_column_values_with_same_starting_value(any_string_dtype):
# GH29635
dtype = any_string_dtype
df = DataFrame(
{
"Name": ["Thomas", "Thomas", "Thomas John"],
"Credit": [1200, 1300, 900],
"Mood": Series(["sad", "happy", "happy"], dtype=any_string_dtype),
"Mood": Series(["sad", "happy", "happy"], dtype=dtype),
}
)
aggregate_details = {"Mood": Series.mode, "Credit": "sum"}

result = df.groupby(["Name"]).agg(aggregate_details)
expected_result = DataFrame(
expected = DataFrame(
{
"Mood": [["happy", "sad"], "happy"],
"Credit": [2500, 900],
"Name": ["Thomas", "Thomas John"],
}
},
).set_index("Name")

tm.assert_frame_equal(result, expected_result)
if getattr(dtype, "storage", None) == "pyarrow":
mood_values = pd.array(["happy", "sad"], dtype=dtype)
expected["Mood"] = [mood_values, "happy"]
tm.assert_frame_equal(result, expected)


def test_groupby_none_in_first_mi_level():
Expand Down