@@ -59,6 +59,10 @@ def get_available_video_models():
59
59
"resnet101" ,
60
60
"resnet152" ,
61
61
"wide_resnet101_2" ,
62
+ "deeplabv3_resnet50" ,
63
+ "deeplabv3_resnet101" ,
64
+ "fcn_resnet50" ,
65
+ "fcn_resnet101" ,
62
66
)
63
67
64
68
@@ -85,21 +89,53 @@ def _test_classification_model(self, name, input_shape, dev):
85
89
self .assertEqual (out .shape [- 1 ], 50 )
86
90
87
91
def _test_segmentation_model (self , name , dev ):
88
- # passing num_class equal to a number other than 1000 helps in making the test
89
- # more enforcing in nature
90
- model = models .segmentation .__dict__ [name ](num_classes = 50 , pretrained_backbone = False )
92
+ set_rng_seed (0 )
93
+ # passing num_classes equal to a number other than 21 helps in making the test's
94
+ # expected file size smaller
95
+ model = models .segmentation .__dict__ [name ](num_classes = 10 , pretrained_backbone = False )
91
96
model .eval ().to (device = dev )
92
- input_shape = (1 , 3 , 300 , 300 )
97
+ input_shape = (1 , 3 , 32 , 32 )
93
98
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
94
99
x = torch .rand (input_shape ).to (device = dev )
95
- out = model (x )
96
- self .assertEqual (tuple (out ["out" ].shape ), (1 , 50 , 300 , 300 ))
100
+ out = model (x )["out" ]
101
+
102
+ def check_out (out ):
103
+ prec = 0.01
104
+ strip_suffix = f"_{ dev } "
105
+ try :
106
+ # We first try to assert the entire output if possible. This is not
107
+ # only the best way to assert results but also handles the cases
108
+ # where we need to create a new expected result.
109
+ self .assertExpected (out .cpu (), prec = prec , strip_suffix = strip_suffix )
110
+ except AssertionError :
111
+ # Unfortunately some segmentation models are flaky with autocast
112
+ # so instead of validating the probability scores, check that the class
113
+ # predictions match.
114
+ expected_file = self ._get_expected_file (strip_suffix = strip_suffix )
115
+ expected = torch .load (expected_file )
116
+ self .assertEqual (out .argmax (dim = 1 ), expected .argmax (dim = 1 ), prec = prec )
117
+ return False # Partial validation performed
118
+
119
+ return True # Full validation performed
120
+
121
+ full_validation = check_out (out )
122
+
97
123
self .check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (name , None ))
98
124
99
125
if dev == torch .device ("cuda" ):
100
126
with torch .cuda .amp .autocast ():
101
- out = model (x )
102
- self .assertEqual (tuple (out ["out" ].shape ), (1 , 50 , 300 , 300 ))
127
+ out = model (x )["out" ]
128
+ # See autocast_flaky_numerics comment at top of file.
129
+ if name not in autocast_flaky_numerics :
130
+ full_validation &= check_out (out )
131
+
132
+ if not full_validation :
133
+ msg = "The output of {} could only be partially validated. " \
134
+ "This is likely due to unit-test flakiness, but you may " \
135
+ "want to do additional manual checks if you made " \
136
+ "significant changes to the codebase." .format (self ._testMethodName )
137
+ warnings .warn (msg , RuntimeWarning )
138
+ raise unittest .SkipTest (msg )
103
139
104
140
def _test_detection_model (self , name , dev ):
105
141
set_rng_seed (0 )
0 commit comments