Skip to content

Commit d0fff0e

Browse files
davidriazatifacebook-github-bot
authored andcommitted
Make is_optional check more robust (pytorch#26312)
Summary: If the `Union` contains a non-class type, `issubclass` would fail, this adds a check for that case ](https://our.intern.facebook.com/intern/diff/17505206/) Pull Request resolved: pytorch#26312 Pulled By: driazati Differential Revision: D17505206 fbshipit-source-id: 1331e412f938e2f08ecb079972147f11e3ec77cd
1 parent 5cc3534 commit d0fff0e

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

test/test_jit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from functools import wraps
4545
from itertools import product, chain
4646
from textwrap import dedent
47-
from typing import List, Dict, Optional, Tuple
47+
from typing import List, Dict, Optional, Tuple, Union
4848
import copy
4949
import inspect
5050
import math
@@ -3525,6 +3525,10 @@ def invalid_prefix_annotation3(a):
35253525
# type: (Int) -> Int
35263526
return a + 2
35273527

3528+
def test_is_optional(self):
3529+
ann = Union[List[int], List[float]]
3530+
torch._jit_internal.is_optional(ann)
3531+
35283532
def test_interpreter_fuzz(self):
35293533
# This test generates random tree-like programs to fuzz test
35303534
# that the interpreter does not have a bug in its stack manipulation

torch/_jit_internal.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,13 +507,20 @@ def is_dict(ann):
507507

508508
def is_optional(ann):
509509
# Optional[T] is just shorthand for Union[T, None], so check for both
510+
def safe_is_subclass(the_type, super_type):
511+
# Don't throw if `the_type` isn't a class type (e.g. if it is
512+
# another type annotation instance)
513+
if not inspect.isclass(the_type):
514+
return False
515+
return issubclass(the_type, super_type)
516+
510517
union_optional = False
511518
if ann.__module__ == 'typing' and \
512519
(getattr(ann, '__origin__', None) is typing.Union):
513520
args = getattr(ann, '__args__', ())
514521
if len(args) == 2:
515-
union_optional = (issubclass(args[1], type(None)) and not issubclass(args[0], type(None))) \
516-
or (issubclass(args[0], type(None)) and not issubclass(args[1], type(None)))
522+
union_optional = (safe_is_subclass(args[1], type(None)) and not safe_is_subclass(args[0], type(None))) \
523+
or (safe_is_subclass(args[0], type(None)) and not safe_is_subclass(args[1], type(None)))
517524

518525
optional = ann.__module__ == 'typing' and \
519526
(getattr(ann, '__origin__', None) is typing.Optional)

0 commit comments

Comments
 (0)