Skip to content

Commit 54333d1

Browse files
committed
add support to pandas nullable dtypes
1 parent 154577c commit 54333d1

File tree

5 files changed

+85
-13
lines changed

5 files changed

+85
-13
lines changed

hypothesis-python/RELEASE.rst

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
RELEASE_TYPE: minor
2+
3+
This release adds support for `nullable pandas dtypes <https://pandas.pydata.org/docs/user_guide/integer_na.html>`__
4+
in :func:`~hypothesis.extra.pandas` (:issue:`3604`).
5+
Thanks to Cheuk Ting Ho for implementing this at the PyCon sprints!

hypothesis-python/src/hypothesis/extra/pandas/impl.py

+47-8
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ def is_categorical_dtype(dt):
4444
return dt == "category"
4545

4646

47+
try:
48+
from pandas.core.arrays.integer import IntegerDtype
49+
except ImportError:
50+
IntegerDtype = ()
51+
52+
4753
def dtype_for_elements_strategy(s):
4854
return st.shared(
4955
s.map(lambda x: pandas.Series([x]).dtype),
@@ -79,6 +85,12 @@ def elements_and_dtype(elements, dtype, source=None):
7985
f"{prefix}dtype is categorical, which is currently unsupported"
8086
)
8187

88+
if isinstance(dtype, type) and issubclass(dtype, IntegerDtype):
89+
raise InvalidArgument(
90+
f"Passed dtype={dtype!r} is a dtype class, please pass in an instance of this class."
91+
"Otherwise it would be treated as dtype=object"
92+
)
93+
8294
if isinstance(dtype, type) and np.dtype(dtype).kind == "O" and dtype is not object:
8395
note_deprecation(
8496
f"Passed dtype={dtype!r} is not a valid Pandas dtype. We'll treat it as "
@@ -92,13 +104,31 @@ def elements_and_dtype(elements, dtype, source=None):
92104
f"Passed dtype={dtype!r} is a strategy, but we require a concrete dtype "
93105
"here. See https://stackoverflow.com/q/74355937 for workaround patterns."
94106
)
95-
dtype = try_convert(np.dtype, dtype, "dtype")
107+
108+
pd_dtype_map = {
109+
t.name: t for t in getattr(IntegerDtype, "__subclasses__", lambda: [])()
110+
}
111+
112+
dtype = pd_dtype_map.get(dtype, dtype)
113+
114+
if isinstance(dtype, IntegerDtype):
115+
is_na_dtype = True
116+
dtype = np.dtype(dtype.name.lower())
117+
elif dtype is not None:
118+
is_na_dtype = False
119+
dtype = try_convert(np.dtype, dtype, "dtype")
120+
else:
121+
is_na_dtype = False
96122

97123
if elements is None:
98124
elements = npst.from_dtype(dtype)
125+
if is_na_dtype:
126+
elements = st.none() | elements
99127
elif dtype is not None:
100128

101129
def convert_element(value):
130+
if value is None:
131+
return None
102132
name = f"draw({prefix}elements)"
103133
try:
104134
return np.array([value], dtype=dtype)[0]
@@ -282,9 +312,17 @@ def series(
282312
else:
283313
check_strategy(index, "index")
284314

285-
elements, dtype = elements_and_dtype(elements, dtype)
315+
elements, np_dtype = elements_and_dtype(elements, dtype)
286316
index_strategy = index
287317

318+
# if it is converted to an object, use object for series type
319+
if (
320+
np_dtype is not None
321+
and np_dtype.kind == "O"
322+
and not isinstance(dtype, IntegerDtype)
323+
):
324+
dtype = np_dtype
325+
288326
@st.composite
289327
def result(draw):
290328
index = draw(index_strategy)
@@ -293,13 +331,13 @@ def result(draw):
293331
if dtype is not None:
294332
result_data = draw(
295333
npst.arrays(
296-
dtype=dtype,
334+
dtype=object,
297335
elements=elements,
298336
shape=len(index),
299337
fill=fill,
300338
unique=unique,
301339
)
302-
)
340+
).tolist()
303341
else:
304342
result_data = list(
305343
draw(
@@ -310,9 +348,8 @@ def result(draw):
310348
fill=fill,
311349
unique=unique,
312350
)
313-
)
351+
).tolist()
314352
)
315-
316353
return pandas.Series(result_data, index=index, dtype=dtype, name=draw(name))
317354
else:
318355
return pandas.Series(
@@ -549,7 +586,7 @@ def row():
549586

550587
column_names.add(c.name)
551588

552-
c.elements, c.dtype = elements_and_dtype(c.elements, c.dtype, label)
589+
c.elements, _ = elements_and_dtype(c.elements, c.dtype, label)
553590

554591
if c.dtype is None and rows is not None:
555592
raise InvalidArgument(
@@ -589,7 +626,9 @@ def just_draw_columns(draw):
589626
if columns_without_fill:
590627
for c in columns_without_fill:
591628
data[c.name] = pandas.Series(
592-
np.zeros(shape=len(index), dtype=c.dtype), index=index
629+
np.zeros(shape=len(index), dtype=object),
630+
index=index,
631+
dtype=c.dtype,
593632
)
594633
seen = {c.name: set() for c in columns_without_fill if c.unique}
595634

hypothesis-python/tests/pandas/test_argument_validation.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
from datetime import datetime
1212

1313
import pandas as pd
14+
import pytest
1415

1516
from hypothesis import given, strategies as st
17+
from hypothesis.errors import InvalidArgument
1618
from hypothesis.extra import pandas as pdst
1719

1820
from tests.common.arguments import argument_validation_test, e
21+
from tests.common.debug import find_any
1922
from tests.common.utils import checks_deprecated_behaviour
2023

2124
BAD_ARGS = [
@@ -30,7 +33,6 @@
3033
e(pdst.data_frames, pdst.columns(1, dtype=float, elements=1)),
3134
e(pdst.data_frames, pdst.columns(1, fill=1, dtype=float)),
3235
e(pdst.data_frames, pdst.columns(["A", "A"], dtype=float)),
33-
e(pdst.data_frames, pdst.columns(1, elements=st.none(), dtype=int)),
3436
e(pdst.data_frames, 1),
3537
e(pdst.data_frames, [1]),
3638
e(pdst.data_frames, pdst.columns(1, dtype="category")),
@@ -64,7 +66,6 @@
6466
e(pdst.indexes, dtype="not a dtype"),
6567
e(pdst.indexes, elements="not a strategy"),
6668
e(pdst.indexes, elements=st.text(), dtype=float),
67-
e(pdst.indexes, elements=st.none(), dtype=int),
6869
e(pdst.indexes, elements=st.integers(0, 10), dtype=st.sampled_from([int, float])),
6970
e(pdst.indexes, dtype=int, max_size=0, min_size=1),
7071
e(pdst.indexes, dtype=int, unique="true"),
@@ -77,7 +78,6 @@
7778
e(pdst.series),
7879
e(pdst.series, dtype="not a dtype"),
7980
e(pdst.series, elements="not a strategy"),
80-
e(pdst.series, elements=st.none(), dtype=int),
8181
e(pdst.series, dtype="category"),
8282
e(pdst.series, index="not a strategy"),
8383
]
@@ -99,3 +99,11 @@ def test_timestamp_as_datetime_bounds(dt):
9999
@checks_deprecated_behaviour
100100
def test_confusing_object_dtype_aliases():
101101
pdst.series(elements=st.tuples(st.integers()), dtype=tuple).example()
102+
103+
104+
def test_pandas_nullable_types_class():
105+
with pytest.raises(
106+
InvalidArgument, match="Otherwise it would be treated as dtype=object"
107+
):
108+
st = pdst.series(dtype=pd.core.arrays.integer.Int8Dtype)
109+
find_any(st, lambda s: s.isna().any())

hypothesis-python/tests/pandas/test_data_frame.py

+8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

1111
import numpy as np
12+
import pandas as pd
1213
import pytest
1314

1415
from hypothesis import HealthCheck, given, reject, settings, strategies as st
@@ -267,3 +268,10 @@ def works_with_object_dtype(df):
267268
assert dtype is None
268269
with pytest.raises(ValueError, match="Maybe passing dtype=object would help"):
269270
works_with_object_dtype()
271+
272+
273+
def test_pandas_nullable_types():
274+
st = pdst.data_frames(pdst.columns(2, dtype=pd.core.arrays.integer.Int8Dtype()))
275+
df = find_any(st, lambda s: s.isna().any().any())
276+
for s in df.columns:
277+
assert type(df[s].dtype) == pd.core.arrays.integer.Int8Dtype

hypothesis-python/tests/pandas/test_series.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# obtain one at https://mozilla.org/MPL/2.0/.
1010

1111
import numpy as np
12-
import pandas
12+
import pandas as pd
1313

1414
from hypothesis import assume, given, strategies as st
1515
from hypothesis.extra import numpy as npst, pandas as pdst
@@ -25,7 +25,7 @@ def test_can_create_a_series_of_any_dtype(data):
2525
# Use raw data to work around pandas bug in repr. See
2626
# https://github.com/pandas-dev/pandas/issues/27484
2727
series = data.conjecture_data.draw(pdst.series(dtype=dtype))
28-
assert series.dtype == pandas.Series([], dtype=dtype).dtype
28+
assert series.dtype == pd.Series([], dtype=dtype).dtype
2929

3030

3131
@given(pdst.series(dtype=float, index=pdst.range_indexes(min_size=2, max_size=5)))
@@ -61,3 +61,15 @@ def test_unique_series_are_unique(s):
6161
@given(pdst.series(dtype="int8", name=st.just("test_name")))
6262
def test_name_passed_on(s):
6363
assert s.name == "test_name"
64+
65+
66+
def test_pandas_nullable_types():
67+
st = pdst.series(dtype=pd.core.arrays.integer.Int8Dtype())
68+
e = find_any(st, lambda s: s.isna().any())
69+
assert type(e.dtype) == pd.core.arrays.integer.Int8Dtype
70+
71+
72+
def test_pandas_nullable_types_in_str():
73+
st = pdst.series(dtype="Int8")
74+
e = find_any(st, lambda s: s.isna().any())
75+
assert type(e.dtype) == pd.core.arrays.integer.Int8Dtype

0 commit comments

Comments
 (0)