Skip to content

Commit 3cb357a

Browse files
authored
bpo-46014: Add ability to use typing.Union with singledispatch (GH-30017)
1 parent 810c176 commit 3cb357a

File tree

3 files changed

+60
-7
lines changed

3 files changed

+60
-7
lines changed

Lib/functools.py

+28-7
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,14 @@ def dispatch(cls):
837837
dispatch_cache[cls] = impl
838838
return impl
839839

840+
def _is_union_type(cls):
841+
from typing import get_origin, Union
842+
return get_origin(cls) in {Union, types.UnionType}
843+
844+
def _is_valid_union_type(cls):
845+
from typing import get_args
846+
return _is_union_type(cls) and all(isinstance(arg, type) for arg in get_args(cls))
847+
840848
def register(cls, func=None):
841849
"""generic_func.register(cls, func) -> func
842850
@@ -845,7 +853,7 @@ def register(cls, func=None):
845853
"""
846854
nonlocal cache_token
847855
if func is None:
848-
if isinstance(cls, type):
856+
if isinstance(cls, type) or _is_valid_union_type(cls):
849857
return lambda f: register(cls, f)
850858
ann = getattr(cls, '__annotations__', {})
851859
if not ann:
@@ -859,12 +867,25 @@ def register(cls, func=None):
859867
# only import typing if annotation parsing is necessary
860868
from typing import get_type_hints
861869
argname, cls = next(iter(get_type_hints(func).items()))
862-
if not isinstance(cls, type):
863-
raise TypeError(
864-
f"Invalid annotation for {argname!r}. "
865-
f"{cls!r} is not a class."
866-
)
867-
registry[cls] = func
870+
if not isinstance(cls, type) and not _is_valid_union_type(cls):
871+
if _is_union_type(cls):
872+
raise TypeError(
873+
f"Invalid annotation for {argname!r}. "
874+
f"{cls!r} not all arguments are classes."
875+
)
876+
else:
877+
raise TypeError(
878+
f"Invalid annotation for {argname!r}. "
879+
f"{cls!r} is not a class."
880+
)
881+
882+
if _is_union_type(cls):
883+
from typing import get_args
884+
885+
for arg in get_args(cls):
886+
registry[arg] = func
887+
else:
888+
registry[cls] = func
868889
if cache_token is None and hasattr(cls, '__abstractmethods__'):
869890
cache_token = get_cache_token()
870891
dispatch_cache.clear()

Lib/test/test_functools.py

+30
Original file line numberDiff line numberDiff line change
@@ -2684,6 +2684,17 @@ def _(arg: typing.Iterable[str]):
26842684
'typing.Iterable[str] is not a class.'
26852685
))
26862686

2687+
with self.assertRaises(TypeError) as exc:
2688+
@i.register
2689+
def _(arg: typing.Union[int, typing.Iterable[str]]):
2690+
return "Invalid Union"
2691+
self.assertTrue(str(exc.exception).startswith(
2692+
"Invalid annotation for 'arg'."
2693+
))
2694+
self.assertTrue(str(exc.exception).endswith(
2695+
'typing.Union[int, typing.Iterable[str]] not all arguments are classes.'
2696+
))
2697+
26872698
def test_invalid_positional_argument(self):
26882699
@functools.singledispatch
26892700
def f(*args):
@@ -2692,6 +2703,25 @@ def f(*args):
26922703
with self.assertRaisesRegex(TypeError, msg):
26932704
f()
26942705

2706+
def test_union(self):
2707+
@functools.singledispatch
2708+
def f(arg):
2709+
return "default"
2710+
2711+
@f.register
2712+
def _(arg: typing.Union[str, bytes]):
2713+
return "typing.Union"
2714+
2715+
@f.register
2716+
def _(arg: int | float):
2717+
return "types.UnionType"
2718+
2719+
self.assertEqual(f([]), "default")
2720+
self.assertEqual(f(""), "typing.Union")
2721+
self.assertEqual(f(b""), "typing.Union")
2722+
self.assertEqual(f(1), "types.UnionType")
2723+
self.assertEqual(f(1.0), "types.UnionType")
2724+
26952725

26962726
class CachedCostItem:
26972727
_cost = 1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Add ability to use ``typing.Union`` and ``types.UnionType`` as dispatch
2+
argument to ``functools.singledispatch``. Patch provided by Yurii Karabas.

0 commit comments

Comments
 (0)