Summary

  1. Gradient Scaling:
    • The narrow dynamic range of FP8 can lead to underflow or overflow in gradients.
    • We introduced an automatic scaling technique that adjusts a scaling factor μ on-the-fly during training.
    • If too many gradient values hit the maximum representable value in FP8, we decrease μ to prevent overflow.
    • Conversely, if values consistently stay below a threshold, we gradually increase μ to mitigate underflow.
  2. Precision Decoupling in Optimizer:
    • Not all components in the optimizer require the same level of precision.
    • We found that gradient statistics (like first-order moments in Adam) can tolerate lower precision.
    • However, master weights and second-order moments are more sensitive and require higher precision.
    • This insight led us to use FP8 for first-order moments, but higher precision (e.g., FP16 with tensor scaling) for master weights.
  3. Tensor Scaling:
    • We use per-tensor scaling factors to better utilize the limited range of FP8.
    • This allows us to map the values of each tensor to the representable range of FP8 more effectively.
  4. Careful Handling of All-Reduce Operations:
    • All-reduce operations in distributed training can be particularly prone to underflow/overflow issues with FP8.
    • We developed a technique using a shared global minimum scaling factor across GPUs to ensure stable gradient aggregation.
  5. Adaptive Scaling for Activation in Parallel Training:
    • In sequence and tensor parallelism, we added FP8 conversion before all-gather and reduce-scatter operations on activations.
    • This helps maintain numerical stability while reducing communication costs.
  6. ZeRO Optimizer Adaptation:
    • We modified the ZeRO optimizer to distribute whole FP8 tensors (rather than partitions) along with their scaling factors.
    • This ensures that tensor scaling information is properly maintained in distributed scenarios.
  7. Monitoring and Fallback Mechanisms:
    • While not explicitly mentioned in the paper, it’s generally good practice to monitor training stability.
    • This could involve watching for NaN values, unusually large gradients, or sudden spikes in loss.
    • Having fallback mechanisms (e.g., temporarily increasing precision or adjusting learning rates) can help recover from potential instabilities.
  • Communication Optimizations:
    • For distributed training, we implemented custom NCCL kernels that work directly with FP8 data, avoiding the need for type conversion during all-reduce operations.
    • We used ring-based algorithms for all-reduce to maximize bandwidth utilization across GPU interconnects.
  • Kernel Fusion:
    • Where possible, we fused multiple operations into single kernels to reduce kernel launch overhead and maximize data reuse in on-chip memory.
    • This was particularly effective for operations like layer normalization and activation functions.
  • Mixed Precision Strategies:
    • We carefully analyzed which parts of the computation could use FP8 without loss of accuracy, and where higher precision was needed.
    • This led to a hybrid approach where some operations (like accumulations) used higher precision internally but stored results in FP8.