The format

  • As explained in Floating point numbers

  • FP8 (E4M3): 1 bit sign, 4 bits exponent, 3 bits mantissa; range: to .

    • more precision
    • *WARNING: E4M3’s dynamic range is extended by not representing infinities and having only one mantissa bit-pattern for NaNs. Greater range achieved is much more useful than supporting multiple encodings for the special values
  • FP8 (E5M2): 1 bit sign, 5 bits exponent, 2 bits mantissa; range: to

    • more range

When to use E4M3 and E5M2

  • E4M3, more precision, weights, forward pass activations
  • E5M2, more range, gradients

Mixed precision recipe

For FP16

  • Partition the DL network graph into safe and unsafe regions

    • Safe regions contain operations (MLPs)
      • benefiting from reduced precision
      • whose outputs dynamic ranges are similar to the inputs
    • Unsafe examples: exponentiation
  • Loss scaling

    • Use the scaling factor during the backward pass to ensure all gradient values are in the correct range. Mentioned in Gradient scaling (fp16)

For FP8

  • Partition the DL network graph into safe and unsafe regions

    • Unsafe regions does not necessarily need to be FP32, FP8 training recipe can be combined with FP16/BF16 recipe
    • Explicit casts are not enough - FP8 operators need to use higher precision internally and be able to output higher precision output
      • in the matmul, the internal accumulator is in FP32, even if the hardware does an FP8 operation
        • Tensor Cores on Hopper GPUs have the option to accumulate matrix products directly into FP32, resulting in better numerical accuracy and avoiding the need for a separate casting kernel.
        • (only for hopper cores)
      • thus, we can output high precision even when using a FP8 linear layer
  • One loss scaling is not enough β‡’ Use per-tensor scaling factors

    • Scaling factors are needed in both passes
      • E4M3 for forward, E5M2 for backward

Choosing the scaling factor

  • Simple conceptually (find the maximum, and scale to that the max is in range) but it’s hard in practice
  • Impossible to keep the entire high precision output in the high speed memory to find the maximum and scale with it
  • To overcome that, we need to know the scaling factor before seeing the output β‡’ keep track of history