Implementation details

Hybrid sharding

  • Given a number of workers/GPUs , the sharding happens only across subsets of size , with replication across different subsets.
    • e.g you have two nodes with 8xA100 GPUs and very bad intra-node communication
    • You have and
  • You can use hybrid sharding to shard model parameters inside each node, and then have replication across nodes.
  • Each forward/backward pass, you have similar all-gather and reduce-scatter operations within each node, namely to get model parameters from other GPUs (intra-node) and compute intermediate activations and gradients.
  • You further have another all-gather across nodes to get an averaged gradient value for the total mini-batch of data being processed in that training step.

Source-code

Auto-wrap-policy

  • fsdp_auto_wrap_policy: (Optional [callable]): A callable specifying a policy to recursively wrap layers with FSDP. Note that this policy currently will only apply to child modules of the passed in module. The remainder modules are always wrapped in the returned FSDP root instance.

  • Can use TRANSFORMER_BASED_WRAP to automatically wrap transformer layers

  • For some architectures such as Transformer encoder-decoders, some parts of the model such as embedding table is being shared with both encoder and decoder. In this case, we need to place the embedding table in the outer FSDP unit so that it could be accessed from both encoder and decoder. In addition, by registering the layer class for a transformer, the sharding plan can be made much more communication efficient.

Backward prefetch

  • The backward prefetch setting controls the timing of when the next FSDP unit’s parameters should be requested.
  • By setting it to BACKWARD_PRE, the next FSDP’s unit params can begin to be requested and arrive sooner before the computation of the current unit starts.
    • This overlaps the all_gather communication and gradient computation which can increase the training speed in exchange for slightly higher memory consumption.
  • By setting it to BACKWARD_POST, it means that the next FSDP unit’s params will not be requested until the current FSDP unit processing is complete, thus minimizing memory overhead.

Code snippets

Imports

import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    enable_wrap,
    wrap,
)

Distributed training setup

  • world_size= torch.cuda.device_count()
  • Using torchrun sets the world_size and rank automatically
def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

Code snippets

  • All data, tensors and models need to moved to the correct device using .to(rank)

  • To compute a loss over a loop:

def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
    model.train()
    ddp_loss = torch.zeros(2).to(rank)
    if sampler:
        sampler.set_epoch(epoch)
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(rank), target.to(rank)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target, reduction='sum')
        loss.backward()
        optimizer.step()
        ddp_loss[0] += loss.item()
        ddp_loss[1] += len(data)

    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
    if rank == 0:
        print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))

Launching the distributed training

WORLD_SIZE = torch.cuda.device_count()
mp.spawn(fsdp_main,
	args=(WORLD_SIZE, args),
	nprocs=WORLD_SIZE,
	join=True)

Setting up everything

def fsdp_main(rank, world_size, args):
    setup(rank, world_size)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    dataset1 = datasets.MNIST('../data', train=True, download=True,
                        transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                        transform=transform)

    sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
    sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)

    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
    cuda_kwargs = {'num_workers': 2,
                    'pin_memory': True,
                    'shuffle': False}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)

    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
    my_auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy, min_num_params=100
    )
    
    torch.cuda.set_device(rank)  ## IMPORTANT


    init_start_event = torch.cuda.Event(enable_timing=True)
    init_end_event = torch.cuda.Event(enable_timing=True)

    model = Net().to(rank) ## IMPORTANT TO DO BEFORE CALLING FSDP
    model = FSDP(model)

    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    init_start_event.record()
    for epoch in range(1, args.epochs + 1):
        train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
        test(model, rank, world_size, test_loader)
        scheduler.step()

    init_end_event.record()

    if rank == 0:
        print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
        print(f"{model}")

    if args.save_model:
        # use a barrier to make sure training is done on all ranks
        dist.barrier()
        states = model.state_dict()
        if rank == 0:
            torch.save(states, "mnist_cnn.pt")

    cleanup()

Saving a model

if args.save_model:
	# use a barrier to make sure training is done on all ranks
	dist.barrier()
	states = model.state_dict()
	if rank == 0:
		torch.save(states, "mnist_cnn.pt")