-
code snippet for bfloat16 optimizer with kahan summation
- https://github.com/pytorch/torchdistx/pull/52
- should be in pytorch distributed too
-
keep a variable , a running compensation of the rounding errors (effectively meant to be stay within the safe range of the floating representation, and thus not suffer from accumulating rounding errors)
-
Important for FP8 training.
Worst-case relative error
- Computed as where (with infinite precision) and is the result with compensated summation.
- All summation errors are relative to
- the condition number :
- represents the intrinsic sensitivity of the summation problem to errors, regardless of how it is computed
- For uncorrelated numbers, the sum is a random walk and the condition number grows as
- is the machine precision of the arithmetic being employed
- Def: smallest positive number that, when added to 1.0, produces a result different from 1.0 using the arithmetic
- For fp32, it is the number of bits in the mantissa i.e.
- the condition number :
- The relative error bound for naive summation (simply adding the numbers in sequence, rounding at each step) grows as
- Kahan summation relative error is bounded by . With enough precision, this is effectively independent of .
Example
-
Suppose we are using six-digit decimal floating-point arithmetic,
sum
has attained the value 10000.0, and the next two values ofinput[i]
are 3.14159 and 2.71828. The exact result is 10005.85987, which rounds to 10005.9. With a plain summation, each incoming value would be aligned withsum
, and many low-order digits would be lost (by truncation or rounding). The first result, after rounding, would be 10003.1. The second result would be 10005.81828 before rounding and 10005.8 after rounding. This is not correct. -
However, with compensated summation, we get the correctly rounded result of 10005.9.
Pseudo-code
**function** KahanSum(input)
// Prepare the accumulator.
**var** sum = 0.0
// A running compensation for lost low-order bits.
**var** c = 0.0
// The array _input_ has elements indexed input[1] to input[input.length].
**for** i = 1 **to** input.length **do**
// _c_ is zero the first time around.
**var** y = input[i] - c
// Alas, _sum_ is big, _y_ small, so low-order digits of _y_ are lost.
**var** t = sum + y
// _(t - sum)_ cancels the high-order part of _y_;
// subtracting _y_ recovers negative (low part of _y_)
c = (t - sum) - y
// Algebraically, _c_ should always be zero. Beware
// overly-aggressive optimizing compilers!
sum = t
// Next time around, the lost low part will be added to _y_ in a fresh attempt.
**next** i
**return** sum