Links

Experimental/closed

FP8

  • What is the plan for fp8 and FSDP?

  • For per-parameter-sharding FSDP (fully_shard in this PR), we can flexibly support fp8 all-gather due to the different sharding. The approach will be something like

model = ...
swap_linear_with_float8_linear(model, Float8Linear)
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
for transformer_block in model.blocks:
    fully_shard(transformer_block, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)
  • Each fully_shard call constructs 1 communication group (i.e. each transformer block’s parameters/gradients are communicated together).
  • Non-fp8 parameters use param_dtype for all-gather and forward/backward compute. This means that their gradients use this dtype.
  • The Float8Linear.weights are treated specially by FSDP. param_dtype only dictates their gradients’ dtype.
    • They use Float8Tensor, which stores the raw data in float8 but advertises its dtype to autograd as bfloat16. More details as to why in autograd
  • This allows for mixed fp8 all-gather where Float8Linear.weights use fp8, while other parameters use bf16, and the reduce-scatter uniformly uses bf16.

API differences

FSDP2:

with torch.device("meta"):
    model = Transformer()
for module in model.modules():
    if isinstance(module, TransformerBlock):
        fully_shard(module)
fully_shard(model)
for tensor in itertools.chain(model.parameters(), model.buffers()):
    assert tensor.device == torch.device("meta")
# Allocate buffers and sharded parameters on GPU
model.to_empty("cuda")
# Run user-defined initializers
model.init_weights() # or `model.apply(init_weights)`

Why?

FlatParamer flaws

  • FSDP1 defines a FlatParameter by flattening and concatenating a group of parameters to represent a communication bucket.
  • However, this FlatParameter complicates applying different behaviors to individual parameters within the FlatParameter,
    • e.g. parameter freezing, parameter casting, etc.,
    • hurting composability,
    • and it complicates the internal implementation, e.g. making state dict logic thousands of lines and requiring additional communications.

Improvements

  • FSDP2 represents sharded parameters as DTensors sharded on dim-0,
  1. Flexible fp8 all-gather: fp8 weights and other non-fp8 parameters can be flexibly mixed in the same all-gather
  2. Flexible frozen parameters: frozen and non-frozen parameters can be flexibly mixed in the same communication group without using extra memory
  3. Communication-free sharded state dicts: matching the training and state dict representation simplifies and speeds up checkpointing
  4. Future communication optimization in the compiler: a partial graph compiler like torch.compile can change the communication groups for all-gather/reduce-scatter
  5. improved memory management system that achieves lower and deterministic GPU memory by avoiding recordStream and does so without any CPU synchronization.