Independent weight decay

  • Parameterizing weight decay independently of learning rate reduces LR sensitivity
    • recommended by authors but not applied in practice in Pytorch

Formulation

  • For parameters , let denote the AdamW update without learning rate or weight decay.
    • For weight decay coefficient , max learning rate , and schedule ,
    • Original authors recommend the update ,
      • we refer to as independent decay.
    • Default implementation in PyTorch or Optax applies the update
      • now scales the weigth decay term.

Attention logit growth / QK-layernorm

  • Summary:
  • Why ?
    • instability was observed for large models. It was caused by extremely large values in attention logits, which lead to (almost one-hot) attention weights with near-zero entropy.
    • attention logits:
    • models with qk-layernorm exhibit considerably lower LR sensitivity and train to low loss at high learning rates

z-loss

  • Another instability, when training large models, is divergence in the output logits from the log probabilities [6]. Let denote the model’s output logits, which are used to compute class probabilities via a softmax where .
  • This instability occurs when the logits diverge and become very negative.
  • In contrast to the attention logit growth instability, this divergence occurs towards the end of training.
  • The mitigation proposed by Chowdh- ery et al. [6] is to encourage log Z to remain close to zero i.e. .
  • To do so, they add an auxiliary loss , referred to as z-loss, with a coefficient of .