Skip to content

Commit 1b5d762

Browse files
Jessica LinJames Reedilia-cherIvanKobzarevShen Li
authored
Release/1.6 (#1087)
* Add TorchScript fork/join tutorial * Add note about zipfile format in serialization tutorial * Profiler recipe (#1019) * Profiler recipe Summary: Adding a recipe for profiler Test Plan: make html-noplot * [mobile] Mobile Perf Recipe * Minor syntax edits to mobile perf recipe * Remove built files * [android] android native app recipe * [mobile_perf][recipe] Add ChannelsLast recommendation * Adding distributed pipeline parallel tutorial * Add async execution tutorials * Fix code block in pipeline tutorial * Adding an Overview Page for PyTorch Distributed (#1056) * Adding an Overview Page for PyTorch Distributed * Let existing PT Distributed tutorials link to the overview page * Add a link to AMP * Address Comments * Remove unnecessary dist.barrier() * [Mobile Perf Recipe] Add the benchmarking part for iOS (#1055) * [Mobile Perf Recipe] Add the benchmarking part for iOS * [Mobile Perf Recipe] Add the benchmarking part for iOS Co-authored-by: Jessica Lin <jplin@fb.com> * RPC profiling recipe (#1068) * Initial commit * Update * Complete most of recipe * Add image * Link image * Remove extra file * update * Update * update * Push latest changes from master into release/1.6 (#1074) * Update feature classification labels * Update NVidia -> Nvidia * Bring back default filename_pattern so that by default we run all galleries. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Add prototype_source directory * Add prototype directory * Add prototype * Remove extra "done" * Add REAME.txt * Update for prototype instructions * Update for prototype feature * refine torchvision_tutorial doc for windows * Update neural_style_tutorial.py (#1059) Updated the mistake in the Loading Images Section. * torch_script_custom_ops restructure (#1057) Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Port custom ops tutorial to new registration API, increase testability. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Kill some other occurrences of RegisterOperators Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Update README.md * Make torch_script_custom_classes tutorial runnable I also fixed some warnings in the tutorial, and fixed some minor bitrot (e.g., torch::script::Module to torch::jit::Module) I also added some missing quotes around some bash expansions. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Update torch_script_custom_classes to use TORCH_LIBRARY (#1062) Signed-off-by: Edward Z. Yang <ezyang@fb.com> Co-authored-by: Edward Z. Yang <ezyang@fb.com> Co-authored-by: Yang Gu <yangu@microsoft.com> Co-authored-by: Hritik Bhandari <bhandari.hritik@gmail.com> * Tutorial for DDP + RPC (#1071) * Update feature classification labels * Update NVidia -> Nvidia * Bring back default filename_pattern so that by default we run all galleries. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Tutorial for DDP + RPC. Summary: Based on example from pytorch/examples#800 * Add to main section Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Added separate code file and used literalinclude Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Co-authored-by: Jessica Lin <jplin@fb.com> Co-authored-by: Edward Z. Yang <ezyang@fb.com> Co-authored-by: pritam <pritam.damania@fb.com> * Make RPC profiling recipe into prototype tutorial (#1078) * Add RPC tutorial * Update to include recipes * Add Graph Mode Dynamic Quant tutorial (#1065) * Update feature classification labels * Update NVidia -> Nvidia * Bring back default filename_pattern so that by default we run all galleries. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Add prototype_source directory * Add prototype directory * Add prototype * Remove extra "done" * Add REAME.txt * Update for prototype instructions * Update for prototype feature * refine torchvision_tutorial doc for windows * Update neural_style_tutorial.py (#1059) Updated the mistake in the Loading Images Section. * torch_script_custom_ops restructure (#1057) Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Port custom ops tutorial to new registration API, increase testability. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Kill some other occurrences of RegisterOperators Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Update README.md * Make torch_script_custom_classes tutorial runnable I also fixed some warnings in the tutorial, and fixed some minor bitrot (e.g., torch::script::Module to torch::jit::Module) I also added some missing quotes around some bash expansions. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Update torch_script_custom_classes to use TORCH_LIBRARY (#1062) Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Add Graph Mode Dynamic Quant tutorial Summary: Tutorial to demonstrate graph mode dynamic quant on BERT model. Currently not directly runnable as it requires to download glue dataset and fine-tuned model Co-authored-by: Jessica Lin <jplin@fb.com> Co-authored-by: Edward Z. Yang <ezyang@fb.com> Co-authored-by: Yang Gu <yangu@microsoft.com> Co-authored-by: Hritik Bhandari <bhandari.hritik@gmail.com> * Add mobile recipes images * Update mobile recipe index * Remove RPC Profiling recipe from index * 1.6 model freezing tutorial (#1077) * Update feature classification labels * Update NVidia -> Nvidia * Bring back default filename_pattern so that by default we run all galleries. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Add prototype_source directory * Add prototype directory * Add prototype * Remove extra "done" * Add REAME.txt * Update for prototype instructions * Update for prototype feature * refine torchvision_tutorial doc for windows * Update neural_style_tutorial.py (#1059) Updated the mistake in the Loading Images Section. * torch_script_custom_ops restructure (#1057) Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Port custom ops tutorial to new registration API, increase testability. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Kill some other occurrences of RegisterOperators Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Update README.md * Make torch_script_custom_classes tutorial runnable I also fixed some warnings in the tutorial, and fixed some minor bitrot (e.g., torch::script::Module to torch::jit::Module) I also added some missing quotes around some bash expansions. Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Update torch_script_custom_classes to use TORCH_LIBRARY (#1062) Signed-off-by: Edward Z. Yang <ezyang@fb.com> * Add Model Freezing in TorchScript Co-authored-by: Edward Z. Yang <ezyang@fb.com> Co-authored-by: Yang Gu <yangu@microsoft.com> Co-authored-by: Hritik Bhandari <bhandari.hritik@gmail.com> * Update title * Update recipes_index.rst Touch for rebuild. * Update dcgan_faces_tutorial.py Update labels to be floats to work around torch.full inference change. Co-authored-by: James Reed <jamesreed@fb.com> Co-authored-by: ilia-cher <30845429+ilia-cher@users.noreply.github.com> Co-authored-by: Ivan Kobzarev <ivankobzarev@fb.com> Co-authored-by: Shen Li <shenli@devfair017.maas> Co-authored-by: Shen Li <cs.shenli@gmail.com> Co-authored-by: Tao Xu <taox@fb.com> Co-authored-by: Rohan Varma <rvarm1@fb.com> Co-authored-by: Edward Z. Yang <ezyang@fb.com> Co-authored-by: Yang Gu <yangu@microsoft.com> Co-authored-by: Hritik Bhandari <bhandari.hritik@gmail.com> Co-authored-by: Pritam Damania <9958665+pritamdamania87@users.noreply.github.com> Co-authored-by: pritam <pritam.damania@fb.com> Co-authored-by: supriyar <supriyar@fb.com> Co-authored-by: Brian Johnson <brianjo@fb.com> Co-authored-by: gchanan <gchanan@fb.com>
1 parent 11569e0 commit 1b5d762

33 files changed

+4075
-16
lines changed

_static/img/rpc-images/batch.png

19.7 KB
Loading

_static/img/rpc_trace_img.png

307 KB
Loading
Loading
Loading
Loading
15.9 KB
Loading
17.6 KB
Loading
23.5 KB
Loading
34.9 KB
Loading

_static/img/trace_img.png

134 KB
Loading

advanced_source/rpc_ddp_tutorial.rst

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
Combining Distributed DataParallel with Distributed RPC Framework
2+
=================================================================
3+
**Author**: `Pritam Damania <https://github.com/pritamdamania87>`_
4+
5+
6+
This tutorial uses a simple example to demonstrate how you can combine
7+
`DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__ (DDP)
8+
with the `Distributed RPC framework <https://pytorch.org/docs/master/rpc.html>`__
9+
to combine distributed data parallelism with distributed model parallelism to
10+
train a simple model. Source code of the example can be found `here <https://github.com/pytorch/examples/tree/master/distributed/rpc/ddp_rpc>`__.
11+
12+
Previous tutorials,
13+
`Getting Started With Distributed Data Parallel <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
14+
and `Getting Started with Distributed RPC Framework <https://pytorch.org/tutorials/intermediate/rpc_tutorial.html>`__,
15+
described how to perform distributed data parallel and distributed model
16+
parallel training respectively. Although, there are several training paradigms
17+
where you might want to combine these two techniques. For example:
18+
19+
1) If we have a model with a sparse part (large embedding table) and a dense
20+
part (FC layers), we might want to put the embedding table on a parameter
21+
server and replicate the FC layer across multiple trainers using `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__.
22+
The `Distributed RPC framework <https://pytorch.org/docs/master/rpc.html>`__
23+
can be used to perform embedding lookups on the parameter server.
24+
2) Enable hybrid parallelism as described in the `PipeDream <https://arxiv.org/abs/1806.03377>`__ paper.
25+
We can use the `Distributed RPC framework <https://pytorch.org/docs/master/rpc.html>`__
26+
to pipeline stages of the model across multiple workers and replicate each
27+
stage (if needed) using `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__.
28+
29+
|
30+
In this tutorial we will cover case 1 mentioned above. We have a total of 4
31+
workers in our setup as follows:
32+
33+
34+
1) 1 Master, which is responsible for creating an embedding table
35+
(nn.EmbeddingBag) on the parameter server. The master also drives the
36+
training loop on the two trainers.
37+
2) 1 Parameter Server, which basically holds the embedding table in memory and
38+
responds to RPCs from the Master and Trainers.
39+
3) 2 Trainers, which store an FC layer (nn.Linear) which is replicated amongst
40+
themselves using `DistributedDataParallel <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__.
41+
The trainers are also responsible for executing the forward pass, backward
42+
pass and optimizer step.
43+
44+
|
45+
The entire training process is executed as follows:
46+
47+
1) The master creates an embedding table on the Parameter Server and holds an
48+
`RRef <https://pytorch.org/docs/master/rpc.html#rref>`__ to it.
49+
2) The master, then kicks off the training loop on the trainers and passes the
50+
embedding table RRef to the trainers.
51+
3) The trainers create a ``HybridModel`` which first performs an embedding lookup
52+
using the embedding table RRef provided by the master and then executes the
53+
FC layer which is wrapped inside DDP.
54+
4) The trainer executes the forward pass of the model and uses the loss to
55+
execute the backward pass using `Distributed Autograd <https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework>`__.
56+
5) As part of the backward pass, the gradients for the FC layer are computed
57+
first and synced to all trainers via allreduce in DDP.
58+
6) Next, Distributed Autograd propagates the gradients to the parameter server,
59+
where the gradients for the embedding table are updated.
60+
7) Finally, the `Distributed Optimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`__ is used to update all the parameters.
61+
62+
63+
.. attention::
64+
65+
You should always use `Distributed Autograd <https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework>`__
66+
for the backward pass if you're combining DDP and RPC.
67+
68+
69+
Now, let's go through each part in detail. Firstly, we need to setup all of our
70+
workers before we can perform any training. We create 4 processes such that
71+
ranks 0 and 1 are our trainers, rank 2 is the master and rank 3 is the
72+
parameter server.
73+
74+
We initialize the RPC framework on all 4 workers using the TCP init_method.
75+
Once RPC initialization is done, the master creates an `EmbeddingBag <https://pytorch.org/docs/master/generated/torch.nn.EmbeddingBag.html>`__
76+
on the Parameter Server using `rpc.remote <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.remote>`__.
77+
The master then loops through each trainer and kicks of the training loop by
78+
calling ``_run_trainer`` on each trainer using `rpc_async <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.rpc_async>`__.
79+
Finally, the master waits for all training to finish before exiting.
80+
81+
The trainers first initialize a ``ProcessGroup`` for DDP with world_size=2
82+
(for two trainers) using `init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__.
83+
Next, they initialize the RPC framework using the TCP init_method. Note that
84+
the ports are different in RPC initialization and ProcessGroup initialization.
85+
This is to avoid port conflicts between initialization of both frameworks.
86+
Once the initialization is done, the trainers just wait for the ``_run_trainer``
87+
RPC from the master.
88+
89+
The parameter server just initializes the RPC framework and waits for RPCs from
90+
the trainers and master.
91+
92+
93+
.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py
94+
:language: py
95+
:start-after: BEGIN run_worker
96+
:end-before: END run_worker
97+
98+
Before we discuss details of the Trainer, let's introduce the ``HybridModel`` that
99+
the trainer uses. As described below, the ``HybridModel`` is initialized using an
100+
RRef to the embedding table (emb_rref) on the parameter server and the ``device``
101+
to use for DDP. The initialization of the model wraps an
102+
`nn.Linear <https://pytorch.org/docs/master/generated/torch.nn.Linear.html>`__
103+
layer inside DDP to replicate and synchronize this layer across all trainers.
104+
105+
The forward method of the model is pretty straightforward. It performs an
106+
embedding lookup on the parameter server using an
107+
`RRef helper <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.RRef.rpc_sync>`__
108+
and passes its output onto the FC layer.
109+
110+
111+
.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py
112+
:language: py
113+
:start-after: BEGIN hybrid_model
114+
:end-before: END hybrid_model
115+
116+
Next, let's look at the setup on the Trainer. The trainer first creates the
117+
``HybridModel`` described above using an RRef to the embedding table on the
118+
parameter server and its own rank.
119+
120+
Now, we need to retrieve a list of RRefs to all the parameters that we would
121+
like to optimize with `DistributedOptimizer <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`__.
122+
To retrieve the parameters for the embedding table from the parameter server,
123+
we define a simple helper function ``_retrieve_embedding_parameters``, which
124+
basically walks through all the parameters for the embedding table and returns
125+
a list of RRefs. The trainer calls this method on the parameter server via RPC
126+
to receive a list of RRefs to the desired parameters. Since the
127+
DistributedOptimizer always takes a list of RRefs to parameters that need to
128+
be optimized, we need to create RRefs even for the local parameters for our
129+
FC layers. This is done by walking ``model.parameters()``, creating an RRef for
130+
each parameter and appending it to a list. Note that ``model.parameters()`` only
131+
returns local parameters and doesn't include ``emb_rref``.
132+
133+
Finally, we create our DistributedOptimizer using all the RRefs and define a
134+
CrossEntropyLoss function.
135+
136+
.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py
137+
:language: py
138+
:start-after: BEGIN setup_trainer
139+
:end-before: END setup_trainer
140+
141+
Now we're ready to introduce the main training loop that is run on each trainer.
142+
``get_next_batch`` is just a helper function to generate random inputs and
143+
targets for training. We run the training loop for multiple epochs and for each
144+
batch:
145+
146+
1) Setup a `Distributed Autograd Context <https://pytorch.org/docs/master/rpc.html#torch.distributed.autograd.context>`__
147+
for Distributed Autograd.
148+
2) Run the forward pass of the model and retrieve its output.
149+
3) Compute the loss based on our outputs and targets using the loss function.
150+
4) Use Distributed Autograd to execute a distributed backward pass using the loss.
151+
5) Finally, run a Distributed Optimizer step to optimize all the parameters.
152+
153+
.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py
154+
:language: py
155+
:start-after: BEGIN run_trainer
156+
:end-before: END run_trainer
157+
.. code:: python
158+
159+
Source code for the entire example can be found `here <https://github.com/pytorch/examples/tree/master/distributed/rpc/ddp_rpc>`__.
+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import os
2+
from functools import wraps
3+
4+
import random
5+
import torch
6+
import torch.distributed as dist
7+
import torch.distributed.autograd as dist_autograd
8+
import torch.distributed.rpc as rpc
9+
from torch.distributed.rpc import ProcessGroupRpcBackendOptions
10+
import torch.multiprocessing as mp
11+
import torch.optim as optim
12+
from torch.distributed.optim import DistributedOptimizer
13+
from torch.distributed.rpc import RRef
14+
from torch.nn.parallel import DistributedDataParallel as DDP
15+
16+
NUM_EMBEDDINGS = 100
17+
EMBEDDING_DIM = 16
18+
19+
# BEGIN hybrid_model
20+
class HybridModel(torch.nn.Module):
21+
r"""
22+
The model consists of a sparse part and a dense part. The dense part is an
23+
nn.Linear module that is replicated across all trainers using
24+
DistributedDataParallel. The sparse part is an nn.EmbeddingBag that is
25+
stored on the parameter server.
26+
27+
The model holds a Remote Reference to the embedding table on the parameter
28+
server.
29+
"""
30+
31+
def __init__(self, emb_rref, device):
32+
super(HybridModel, self).__init__()
33+
self.emb_rref = emb_rref
34+
self.fc = DDP(torch.nn.Linear(16, 8).cuda(device), device_ids=[device])
35+
self.device = device
36+
37+
def forward(self, indices, offsets):
38+
emb_lookup = self.emb_rref.rpc_sync().forward(indices, offsets)
39+
return self.fc(emb_lookup.cuda(self.device))
40+
# END hybrid_model
41+
42+
# BEGIN setup_trainer
43+
def _retrieve_embedding_parameters(emb_rref):
44+
param_rrefs = []
45+
for param in emb_rref.local_value().parameters():
46+
param_rrefs.append(RRef(param))
47+
return param_rrefs
48+
49+
50+
def _run_trainer(emb_rref, rank):
51+
r"""
52+
Each trainer runs a forward pass which involves an embedding lookup on the
53+
parameter server and running nn.Linear locally. During the backward pass,
54+
DDP is responsible for aggregating the gradients for the dense part
55+
(nn.Linear) and distributed autograd ensures gradients updates are
56+
propagated to the parameter server.
57+
"""
58+
59+
# Setup the model.
60+
model = HybridModel(emb_rref, rank)
61+
62+
# Retrieve all model parameters as rrefs for DistributedOptimizer.
63+
64+
# Retrieve parameters for embedding table.
65+
model_parameter_rrefs = rpc.rpc_sync(
66+
"ps", _retrieve_embedding_parameters, args=(emb_rref,))
67+
68+
# model.parameters() only includes local parameters.
69+
for param in model.parameters():
70+
model_parameter_rrefs.append(RRef(param))
71+
72+
# Setup distributed optimizer
73+
opt = DistributedOptimizer(
74+
optim.SGD,
75+
model_parameter_rrefs,
76+
lr=0.05,
77+
)
78+
79+
criterion = torch.nn.CrossEntropyLoss()
80+
# END setup_trainer
81+
82+
# BEGIN run_trainer
83+
def get_next_batch(rank):
84+
for _ in range(10):
85+
num_indices = random.randint(20, 50)
86+
indices = torch.LongTensor(num_indices).random_(0, NUM_EMBEDDINGS)
87+
88+
# Generate offsets.
89+
offsets = []
90+
start = 0
91+
batch_size = 0
92+
while start < num_indices:
93+
offsets.append(start)
94+
start += random.randint(1, 10)
95+
batch_size += 1
96+
97+
offsets_tensor = torch.LongTensor(offsets)
98+
target = torch.LongTensor(batch_size).random_(8).cuda(rank)
99+
yield indices, offsets_tensor, target
100+
101+
# Train for 100 epochs
102+
for epoch in range(100):
103+
# create distributed autograd context
104+
for indices, offsets, target in get_next_batch(rank):
105+
with dist_autograd.context() as context_id:
106+
output = model(indices, offsets)
107+
loss = criterion(output, target)
108+
109+
# Run distributed backward pass
110+
dist_autograd.backward(context_id, [loss])
111+
112+
# Tun distributed optimizer
113+
opt.step(context_id)
114+
115+
# Not necessary to zero grads as each iteration creates a different
116+
# distributed autograd context which hosts different grads
117+
print("Training done for epoch {}".format(epoch))
118+
# END run_trainer
119+
120+
121+
# BEGIN run_worker
122+
def run_worker(rank, world_size):
123+
r"""
124+
A wrapper function that initializes RPC, calls the function, and shuts down
125+
RPC.
126+
"""
127+
os.environ['MASTER_ADDR'] = 'localhost'
128+
os.environ['MASTER_PORT'] = '29500'
129+
130+
131+
rpc_backend_options = ProcessGroupRpcBackendOptions()
132+
rpc_backend_options.init_method='tcp://localhost:29501'
133+
134+
# Rank 2 is master, 3 is ps and 0 and 1 are trainers.
135+
if rank == 2:
136+
rpc.init_rpc(
137+
"master",
138+
rank=rank,
139+
world_size=world_size,
140+
rpc_backend_options=rpc_backend_options)
141+
142+
# Build the embedding table on the ps.
143+
emb_rref = rpc.remote(
144+
"ps",
145+
torch.nn.EmbeddingBag,
146+
args=(NUM_EMBEDDINGS, EMBEDDING_DIM),
147+
kwargs={"mode": "sum"})
148+
149+
# Run the training loop on trainers.
150+
futs = []
151+
for trainer_rank in [0, 1]:
152+
trainer_name = "trainer{}".format(trainer_rank)
153+
fut = rpc.rpc_async(
154+
trainer_name, _run_trainer, args=(emb_rref, rank))
155+
futs.append(fut)
156+
157+
# Wait for all training to finish.
158+
for fut in futs:
159+
fut.wait()
160+
elif rank <= 1:
161+
# Initialize process group for Distributed DataParallel on trainers.
162+
dist.init_process_group(
163+
backend="gloo", rank=rank, world_size=2)
164+
165+
# Initialize RPC.
166+
trainer_name = "trainer{}".format(rank)
167+
rpc.init_rpc(
168+
trainer_name,
169+
rank=rank,
170+
world_size=world_size,
171+
rpc_backend_options=rpc_backend_options)
172+
173+
# Trainer just waits for RPCs from master.
174+
else:
175+
rpc.init_rpc(
176+
"ps",
177+
rank=rank,
178+
world_size=world_size,
179+
rpc_backend_options=rpc_backend_options)
180+
# parameter server do nothing
181+
pass
182+
183+
# block until all rpcs finish
184+
rpc.shutdown()
185+
186+
187+
if __name__=="__main__":
188+
# 2 trainers, 1 parameter server, 1 master.
189+
world_size = 4
190+
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
191+
# END run_worker

0 commit comments

Comments
 (0)