 fp8_experimental https://github.com/pytorchlabs/float8_experimental?tab=readmeovfile (compatible with torch.compile)
 torchtitan support for it (https://github.com/pytorch/torchtitan)
Summary

Primitives to express a float8 matrix multiplication with pertensor scaling

the torch.float8_e4m3 and torch.float8_e5m2 dtypes

the torch._scaled_mm op
 calls into CUBLAS

float8_experimental , a lightweight library for accelerating training with float8 in native PyTorch with support for torch.compile and distributed. Initial results show throughput speedups of up to 1.2x on small scale (8 GPUs) LLaMa pretraining jobs. Peak memory usage improvements and large scale distributed support are coming soon.
Detailed
autograd
 In the context of float8 training, for a tensor x we usually need x.dtype to be float8 but x.grad.dtype to be bfloat16. Autograd currently enforces x.dtype to equal x.grad.dtype for historical reasons. To get around this restriction we use Float8Tensor, which stores the raw data in float8 but advertises its dtype to autograd as bfloat16.
Use model rewrites to implement pertensor scaling

The current SOTA scaling strategy is delayed scaling; this requires stateful pertensor statistics collection for a subset of weights, activations, and gradients. A model rewrite is necessary to implement this cleanly; lighterweight approaches such as automated mixed precision 7 are not expressive enough. Even for stateless scaling strategies such as delayed scaling, a model rewrite implementation allows them to easily be compared with stateful strategies.

The current model rewrite approach we are using in float8_experimental is module swaps. In the future, we may explore module hooks and graph capture + graph pass to cover more cases.
Performance
torch.compile
 Scaling and casting tensors to float8 introduces overhead; we accept this overhead in eager mode to keep the simple and
 depend on torch.compile + inductor to recover performance. For example, LLaMa 7B training with float8 dynamic scaling has a speedup of 0.81 over bf16 in eager mode, and 1.22 with torch.compile.
inductor
 After we get a graph from torch.compile, we use inductor to generate kernels for amax and scaled cast fused into surrounding ops. We added inductor support for float8 dtypes, and optimized code generation to be performant for amax calculation, scaling and float8 cast necessary for Float8Linear