Skip to content

Commit 3a15b6a

Browse files
committed
Fix multiprocessing in torchvision_tutorial.py
1 parent 3607e8a commit 3a15b6a

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

Diff for: intermediate_source/torchvision_tutorial.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,19 @@ def get_transform(train):
381381
# Before iterating over the dataset, it's good to see what the model
382382
# expects during training and inference time on sample data.
383383
import utils
384+
import platform
384385

386+
# In Windows and MacOS, the number of workers have to be set to 0 if we are using notebook
387+
# We can set num_workers > 0 if we are using a python script and wrapping the code in if __name__ == "__main__": block
388+
num_workers = 0 if platform.system() in ('Windows', 'Darwin') else 4
385389

386390
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
387391
dataset = PennFudanDataset('data/PennFudanPed', get_transform(train=True))
388392
data_loader = torch.utils.data.DataLoader(
389393
dataset,
390394
batch_size=2,
391395
shuffle=True,
392-
num_workers=4,
396+
num_workers=num_workers,
393397
collate_fn=utils.collate_fn
394398
)
395399

@@ -428,20 +432,22 @@ def get_transform(train):
428432
dataset = torch.utils.data.Subset(dataset, indices[:-50])
429433
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])
430434

435+
num_workers = 0 if platform.system() in ('Windows', 'Darwin') else 4
436+
431437
# define training and validation data loaders
432438
data_loader = torch.utils.data.DataLoader(
433439
dataset,
434440
batch_size=2,
435441
shuffle=True,
436-
num_workers=4,
442+
num_workers=num_workers,
437443
collate_fn=utils.collate_fn
438444
)
439445

440446
data_loader_test = torch.utils.data.DataLoader(
441447
dataset_test,
442448
batch_size=1,
443449
shuffle=False,
444-
num_workers=4,
450+
num_workers=num_workers,
445451
collate_fn=utils.collate_fn
446452
)
447453

0 commit comments

Comments
 (0)