Summary

Detailed

  • FSDP introduces deferred initialization that allows users to create a model instance on a dummy device and record operations invoked during initialization.

System Design

  • Example: three units

  • More specifically, FSDP decomposes the model instance into smaller units and handles each unit independently.

    • FSDP breaks down a model instance into smaller units
      • and then flattens and shards all of the parameters within each unit.
  • During forward and backward computation, FSDP only materializes unsharded parameters and gradients of one unit at a time.

Optimizing communication performance

  • Input size impact on performance
      • the plot fixes the total communication to be ≈ 1B FP32 elements and varies the size per All-Gather.
    • Even Input Size: The Nvidia NCCL [ 22 ] library offers efficient collective implementations for all-gather and reduce-scatter that require even input tensor sizes across ranks.
      • NCCL’s AllGather API (All-Gather Base) requires even input tensor size and writes outputs into one single tensor.
      • PyTorch’s ProcessGroup (All-Gather) wraps the NCCL API and enhances it by supporting uneven input tensor sizes across ranks and allowing users to provide a list of output tensors. The flexibility comes with an efficiency trade-off
        • incurs additional copies between the individual output tensors and the consolidated single large output tensor before and after the communication
    • Larger Input Size: For fixed communication volume, batching data and issuing fewer collectives improves performance by avoiding the collectives’ launch overhead and increasing network bandwidth utilization.
      • This is why full sharding with many GPUs can get very inefficient
      • Thus, to deliver highly efficient communications, FSDP organizes all parameters within one FSDP unit into a large FlatParameter.

FlatParameter Design - Sharding

  • One FlatParameter accommodates storage for all parameter tensors within one FSDP unit.

    • One FSDP unit to shard a 4 × 3 nn.Linear layer across 16 GPUs.
  • FSDP organizes all parameters within one FSDP unit into a large FlatParameter, where the FlatParameter coalesces the communications of its individual parameters and also evenly shards them across ranks.

    • More specifically, the FlatParameter is a 1D tensor constructed by concatenating 𝑝 flattened original parameters and padding on the right to achieve a size divisible by the sharding factor.
    • divides it into equal-sized chunks, and assigns one chunk per rank.
    • The FlatParameter’s gradient inherits the same unsharded and sharded shapes from the FlatParameter, and the FlatParameter and its gradient own the underlying storage of the original parameters and their gradients, respectively.
      • The sharded and unsharded FlatParameter and its gradient have the exact data layout expected by AllGather and ReduceScatter, respectively. This enables calling the collectives without any additional copies for either the input or output tensor

Code

  • FlatParameters are always considered as Tensors by outside pytorch code because of metaclass
    • you interact with it using FlatParamHandle
  • def get_shard(tensor, rank, world-size)
    • get chunk from unshard tensor chunk.clone()
    • allocates new memory (via :meth:clone) since the unsharded tensor may be deallocated after this method returns.
  • at shard time, flat_param.set_(sharded_flat_param), meaning the flat_param on the device now points to the same storage and strides as the sharded flat_param

Peak memory consumption

  • Given a model with number of elements, FSDP constructs FlatParameter /units with number of elements and
  • With a sharding factor , the peak parameter memory usage is
    • all the sharded parameters (always there) + the biggest FlatParameter (materialized during forward + backward)

Hybrid sharding

  • World size , sharding factor
  • Divide the the model in units, shard them over devices, replicate over replicas
  • For gradients, first do reduce_scatter over the “local” devices, then each device does all_reduce with its corresponding replicates devices.
  • the AllReduce collectives used in hybrid sharding operates at a smaller world size
    • they empirically achieve a better performance than invoking collectives at the global scale (in the case of full replication and full sharding), due to straggler effects (the slowest link in the chain defines the speed of execution) and larger network interference.

Autograd

  • We use Tensor views in the forward so that they are tracked by autograd.
    • _use_unsharded_views()
  • We use them in the pre-backward as well to support reentrant activation checkpointing, which needs the views to be tracked by autograd in the backward pass’s recomputed forward.

Pre-forward logic

  • There is root_pre_forward which runs on the root FSDP node
    • This starts with an attempt at lazy initialization (which only runs non-vacuously once).
  • torch/distributed/fsdp/_runtime_utils.py/pre_forward()
    • (1) This includes an opportunity to unshard currently sharded parameters such as those for the current forward (with _pre_forward_unshard)
      • torch/distributed/fsdp/_runtime_utils.py/_unshard()
      • The unsharding for a FlatParamHandle is in 3 steps
        • launched in its own pre_shard CUDA stream for the current device
          • handle.pre_shard()
            • prepares views, allocates memory in case of mixed precision
        • launched on the unshard CUDA stream
          • handle.unshard()
            • (1) allocate padded_unsharded_flat_param
            • (2) gather into padded_unsharded_flat_param for efficient comms
            • (3) and switching to using the unpadded unsharded flat parameter
          • handle.post_unshard()
            • free low precision shards if forward/backward is not using the same precision as param dtype
    • (2) Register post-backward hooks to reshard the parameters and reduce-scatter their gradients.
    • (3) also converts forward args and kwargs to the given precision.
      • recursively apply x.to(dtype) to all tensors

Communication Optimizations

  • FSDP uses a separate CUDA stream to issue the AllGathers.
  • FSDP enforces a single CUDA device per rank and uses a single process group for both AllGather and ReduceScatter, which means that its collectives run sequentially in the process group’s internal NCCL stream.
    • problematic for performance
    • but also helpful for improving memory allocation, since CUDA streams cannot share memory blocks.

Backward prefetching

  • In the backward pass, FSDP issues the ReduceScatter for the current FlatParameter and then the AllGather for the next FlatParameter.
    • Hence, the single NCCL stream forces the ReduceScatter to block the next AllGather, which in turn blocks the next gradient computation and may become exposed on the critical path.
  • To avoid two consecutive exposed communication calls in the backward pass,
    • FSDP backward prefetching issues the next AllGather before the current ReduceScatter.
    • However, as mentioned before, a challenge for eager execution is knowing which FlatParameter to AllGather next. FSDP resolved this challenge by recording the reverse forward execution order of modules as the proxy of their backward execution order.
      • (why we register backward hooks during pre_forward )
      • Moreover, the forward order is freshly recorded each iteration, meaning that the backward prefetching is compatible with dynamism across iterations.

Forward prefetching

  • For some workloads with relatively slow CPU execution, the CPU thread may not be able to issue the next forward AllGather early enough to efficiently fill the NCCL stream.
  • forward prefetching issues the next AllGather before forward computation of current FSDP unit

Memory management

  • PyTorch uses a CUDA caching allocator as a middle layer to serve GPU allocation and free requests for PyTorch programs.
  • FSDP uses a rate limiter to take into account the memory impact of the caching allocator on programs that use several CUDA streams and run fast CPU threads.
  • Producer stream: responsible for generating or preparing data that will be used by subsequent operations.
    • e.g. the AllGather destination tensor representing the unsharded FlatParameter is allocated in a producer stream
    • The producer stream runs ahead of the consumer stream, preparing data in advance
  • Consumer streams: responsible for performing computations on the data produced by the producer stream.
    • model forward passes, loss calculations
  1. PyTorch uses a caching allocator to minimize calls to cudaMalloc and cudaFree, which can be costly.
  2. The caching allocator runs on the CPU thread and must decide on memory allocation before GPU kernels actually run.
  3. For a single stream, memory blocks can be reused easily due to sequential ordering.
  4. With separate producer and consumer streams, reusing memory is more complex due to lack of inter-stream ordering guarantees.
  5. The caching allocator allocates blocks per stream, which can lead to over-allocation in the producer stream.
  6. Over-allocation can cause memory allocation failures in the consumer stream, even if the GPU has enough memory overall.
  7. Failed allocations can trigger a “cudaMalloc retry” process, which significantly reduces training performance.

Rate limiter

  • FSDP allocates the AllGather destination tensor representing the unsharded FlatParameter in a producer stream, and the forward and backward computations using the AllGathered parameters run in a consumer stream (typically the default stream)
  • For a fast CPU thread, there may be pending GPU computation kernels when the caching allocator must serve the next AllGather, leading to no block reuse.
  • Even after the blocks are not active in the AllGather producer stream, these reserved blocks can not serve the other computation stream’s allocation requests, and thus may force blocking cudaFrees and cudaMallocs.
  • FSDP offers a rate limiter that intentionally blocks the CPU thread to ensure proper caching allocator block reuse. It allows at most two inflight AllGathers, which is the minimum amount to still achieve communication and computation overlap.

Mixed precision

  • In terms of parameter management, it adheres to the standard mixed precision technique, which maintains both low and full precision copies of parameters.
  • Forward and backward computation use the low precision, and the optimizer step uses full precision.
  • In contrast to torch.amp.autocast that performs just-in-time casts at the operator level,
    • FSDP’s native mixed precision only incurs a full-to-low-precision cast per FlatParameter in its pre-forward and, if resharding after forward, its pre-backward.
  • Moreover, FSDP’s mixed precision permits running all collectives in the low precision, which saves communication volume.