The narrow dynamic range of FP8 can lead to underflow or overflow in gradients.
We introduced an automatic scaling technique that adjusts a scaling factor μ on-the-fly during training.
If too many gradient values hit the maximum representable value in FP8, we decrease μ to prevent overflow.
Conversely, if values consistently stay below a threshold, we gradually increase μ to mitigate underflow.
Precision Decoupling in Optimizer:
Not all components in the optimizer require the same level of precision.
We found that gradient statistics (like first-order moments in Adam) can tolerate lower precision.
However, master weights and second-order moments are more sensitive and require higher precision.
This insight led us to use FP8 for first-order moments, but higher precision (e.g., FP16 with tensor scaling) for master weights.
Tensor Scaling:
We use per-tensor scaling factors to better utilize the limited range of FP8.
This allows us to map the values of each tensor to the representable range of FP8 more effectively.
Careful Handling of All-Reduce Operations:
All-reduce operations in distributed training can be particularly prone to underflow/overflow issues with FP8.
We developed a technique using a shared global minimum scaling factor across GPUs to ensure stable gradient aggregation.
Adaptive Scaling for Activation in Parallel Training:
In sequence and tensor parallelism, we added FP8 conversion before all-gather and reduce-scatter operations on activations.
This helps maintain numerical stability while reducing communication costs.
ZeRO Optimizer Adaptation:
We modified the ZeRO optimizer to distribute whole FP8 tensors (rather than partitions) along with their scaling factors.
This ensures that tensor scaling information is properly maintained in distributed scenarios.
Monitoring and Fallback Mechanisms:
While not explicitly mentioned in the paper, it’s generally good practice to monitor training stability.
This could involve watching for NaN values, unusually large gradients, or sudden spikes in loss.
Having fallback mechanisms (e.g., temporarily increasing precision or adjusting learning rates) can help recover from potential instabilities.
Communication Optimizations:
For distributed training, we implemented custom NCCL kernels that work directly with FP8 data, avoiding the need for type conversion during all-reduce operations.
We used ring-based algorithms for all-reduce to maximize bandwidth utilization across GPU interconnects.
Kernel Fusion:
Where possible, we fused multiple operations into single kernels to reduce kernel launch overhead and maximize data reuse in on-chip memory.
This was particularly effective for operations like layer normalization and activation functions.
Mixed Precision Strategies:
We carefully analyzed which parts of the computation could use FP8 without loss of accuracy, and where higher precision was needed.
This led to a hybrid approach where some operations (like accumulations) used higher precision internally but stored results in FP8.