4
4
import os
5
5
import pkgutil
6
6
import sys
7
- import traceback
8
7
import warnings
9
8
from collections import OrderedDict
10
9
from tempfile import TemporaryDirectory
@@ -119,27 +118,16 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None):
119
118
torch .testing .assert_close (output , expected , rtol = rtol , atol = atol , check_dtype = False )
120
119
121
120
122
- def _check_jit_scriptable (nn_module , args , unwrapper = None , skip = False ):
121
+ def _check_jit_scriptable (nn_module , args , unwrapper = None , skip = False , eager_out = None ):
123
122
"""Check that a nn.Module's results in TorchScript match eager and that it can be exported"""
124
123
125
- def assert_export_import_module (m , args ):
126
- """Check that the results of a model are the same after saving and loading"""
127
-
128
- def get_export_import_copy (m ):
129
- """Save and load a TorchScript model"""
130
- with TemporaryDirectory () as dir :
131
- path = os .path .join (dir , "script.pt" )
132
- m .save (path )
133
- imported = torch .jit .load (path )
134
- return imported
135
-
136
- m_import = get_export_import_copy (m )
137
- with torch .no_grad (), freeze_rng_state ():
138
- results = m (* args )
139
- with torch .no_grad (), freeze_rng_state ():
140
- results_from_imported = m_import (* args )
141
- tol = 3e-4
142
- torch .testing .assert_close (results , results_from_imported , atol = tol , rtol = tol )
124
+ def get_export_import_copy (m ):
125
+ """Save and load a TorchScript model"""
126
+ with TemporaryDirectory () as dir :
127
+ path = os .path .join (dir , "script.pt" )
128
+ m .save (path )
129
+ imported = torch .jit .load (path )
130
+ return imported
143
131
144
132
TEST_WITH_SLOW = os .getenv ("PYTORCH_TEST_WITH_SLOW" , "0" ) == "1"
145
133
if not TEST_WITH_SLOW or skip :
@@ -157,23 +145,33 @@ def get_export_import_copy(m):
157
145
158
146
sm = torch .jit .script (nn_module )
159
147
160
- with torch .no_grad (), freeze_rng_state ():
161
- eager_out = nn_module (* args )
148
+ if eager_out is None :
149
+ with torch .no_grad (), freeze_rng_state ():
150
+ if unwrapper :
151
+ eager_out = nn_module (* args )
162
152
163
153
with torch .no_grad (), freeze_rng_state ():
164
154
script_out = sm (* args )
165
155
if unwrapper :
166
156
script_out = unwrapper (script_out )
167
157
168
158
torch .testing .assert_close (eager_out , script_out , atol = 1e-4 , rtol = 1e-4 )
169
- assert_export_import_module (sm , args )
159
+
160
+ m_import = get_export_import_copy (sm )
161
+ with torch .no_grad (), freeze_rng_state ():
162
+ imported_script_out = m_import (* args )
163
+ if unwrapper :
164
+ imported_script_out = unwrapper (imported_script_out )
165
+
166
+ torch .testing .assert_close (script_out , imported_script_out , atol = 3e-4 , rtol = 3e-4 )
170
167
171
168
172
- def _check_fx_compatible (model , inputs ):
169
+ def _check_fx_compatible (model , inputs , eager_out = None ):
173
170
model_fx = torch .fx .symbolic_trace (model )
174
- out = model (inputs )
175
- out_fx = model_fx (inputs )
176
- torch .testing .assert_close (out , out_fx )
171
+ if eager_out is None :
172
+ eager_out = model (inputs )
173
+ fx_out = model_fx (inputs )
174
+ torch .testing .assert_close (eager_out , fx_out )
177
175
178
176
179
177
def _check_input_backprop (model , inputs ):
@@ -298,6 +296,24 @@ def _check_input_backprop(model, inputs):
298
296
"rpn_post_nms_top_n_test" : 1000 ,
299
297
},
300
298
}
299
+ # speeding up slow models:
300
+ slow_models = [
301
+ "convnext_base" ,
302
+ "convnext_large" ,
303
+ "resnext101_32x8d" ,
304
+ "wide_resnet101_2" ,
305
+ "efficientnet_b6" ,
306
+ "efficientnet_b7" ,
307
+ "efficientnet_v2_m" ,
308
+ "efficientnet_v2_l" ,
309
+ "regnet_y_16gf" ,
310
+ "regnet_y_32gf" ,
311
+ "regnet_y_128gf" ,
312
+ "regnet_x_16gf" ,
313
+ "regnet_x_32gf" ,
314
+ ]
315
+ for m in slow_models :
316
+ _model_params [m ] = {"input_shape" : (1 , 3 , 64 , 64 )}
301
317
302
318
303
319
# The following contains configuration and expected values to be used tests that are model specific
@@ -564,8 +580,8 @@ def test_classification_model(model_fn, dev):
564
580
out = model (x )
565
581
_assert_expected (out .cpu (), model_name , prec = 0.1 )
566
582
assert out .shape [- 1 ] == num_classes
567
- _check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (model_name , None ))
568
- _check_fx_compatible (model , x )
583
+ _check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (model_name , None ), eager_out = out )
584
+ _check_fx_compatible (model , x , eager_out = out )
569
585
570
586
if dev == torch .device ("cuda" ):
571
587
with torch .cuda .amp .autocast ():
@@ -595,7 +611,7 @@ def test_segmentation_model(model_fn, dev):
595
611
model .eval ().to (device = dev )
596
612
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
597
613
x = torch .rand (input_shape ).to (device = dev )
598
- out = model (x )[ "out" ]
614
+ out = model (x )
599
615
600
616
def check_out (out ):
601
617
prec = 0.01
@@ -615,17 +631,17 @@ def check_out(out):
615
631
616
632
return True # Full validation performed
617
633
618
- full_validation = check_out (out )
634
+ full_validation = check_out (out [ "out" ] )
619
635
620
- _check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (model_name , None ))
621
- _check_fx_compatible (model , x )
636
+ _check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (model_name , None ), eager_out = out )
637
+ _check_fx_compatible (model , x , eager_out = out )
622
638
623
639
if dev == torch .device ("cuda" ):
624
640
with torch .cuda .amp .autocast ():
625
- out = model (x )[ "out" ]
641
+ out = model (x )
626
642
# See autocast_flaky_numerics comment at top of file.
627
643
if model_name not in autocast_flaky_numerics :
628
- full_validation &= check_out (out )
644
+ full_validation &= check_out (out [ "out" ] )
629
645
630
646
if not full_validation :
631
647
msg = (
@@ -716,7 +732,7 @@ def compute_mean_std(tensor):
716
732
return True # Full validation performed
717
733
718
734
full_validation = check_out (out )
719
- _check_jit_scriptable (model , ([x ],), unwrapper = script_model_unwrapper .get (model_name , None ))
735
+ _check_jit_scriptable (model , ([x ],), unwrapper = script_model_unwrapper .get (model_name , None ), eager_out = out )
720
736
721
737
if dev == torch .device ("cuda" ):
722
738
with torch .cuda .amp .autocast ():
@@ -780,8 +796,8 @@ def test_video_model(model_fn, dev):
780
796
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
781
797
x = torch .rand (input_shape ).to (device = dev )
782
798
out = model (x )
783
- _check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (model_name , None ))
784
- _check_fx_compatible (model , x )
799
+ _check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (model_name , None ), eager_out = out )
800
+ _check_fx_compatible (model , x , eager_out = out )
785
801
assert out .shape [- 1 ] == 50
786
802
787
803
if dev == torch .device ("cuda" ):
@@ -821,8 +837,13 @@ def test_quantized_classification_model(model_fn):
821
837
if model_name not in quantized_flaky_models :
822
838
_assert_expected (out , model_name + "_quantized" , prec = 0.1 )
823
839
assert out .shape [- 1 ] == 5
824
- _check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (model_name , None ))
825
- _check_fx_compatible (model , x )
840
+ _check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (model_name , None ), eager_out = out )
841
+ _check_fx_compatible (model , x , eager_out = out )
842
+ else :
843
+ try :
844
+ torch .jit .script (model )
845
+ except Exception as e :
846
+ raise AssertionError ("model cannot be scripted." ) from e
826
847
827
848
kwargs ["quantize" ] = False
828
849
for eval_mode in [True , False ]:
@@ -843,12 +864,6 @@ def test_quantized_classification_model(model_fn):
843
864
844
865
torch .ao .quantization .convert (model , inplace = True )
845
866
846
- try :
847
- torch .jit .script (model )
848
- except Exception as e :
849
- tb = traceback .format_exc ()
850
- raise AssertionError (f"model cannot be scripted. Traceback = { str (tb )} " ) from e
851
-
852
867
853
868
@pytest .mark .parametrize ("model_fn" , get_models_from_module (models .detection ))
854
869
def test_detection_model_trainable_backbone_layers (model_fn , disable_weight_loading ):
0 commit comments