By Apoorv Khandelwal and Peter Curtin
The easiest way to run PyTorch on multiple GPUs or machines.
torchrunx is a functional utility for distributing PyTorch code across devices. This is a more convenient, robust, and featureful alternative to CLI-based launchers, like torchrun, accelerate launch, and deepspeed.
It enables complex workflows within a single script and has useful features even if only using 1 GPU.
pip install torchrunxRequires: Linux. If using multiple machines: SSH & shared filesystem.
Suppose we have some distributed training function (needs to run on every GPU):
def distributed_training(output_dir: str, num_steps: int = 10) -> str:
    # returns path to model checkpointClick to expand (implementation)
from __future__ import annotations
import os
import torch
import torch.nn as nn
def distributed_training(output_dir: str, num_steps: int = 10) -> str | None:
    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])
    model = nn.Linear(10, 10)
    model.to(local_rank)
    ddp_model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
    optimizer = torch.optim.AdamW(ddp_model.parameters())
    for step in range(num_steps):
        optimizer.zero_grad()
        inputs = torch.randn(5, 10).to(local_rank)
        labels = torch.randn(5, 10).to(local_rank)
        outputs = ddp_model(inputs)
        torch.nn.functional.mse_loss(outputs, labels).backward()
        optimizer.step()
    if rank == 0:
        os.makedirs(output_dir, exist_ok=True)
        checkpoint_path = os.path.join(output_dir, "model.pt")
        torch.save(model, checkpoint_path)
        return checkpoint_path
    return NoneWe can distribute and run this function (e.g. on 2 machines x 2 GPUs) using torchrunx!
import logging
import torchrunx
logging.basicConfig(level=logging.INFO)
launcher = torchrunx.Launcher(
    hostnames = ["localhost", "second_machine"],  # or IP addresses
    workers_per_host = "gpu"  # default, or just: 2
)
results = launcher.run(
    distributed_training,
    output_dir = "outputs",
    num_steps = 10,
)Once completed, you can retrieve the results and process them as you wish.
checkpoint_path: str = results.rank(0)
                 # or: results.index(hostname="localhost", local_rank=0)
# and continue your script
model = torch.load(checkpoint_path, weights_only=False)
model.eval()See more examples where we fine-tune LLMs using:
Refer to our API, Features, and Usage for many more capabilities!