The format
-
As explained in Floating point numbers
-
FP8 (E4M3): 1 bit sign, 4 bits exponent, 3 bits mantissa; range: to .
- more precision
- *WARNING: E4M3βs dynamic range is extended by not representing infinities and having only one mantissa bit-pattern for NaNs. Greater range achieved is much more useful than supporting multiple encodings for the special values
-
FP8 (E5M2): 1 bit sign, 5 bits exponent, 2 bits mantissa; range: to
- more range
When to use E4M3 and E5M2
- E4M3, more precision, weights, forward pass activations
- E5M2, more range, gradients
Mixed precision recipe
For FP16
-
Partition the DL network graph into safe and unsafe regions
- Safe regions contain operations (MLPs)
- benefiting from reduced precision
- whose outputs dynamic ranges are similar to the inputs
- Unsafe examples: exponentiation
- Safe regions contain operations (MLPs)
-
Loss scaling
- Use the scaling factor during the backward pass to ensure all gradient values are in the correct range. Mentioned in Gradient scaling (fp16)
For FP8
-
Partition the DL network graph into safe and unsafe regions
- Unsafe regions does not necessarily need to be FP32, FP8 training recipe can be combined with FP16/BF16 recipe
- Explicit casts are not enough - FP8 operators need to use higher precision internally and be able to output higher precision output
- in the matmul, the internal accumulator is in FP32, even if the hardware does an FP8 operation
- Tensor Cores on Hopper GPUs have the option to accumulate matrix products directly into FP32, resulting in better numerical accuracy and avoiding the need for a separate casting kernel.
- (only for hopper cores)
- thus, we can output high precision even when using a FP8 linear layer
- in the matmul, the internal accumulator is in FP32, even if the hardware does an FP8 operation
-
One loss scaling is not enough β Use per-tensor scaling factors
- Scaling factors are needed in both passes
- E4M3 for forward, E5M2 for backward
- Scaling factors are needed in both passes
Choosing the scaling factor
- Simple conceptually (find the maximum, and scale to that the max is in range) but itβs hard in practice
- Impossible to keep the entire high precision output in the high speed memory to find the maximum and scale with it
- To overcome that, we need to know the scaling factor before seeing the output β keep track of history