Floating point representation
- The number of bits in the exponents represent the scale/range of numbers we can represent
- The number of bits in the mantissa represents the precision of numbers we can represent
- Because of the exponent fraction format, there is inevitably bigger gaps in between large numbers than smaller numbers
- Representation of fp32
Bits assignments, range, and granularity
-
float32: 1 bit sign, 8 bits exponent, 23 bits mantissa; range: to
-
bfloat16: 1 bit sign, 8 bits exponent, 7 bits mantissa; range: to (more range, less granular)
-
FP16: 1 bit sign, 5 bits exponent, 10 bits mantissa; range: to (less range, more granular)
-
FP8 (E4M3): 1 bit sign, 4 bits exponent, 3 bits mantissa; range: to .
- *WARNING: E4M3βs dynamic range is extended by not representing infinities and having only one mantissa bit-pattern for NaNs. Greater range achieved is much more useful than supporting multiple encodings for the special values
-
FP8 (E5M2): 1 bit sign, 5 bits exponent, 2 bits mantissa; range: to
-
Converting from fp32 to bfloat16 is easy: the exponent is kept the same and the significand is rounded or truncated from 24 bits to 8; hence overflow and underflow are not possible in the ki conversion.
BF16 bad precision example
- In bf16,
- Because of this lack of precision, multiple sums can lead to overlfow in bf16, which wouldnβt happen in fp16 or fp32.
Encoding
- = sign bit, = exponent bits, = fraction or mantissa bits
- Value = (in most cases, special values apply e.g. zero and infinity)
- exponent bias =
Exponent encoding
- To adequately represent values below 1, the exponent is encoded using an offset-binary representation, with the offset usually equal to (middle of the representable range), also called exponent bias. Equal to 127 for fp32 and 15 for fp16.
- For fp16,
- Exponent bias = =15
- and are special cases.
- When using , we go into SUBNORMAL REGIME and the equation changes, the smallest positive value in this regime is
- Thus, the minimum positive normal value is
- The maximum positive normal value (excluding infinity) is
Gradient scaling (fp16)
- When training in fp16, Maximum normalized value is 65,504 and minimum normalized is 2-14= ~6.10e-5. We need to prevent underflow!
- AMP/fp16 may not work for every model! For example, most bf16-pretrained models cannot operate in the fp16 numerical range of max 65504 and will cause gradients to overflow instead of underflow. In this case, the scale factor may decrease under 1 as an attempt to bring gradients to a number representable in the fp16 dynamic range. While one may expect the scale to always be above 1, our GradScaler does NOT make this guarantee to maintain performance.
- Loss scaling to shift gradient values in representable range of fp16
-
- Maintain a primary copy of weights in FP32.
- For each iteration:
- Make an FP16 copy of the weights.
- Forward propagation (FP16 weights and activations).
- Multiply the resulting loss with the scaling factor S.
- Backward propagation (FP16 weights, activations, and their gradients).
- Multiply the weight gradient with 1/S.
- Complete the weight update (including gradient clipping, etc.).
-