diff --git a/_static/img/thumbnails/cropped/Distributed-Pipeline-Parallelism-Using-RPC.png b/_static/img/thumbnails/cropped/Distributed-Pipeline-Parallelism-Using-RPC.png
new file mode 100644
index 00000000000..426a14d98f5
Binary files /dev/null and b/_static/img/thumbnails/cropped/Distributed-Pipeline-Parallelism-Using-RPC.png differ
diff --git a/index.rst b/index.rst
index 81113c919ed..22921f61b35 100644
--- a/index.rst
+++ b/index.rst
@@ -332,6 +332,13 @@ Welcome to PyTorch Tutorials
:link: intermediate/rpc_param_server_tutorial.html
:tags: Parallel-and-Distributed-Training
+.. customcarditem::
+ :header: Distributed Pipeline Parallelism Using RPC
+ :card_description: Demonstrate how to implement distributed pipeline parallelism using RPC
+ :image: _static/img/thumbnails/cropped/Distributed-Pipeline-Parallelism-Using-RPC.png
+ :link: intermediate/dist_pipeline_parallel_tutorial.html
+ :tags: Parallel-and-Distributed-Training
+
.. End of tutorial card section
.. raw:: html
@@ -497,3 +504,4 @@ Additional Resources
intermediate/rpc_tutorial
beginner/aws_distributed_training_tutorial
intermediate/rpc_param_server_tutorial
+ intermediate/dist_pipeline_parallel_tutorial
diff --git a/intermediate_source/dist_pipeline_parallel_tutorial.rst b/intermediate_source/dist_pipeline_parallel_tutorial.rst
new file mode 100644
index 00000000000..ef7df000508
--- /dev/null
+++ b/intermediate_source/dist_pipeline_parallel_tutorial.rst
@@ -0,0 +1,370 @@
+Distributed Pipeline Parallelism Using RPC
+==========================================
+**Author**: `Shen Li `_
+
+Prerequisites:
+
+- `Single-Machine Model Parallel Best Practices `__
+- `Getting started with Distributed RPC Framework `__
+- RRef helper functions:
+ `RRef.rpc_sync() `__,
+ `RRef.rpc_async() `__, and
+ `RRef.remote() `__
+
+
+
+This tutorial uses a Resnet50 model to demonstrate implementing distributed
+pipeline parallelism with `torch.distributed.rpc `__
+APIs. This can be viewed as the distributed counterpart of the multi-GPU
+pipeline parallelism discussed in
+`Single-Machine Model Parallel Best Practices `_.
+
+.. note:: This tutorial requires PyTorch v1.6.0 or above.
+
+.. note:: Full source code of this tutorial can be found at
+ `pytorch/examples `__.
+
+Basics
+------
+
+
+The previous tutorial, `Getting Started with Distributed RPC Framework `_
+shows how to use `torch.distributed.rpc `_
+to implement distributed model parallelism for an RNN model. That tutorial uses
+one GPU to host the ``EmbeddingTable``, and the provided code works fine.
+However, if a model lives on multiple GPUs, it would require some extra steps to
+increase the amortized utilization of all GPUs. Pipeline parallelism is one type
+of paradigm that can help in this case.
+
+In this tutorial, we use ``ResNet50`` as an example model which is also used by
+the `Single-Machine Model Parallel Best Practices `_
+tutorial. Similarly, the ``ResNet50`` model is divided into two shards and
+the input batch is partitioned into multiple splits and fed into the two model
+shards in a pipelined fashion. The difference is that, instead of parallelizing
+the execution using CUDA streams, this tutorial invokes asynchronous RPCs. So,
+the solution presented in this tutorial also works across machine boundaries.
+The remainder of this tutorial presents the implementation in four steps.
+
+
+
+Step 1: Partition ResNet50 Model
+--------------------------------
+
+This is the preparation step which implements ``ResNet50`` in two model shards.
+The code below is borrowed from the
+`ResNet implementation in torchvision `_.
+The ``ResNetBase`` module contains the common building blocks and attributes for
+the two ResNet shards.
+
+
+.. code:: python
+ import threading
+
+ import torch
+ import torch.nn as nn
+
+ from torchvision.models.resnet import Bottleneck
+
+ num_classes = 1000
+
+
+ def conv1x1(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+ class ResNetBase(nn.Module):
+ def __init__(self, block, inplanes, num_classes=1000,
+ groups=1, width_per_group=64, norm_layer=None):
+ super(ResNetBase, self).__init__()
+
+ self._lock = threading.Lock()
+ self._block = block
+ self._norm_layer = nn.BatchNorm2d
+ self.inplanes = inplanes
+ self.dilation = 1
+ self.groups = groups
+ self.base_width = width_per_group
+
+ def _make_layer(self, planes, blocks, stride=1):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if stride != 1 or self.inplanes != planes * self._block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * self._block.expansion, stride),
+ norm_layer(planes * self._block.expansion),
+ )
+
+ layers = []
+ layers.append(self._block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * self._block.expansion
+ for _ in range(1, blocks):
+ layers.append(self._block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def parameter_rrefs(self):
+ return [RRef(p) for p in self.parameters()]
+
+
+Now, we are ready to define the two model shards. For the constructor, we
+simply split all ResNet50 layers into two parts and move each part into the
+provided device. The ``forward`` functions of both shards take an ``RRef`` of
+the input data, fetch the data locally, and then move it to the expected device.
+After applying all layers to the input, it moves the output to CPU and returns.
+It is because the RPC API requires tensors to reside on CPU to avoid invalid
+device errors when the numbers of devices in the caller and the callee do not
+match.
+
+
+.. code:: python
+
+ class ResNetShard1(ResNetBase):
+ def __init__(self, device, *args, **kwargs):
+ super(ResNetShard1, self).__init__(
+ Bottleneck, 64, num_classes=num_classes, *args, **kwargs)
+
+ self.device = device
+ self.seq = nn.Sequential(
+ nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
+ self._norm_layer(self.inplanes),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
+ self._make_layer(64, 3),
+ self._make_layer(128, 4, stride=2)
+ ).to(self.device)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x_rref):
+ x = x_rref.to_here().to(self.device)
+ with self._lock:
+ out = self.seq(x)
+ return out.cpu()
+
+
+ class ResNetShard2(ResNetBase):
+ def __init__(self, device, *args, **kwargs):
+ super(ResNetShard2, self).__init__(
+ Bottleneck, 512, num_classes=num_classes, *args, **kwargs)
+
+ self.device = device
+ self.seq = nn.Sequential(
+ self._make_layer(256, 6, stride=2),
+ self._make_layer(512, 3, stride=2),
+ nn.AdaptiveAvgPool2d((1, 1)),
+ ).to(self.device)
+
+ self.fc = nn.Linear(512 * self._block.expansion, num_classes).to(self.device)
+
+ def forward(self, x_rref):
+ x = x_rref.to_here().to(self.device)
+ with self._lock:
+ out = self.fc(torch.flatten(self.seq(x), 1))
+ return out.cpu()
+
+
+Step 2: Stitch ResNet50 Model Shards Into One Module
+----------------------------------------------------
+
+
+Then, we create a ``DistResNet50`` module to assemble the two shards and
+implement the pipeline parallel logic. In the constructor, we use two
+``rpc.remote`` calls to put the two shards on two different RPC workers
+respectively and hold on to the ``RRef`` to the two model parts so that they
+can be referenced in the forward pass. The ``forward`` function
+splits the input batch into multiple micro-batches, and feeds these
+micro-batches to the two model parts in a pipelined fashion. It first uses an
+``rpc.remote`` call to apply the first shard to a micro-batch and then forwards
+the returned intermediate output ``RRef`` to the second model shard. After that,
+it collects the ``Future`` of all micro-outputs, and waits for all of them after
+the loop. Note that both ``remote()`` and ``rpc_async()`` return immediately and
+run asynchronously. Therefore, the entire loop is non-blocking, and will launch
+multiple RPCs concurrently. The execution order of one micro-batch on two model
+parts are preserved by intermediate output ``y_rref``. The execution order
+across micro-batches does not matter. In the end, the forward function
+concatenates outputs of all micro-batches into one single output tensor and
+returns. The ``parameter_rrefs`` function is a helper to
+simplify distributed optimizer construction, which will be used later.
+
+
+
+.. code:: python
+
+ class DistResNet50(nn.Module):
+ def __init__(self, num_split, workers, *args, **kwargs):
+ super(DistResNet50, self).__init__()
+
+ self.num_split = num_split
+
+ # Put the first part of the ResNet50 on workers[0]
+ self.p1_rref = rpc.remote(
+ workers[0],
+ ResNetShard1,
+ args = ("cuda:0",) + args,
+ kwargs = kwargs
+ )
+
+ # Put the second part of the ResNet50 on workers[1]
+ self.p2_rref = rpc.remote(
+ workers[1],
+ ResNetShard2,
+ args = ("cuda:1",) + args,
+ kwargs = kwargs
+ )
+
+ def forward(self, xs):
+ out_futures = []
+ for x in iter(xs.split(self.split_size, dim=0)):
+ x_rref = RRef(x)
+ y_rref = self.p1_rref.remote().forward(x_rref)
+ z_fut = self.p2_rref.rpc_async().forward(y_rref)
+ out_futures.append(z_fut)
+
+ return torch.cat(torch.futures.wait_all(out_futures))
+
+ def parameter_rrefs(self):
+ remote_params = []
+ remote_params.extend(self.p1_rref.remote().parameter_rrefs().to_here())
+ remote_params.extend(self.p2_rref.remote().parameter_rrefs().to_here())
+ return remote_params
+
+
+Step 3: Define The Training Loop
+--------------------------------
+
+
+After defining the model, let us implement the training loop. We use a
+dedicated "master" worker to prepare random inputs and labels, and control the
+distributed backward pass and distributed optimizer step. It first creates an
+instance of the ``DistResNet50`` module. It specifies the number of
+micro-batches for each batch, and also provides the name of the two RPC workers
+(i.e., "worker1", and "worker2"). Then it defines the loss function and creates
+a ``DistributedOptimizer`` using the ``parameter_rrefs()`` helper to acquire a
+list of parameter ``RRefs``. Then, the main training loop is very similar to
+regular local training, except that it uses ``dist_autograd`` to launch
+backward and provides the ``context_id`` for both backward and optimizer
+``step()``.
+
+
+.. code:: python
+
+ import torch.distributed.autograd as dist_autograd
+ import torch.optim as optim
+ from torch.distributed.optim import DistributedOptimizer
+
+ num_batches = 3
+ batch_size = 120
+ image_w = 128
+ image_h = 128
+
+
+ def run_master(num_split):
+ # put the two model parts on worker1 and worker2 respectively
+ model = DistResNet50(num_split, ["worker1", "worker2"])
+ loss_fn = nn.MSELoss()
+ opt = DistributedOptimizer(
+ optim.SGD,
+ model.parameter_rrefs(),
+ lr=0.05,
+ )
+
+ one_hot_indices = torch.LongTensor(batch_size) \
+ .random_(0, num_classes) \
+ .view(batch_size, 1)
+
+ for i in range(num_batches):
+ print(f"Processing batch {i}")
+ # generate random inputs and labels
+ inputs = torch.randn(batch_size, 3, image_w, image_h)
+ labels = torch.zeros(batch_size, num_classes) \
+ .scatter_(1, one_hot_indices, 1)
+
+ with dist_autograd.context() as context_id:
+ outputs = model(inputs)
+ dist_autograd.backward(context_id, [loss_fn(outputs, labels)])
+ opt.step(context_id)
+
+
+Step 4: Launch RPC Processes
+----------------------------
+
+
+Finally, the code below shows the target function for all processes. The main
+logic is defined in ``run_master``. The workers passively waiting for
+commands from the master, and hence simply runs ``init_rpc`` and ``shutdown``,
+where the ``shutdown`` by default will block until all RPC participants finish.
+
+.. code:: python
+
+ import os
+ import time
+
+ import torch.multiprocessing as mp
+
+
+ def run_worker(rank, world_size, num_split):
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = '29500'
+ options = rpc.ProcessGroupRpcBackendOptions(num_send_recv_threads=128)
+
+ if rank == 0:
+ rpc.init_rpc(
+ "master",
+ rank=rank,
+ world_size=world_size,
+ rpc_backend_options=options
+ )
+ run_master(num_split)
+ else:
+ rpc.init_rpc(
+ f"worker{rank}",
+ rank=rank,
+ world_size=world_size,
+ rpc_backend_options=options
+ )
+ pass
+
+ # block until all rpcs finish
+ rpc.shutdown()
+
+
+ if __name__=="__main__":
+ world_size = 3
+ for num_split in [1, 2, 4, 8]:
+ tik = time.time()
+ mp.spawn(run_worker, args=(world_size, num_split), nprocs=world_size, join=True)
+ tok = time.time()
+ print(f"number of splits = {num_split}, execution time = {tok - tik}")
+
+
+The output below shows the speedup attained by increasing the number of splits
+in each batch.
+
+::
+
+ $ python main.py
+ Processing batch 0
+ Processing batch 1
+ Processing batch 2
+ number of splits = 1, execution time = 16.45062756538391
+ Processing batch 0
+ Processing batch 1
+ Processing batch 2
+ number of splits = 2, execution time = 12.329529762268066
+ Processing batch 0
+ Processing batch 1
+ Processing batch 2
+ number of splits = 4, execution time = 10.164430618286133
+ Processing batch 0
+ Processing batch 1
+ Processing batch 2
+ number of splits = 8, execution time = 9.076049566268921
\ No newline at end of file