Skip to content

Commit 24f5fa5

Browse files
datumboxfmassa
andauthored
Use assertExpected on the segmentation tests (#3287)
* Modify segmentation tests compare against expected values. * Exclude flaky autocast tests. Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
1 parent 05c5425 commit 24f5fa5

5 files changed

+44
-8
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

test/test_models.py

+44-8
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def get_available_video_models():
5959
"resnet101",
6060
"resnet152",
6161
"wide_resnet101_2",
62+
"deeplabv3_resnet50",
63+
"deeplabv3_resnet101",
64+
"fcn_resnet50",
65+
"fcn_resnet101",
6266
)
6367

6468

@@ -85,21 +89,53 @@ def _test_classification_model(self, name, input_shape, dev):
8589
self.assertEqual(out.shape[-1], 50)
8690

8791
def _test_segmentation_model(self, name, dev):
88-
# passing num_class equal to a number other than 1000 helps in making the test
89-
# more enforcing in nature
90-
model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False)
92+
set_rng_seed(0)
93+
# passing num_classes equal to a number other than 21 helps in making the test's
94+
# expected file size smaller
95+
model = models.segmentation.__dict__[name](num_classes=10, pretrained_backbone=False)
9196
model.eval().to(device=dev)
92-
input_shape = (1, 3, 300, 300)
97+
input_shape = (1, 3, 32, 32)
9398
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
9499
x = torch.rand(input_shape).to(device=dev)
95-
out = model(x)
96-
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
100+
out = model(x)["out"]
101+
102+
def check_out(out):
103+
prec = 0.01
104+
strip_suffix = f"_{dev}"
105+
try:
106+
# We first try to assert the entire output if possible. This is not
107+
# only the best way to assert results but also handles the cases
108+
# where we need to create a new expected result.
109+
self.assertExpected(out.cpu(), prec=prec, strip_suffix=strip_suffix)
110+
except AssertionError:
111+
# Unfortunately some segmentation models are flaky with autocast
112+
# so instead of validating the probability scores, check that the class
113+
# predictions match.
114+
expected_file = self._get_expected_file(strip_suffix=strip_suffix)
115+
expected = torch.load(expected_file)
116+
self.assertEqual(out.argmax(dim=1), expected.argmax(dim=1), prec=prec)
117+
return False # Partial validation performed
118+
119+
return True # Full validation performed
120+
121+
full_validation = check_out(out)
122+
97123
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
98124

99125
if dev == torch.device("cuda"):
100126
with torch.cuda.amp.autocast():
101-
out = model(x)
102-
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
127+
out = model(x)["out"]
128+
# See autocast_flaky_numerics comment at top of file.
129+
if name not in autocast_flaky_numerics:
130+
full_validation &= check_out(out)
131+
132+
if not full_validation:
133+
msg = "The output of {} could only be partially validated. " \
134+
"This is likely due to unit-test flakiness, but you may " \
135+
"want to do additional manual checks if you made " \
136+
"significant changes to the codebase.".format(self._testMethodName)
137+
warnings.warn(msg, RuntimeWarning)
138+
raise unittest.SkipTest(msg)
103139

104140
def _test_detection_model(self, name, dev):
105141
set_rng_seed(0)

0 commit comments

Comments
 (0)