From c9a1ddf634bdc4df1dfc4a5126e83e984096cb45 Mon Sep 17 00:00:00 2001 From: Anthony Ting Date: Mon, 14 Apr 2025 16:17:24 -0700 Subject: [PATCH] fix: pass weights_only=True to torch.load --- tests/data/pytorch_mnist/mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/pytorch_mnist/mnist.py b/tests/data/pytorch_mnist/mnist.py index bb318ba..0feeeab 100644 --- a/tests/data/pytorch_mnist/mnist.py +++ b/tests/data/pytorch_mnist/mnist.py @@ -169,7 +169,7 @@ def model_fn(model_dir): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.nn.DataParallel(Net()) with open(os.path.join(model_dir, 'model.pth'), 'rb') as f: - model.load_state_dict(torch.load(f)) + model.load_state_dict(torch.load(f, weights_only=True)) return model.to(device)