-
FSDP & torch.compile https://pytorch.org/blog/maximizing-training-throughput/
-
torch.compile, the missing manual, https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.ivdr7fmrbeab
Tips
Accumulated cache size
FSDP
Compiling the backward at first call
- Typically, the act of compiling a module also gives you the compiled backwards. However, you can also directly run compiled autograd with torch._dynamo.compiled_autograd which will directly compile the autograd graph executed by a backward() call. There are three situations when you might want to use this: (1) you will get some performance out of the box as accumulate grad nodes can be fused into the compiled region (they can’t with traditional AOTAutograd), (2) your forwards graph cannot be compiled all in one go because it has some dynamism, but your backwards graph is always the same each iteration, (3) you are making use of nontrivial autograd features like hooks which cannot be compiled ahead of time, and whose compilation must be deferred until you actually run backward(). This is especially common when compiling distributed wrappers like FSDP.
Dealing with buffers
- You may need to fill your buffers (e.g. rope
model.rot_emb.compute_freqs_cis
) post-fsdp call to avoid graph break with torch.compile