• at inference, we can technically stick with static scaling

  • need to call precompute_float8_dynamic_scale_for_fsdp for correct dynamic scaling

# toy training loop
for _ in range(N_ITER):
    optimizer.zero_grad()
    y = m(x)
    y.sum().backward()
    optimizer.step()
 
    # specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
    # this method is optional but is highly recommended for performance
    # it calcuclates scales for all parameters in a single all-reduce
    precompute_float8_dynamic_scale_for_fsdp(model)