Skip to content

Commit 96c150a

Browse files
authored
test: test most relevant dtype for aggregates (#595)
* fix: keep most relevant dtype for aggregates * add aggregate tests for bool result * refactor and reuse dtypes.lcd_dtype * check_dtype=False
1 parent 8d2a51c commit 96c150a

File tree

2 files changed

+119
-23
lines changed

2 files changed

+119
-23
lines changed

bigframes/dtypes.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -658,10 +658,14 @@ def is_compatible(scalar: typing.Any, dtype: Dtype) -> typing.Optional[Dtype]:
658658
return None
659659

660660

661-
def lcd_type(dtype1: Dtype, dtype2: Dtype) -> Dtype:
662-
"""Get the supertype of the two types."""
663-
if dtype1 == dtype2:
664-
return dtype1
661+
def lcd_type(*dtypes: Dtype) -> Dtype:
662+
if len(dtypes) < 1:
663+
raise ValueError("at least one dypes should be provided")
664+
if len(dtypes) == 1:
665+
return dtypes[0]
666+
unique_dtypes = set(dtypes)
667+
if len(unique_dtypes) == 1:
668+
return unique_dtypes.pop()
665669
# Implicit conversion currently only supported for numeric types
666670
hierarchy: list[Dtype] = [
667671
pd.BooleanDtype(),
@@ -670,9 +674,9 @@ def lcd_type(dtype1: Dtype, dtype2: Dtype) -> Dtype:
670674
pd.ArrowDtype(pa.decimal256(76, 38)),
671675
pd.Float64Dtype(),
672676
]
673-
if (dtype1 not in hierarchy) or (dtype2 not in hierarchy):
677+
if any([dtype not in hierarchy for dtype in dtypes]):
674678
return None
675-
lcd_index = max(hierarchy.index(dtype1), hierarchy.index(dtype2))
679+
lcd_index = max([hierarchy.index(dtype) for dtype in dtypes])
676680
return hierarchy[lcd_index]
677681

678682

tests/system/small/test_dataframe.py

+109-17
Original file line numberDiff line numberDiff line change
@@ -2390,12 +2390,27 @@ def test_dataframe_pct_change(scalars_df_index, scalars_pandas_df_index, periods
23902390
def test_dataframe_agg_single_string(scalars_dfs):
23912391
numeric_cols = ["int64_col", "int64_too", "float64_col"]
23922392
scalars_df, scalars_pandas_df = scalars_dfs
2393+
23932394
bf_result = scalars_df[numeric_cols].agg("sum").to_pandas()
23942395
pd_result = scalars_pandas_df[numeric_cols].agg("sum")
23952396

2396-
# Pandas may produce narrower numeric types, but bigframes always produces Float64
2397-
pd_result = pd_result.astype("Float64")
2398-
pd.testing.assert_series_equal(pd_result, bf_result, check_index_type=False)
2397+
assert bf_result.dtype == "Float64"
2398+
pd.testing.assert_series_equal(
2399+
pd_result, bf_result, check_dtype=False, check_index_type=False
2400+
)
2401+
2402+
2403+
def test_dataframe_agg_int_single_string(scalars_dfs):
2404+
numeric_cols = ["int64_col", "int64_too", "bool_col"]
2405+
scalars_df, scalars_pandas_df = scalars_dfs
2406+
2407+
bf_result = scalars_df[numeric_cols].agg("sum").to_pandas()
2408+
pd_result = scalars_pandas_df[numeric_cols].agg("sum")
2409+
2410+
assert bf_result.dtype == "Int64"
2411+
pd.testing.assert_series_equal(
2412+
pd_result, bf_result, check_dtype=False, check_index_type=False
2413+
)
23992414

24002415

24012416
def test_dataframe_agg_multi_string(scalars_dfs):
@@ -2431,6 +2446,27 @@ def test_dataframe_agg_multi_string(scalars_dfs):
24312446
).all()
24322447

24332448

2449+
def test_dataframe_agg_int_multi_string(scalars_dfs):
2450+
numeric_cols = ["int64_col", "int64_too", "bool_col"]
2451+
aggregations = [
2452+
"sum",
2453+
"nunique",
2454+
"count",
2455+
]
2456+
scalars_df, scalars_pandas_df = scalars_dfs
2457+
bf_result = scalars_df[numeric_cols].agg(aggregations).to_pandas()
2458+
pd_result = scalars_pandas_df[numeric_cols].agg(aggregations)
2459+
2460+
for dtype in bf_result.dtypes:
2461+
assert dtype == "Int64"
2462+
2463+
# Pandas may produce narrower numeric types
2464+
# Pandas has object index type
2465+
pd.testing.assert_frame_equal(
2466+
pd_result, bf_result, check_dtype=False, check_index_type=False
2467+
)
2468+
2469+
24342470
@skip_legacy_pandas
24352471
def test_df_describe(scalars_dfs):
24362472
scalars_df, scalars_pandas_df = scalars_dfs
@@ -2982,6 +3018,58 @@ def test_loc_setitem_bool_series_scalar_error(scalars_dfs):
29823018
pd_df.loc[pd_df["int64_too"] == 1, "string_col"] = 99
29833019

29843020

3021+
@pytest.mark.parametrize(
3022+
("col", "op"),
3023+
[
3024+
# Int aggregates
3025+
pytest.param("int64_col", lambda x: x.sum(), id="int-sum"),
3026+
pytest.param("int64_col", lambda x: x.min(), id="int-min"),
3027+
pytest.param("int64_col", lambda x: x.max(), id="int-max"),
3028+
pytest.param("int64_col", lambda x: x.count(), id="int-count"),
3029+
pytest.param("int64_col", lambda x: x.nunique(), id="int-nunique"),
3030+
# Float aggregates
3031+
pytest.param("float64_col", lambda x: x.count(), id="float-count"),
3032+
pytest.param("float64_col", lambda x: x.nunique(), id="float-nunique"),
3033+
# Bool aggregates
3034+
pytest.param("bool_col", lambda x: x.sum(), id="bool-sum"),
3035+
pytest.param("bool_col", lambda x: x.count(), id="bool-count"),
3036+
pytest.param("bool_col", lambda x: x.nunique(), id="bool-nunique"),
3037+
# String aggregates
3038+
pytest.param("string_col", lambda x: x.count(), id="string-count"),
3039+
pytest.param("string_col", lambda x: x.nunique(), id="string-nunique"),
3040+
],
3041+
)
3042+
def test_dataframe_aggregate_int(scalars_df_index, scalars_pandas_df_index, col, op):
3043+
bf_result = op(scalars_df_index[[col]]).to_pandas()
3044+
pd_result = op(scalars_pandas_df_index[[col]])
3045+
3046+
# Check dtype separately
3047+
assert bf_result.dtype == "Int64"
3048+
3049+
# Pandas may produce narrower numeric types
3050+
# Pandas has object index type
3051+
assert_series_equal(pd_result, bf_result, check_dtype=False, check_index_type=False)
3052+
3053+
3054+
@pytest.mark.parametrize(
3055+
("col", "op"),
3056+
[
3057+
pytest.param("bool_col", lambda x: x.min(), id="bool-min"),
3058+
pytest.param("bool_col", lambda x: x.max(), id="bool-max"),
3059+
],
3060+
)
3061+
def test_dataframe_aggregate_bool(scalars_df_index, scalars_pandas_df_index, col, op):
3062+
bf_result = op(scalars_df_index[[col]]).to_pandas()
3063+
pd_result = op(scalars_pandas_df_index[[col]])
3064+
3065+
# Check dtype separately
3066+
assert bf_result.dtype == "boolean"
3067+
3068+
# Pandas may produce narrower numeric types
3069+
# Pandas has object index type
3070+
assert_series_equal(pd_result, bf_result, check_dtype=False, check_index_type=False)
3071+
3072+
29853073
@pytest.mark.parametrize(
29863074
("ordered"),
29873075
[
@@ -2990,34 +3078,38 @@ def test_loc_setitem_bool_series_scalar_error(scalars_dfs):
29903078
],
29913079
)
29923080
@pytest.mark.parametrize(
2993-
("op"),
3081+
("op", "bf_dtype"),
29943082
[
2995-
(lambda x: x.sum(numeric_only=True)),
2996-
(lambda x: x.mean(numeric_only=True)),
2997-
(lambda x: x.min(numeric_only=True)),
2998-
(lambda x: x.max(numeric_only=True)),
2999-
(lambda x: x.std(numeric_only=True)),
3000-
(lambda x: x.var(numeric_only=True)),
3001-
(lambda x: x.count(numeric_only=False)),
3002-
(lambda x: x.nunique()),
3083+
(lambda x: x.sum(numeric_only=True), "Float64"),
3084+
(lambda x: x.mean(numeric_only=True), "Float64"),
3085+
(lambda x: x.min(numeric_only=True), "Float64"),
3086+
(lambda x: x.max(numeric_only=True), "Float64"),
3087+
(lambda x: x.std(numeric_only=True), "Float64"),
3088+
(lambda x: x.var(numeric_only=True), "Float64"),
3089+
(lambda x: x.count(numeric_only=False), "Int64"),
3090+
(lambda x: x.nunique(), "Int64"),
30033091
],
30043092
ids=["sum", "mean", "min", "max", "std", "var", "count", "nunique"],
30053093
)
3006-
def test_dataframe_aggregates(scalars_df_index, scalars_pandas_df_index, op, ordered):
3094+
def test_dataframe_aggregates(
3095+
scalars_df_index, scalars_pandas_df_index, op, bf_dtype, ordered
3096+
):
30073097
col_names = ["int64_too", "float64_col", "string_col", "int64_col", "bool_col"]
30083098
bf_series = op(scalars_df_index[col_names])
3009-
pd_series = op(scalars_pandas_df_index[col_names])
30103099
bf_result = bf_series.to_pandas(ordered=ordered)
3100+
pd_result = op(scalars_pandas_df_index[col_names])
3101+
3102+
# Check dtype separately
3103+
assert bf_result.dtype == bf_dtype
30113104

30123105
# Pandas may produce narrower numeric types, but bigframes always produces Float64
30133106
# Pandas has object index type
3014-
pd_series.index = pd_series.index.astype(pd.StringDtype(storage="pyarrow"))
30153107
assert_series_equal(
3016-
pd_series,
3108+
pd_result,
30173109
bf_result,
3110+
check_dtype=False,
30183111
check_index_type=False,
30193112
ignore_order=not ordered,
3020-
check_dtype=False,
30213113
)
30223114

30233115

0 commit comments

Comments
 (0)