diff --git a/tests/data/pytorch_mnist/mnist.py b/tests/data/pytorch_mnist/mnist.py index bb318ba..0feeeab 100644 --- a/tests/data/pytorch_mnist/mnist.py +++ b/tests/data/pytorch_mnist/mnist.py @@ -169,7 +169,7 @@ def model_fn(model_dir): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.nn.DataParallel(Net()) with open(os.path.join(model_dir, 'model.pth'), 'rb') as f: - model.load_state_dict(torch.load(f)) + model.load_state_dict(torch.load(f, weights_only=True)) return model.to(device)