Multi-Node
1st approach (Torchrun on multiple machines)
- Example with two nodes
torchrun \
--nproc_per_nodes=$num_gpus_for_node{i} \ ## doesn't have to be the same for all nodes
-- nnodes=2 \
--node_rank={0,1} \ #depends on which machine we're launching this
## rendez-vous arguments
--rdvzv_id=456 \ #any random number
--rdvz_backend=c10d \ ## recommended backend
--rdzv_endpoint=IP_address:port ## of any of the participant nodes (one with high networking bandwidth)
2nd approach (slurm)
#!/bin/bash
#SBATCH --job-name=multinode-example
#SBATCH --nodes=4
#SBATCH --ntasks=4
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=4
nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
echo Node IP: $head_node_ip
export LOGLEVEL=INFO
srun torchrun \
--nnodes 4 \
--nproc_per_node 1 \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:29500 \
/shared/examples/multinode_torchrun.py 50 10
Torchrun
Usage
torchrun
--standalone
--nnodes=1
--nproc-per-node=$NUM_TRAINERS
YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
Arguments
--nnodes
, total number of ndoes--nproce-per-node
, number of GPU per node---standalone
, for single-node multi worker
Implicit rank and worldsize
- No need to pass
rank
andworld_size
anymore rank
is now accessed byrank = os.environ["LOCAL_RANK"]
Recovering from a snapshot
- Place at the beginning of your training script
def main():
## some code
if os.path.exists(snapshot_path):
self._load_snapshot(snapshot_path)
def _load_snapshot(self,snapshot_path):
snapshot = torch.load(snapshot_path)
self.model.load_state_dict(snapshot["MODEL_STATE"])
## recover anything else anything interest
def _save_snapshot(self,snapshot_path):
snapshot = {}
## save your stuff
torch.save(snapshot, snapshot_path)
General Structure of a torchrun training script
from torch.distributed.elastic.multiprocessing.errors import record
@record
def main():
load_checkpoint(checkpoint_path)
initialize()
train()
def train():
for batch in iter(dataset):
train_step(batch)
if should_checkpoint:
save_checkpoint(checkpoint_path)
General
Launching the distributed training
- If not using torchrun
- We will use
torch.multiprocessing
- Given our training script of this form
fsdp_main(rank, world_size, args)
- We call
if __name__ == "__main__": mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
- The
rank
argument is automatically created and fed tofsdp_main
- We will use
- If using torchrun
- just launch
fsdp_main()
- torchrun will take care of launching multiple processes and set the correct environment variables for each
- just launch
Constructing the process group
-
Setup the process group at the beginning of the script
-
world_size
= total number of processes/GPUs in your training -
rank
= unique identifier assigned to each process -
If not using torchrun
-
os.environ["MASTER_ADDR"]
= IP address of the machine running the rank 0 process- single machine setup ⇒ use localhost
-
os.environ["MASTER_PORT"]
= port of the machine running the rank 0 process -
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
- setups the process group and its communication primitive
-
-
If using torchrun
- everything is setup by torchrun, including rank and world_size
- you just need to run
torch.distributed.init_process_group(backend="nccl)
Destroying the process group
- At the end of the training script
- `torch.distributed.destroy_process_group()
Creating the data loader
- just need to add the argument
sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=True)
DistributedSampler
ensures the batches are split evenly among the data-parallel processes
- and set
shuffle=False
because we’re passing a sampler - Total code
train_loader = torch.utils.data.DataLoader(
dataset,
num_workers = 2,
pin_memory = True,
shuffle = False,
batch_size = args.batch_size,
sampler = sampler)
Wrapping the model
torch.cuda.set_device(rank)
model = Net().to(rank) ## IMPORTANT TO DO BEFORE CALLING FSDP
model = FSDP(model)
- Pytorch DDP uses autograd hooks registered at construction time (when constructing the DDP model) to trigger gradient sychronization, when calling .backward()
- When this is done, averaged gradients are written to the param.grad field of all parameters.
- Agnostic to the communication framework (NCCL, …)