Summary
 Computing the gradients of a given layer $b_{i}$ requires both the upper layer gradients $b_{i+1}$ and the cached activations $f_{i}(x)$.
 This uses a lot of VRAM
 You can decide to drop them and compute them twice (once in forward, again in backward)
 Leads to 2025% decrease in throughput for a fixed batch size
 Not 50% because forward passes are much more optimized than backward passes
 However, it frees up lots of memory â‡’ can double or quadruple batch size
 Effective throughput may be increased by (100/125)*(2) = 175%
How it works on the lowlevel
 Reducing Activation Recomputation in Large Transformer Models
 Let $b,s,h$ be the batch size, sequence length, and hidden size
 The input size to a Transformer block is $bsh$
Activation Memory per Transformer Layer

Each transformer layer consists of an attention and an MLP block connected with two layernorms. Below, we derive the memory required to store activations for each of these elements

We assume the activation are stored in 16bit floating format i.e. $2$ bytes per element, thus why we multiply often by 2 (compared to by 1 for a boolean mask)

We look at what needs to be stored to compute the backward pass
 Remember, as explained in Backpropagation, if $y^â€‹=h_{1}W_{2}$, where $h_{1}$ is the input and $W_{2}$ a matrix or linear layer, $dW_{2}dLâ€‹=h_{1}dy^â€‹dLâ€‹$
 â‡’ only need to store the input to the linear layer, $h_{1}$ to be able to compute the backward pass, since $dy^â€‹dLâ€‹$ will be provided at the moment of computation
Attention block:
 QKV matrix multiply i.e.
q,k,v = linear_proj(input).split(hidden//3)
: we only need to store their shared input with size $2sbh$  $QK_{T}$ matrix multiply: It requires storage of both Q and K with total size $4sbh$.
 Softmax: Softmax output with size $2as_{2}b$ is required for backpropagation.
 Softmax dropout: Only a mask with size $as_{2}b$ is needed.
 Attention over Values (V): We need to store the dropout output ($2as_{2}b$) and the Values ($2sbh$) and therefore need $2as_{2}b+2sbh$ of storage.
Summing the above values, in total, the attention block requires $11sbh+5as_{2}b$ bytes of storage.
MLP
 The two linear layers store their inputs with size $2sbh$ and $8sbh$. The GeLU nonlinearity also needs its input with size $8sbh$ for backpropagation. Finally, dropout stores its mask with size $sbh$. In total, MLP block requires $19sbh$ bytes of storage
Layer norm
Each layer norm stores its input with size $2sbh$ and therefore in total, we will need $4sbh$ of storage.
Total
 $ActivationÂ memoryÂ forÂ aÂ transformerÂ layer=sbh(34+5hasâ€‹)$
Full Activation Recomputation
 Now if you only store the transformer layer input and fully recompute the activations, you can get away with only storing 2bsh bytes
Tensor and Sequence Parallelism as paradigm to reduce activation storage per GPU
Tensor parallelism
 If we parallelize the attention and the MLP blocks
 Not only does tensor parallelism parallelize model parameters and optimizer states inside the attention and MLP blocks, but it also parallelizes the activations inside those blocks.
 Note that the input activations to these blocks (for example input to the Q, K, and V matrix multiplies or input to the h â†’ 4h linear layer) are not parallelized, and only activations within each block are divided across the tensor parallel group
 Assuming $t$way tensor parallelism, the perlayer memory required to store the activations reduces to (as a single GPU now only cares about the $1/t$ slice of the attention and mlp) $ActivationÂ memoryÂ perÂ layer=sbh(10+t24â€‹+5htasâ€‹)$
Sequence Parallelism

Tensor parallelism, parallelizes the parts of the transformer layer that take the most time during training (attention and MLP) and as a result, it is computationally efficient.

However, it leaves the layernorms as well as the dropouts after attention and MLP blocks intact and as a result, they are replicated across the tensor parallel group.

These elements do not require a lot of compute but demand a considerable amount of activation memory
 Quantitatively, the $10sbh$ part of the above equation is due to these replicated operations and as a result they are not divided by the tensor parallel size t.

The operations are independent along the sequence dimension.
 This characteristic allows us to partition these regions along the sequence dimension s.
 Partitioning along the sequence dimension reduces the memory required for the activations.
 This extra level of parallelism introduces new communication collectives before which will act as converters between sequence and tensor parallel region
 If youâ€™re smart, you can write good communication imperatives and the communication bandwidth used for tensor parallelism and tensor together with sequence parallelism are the same
 Now every operation is well distributed among GPUs, and assuming a sequence parallelism equal to tensor parallelism (which is likely) $ActivationÂ memoryÂ perÂ layer=tsbhâ€‹(34+5hasâ€‹)$
Selective recomputation
 The most problematic part (in terms of memory) are the
 Softmax: $2as_{2}b$
 Softmax dropout: $as_{2}b$
 Attention over Values (V): $2as_{2}b+2sbh$
 If we recompute only this part and use tensor + sequence parallelism, we now have $ActivationÂ memoryÂ forÂ aÂ transformerÂ layer=t34sbhâ€‹$
 Using selective activation recomputation allows the required activation memory to scale linearly with sequence length and be independent of the number of attention heads
How to do it in code
In practice
 Just use
from torch.utils.checkpoint import checkpoint
Thorough
 Define a class
CheckpointFunction(torch.autograd.Function)
that will need to define aforward()
andbackward()
 The
forward()
will save therun_function
,input_tensors
inctx
 Do the forward with no gradients i.e.
with torch.no_grad():
output_tensors = run_function(*input_tensors)
 At backward time
 add back the
input_tensors
to the computation graphctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
 Run the forward with grad_enabled:
with torch.enable_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
 Finally compute the gradients
input_grads = torch.autograd.grad(output_tensors, ctx.input_tensors, output_grads)
 return
input_grads
 `torch.autograd.grad(outputs, inputs,grad_outputs=None,â€¦)
 Computes and returns the sum of gradients of outputs with respect to the inputs.
 outputs (sequence of Tensor) â€“ outputs of the differentiated function.
 inputs (sequence of Tensor or GradientEdge) â€“ Inputs w.r.t. which the gradient will be returned (and not accumulated into
.grad
).  grad_outputs (sequence of Tensor) â€“ The â€śvectorâ€ť in the vectorJacobian product. Usually gradients w.r.t. each output. None values can be specified for scalar Tensors or ones that donâ€™t require grad.
 add back the