
code snippet for bfloat16 optimizer with kahan summation
 https://github.com/pytorch/torchdistx/pull/52
 should be in pytorch distributed too

keep a variable $c$ , a running compensation of the rounding errors (effectively meant to be stay within the safe range of the floating representation, and thus not suffer from accumulating rounding errors)

Important for FP8 training.
Worstcase relative error
 Computed as $∣E_{n}∣∣S_{n}∣ $ where $S_{n}=∑x_{i}$ (with infinite precision) and $S_{n}+E_{n}$ is the result with compensated summation.
 All summation errors are relative to
 the condition number $C$: $∣∑x_{i}∣∑∣x_{i}∣ $
 represents the intrinsic sensitivity of the summation problem to errors, regardless of how it is computed
 For uncorrelated numbers, the sum is a random walk and the condition number grows as $O(n )$
 $ϵ$ is the machine precision of the arithmetic being employed
 Def: smallest positive number that, when added to 1.0, produces a result different from 1.0 using the arithmetic
 For fp32, it is the number of bits in the mantissa i.e. $2_{−23}$
 the condition number $C$: $∣∑x_{i}∣∑∣x_{i}∣ $
 The relative error bound for naive summation (simply adding the numbers in sequence, rounding at each step) grows as $O(ϵn)∗C$
 Kahan summation relative error is bounded by $O(ϵ_{2}n)∗C$. With enough precision, this is effectively independent of $n$.
Example

Suppose we are using sixdigit decimal floatingpoint arithmetic,
sum
has attained the value 10000.0, and the next two values ofinput[i]
are 3.14159 and 2.71828. The exact result is 10005.85987, which rounds to 10005.9. With a plain summation, each incoming value would be aligned withsum
, and many loworder digits would be lost (by truncation or rounding). The first result, after rounding, would be 10003.1. The second result would be 10005.81828 before rounding and 10005.8 after rounding. This is not correct. 
However, with compensated summation, we get the correctly rounded result of 10005.9.
Pseudocode
**function** KahanSum(input)
// Prepare the accumulator.
**var** sum = 0.0
// A running compensation for lost loworder bits.
**var** c = 0.0
// The array _input_ has elements indexed input[1] to input[input.length].
**for** i = 1 **to** input.length **do**
// _c_ is zero the first time around.
**var** y = input[i]  c
// Alas, _sum_ is big, _y_ small, so loworder digits of _y_ are lost.
**var** t = sum + y
// _(t  sum)_ cancels the highorder part of _y_;
// subtracting _y_ recovers negative (low part of _y_)
c = (t  sum)  y
// Algebraically, _c_ should always be zero. Beware
// overlyaggressive optimizing compilers!
sum = t
// Next time around, the lost low part will be added to _y_ in a fresh attempt.
**next** i
**return** sum