Skip to content

Commit 13bd09d

Browse files
NicolasHugfmassa
andauthored
Prevent tests from leaking their respective RNG (#4497)
* Add autouse fixture to save and reset RNG in tests * Add other RNG generators * delete freeze_rng_state * Hopefully fix GaussianBlur test * Alternative fix, probably better * revert changes to test_models Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
1 parent 5e8a211 commit 13bd09d

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

test/conftest.py

+25
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG
22
import torch
3+
import numpy as np
4+
import random
35
import pytest
46

57

@@ -80,3 +82,26 @@ def pytest_sessionfinish(session, exitstatus):
8082
# To avoid this, we transform this 5 into a 0 to make testpilot happy.
8183
if exitstatus == 5:
8284
session.exitstatus = 0
85+
86+
87+
@pytest.fixture(autouse=True)
88+
def prevent_leaking_rng():
89+
# Prevent each test from leaking the rng to all other test when they call
90+
# torch.manual_seed() or random.seed() or np.random.seed().
91+
# Note: the numpy rngs should never leak anyway, as we never use
92+
# np.random.seed() and instead rely on np.random.RandomState instances (see
93+
# issue #4247). We still do it for extra precaution.
94+
95+
torch_rng_state = torch.get_rng_state()
96+
builtin_rng_state = random.getstate()
97+
nunmpy_rng_state = np.random.get_state()
98+
if torch.cuda.is_available():
99+
cuda_rng_state = torch.cuda.get_rng_state()
100+
101+
yield
102+
103+
torch.set_rng_state(torch_rng_state)
104+
random.setstate(builtin_rng_state)
105+
np.random.set_state(nunmpy_rng_state)
106+
if torch.cuda.is_available():
107+
torch.cuda.set_rng_state(cuda_rng_state)

test/test_transforms_tensor.py

+1
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,7 @@ def test_random_apply(device):
714714
@pytest.mark.parametrize('channels', [1, 3])
715715
def test_gaussian_blur(device, channels, meth_kwargs):
716716
tol = 1.0 + 1e-10
717+
torch.manual_seed(12)
717718
_test_class_op(
718719
T.GaussianBlur, meth_kwargs=meth_kwargs, channels=channels,
719720
test_exact_match=False, device=device, agg_method="max", tol=tol

0 commit comments

Comments
 (0)