1
1
import datetime
2
2
import os
3
3
import time
4
+ import warnings
4
5
5
6
import presets
6
7
import torch
@@ -50,6 +51,7 @@ def evaluate(model, criterion, data_loader, device):
50
51
model .eval ()
51
52
metric_logger = utils .MetricLogger (delimiter = " " )
52
53
header = "Test:"
54
+ num_processed_samples = 0
53
55
with torch .inference_mode ():
54
56
for video , target in metric_logger .log_every (data_loader , 100 , header ):
55
57
video = video .to (device , non_blocking = True )
@@ -64,7 +66,28 @@ def evaluate(model, criterion, data_loader, device):
64
66
metric_logger .update (loss = loss .item ())
65
67
metric_logger .meters ["acc1" ].update (acc1 .item (), n = batch_size )
66
68
metric_logger .meters ["acc5" ].update (acc5 .item (), n = batch_size )
69
+ num_processed_samples += batch_size
67
70
# gather the stats from all processes
71
+ num_processed_samples = utils .reduce_across_processes (num_processed_samples )
72
+ if isinstance (data_loader .sampler , DistributedSampler ):
73
+ # Get the len of UniformClipSampler inside DistributedSampler
74
+ num_data_from_sampler = len (data_loader .sampler .dataset )
75
+ else :
76
+ num_data_from_sampler = len (data_loader .sampler )
77
+
78
+ if (
79
+ hasattr (data_loader .dataset , "__len__" )
80
+ and num_data_from_sampler != num_processed_samples
81
+ and torch .distributed .get_rank () == 0
82
+ ):
83
+ # See FIXME above
84
+ warnings .warn (
85
+ f"It looks like the sampler has { num_data_from_sampler } samples, but { num_processed_samples } "
86
+ "samples were used for the validation, which might bias the results. "
87
+ "Try adjusting the batch size and / or the world size. "
88
+ "Setting the world size to 1 is always a safe bet."
89
+ )
90
+
68
91
metric_logger .synchronize_between_processes ()
69
92
70
93
print (
@@ -99,7 +122,11 @@ def main(args):
99
122
100
123
device = torch .device (args .device )
101
124
102
- torch .backends .cudnn .benchmark = True
125
+ if args .use_deterministic_algorithms :
126
+ torch .backends .cudnn .benchmark = False
127
+ torch .use_deterministic_algorithms (True )
128
+ else :
129
+ torch .backends .cudnn .benchmark = True
103
130
104
131
# Data loading code
105
132
print ("Loading data" )
@@ -173,7 +200,7 @@ def main(args):
173
200
test_sampler = UniformClipSampler (dataset_test .video_clips , args .clips_per_video )
174
201
if args .distributed :
175
202
train_sampler = DistributedSampler (train_sampler )
176
- test_sampler = DistributedSampler (test_sampler )
203
+ test_sampler = DistributedSampler (test_sampler , shuffle = False )
177
204
178
205
data_loader = torch .utils .data .DataLoader (
179
206
dataset ,
@@ -248,6 +275,9 @@ def main(args):
248
275
scaler .load_state_dict (checkpoint ["scaler" ])
249
276
250
277
if args .test_only :
278
+ # We disable the cudnn benchmarking because it can noticeably affect the accuracy
279
+ torch .backends .cudnn .benchmark = False
280
+ torch .backends .cudnn .deterministic = True
251
281
evaluate (model , criterion , data_loader_test , device = device )
252
282
return
253
283
@@ -335,6 +365,9 @@ def parse_args():
335
365
help = "Only test the model" ,
336
366
action = "store_true" ,
337
367
)
368
+ parser .add_argument (
369
+ "--use-deterministic-algorithms" , action = "store_true" , help = "Forces the use of deterministic algorithms only."
370
+ )
338
371
339
372
# distributed training parameters
340
373
parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
0 commit comments