Links
-
FSDP2 main github issue
-
[From TorchTitan, why FSDP2?](## Why FSDP2?)
Experimental/closed
- Separate unshard stream for each process group with FSDP
- quite important but good enough and precise wrapping of the modules might be good enough
FP8
-
For per-parameter-sharding FSDP (
fully_shard
in this PR), we can flexibly support fp8 all-gather due to the different sharding. The approach will be something like
model = ...
swap_linear_with_float8_linear(model, Float8Linear)
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
for transformer_block in model.blocks:
fully_shard(transformer_block, mp_policy=mp_policy)
fully_shard(model, mp_policy=mp_policy)
- Each
fully_shard
call constructs 1 communication group (i.e. each transformer blockâs parameters/gradients are communicated together). - Non-fp8 parameters use
param_dtype
for all-gather and forward/backward compute. This means that their gradients use this dtype. - The
Float8Linear.weight
s are treated specially by FSDP.param_dtype
only dictates their gradientsâ dtype.- They use Float8Tensor, which stores the raw data in float8 but advertises its dtype to autograd as bfloat16. More details as to why in autograd
- This allows for mixed fp8 all-gather where
Float8Linear.weight
s use fp8, while other parameters use bf16, and the reduce-scatter uniformly uses bf16.
API differences
- https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md#fsdp1âfsdp2-api-differences (torchtitan summary of fsdp1 vs fsdp2 api differnces)
FSDP2:
with torch.device("meta"):
model = Transformer()
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module)
fully_shard(model)
for tensor in itertools.chain(model.parameters(), model.buffers()):
assert tensor.device == torch.device("meta")
# Allocate buffers and sharded parameters on GPU
model.to_empty("cuda")
# Run user-defined initializers
model.init_weights() # or `model.apply(init_weights)`
Why?
FlatParamer flaws
- FSDP1 defines a
FlatParameter
by flattening and concatenating a group of parameters to represent a communication bucket. - However, this
FlatParameter
complicates applying different behaviors to individual parameters within theFlatParameter
,- e.g. parameter freezing, parameter casting, etc.,
- hurting composability,
- and it complicates the internal implementation, e.g. making state dict logic thousands of lines and requiring additional communications.
Improvements
- FSDP2 represents sharded parameters as
DTensor
s sharded on dim-0,
- Flexible fp8 all-gather: fp8 weights and other non-fp8 parameters can be flexibly mixed in the same all-gather
- Flexible frozen parameters: frozen and non-frozen parameters can be flexibly mixed in the same communication group without using extra memory
- Communication-free sharded state dicts: matching the training and state dict representation simplifies and speeds up checkpointing
- Future communication optimization in the compiler: a partial graph compiler like
torch.compile
can change the communication groups for all-gather/reduce-scatter - improved memory management system that achieves lower and deterministic GPU memory by avoiding
recordStream
and does so without any CPU synchronization.