Skip to content

Commit c35d385

Browse files
kit1980pmeierNicolasHug
authored
[TorchFix] Add weights_only to torch.load (#8105)
Co-authored-by: Philip Meier <github.pmeier@posteo.de> Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
1 parent 01dca0e commit c35d385

File tree

18 files changed

+29
-26
lines changed

18 files changed

+29
-26
lines changed

references/classification/train.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def load_data(traindir, valdir, args):
127127
if args.cache_dataset and os.path.exists(cache_path):
128128
# Attention, as the transforms are also cached!
129129
print(f"Loading dataset_train from {cache_path}")
130-
dataset, _ = torch.load(cache_path)
130+
# TODO: this could probably be weights_only=True
131+
dataset, _ = torch.load(cache_path, weights_only=False)
131132
else:
132133
# We need a default value for the variables below because args may come
133134
# from train_quantization.py which doesn't define them.
@@ -159,7 +160,8 @@ def load_data(traindir, valdir, args):
159160
if args.cache_dataset and os.path.exists(cache_path):
160161
# Attention, as the transforms are also cached!
161162
print(f"Loading dataset_test from {cache_path}")
162-
dataset_test, _ = torch.load(cache_path)
163+
# TODO: this could probably be weights_only=True
164+
dataset_test, _ = torch.load(cache_path, weights_only=False)
163165
else:
164166
if args.weights and args.test_only:
165167
weights = torchvision.models.get_weight(args.weights)
@@ -337,7 +339,7 @@ def collate_fn(batch):
337339
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
338340

339341
if args.resume:
340-
checkpoint = torch.load(args.resume, map_location="cpu")
342+
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
341343
model_without_ddp.load_state_dict(checkpoint["model"])
342344
if not args.test_only:
343345
optimizer.load_state_dict(checkpoint["optimizer"])

references/classification/train_quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def main(args):
7474
model_without_ddp = model.module
7575

7676
if args.resume:
77-
checkpoint = torch.load(args.resume, map_location="cpu")
77+
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
7878
model_without_ddp.load_state_dict(checkpoint["model"])
7979
optimizer.load_state_dict(checkpoint["optimizer"])
8080
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

references/classification/utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,7 @@ def average_checkpoints(inputs):
287287
for fpath in inputs:
288288
with open(fpath, "rb") as f:
289289
state = torch.load(
290-
f,
291-
map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")),
290+
f, map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), weights_only=True
292291
)
293292
# Copies over the settings from the first checkpoint
294293
if new_state is None:
@@ -367,7 +366,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
367366

368367
# Deep copy to avoid side effects on the model object.
369368
model = copy.deepcopy(model)
370-
checkpoint = torch.load(checkpoint_path, map_location="cpu")
369+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
371370

372371
# Load the weights to the model to validate that everything works
373372
# and remove unnecessary weights (such as auxiliaries, etc.)

references/depth/stereo/cascade_evaluation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def load_checkpoint(args):
262262
utils.setup_ddp(args)
263263

264264
if not args.weights:
265-
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"))
265+
checkpoint = torch.load(args.checkpoint, map_location=torch.device("cpu"), weights_only=True)
266266
if "model" in checkpoint:
267267
experiment_args = checkpoint["args"]
268268
model = torchvision.prototype.models.depth.stereo.__dict__[experiment_args.model](weights=None)

references/depth/stereo/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ def main(args):
498498
# load them from checkpoint if needed
499499
args.start_step = 0
500500
if args.resume_path is not None:
501-
checkpoint = torch.load(args.resume_path, map_location="cpu")
501+
checkpoint = torch.load(args.resume_path, map_location="cpu", weights_only=True)
502502
if "model" in checkpoint:
503503
# this means the user requested to resume from a training checkpoint
504504
model_without_ddp.load_state_dict(checkpoint["model"])

references/detection/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def main(args):
288288
)
289289

290290
if args.resume:
291-
checkpoint = torch.load(args.resume, map_location="cpu")
291+
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
292292
model_without_ddp.load_state_dict(checkpoint["model"])
293293
optimizer.load_state_dict(checkpoint["optimizer"])
294294
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

references/optical_flow/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def main(args):
226226
model_without_ddp = model
227227

228228
if args.resume is not None:
229-
checkpoint = torch.load(args.resume, map_location="cpu")
229+
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
230230
model_without_ddp.load_state_dict(checkpoint["model"])
231231

232232
if args.test_only:

references/segmentation/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def main(args):
223223
lr_scheduler = main_lr_scheduler
224224

225225
if args.resume:
226-
checkpoint = torch.load(args.resume, map_location="cpu")
226+
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
227227
model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only)
228228
if not args.test_only:
229229
optimizer.load_state_dict(checkpoint["optimizer"])

references/similarity/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def main(args):
101101

102102
model = EmbeddingNet()
103103
if args.resume:
104-
model.load_state_dict(torch.load(args.resume))
104+
model.load_state_dict(torch.load(args.resume, weights_only=True))
105105

106106
model.to(device)
107107

references/video_classification/train.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def main(args):
164164

165165
if args.cache_dataset and os.path.exists(cache_path):
166166
print(f"Loading dataset_train from {cache_path}")
167-
dataset, _ = torch.load(cache_path)
167+
dataset, _ = torch.load(cache_path, weights_only=True)
168168
dataset.transform = transform_train
169169
else:
170170
if args.distributed:
@@ -201,7 +201,7 @@ def main(args):
201201

202202
if args.cache_dataset and os.path.exists(cache_path):
203203
print(f"Loading dataset_test from {cache_path}")
204-
dataset_test, _ = torch.load(cache_path)
204+
dataset_test, _ = torch.load(cache_path, weights_only=True)
205205
dataset_test.transform = transform_test
206206
else:
207207
if args.distributed:
@@ -295,7 +295,7 @@ def main(args):
295295
model_without_ddp = model.module
296296

297297
if args.resume:
298-
checkpoint = torch.load(args.resume, map_location="cpu")
298+
checkpoint = torch.load(args.resume, map_location="cpu", weights_only=True)
299299
model_without_ddp.load_state_dict(checkpoint["model"])
300300
optimizer.load_state_dict(checkpoint["optimizer"])
301301
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])

test/test_functional_tensor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1024,7 +1024,8 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
10241024
# "23_23_1.7": ...
10251025
# }
10261026
p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt")
1027-
true_cv2_results = torch.load(p)
1027+
1028+
true_cv2_results = torch.load(p, weights_only=False)
10281029

10291030
if image_size == "small":
10301031
tensor = (

test/test_models.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def _assert_expected(output, name, prec=None, atol=None, rtol=None):
149149
if binary_size > MAX_PICKLE_SIZE:
150150
raise RuntimeError(f"The output for {filename}, is larger than 50kb - got {binary_size}kb")
151151
else:
152-
expected = torch.load(expected_file)
152+
expected = torch.load(expected_file, weights_only=True)
153153
rtol = rtol or prec # keeping prec param for legacy reason, but could be removed ideally
154154
atol = atol or prec
155155
torch.testing.assert_close(output, expected, rtol=rtol, atol=atol, check_dtype=False, check_device=False)
@@ -747,7 +747,7 @@ def check_out(out):
747747
# so instead of validating the probability scores, check that the class
748748
# predictions match.
749749
expected_file = _get_expected_file(model_name)
750-
expected = torch.load(expected_file)
750+
expected = torch.load(expected_file, weights_only=True)
751751
torch.testing.assert_close(
752752
out.argmax(dim=1), expected.argmax(dim=1), rtol=prec, atol=prec, check_device=False
753753
)
@@ -847,7 +847,7 @@ def compute_mean_std(tensor):
847847
# as in NMSTester.test_nms_cuda to see if this is caused by duplicate
848848
# scores.
849849
expected_file = _get_expected_file(model_name)
850-
expected = torch.load(expected_file)
850+
expected = torch.load(expected_file, weights_only=True)
851851
torch.testing.assert_close(
852852
output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False
853853
)

test/test_prototype_datasets_builtin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def test_save_load(self, dataset_mock, config):
215215
with io.BytesIO() as buffer:
216216
torch.save(sample, buffer)
217217
buffer.seek(0)
218-
assert_samples_equal(torch.load(buffer), sample)
218+
assert_samples_equal(torch.load(buffer, weights_only=True), sample)
219219

220220
@parametrize_dataset_mocks(DATASET_MOCKS)
221221
def test_infinite_buffer_size(self, dataset_mock, config):

test/test_transforms_v2.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3176,7 +3176,8 @@ def test__get_params(self, sigma):
31763176
# "26_28_1__23_23_1.7": cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7),
31773177
# }
31783178
REFERENCE_GAUSSIAN_BLUR_IMAGE_RESULTS = torch.load(
3179-
Path(__file__).parent / "assets" / "gaussian_blur_opencv_results.pt"
3179+
Path(__file__).parent / "assets" / "gaussian_blur_opencv_results.pt",
3180+
weights_only=False,
31803181
)
31813182

31823183
@pytest.mark.parametrize(

test/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def test_flow_to_image(batch):
375375
assert img.shape == (2, 3, h, w) if batch else (3, h, w)
376376

377377
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "expected_flow.pt")
378-
expected_img = torch.load(path, map_location="cpu")
378+
expected_img = torch.load(path, map_location="cpu", weights_only=True)
379379

380380
if batch:
381381
expected_img = torch.stack([expected_img, expected_img])

torchvision/datasets/imagenet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str
8484
file = os.path.join(root, file)
8585

8686
if check_integrity(file):
87-
return torch.load(file)
87+
return torch.load(file, weights_only=True)
8888
else:
8989
msg = (
9090
"The meta file {} is not present in the root directory or is corrupted. "

torchvision/datasets/mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _load_legacy_data(self):
116116
# This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data
117117
# directly.
118118
data_file = self.training_file if self.train else self.test_file
119-
return torch.load(os.path.join(self.processed_folder, data_file))
119+
return torch.load(os.path.join(self.processed_folder, data_file), weights_only=True)
120120

121121
def _load_data(self):
122122
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"

torchvision/datasets/phototour.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(
106106
self.cache()
107107

108108
# load the serialized data
109-
self.data, self.labels, self.matches = torch.load(self.data_file)
109+
self.data, self.labels, self.matches = torch.load(self.data_file, weights_only=True)
110110

111111
def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[Any, Any, torch.Tensor]]:
112112
"""

0 commit comments

Comments
 (0)