-
at inference, we can technically stick with static scaling
-
need to call
precompute_float8_dynamic_scale_for_fsdp
for correct dynamic scaling
# toy training loop
for _ in range(N_ITER):
optimizer.zero_grad()
y = m(x)
y.sum().backward()
optimizer.step()
# specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
# this method is optional but is highly recommended for performance
# it calcuclates scales for all parameters in a single all-reduce
precompute_float8_dynamic_scale_for_fsdp(model)