Summary

  1. Primitives to express a float8 matrix multiplication with per-tensor scaling

  2. the torch.float8_e4m3 and torch.float8_e5m2 dtypes

  3. the torch._scaled_mm op

    1. calls into CUBLAS
  4. float8_experimental , a lightweight library for accelerating training with float8 in native PyTorch with support for torch.compile and distributed. Initial results show throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs. Peak memory usage improvements and large scale distributed support are coming soon.

Detailed

autograd

  • In the context of float8 training, for a tensor x we usually need x.dtype to be float8 but x.grad.dtype to be bfloat16. Autograd currently enforces x.dtype to equal x.grad.dtype for historical reasons. To get around this restriction we use Float8Tensor, which stores the raw data in float8 but advertises its dtype to autograd as bfloat16.

Use model rewrites to implement per-tensor scaling

  • The current SOTA scaling strategy is delayed scaling; this requires stateful per-tensor statistics collection for a subset of weights, activations, and gradients. A model rewrite is necessary to implement this cleanly; lighter-weight approaches such as automated mixed precision 7 are not expressive enough. Even for stateless scaling strategies such as delayed scaling, a model rewrite implementation allows them to easily be compared with stateful strategies.

  • The current model rewrite approach we are using in float8_experimental is module swaps. In the future, we may explore module hooks and graph capture + graph pass to cover more cases.

Performance

torch.compile

  • Scaling and casting tensors to float8 introduces overhead; we accept this overhead in eager mode to keep the simple and
  • depend on torch.compile + inductor to recover performance. For example, LLaMa 7B training with float8 dynamic scaling has a speedup of 0.81 over bf16 in eager mode, and 1.22 with torch.compile.

inductor

  • After we get a graph from torch.compile, we use inductor to generate kernels for amax and scaled cast fused into surrounding ops. We added inductor support for float8 dtypes, and optimized code generation to be performant for amax calculation, scaling and float8 cast necessary for Float8Linear

Support for matmuls