Skip to content

Commit 97ccfad

Browse files
isurufpytorchmergebot
authored andcommitted
Fix test_decomp test for ops with py_impl(CompositeImplicitAutograd) (pytorch#116832)
Pull Request resolved: pytorch#116832 Approved by: https://github.com/lezcano
1 parent a3e3693 commit 97ccfad

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

test/test_decomp.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def ref_vjp_no_create(f, *primals):
144144

145145
def wrapped(cotangents):
146146
return _autograd_grad(
147-
_as_tuple(result), primals, _as_tuple(cotangents), create_graph=False
147+
_as_tuple(result), primals, _as_tuple(cotangents), create_graph=False,
148+
retain_graph=True,
148149
)
149150

150151
return result, wrapped
@@ -200,6 +201,12 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
200201
(torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1,
201202
(torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2,
202203
(torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1,
204+
(torch.float16, torch.ops.aten.hardswish.default): 2e-7,
205+
(torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7,
206+
(torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2,
207+
(torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 3e-2,
208+
(torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
209+
(torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
203210
# see https://github.com/pytorch/pytorch/pull/96264
204211
(torch.float16, torch.ops.aten.mv.default): 1e-5,
205212
}
@@ -488,6 +495,11 @@ def test_unsupported(t):
488495
skip('unsafe_split'), # slow: takes 49 sec on A100
489496
})
490497

498+
comprehensive_failures = {
499+
xfail("nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)), # off by one error
500+
xfail("nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)), # off by one error
501+
xfail("nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)), # off by one error
502+
}
491503

492504
@unMarkDynamoStrictTest
493505
class TestDecomp(TestCase):
@@ -524,6 +536,7 @@ def test_quick_core_backward(self, device, dtype, op):
524536
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
525537
@onlyNativeDeviceTypes
526538
@skipIfCrossRef
539+
@skipOps('TestDecomp', 'test_comprehensive', comprehensive_failures)
527540
@suppress_warnings
528541
@ops(op_db)
529542
def test_comprehensive(self, device, dtype, op):
@@ -810,6 +823,12 @@ def do_cross_ref(self, device, dtype, op, *, run_all):
810823
aten_name = op.decomp_aten_name or op.aten_name
811824

812825
func = op.get_op()
826+
827+
def run_without_python_dispatcher(mode):
828+
return any(isinstance(op, torch._ops.OpOverload) and
829+
op.has_kernel_for_dispatch_key(DispatchKey.CompositeImplicitAutograd)
830+
for op in mode.decomposed.union([func]))
831+
813832
for sample_input in samples:
814833
if requires_grad:
815834
fn, primals = normalize_op_input_output(func, sample_input)
@@ -824,6 +843,12 @@ def do_cross_ref(self, device, dtype, op, *, run_all):
824843
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
825844
as mode, enable_python_dispatcher():
826845
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
846+
if run_without_python_dispatcher(mode):
847+
# without this check, incorrect decomps at the python dispatcher level can still pass because
848+
# they're checking aten decomps at the torch_dispatch level.
849+
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
850+
as mode:
851+
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
827852
if aten_name in decomposition_names:
828853
self.check_decomposed(aten_name, mode)
829854

@@ -833,15 +858,31 @@ def do_cross_ref(self, device, dtype, op, *, run_all):
833858
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
834859
as mode, enable_python_dispatcher():
835860
decomp_vjp_fn(cotangents)
861+
if run_without_python_dispatcher(mode):
862+
# without this check, incorrect decomps at the python dispatcher level can still pass because
863+
# they're checking aten decomps at the torch_dispatch level.
864+
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
865+
as mode:
866+
decomp_vjp_fn(cotangents)
836867
if not run_all:
837868
self.check_decomposed(op.aten_backward_name, mode)
838869

839870
elif aten_name in decomposition_names or run_all:
840871
args = [sample_input.input] + list(sample_input.args)
841872
kwargs = sample_input.kwargs
873+
# A failure here might be because the decomposition for the op is wrong or because a
874+
# decomposition used by the particular op is wrong.
842875
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
843876
as mode, enable_python_dispatcher():
844877
func(*args, **kwargs)
878+
879+
if run_without_python_dispatcher(mode):
880+
# without this check, incorrect decomps at the python dispatcher level can still pass because
881+
# they're checking aten decomps at the torch_dispatch level.
882+
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
883+
as mode:
884+
func(*args, **kwargs)
885+
845886
if not run_all:
846887
self.check_decomposed(aten_name, mode)
847888
else:

0 commit comments

Comments
 (0)