GPT / Transformers

Positional embeddings

  • sinuosidal embeddings are not the norm anymore
  • GPT-2 simply learns the positional embedding table emb_dim to be added to the word embeddings.
  • There’s also relative positional embeddings, which are matrices and s.t. captures the interaction between token and . The self-attention would be modifed such that and

LayerNorm

  • In Vaswani, the layernorm is after the multi-head, and then after the feed-forward
  • Now, it’s more usual to pre-layernorm i.e. before multi-head and then before feed-forward.
  • Gemma uses before and after layernorm i..e layernorm sandwich

Dropout

  • It’s fairly usual to add dropout, just after the feed-forward, after the multi-head, and even within the self-attention, e.g. Attention(Q,K,V) = Dropout(softmax(\frac{QK^T}{\sqrt{d_k}}))V$$

Weight Tying

  • Weight Tying improves the performance of language models by tying (sharing) the weights of the embedding and softmax layers. This method also massively reduces the total number of parameters in the language models that it is applied to.
  • Language models are typically comprised of an embedding layer, followed by a number of Transformer or LSTM layers, which are finally followed by a softmax layer. Embedding layers learn word representations, such that similar words (in meaning) are represented by vectors that are near each other (in cosine distance). [Press & Wolf, 2016] showed that the softmax matrix, in which every word also has a vector representation, also exhibits this property. This leads them to propose to share the softmax and embedding matrices, which is done today in nearly all language models.

Weight Decay

  • Setup optimization. Any parameters that is 2D will be weight decayed, otherwise no. i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don’t
  • Usual weight decay =0.1!!! for LLMs
  • Why weight decay 1) better optimization as observed in Chinchilla (2) prevention of loss divergences when using bfloat16.
  • Another crucial effect of weight decay is that it enables stable bfloat16 mixed-precision training. Scao et al. (2022) briefly observed that usage of float16 causes spikes, while bfloat16 is more stable. Although bfloat16 shares the same floating-point exponent size as float32, it offers lower precision, with only 7 bits for the fraction instead of 23. Interestingly, we observe that even the presumably more stable bfloat16 can still exhibit late-training spikes that irreparably harm model performance. We suspect that LLM practitioners may be aware of this phenomenon qualitatively, but we could not find any systematic reference addressing it.

Gradient clipping

  • GPT-2 clips the gradient L2 norm at 1. This means that the gradients of all parameters concatenated as a vector has maximum L2 norm of 1, and if that’s not the case, all the gradients are scaled down.