Skip to content

Commit d0dede0

Browse files
authored
Speed up Model tests by 20% (#5574)
* Measuring execution times of models. * Speed up models by avoiding re-estimation of eager output * Fixing linter * Reduce input size for big models * Speed up jit check method. * Add simple jitscript fallback check for flaky models. * Restore pytest filtering * Fixing linter
1 parent cddad9c commit d0dede0

15 files changed

+61
-46
lines changed
396 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

test/test_models.py

+61-46
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import os
55
import pkgutil
66
import sys
7-
import traceback
87
import warnings
98
from collections import OrderedDict
109
from tempfile import TemporaryDirectory
@@ -119,27 +118,16 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None):
119118
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False)
120119

121120

122-
def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False):
121+
def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False, eager_out=None):
123122
"""Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
124123

125-
def assert_export_import_module(m, args):
126-
"""Check that the results of a model are the same after saving and loading"""
127-
128-
def get_export_import_copy(m):
129-
"""Save and load a TorchScript model"""
130-
with TemporaryDirectory() as dir:
131-
path = os.path.join(dir, "script.pt")
132-
m.save(path)
133-
imported = torch.jit.load(path)
134-
return imported
135-
136-
m_import = get_export_import_copy(m)
137-
with torch.no_grad(), freeze_rng_state():
138-
results = m(*args)
139-
with torch.no_grad(), freeze_rng_state():
140-
results_from_imported = m_import(*args)
141-
tol = 3e-4
142-
torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol)
124+
def get_export_import_copy(m):
125+
"""Save and load a TorchScript model"""
126+
with TemporaryDirectory() as dir:
127+
path = os.path.join(dir, "script.pt")
128+
m.save(path)
129+
imported = torch.jit.load(path)
130+
return imported
143131

144132
TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1"
145133
if not TEST_WITH_SLOW or skip:
@@ -157,23 +145,33 @@ def get_export_import_copy(m):
157145

158146
sm = torch.jit.script(nn_module)
159147

160-
with torch.no_grad(), freeze_rng_state():
161-
eager_out = nn_module(*args)
148+
if eager_out is None:
149+
with torch.no_grad(), freeze_rng_state():
150+
if unwrapper:
151+
eager_out = nn_module(*args)
162152

163153
with torch.no_grad(), freeze_rng_state():
164154
script_out = sm(*args)
165155
if unwrapper:
166156
script_out = unwrapper(script_out)
167157

168158
torch.testing.assert_close(eager_out, script_out, atol=1e-4, rtol=1e-4)
169-
assert_export_import_module(sm, args)
159+
160+
m_import = get_export_import_copy(sm)
161+
with torch.no_grad(), freeze_rng_state():
162+
imported_script_out = m_import(*args)
163+
if unwrapper:
164+
imported_script_out = unwrapper(imported_script_out)
165+
166+
torch.testing.assert_close(script_out, imported_script_out, atol=3e-4, rtol=3e-4)
170167

171168

172-
def _check_fx_compatible(model, inputs):
169+
def _check_fx_compatible(model, inputs, eager_out=None):
173170
model_fx = torch.fx.symbolic_trace(model)
174-
out = model(inputs)
175-
out_fx = model_fx(inputs)
176-
torch.testing.assert_close(out, out_fx)
171+
if eager_out is None:
172+
eager_out = model(inputs)
173+
fx_out = model_fx(inputs)
174+
torch.testing.assert_close(eager_out, fx_out)
177175

178176

179177
def _check_input_backprop(model, inputs):
@@ -298,6 +296,24 @@ def _check_input_backprop(model, inputs):
298296
"rpn_post_nms_top_n_test": 1000,
299297
},
300298
}
299+
# speeding up slow models:
300+
slow_models = [
301+
"convnext_base",
302+
"convnext_large",
303+
"resnext101_32x8d",
304+
"wide_resnet101_2",
305+
"efficientnet_b6",
306+
"efficientnet_b7",
307+
"efficientnet_v2_m",
308+
"efficientnet_v2_l",
309+
"regnet_y_16gf",
310+
"regnet_y_32gf",
311+
"regnet_y_128gf",
312+
"regnet_x_16gf",
313+
"regnet_x_32gf",
314+
]
315+
for m in slow_models:
316+
_model_params[m] = {"input_shape": (1, 3, 64, 64)}
301317

302318

303319
# The following contains configuration and expected values to be used tests that are model specific
@@ -564,8 +580,8 @@ def test_classification_model(model_fn, dev):
564580
out = model(x)
565581
_assert_expected(out.cpu(), model_name, prec=0.1)
566582
assert out.shape[-1] == num_classes
567-
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
568-
_check_fx_compatible(model, x)
583+
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
584+
_check_fx_compatible(model, x, eager_out=out)
569585

570586
if dev == torch.device("cuda"):
571587
with torch.cuda.amp.autocast():
@@ -595,7 +611,7 @@ def test_segmentation_model(model_fn, dev):
595611
model.eval().to(device=dev)
596612
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
597613
x = torch.rand(input_shape).to(device=dev)
598-
out = model(x)["out"]
614+
out = model(x)
599615

600616
def check_out(out):
601617
prec = 0.01
@@ -615,17 +631,17 @@ def check_out(out):
615631

616632
return True # Full validation performed
617633

618-
full_validation = check_out(out)
634+
full_validation = check_out(out["out"])
619635

620-
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
621-
_check_fx_compatible(model, x)
636+
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
637+
_check_fx_compatible(model, x, eager_out=out)
622638

623639
if dev == torch.device("cuda"):
624640
with torch.cuda.amp.autocast():
625-
out = model(x)["out"]
641+
out = model(x)
626642
# See autocast_flaky_numerics comment at top of file.
627643
if model_name not in autocast_flaky_numerics:
628-
full_validation &= check_out(out)
644+
full_validation &= check_out(out["out"])
629645

630646
if not full_validation:
631647
msg = (
@@ -716,7 +732,7 @@ def compute_mean_std(tensor):
716732
return True # Full validation performed
717733

718734
full_validation = check_out(out)
719-
_check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None))
735+
_check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
720736

721737
if dev == torch.device("cuda"):
722738
with torch.cuda.amp.autocast():
@@ -780,8 +796,8 @@ def test_video_model(model_fn, dev):
780796
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
781797
x = torch.rand(input_shape).to(device=dev)
782798
out = model(x)
783-
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
784-
_check_fx_compatible(model, x)
799+
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
800+
_check_fx_compatible(model, x, eager_out=out)
785801
assert out.shape[-1] == 50
786802

787803
if dev == torch.device("cuda"):
@@ -821,8 +837,13 @@ def test_quantized_classification_model(model_fn):
821837
if model_name not in quantized_flaky_models:
822838
_assert_expected(out, model_name + "_quantized", prec=0.1)
823839
assert out.shape[-1] == 5
824-
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
825-
_check_fx_compatible(model, x)
840+
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None), eager_out=out)
841+
_check_fx_compatible(model, x, eager_out=out)
842+
else:
843+
try:
844+
torch.jit.script(model)
845+
except Exception as e:
846+
raise AssertionError("model cannot be scripted.") from e
826847

827848
kwargs["quantize"] = False
828849
for eval_mode in [True, False]:
@@ -843,12 +864,6 @@ def test_quantized_classification_model(model_fn):
843864

844865
torch.ao.quantization.convert(model, inplace=True)
845866

846-
try:
847-
torch.jit.script(model)
848-
except Exception as e:
849-
tb = traceback.format_exc()
850-
raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e
851-
852867

853868
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
854869
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):

0 commit comments

Comments
 (0)