Skip to content

Commit fa311ec

Browse files
authored
add dist op doc (#8273)
1 parent a0f81e5 commit fa311ec

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed
80.5 KB
Loading

docs/source/features/distop.md

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Support of Torch Distributed API in PyTorch/XLA
2+
Before the 2.5 release, PyTorch/XLA only supported collective ops through our custom API call `torch_xla.core.xla_model.*`. In the 2.5 release, we adopt `torch.distributed.*` in PyTorch/XLA for both Dynamo and non-Dynamo cases.
3+
## Collective ops lowering
4+
### Collective ops lowering stack
5+
After introducing the [traceable collective communication APIs](https://github.com/pytorch/pytorch/issues/93173), dynamo can support the collective ops with reimplementing lowering in PyTorch/XLA. The collective op is only traceable through `torch.ops._c10d_functional` call. Below is the figure that shows how the collective op, `all_reduce` in this case, is lowered between torch and torch_xla:
6+
7+
8+
<img src="../_static/img/dist_op_stack.png" alt="Alt Text" width="500" height="400">
9+
10+
_<span style="text-decoration:underline;">Figure 1. Collective ops lowering stack</span>_
11+
12+
### non-Dynamo case
13+
Collective ops are lowered through registering the `ProcessGroupXla`, which is derived from `ProcessGroup`:
14+
15+
```Python
16+
# torch_xla/distributed/xla_backend.py
17+
def _create_xla_process_group(prefix_store, rank, size, timeout):
18+
assert not xr.is_spmd(
19+
), "XLA backend is not supported with SPMD. Please use a CPU process group instead."
20+
return ProcessGroupXla(prefix_store, rank, size, timeout)
21+
22+
23+
def _register_xla_backend():
24+
dist.Backend.register_backend('xla', _create_xla_process_group, devices='xla')
25+
26+
27+
class ProcessGroupXla(ProcessGroup):
28+
...
29+
def allreduce(self, tensors, all_reduce_options):
30+
...
31+
def allgather(self, output_tensors_list, input_tensors, opts=None):
32+
...
33+
```
34+
35+
The corresponding xla dist backend is initialized when we call:
36+
```Python
37+
def _mp_fn(rank):
38+
dist.init_process_group("xla", init_method='xla://')
39+
40+
In this way, collective ops will be called based on the progress group instance:
41+
42+
# E.g., pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
43+
@_exception_logger
44+
def all_gather(tensor_list, tensor, group=None, async_op=False):
45+
...
46+
group = group or _get_default_group()
47+
work = group.allgather([tensor_list], [tensor]) # uses ProcessGroupXla.allgather instead
48+
```
49+
50+
### Dynamo case
51+
For dynamo case, certain collective ops are remapped to the new function in [pytorch/torch/distributed/_functional_collectives.py](https://github.com/pytorch/pytorch/blob/v2.5.0-rc10/torch/distributed/_functional_collectives.py#L1129-L1150). For example, `all_reduce()` will be mapped to `all_reduce_inplace()`, where eventually `torch.ops._c10d_functional.all_reduce()`. Once we reach the _c10d_functional, we can rewrite the op through PyTorch/Xla lowering:
52+
53+
54+
```C++
55+
at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
56+
std::string /*group_name*/) {...}
57+
58+
TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
59+
m.impl("all_reduce", all_reduce);
60+
}
61+
```
62+
63+
64+
## API description
65+
66+
For release 2.5, we now support four collective operations for both Dynamo and non-Dynamo cases. Our goal is to align the distributed operation (dist op) APIs with PyTorch's upstream implementation. While the function signatures remain consistent, certain input restrictions still apply.
67+
For instance, specifying multiple groups for distributed collective operations is not yet supported. For usage examples, refer to [test_collective_ops_tpu.py](https://github.com/pytorch/xla/blob/v2.5.0-rc10/test/pjrt/test_collective_ops_tpu.py), which demonstrates the use of dist ops in both Dynamo and non-Dynamo scenarios.
68+
Below are the details for each operation:
69+
```Python
70+
dist.all_reduce(input: torch.Tensor, op: dist.ReduceOp = ReduceOp.SUM)
71+
```
72+
`all_reduce` performs an in-place reduction on the `input` tensor by aggregating data from all nodes.
73+
74+
```Python
75+
dist.all_gather_into_tensor(output, input)
76+
```
77+
`all_gather_into_tensor` gathers the input tensor from all nodes and updates the `output` tensor in-place. It also returns an alias of the output.
78+
79+
```Python
80+
dist.reduce_scatter_tensor(output, input, op: dist.ReduceOp = ReduceOp.SUM)
81+
```
82+
`reduce_scatter_tensor` reduces the input tensor across all nodes and distributes the result to the `output` tensor in-place. It returns an alias of the output.
83+
84+
```Python
85+
dist.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None)
86+
```
87+
`all_to_all_single` function performs an all-to-all communication, updating the output tensor in-place and returning its alias.
88+
89+
Note: Although `output_split_sizes` and `input_split_sizes` are accepted as arguments, they must be either None or set to all 1s. This limitation reflects a compromise between maintaining PyTorch’s API signature and the constraints of the XLA AllToAll operation.

0 commit comments

Comments
 (0)