Skip to content

Commit 680e238

Browse files
author
Kei
committed
Update impl to use maybe_cast_pointwise_result instead of maybe_cast_to_pyarrow_array
1 parent 3b6696b commit 680e238

File tree

2 files changed

+13
-83
lines changed

2 files changed

+13
-83
lines changed

pandas/core/dtypes/cast.py

+1-74
Original file line numberDiff line numberDiff line change
@@ -478,40 +478,6 @@ def maybe_cast_pointwise_result(
478478
return result
479479

480480

481-
def maybe_cast_to_pyarrow_result(result: ArrayLike, obj_dtype: DtypeObj) -> ArrayLike:
482-
"""
483-
Try casting result of a pointwise operation to its pyarrow dtype
484-
and arrow extension array if appropriate. If not possible,
485-
returns np.ndarray.
486-
487-
Parameters
488-
----------
489-
result : array-like
490-
Result to cast.
491-
492-
Returns
493-
-------
494-
result : array-like
495-
result maybe casted to the dtype.
496-
"""
497-
from pandas.core.construction import array as pd_array
498-
499-
# maybe_convert_objects is unable to detect NA as nan
500-
# (detects it as object instead)
501-
stripped_result = result[~isna(result)]
502-
npvalues = lib.maybe_convert_objects(stripped_result, try_float=False)
503-
504-
if stripped_result.size == 0:
505-
return maybe_cast_pointwise_result(npvalues, obj_dtype, numeric_only=True)
506-
507-
try:
508-
dtype = convert_dtypes(npvalues, dtype_backend="pyarrow")
509-
out = pd_array(result, dtype=dtype)
510-
except (TypeError, ValueError, np.ComplexWarning):
511-
out = npvalues
512-
return out
513-
514-
515481
def _maybe_cast_to_extension_array(
516482
cls: type[ExtensionArray], obj: ArrayLike, dtype: ExtensionDtype | None = None
517483
) -> ArrayLike:
@@ -1061,7 +1027,6 @@ def convert_dtypes(
10611027
np.dtype, or ExtensionDtype
10621028
"""
10631029
inferred_dtype: str | DtypeObj
1064-
orig_inferred_dtype = None
10651030

10661031
if (
10671032
convert_string or convert_integer or convert_boolean or convert_floating
@@ -1070,7 +1035,6 @@ def convert_dtypes(
10701035
inferred_dtype = lib.infer_dtype(input_array)
10711036
else:
10721037
inferred_dtype = input_array.dtype
1073-
orig_inferred_dtype = inferred_dtype
10741038

10751039
if is_string_dtype(inferred_dtype):
10761040
if not convert_string or inferred_dtype == "bytes":
@@ -1168,8 +1132,7 @@ def convert_dtypes(
11681132
elif isinstance(inferred_dtype, StringDtype):
11691133
base_dtype = np.dtype(str)
11701134
else:
1171-
base_dtype = _infer_pyarrow_dtype(input_array, orig_inferred_dtype)
1172-
1135+
base_dtype = inferred_dtype
11731136
if (
11741137
base_dtype.kind == "O" # type: ignore[union-attr]
11751138
and input_array.size > 0
@@ -1180,10 +1143,8 @@ def convert_dtypes(
11801143
pa_type = pa.null()
11811144
else:
11821145
pa_type = to_pyarrow_type(base_dtype)
1183-
11841146
if pa_type is not None:
11851147
inferred_dtype = ArrowDtype(pa_type)
1186-
11871148
elif dtype_backend == "numpy_nullable" and isinstance(inferred_dtype, ArrowDtype):
11881149
# GH 53648
11891150
inferred_dtype = _arrow_dtype_mapping()[inferred_dtype.pyarrow_dtype]
@@ -1193,40 +1154,6 @@ def convert_dtypes(
11931154
return inferred_dtype # type: ignore[return-value]
11941155

11951156

1196-
def _infer_pyarrow_dtype(
1197-
input_array: ArrayLike,
1198-
inferred_dtype: str,
1199-
) -> DtypeObj:
1200-
import pyarrow as pa
1201-
1202-
if inferred_dtype == "date":
1203-
return ArrowDtype(pa.date32())
1204-
elif inferred_dtype == "time":
1205-
return ArrowDtype(pa.time64("us"))
1206-
elif inferred_dtype == "bytes":
1207-
return ArrowDtype(pa.binary())
1208-
elif inferred_dtype == "decimal":
1209-
from pyarrow import (
1210-
ArrowInvalid,
1211-
ArrowMemoryError,
1212-
ArrowNotImplementedError,
1213-
)
1214-
1215-
try:
1216-
pyarrow_array = pa.array(input_array)
1217-
return ArrowDtype(pyarrow_array.type)
1218-
except (
1219-
TypeError,
1220-
ValueError,
1221-
ArrowInvalid,
1222-
ArrowMemoryError,
1223-
ArrowNotImplementedError,
1224-
):
1225-
return input_array.dtype
1226-
1227-
return input_array.dtype
1228-
1229-
12301157
def maybe_infer_to_datetimelike(
12311158
value: npt.NDArray[np.object_],
12321159
) -> np.ndarray | DatetimeArray | TimedeltaArray | PeriodArray | IntervalArray:

pandas/core/groupby/ops.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
from pandas.core.dtypes.cast import (
3838
maybe_cast_pointwise_result,
39-
maybe_cast_to_pyarrow_result,
4039
maybe_downcast_to_dtype,
4140
)
4241
from pandas.core.dtypes.common import (
@@ -46,6 +45,7 @@
4645
ensure_uint64,
4746
is_1d_only_ea_dtype,
4847
)
48+
from pandas.core.dtypes.dtypes import ArrowDtype
4949
from pandas.core.dtypes.missing import (
5050
isna,
5151
maybe_fill,
@@ -917,21 +917,24 @@ def agg_series(
917917
"""
918918

919919
result = self._aggregate_series_pure_python(obj, func)
920+
npvalues = lib.maybe_convert_objects(result, try_float=False)
920921

921922
if isinstance(obj._values, ArrowExtensionArray):
922-
return maybe_cast_to_pyarrow_result(result, obj.dtype)
923+
out = maybe_cast_pointwise_result(
924+
npvalues, obj.dtype, numeric_only=True, same_dtype=False
925+
)
926+
import pyarrow as pa
927+
928+
if isinstance(out.dtype, ArrowDtype) and pa.types.is_struct(
929+
out.dtype.pyarrow_dtype
930+
):
931+
out = npvalues
923932

924-
if not isinstance(obj._values, np.ndarray) and not isinstance(
925-
obj._values, ArrowExtensionArray
926-
):
933+
elif not isinstance(obj._values, np.ndarray):
927934
# we can preserve a little bit more aggressively with EA dtype
928935
# because maybe_cast_pointwise_result will do a try/except
929936
# with _from_sequence. NB we are assuming here that _from_sequence
930937
# is sufficiently strict that it casts appropriately.
931-
preserve_dtype = True
932-
933-
npvalues = lib.maybe_convert_objects(result, try_float=False)
934-
if preserve_dtype:
935938
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
936939
else:
937940
out = npvalues

0 commit comments

Comments
 (0)