Skip to content

Commit dc448c2

Browse files
ChanBongsvekarsNicolasHugawgu
authored
Add image for better explanation to FSDP tutorial (#2644)
* Add image for better explanation * Edit explanation for fsdp sharding --------- Co-authored-by: Svetlana Karslioglu <svekars@meta.com> Co-authored-by: Nicolas Hug <contact@nicolas-hug.com> Co-authored-by: Andrew Gu <31054793+awgu@users.noreply.github.com>
1 parent f05f050 commit dc448c2

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed
91 KB
Loading

intermediate_source/FSDP_tutorial.rst

+9
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ At a high level FSDP works as follow:
4646
* Run reduce_scatter to sync gradients
4747
* Discard parameters.
4848

49+
One way to view FSDP's sharding is to decompose the DDP gradient all-reduce into reduce-scatter and all-gather. Specifically, during the backward pass, FSDP reduces and scatters gradients, ensuring that each rank possesses a shard of the gradients. Then it updates the corresponding shard of the parameters in the optimizer step. Finally, in the subsequent forward pass, it performs an all-gather operation to collect and combine the updated parameter shards.
50+
51+
.. figure:: /_static/img/distributed/fsdp_sharding.png
52+
:width: 100%
53+
:align: center
54+
:alt: FSDP allreduce
55+
56+
FSDP Allreduce
57+
4958
How to use FSDP
5059
--------------
5160
Here we use a toy model to run training on the MNIST dataset for demonstration purposes. The APIs and logic can be applied to training larger models as well.

0 commit comments

Comments
 (0)