- Good code from ms-amp https://github.com/Azure/MS-AMP/blob/main/msamp/deepspeed/runtime/fp8/fused_optimizer.py

# Summary

- Gradient Scaling:
- The narrow dynamic range of FP8 can lead to underflow or overflow in gradients.
- We introduced an automatic scaling technique that adjusts a scaling factor μ on-the-fly during training.
- If too many gradient values hit the maximum representable value in FP8, we decrease μ to prevent overflow.
- Conversely, if values consistently stay below a threshold, we gradually increase μ to mitigate underflow.

- Precision Decoupling in Optimizer:
- Not all components in the optimizer require the same level of precision.
- We found that gradient statistics (like first-order moments in Adam) can tolerate lower precision.
- However, master weights and second-order moments are more sensitive and require higher precision.
- This insight led us to use FP8 for first-order moments, but higher precision (e.g., FP16 with tensor scaling) for master weights.

- Tensor Scaling:
- We use per-tensor scaling factors to better utilize the limited range of FP8.
- This allows us to map the values of each tensor to the representable range of FP8 more effectively.

- Careful Handling of All-Reduce Operations:
- All-reduce operations in distributed training can be particularly prone to underflow/overflow issues with FP8.
- We developed a technique using a shared global minimum scaling factor across GPUs to ensure stable gradient aggregation.

- Adaptive Scaling for Activation in Parallel Training:
- In sequence and tensor parallelism, we added FP8 conversion before all-gather and reduce-scatter operations on activations.
- This helps maintain numerical stability while reducing communication costs.

- ZeRO Optimizer Adaptation:
- We modified the ZeRO optimizer to distribute whole FP8 tensors (rather than partitions) along with their scaling factors.
- This ensures that tensor scaling information is properly maintained in distributed scenarios.

- Monitoring and Fallback Mechanisms:
- While not explicitly mentioned in the paper, it’s generally good practice to monitor training stability.
- This could involve watching for NaN values, unusually large gradients, or sudden spikes in loss.
- Having fallback mechanisms (e.g., temporarily increasing precision or adjusting learning rates) can help recover from potential instabilities.

- Communication Optimizations:
- For distributed training, we implemented custom NCCL kernels that work directly with FP8 data, avoiding the need for type conversion during all-reduce operations.
- We used ring-based algorithms for all-reduce to maximize bandwidth utilization across GPU interconnects.

- Kernel Fusion:
- Where possible, we fused multiple operations into single kernels to reduce kernel launch overhead and maximize data reuse in on-chip memory.
- This was particularly effective for operations like layer normalization and activation functions.

- Mixed Precision Strategies:
- We carefully analyzed which parts of the computation could use FP8 without loss of accuracy, and where higher precision was needed.
- This led to a hybrid approach where some operations (like accumulations) used higher precision internally but stored results in FP8.