-
https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/
-
Available modules
- Linear, LayerNormLinear, LayerNormMLP, full transformer layer
Features
fp8_autocast
fp8_autocast
context manager tells TE modules to use FP8 computation internally- Enables specifying the details of the FP8 recipe (window range, different algorithm to calculate the scaling factors)
- It does not change anything else in the model
- To do mixed FP16/FP8 training, you need to combine it with native framworks’ AMP or casting of the model
- Backward pass inherits the settings of the forward pass and should be outside of the
fp8_autocast
region
fp8_model_init
fp8_model_init
context manager tells TE modules to create their weights in FP8 only, without high precision copy- Useful for inference
- Enables full memory saving from FP8
Gradient accumulation fusion¶
- only for Hopper cores, Tensor Cores on Hopper GPUs have the option to accumulate matrix products directly into FP32, resulting in better numerical accuracy and avoiding the need for a separate casting kernel.
- Thus, Transformer Engine provides an option to directly generate FP32 gradients for weight tensors. The FP32 gradients are not output to the parameter’s
grad
tensor, but rather to amain_grad
tensor that must be initialized before the backward pass.
FP8 weight caching¶
- useful if we’re doing gradient accumulation
Future
- More granular scaling factors (tiles of tensor)
- Possible optimized building blocks for SSMs
Linear class
-
fuse_wgrad_accumulation
- Transformer Engine provides an option to directly generate FP32 gradients for weight tensors. The FP32 gradients are not output to the parameter’s
grad
tensor, but rather to amain_grad
tensor that must be initialized before the backward pass.
- Transformer Engine provides an option to directly generate FP32 gradients for weight tensors. The FP32 gradients are not output to the parameter’s
-
FP8 execution requires 2D input matrices with height divisible by 8 and width divisible by 16. Both dimensions need to be multiple of 16, since the backward will invoke a transpose.