@@ -144,7 +144,8 @@ def ref_vjp_no_create(f, *primals):
144
144
145
145
def wrapped (cotangents ):
146
146
return _autograd_grad (
147
- _as_tuple (result ), primals , _as_tuple (cotangents ), create_graph = False
147
+ _as_tuple (result ), primals , _as_tuple (cotangents ), create_graph = False ,
148
+ retain_graph = True ,
148
149
)
149
150
150
151
return result , wrapped
@@ -200,6 +201,12 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
200
201
(torch .bfloat16 , torch .ops .aten .nll_loss_forward .default ): 1e-1 ,
201
202
(torch .float16 , torch .ops .aten .nll_loss2d_forward .default ): 1e-2 ,
202
203
(torch .bfloat16 , torch .ops .aten .nll_loss2d_forward .default ): 2e-1 ,
204
+ (torch .float16 , torch .ops .aten .hardswish .default ): 2e-7 ,
205
+ (torch .bfloat16 , torch .ops .aten .hardswish .default ): 2e-7 ,
206
+ (torch .float16 , torch .ops .aten .multi_margin_loss .default ): 3e-2 ,
207
+ (torch .bfloat16 , torch .ops .aten .multi_margin_loss .default ): 3e-2 ,
208
+ (torch .float16 , torch .ops .aten .multilabel_margin_loss_forward .default ): 3e-2 ,
209
+ (torch .bfloat16 , torch .ops .aten .multilabel_margin_loss_forward .default ): 3e-2 ,
203
210
# see https://github.com/pytorch/pytorch/pull/96264
204
211
(torch .float16 , torch .ops .aten .mv .default ): 1e-5 ,
205
212
}
@@ -488,6 +495,11 @@ def test_unsupported(t):
488
495
skip ('unsafe_split' ), # slow: takes 49 sec on A100
489
496
})
490
497
498
+ comprehensive_failures = {
499
+ xfail ("nn.functional.interpolate" , "bilinear" , dtypes = (torch .uint8 ,)), # off by one error
500
+ xfail ("nn.functional.interpolate" , "bicubic" , dtypes = (torch .uint8 ,)), # off by one error
501
+ xfail ("nn.functional.upsample_bilinear" , "" , dtypes = (torch .uint8 ,)), # off by one error
502
+ }
491
503
492
504
@unMarkDynamoStrictTest
493
505
class TestDecomp (TestCase ):
@@ -524,6 +536,7 @@ def test_quick_core_backward(self, device, dtype, op):
524
536
@unittest .skipIf (TEST_WITH_ASAN , "Skipped under ASAN" )
525
537
@onlyNativeDeviceTypes
526
538
@skipIfCrossRef
539
+ @skipOps ('TestDecomp' , 'test_comprehensive' , comprehensive_failures )
527
540
@suppress_warnings
528
541
@ops (op_db )
529
542
def test_comprehensive (self , device , dtype , op ):
@@ -810,6 +823,12 @@ def do_cross_ref(self, device, dtype, op, *, run_all):
810
823
aten_name = op .decomp_aten_name or op .aten_name
811
824
812
825
func = op .get_op ()
826
+
827
+ def run_without_python_dispatcher (mode ):
828
+ return any (isinstance (op , torch ._ops .OpOverload ) and
829
+ op .has_kernel_for_dispatch_key (DispatchKey .CompositeImplicitAutograd )
830
+ for op in mode .decomposed .union ([func ]))
831
+
813
832
for sample_input in samples :
814
833
if requires_grad :
815
834
fn , primals = normalize_op_input_output (func , sample_input )
@@ -824,6 +843,12 @@ def do_cross_ref(self, device, dtype, op, *, run_all):
824
843
with self .DecompCrossRefMode (self , self .precision , self .rel_tol , dtype , run_all )\
825
844
as mode , enable_python_dispatcher ():
826
845
decomp_out , decomp_vjp_fn = ref_vjp_no_create (fn , * primals )
846
+ if run_without_python_dispatcher (mode ):
847
+ # without this check, incorrect decomps at the python dispatcher level can still pass because
848
+ # they're checking aten decomps at the torch_dispatch level.
849
+ with self .DecompCrossRefMode (self , self .precision , self .rel_tol , dtype , run_all )\
850
+ as mode :
851
+ decomp_out , decomp_vjp_fn = ref_vjp_no_create (fn , * primals )
827
852
if aten_name in decomposition_names :
828
853
self .check_decomposed (aten_name , mode )
829
854
@@ -833,15 +858,31 @@ def do_cross_ref(self, device, dtype, op, *, run_all):
833
858
with self .DecompCrossRefMode (self , self .precision , self .rel_tol , dtype , run_all )\
834
859
as mode , enable_python_dispatcher ():
835
860
decomp_vjp_fn (cotangents )
861
+ if run_without_python_dispatcher (mode ):
862
+ # without this check, incorrect decomps at the python dispatcher level can still pass because
863
+ # they're checking aten decomps at the torch_dispatch level.
864
+ with self .DecompCrossRefMode (self , self .precision , self .rel_tol , dtype , run_all )\
865
+ as mode :
866
+ decomp_vjp_fn (cotangents )
836
867
if not run_all :
837
868
self .check_decomposed (op .aten_backward_name , mode )
838
869
839
870
elif aten_name in decomposition_names or run_all :
840
871
args = [sample_input .input ] + list (sample_input .args )
841
872
kwargs = sample_input .kwargs
873
+ # A failure here might be because the decomposition for the op is wrong or because a
874
+ # decomposition used by the particular op is wrong.
842
875
with self .DecompCrossRefMode (self , self .precision , self .rel_tol , dtype , run_all )\
843
876
as mode , enable_python_dispatcher ():
844
877
func (* args , ** kwargs )
878
+
879
+ if run_without_python_dispatcher (mode ):
880
+ # without this check, incorrect decomps at the python dispatcher level can still pass because
881
+ # they're checking aten decomps at the torch_dispatch level.
882
+ with self .DecompCrossRefMode (self , self .precision , self .rel_tol , dtype , run_all )\
883
+ as mode :
884
+ func (* args , ** kwargs )
885
+
845
886
if not run_all :
846
887
self .check_decomposed (aten_name , mode )
847
888
else :
0 commit comments