Summary of Attention

  • In a usual attention block, there’s 1. self-attention (communication) 2. feedforward (computation), with some layer-norms and residual connections sprinkled in there for stability.
  • GPT is decoder-only, because it is auto-regressive and only generates on past outputs.
  • Original Vaswani was encoder-decoder because they did machine translation, and thus the decoding is conditioned on the original input to translate, and this conditioning is fed through a cross-attention block (non-causal) within the attention block. (see vaswani picture)
  • In production, you compute yourself and then pass it to FlashAttention y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)

Self-attention Implementation

  • Every token in a sequence will emit two vectors, a query and key vector
    • query vector = “what the token is looking for” e.g. vowel looking for consonants
    • key vector = “what do I contain”
    • value vector = “what do I communicate when aggregated”
	X (B, T, C), usually X = embeddings + positional embeddings
	head_size = 16
	Q = nn.Linear(C, head_size, bias=False)
	K = nn.Linear(C, head_size, bias=False)
	V = nn.Linear(C, head_size, bias=False)
	
	query = Q(x) # (B,T,head_size)
	key = K(x) # (B,T,head_size)
	W = query @ rearrange(key, "b t q ->b q t") ## batched matrix multiply (B,T,T)
	mask = torch.tril(torch.ones(T,T))
	W = torch.mask_fill(mask == 0, float("-inf"))
	W = F.softmax(W, dim=1) / torch.sqrt(head_size)  # B,T,T

	out = W @ V(x) # (B,T,T @ B,T,head_size) -> (B,T,head_size)
  • self-attention just means that the keys and value come from the same source as the queries i.e. they all take as input . For cross-attention, the keys and values may come from an external source e.g. some context .
  • An encoder block deletes the masking part, i.e. not causal. A decoder block keeps the masking, making it causal, to be used in the auto-regressive manner.
  • The scaling of allows W (before softmax) to be unit-gaussian at init time, which makes sure that the softmax doesn’t output one-hot vectors. This is because softmax approximates the maximum, and thus will tend to have peaked distributions if there’s large values (because of exponentiation).

Many implementations are done while keeping in mind KV caching.

Multi-head attention

  • Given the original input, MHA computes the attention in parallel (over all the heads),
  • just a bunch of attention block ran in parallel that we concatenate at the end
  • MHA is Q_heads=K_heads=V_heads=N
  • where and there are a bunch of projection matrices.
  • In this case, Q,K,V may very well be just .

Multi-Query Attention

  • Multi-Query attention is a change to the model architecture that shrinks the size of the KV cache by assigning multiple heads to Q, and only a single head to K and V.
  • Instead of Q, K, and V all being split into separate heads, only Q is split. K and V are smaller, the size of a single head, and that single K and V is shared among all the Q heads.
  • MQA is Q_heads=N; K_heads=V_heads=1 i.e. where . Note that the lack of subscript for and .
  • The KV cache is consequently smaller than vanilla-attention

Grouped-Query Attention

  • hybrid between MQA and MHA

  • GQA is Q_heads=N; K_heads=V_heads=G where 1 < G < N. GQA claims less effect on perplexity and better training stability than MQA.

  • One can convert a transformer checkpoint trained with MHA by

    1. converting a multi-head checkpoint into a multi-query checkpoint
      1. The projection matrices for key and value heads are mean-pooled into single projection matrices
    2. additional pre-training to allow the model to adapt to its new structure
  • We note that GQA is not applied to the encoder self-attention layers; encoder representations are computed in parallel, and memory bandwidth is therefore generally not the primary bottleneck.