Multi-Node

1st approach (Torchrun on multiple machines)

  • Example with two nodes
torchrun \
--nproc_per_nodes=$num_gpus_for_node{i} \ ## doesn't have to be the same for all nodes
-- nnodes=2 \
--node_rank={0,1} \ #depends on which machine we're launching this 
## rendez-vous arguments
--rdvzv_id=456 \ #any random number
--rdvz_backend=c10d \ ## recommended backend
--rdzv_endpoint=IP_address:port ## of any of the participant nodes (one with high networking bandwidth)

2nd approach (slurm)

#!/bin/bash

#SBATCH --job-name=multinode-example
#SBATCH --nodes=4
#SBATCH --ntasks=4
#SBATCH --gpus-per-task=1
#SBATCH --cpus-per-task=4

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)

echo Node IP: $head_node_ip
export LOGLEVEL=INFO

srun torchrun \
--nnodes 4 \
--nproc_per_node 1 \
--rdzv_id $RANDOM \
--rdzv_backend c10d \
--rdzv_endpoint $head_node_ip:29500 \
/shared/examples/multinode_torchrun.py 50 10

Torchrun

Usage

torchrun
    --standalone
    --nnodes=1
    --nproc-per-node=$NUM_TRAINERS
    YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

Arguments

  • --nnodes, total number of ndoes
  • --nproce-per-node, number of GPU per node
  • ---standalone, for single-node multi worker

Implicit rank and worldsize

  • No need to pass rank and world_size anymore
  • rank is now accessed by rank = os.environ["LOCAL_RANK"]

Recovering from a snapshot

  • Place at the beginning of your training script
def main():
	## some code
	if os.path.exists(snapshot_path):
		self._load_snapshot(snapshot_path)

def _load_snapshot(self,snapshot_path):
	snapshot = torch.load(snapshot_path)
	self.model.load_state_dict(snapshot["MODEL_STATE"])
	## recover anything else anything interest

def _save_snapshot(self,snapshot_path):
	snapshot = {}
	## save your stuff
	torch.save(snapshot, snapshot_path)

General Structure of a torchrun training script

from torch.distributed.elastic.multiprocessing.errors import record

@record
def main():
  load_checkpoint(checkpoint_path)
  initialize()
  train()

def train():
  for batch in iter(dataset):
    train_step(batch)

    if should_checkpoint:
      save_checkpoint(checkpoint_path)

General

Launching the distributed training

  • If not using torchrun
    • We will use torch.multiprocessing
    • Given our training script of this form fsdp_main(rank, world_size, args)
      • We call
    if __name__ == "__main__": 
    	mp.spawn(fsdp_main,
    	args=(WORLD_SIZE, args),
    	nprocs=WORLD_SIZE,
    	join=True)
    
    • The rank argument is automatically created and fed to fsdp_main
  • If using torchrun
    • just launch fsdp_main()
    • torchrun will take care of launching multiple processes and set the correct environment variables for each

Constructing the process group

  • Setup the process group at the beginning of the script

  • world_size = total number of processes/GPUs in your training

  • rank= unique identifier assigned to each process

  • If not using torchrun

    • os.environ["MASTER_ADDR"] = IP address of the machine running the rank 0 process

      • single machine setup use localhost
    • os.environ["MASTER_PORT"] = port of the machine running the rank 0 process

    • torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)

      • setups the process group and its communication primitive
  • If using torchrun

    • everything is setup by torchrun, including rank and world_size
    • you just need to run torch.distributed.init_process_group(backend="nccl)

Destroying the process group

  • At the end of the training script
  • `torch.distributed.destroy_process_group()

Creating the data loader

  • just need to add the argument sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=True)
    • DistributedSampler ensures the batches are split evenly among the data-parallel processes
  • and set shuffle=False because we’re passing a sampler
  • Total code
train_loader = torch.utils.data.DataLoader(
					 dataset,
					 num_workers = 2,
                     pin_memory = True,
                     shuffle = False,
                     batch_size = args.batch_size, 
                     sampler = sampler)

Wrapping the model

torch.cuda.set_device(rank)
model = Net().to(rank) ## IMPORTANT TO DO BEFORE CALLING FSDP
model = FSDP(model)
  • Pytorch DDP uses autograd hooks registered at construction time (when constructing the DDP model) to trigger gradient sychronization, when calling .backward()
    • When this is done, averaged gradients are written to the param.grad field of all parameters.
    • Agnostic to the communication framework (NCCL, …)