@@ -381,15 +381,19 @@ def get_transform(train):
381
381
# Before iterating over the dataset, it's good to see what the model
382
382
# expects during training and inference time on sample data.
383
383
import utils
384
+ import platform
384
385
386
+ # In Windows and MacOS, the number of workers have to be set to 0 if we are using notebook
387
+ # We can set num_workers > 0 if we are using a python script and wrapping the code in if __name__ == "__main__": block
388
+ num_workers = 0 if platform .system () in ('Windows' , 'Darwin' ) else 4
385
389
386
390
model = torchvision .models .detection .fasterrcnn_resnet50_fpn (weights = "DEFAULT" )
387
391
dataset = PennFudanDataset ('data/PennFudanPed' , get_transform (train = True ))
388
392
data_loader = torch .utils .data .DataLoader (
389
393
dataset ,
390
394
batch_size = 2 ,
391
395
shuffle = True ,
392
- num_workers = 4 ,
396
+ num_workers = num_workers ,
393
397
collate_fn = utils .collate_fn
394
398
)
395
399
@@ -428,20 +432,22 @@ def get_transform(train):
428
432
dataset = torch .utils .data .Subset (dataset , indices [:- 50 ])
429
433
dataset_test = torch .utils .data .Subset (dataset_test , indices [- 50 :])
430
434
435
+ num_workers = 0 if platform .system () in ('Windows' , 'Darwin' ) else 4
436
+
431
437
# define training and validation data loaders
432
438
data_loader = torch .utils .data .DataLoader (
433
439
dataset ,
434
440
batch_size = 2 ,
435
441
shuffle = True ,
436
- num_workers = 4 ,
442
+ num_workers = num_workers ,
437
443
collate_fn = utils .collate_fn
438
444
)
439
445
440
446
data_loader_test = torch .utils .data .DataLoader (
441
447
dataset_test ,
442
448
batch_size = 1 ,
443
449
shuffle = False ,
444
- num_workers = 4 ,
450
+ num_workers = num_workers ,
445
451
collate_fn = utils .collate_fn
446
452
)
447
453
0 commit comments