diff --git a/intermediate_source/ddp_tutorial.rst b/intermediate_source/ddp_tutorial.rst index 366db8db13..a8955569df 100644 --- a/intermediate_source/ddp_tutorial.rst +++ b/intermediate_source/ddp_tutorial.rst @@ -269,8 +269,8 @@ either the application or the model ``forward()`` method. setup(rank, world_size) # setup mp_model and devices for this process - dev0 = (rank * 2) % world_size - dev1 = (rank * 2 + 1) % world_size + dev0 = rank * 2 + dev1 = rank * 2 + 1 mp_model = ToyMpModel(dev0, dev1) ddp_mp_model = DDP(mp_model) @@ -293,6 +293,7 @@ either the application or the model ``forward()`` method. world_size = n_gpus run_demo(demo_basic, world_size) run_demo(demo_checkpoint, world_size) + world_size = n_gpus//2 run_demo(demo_model_parallel, world_size) Initialize DDP with torch.distributed.run/torchrun