Summary

  • Feature learning is achieved by scaling the spectral norm of weight matrices and their updates like

    • i.e. weight matrices should take in vectors fan_in of length and spit out weight vectors of length fan_out.
    • This scaling happens at init time and also at gradient update time.
  • An important fact about a matrix self.weight with fan_in much larger than fan_out is that the null space is huge, meaning that most of the input space is mapped to zero. The dimension of the null space is at least fan_in - fan_out. At initialization, most of a fixed input x will lie in this nullspace.

    • This means that to get the output of self.forward to have unit variance at initialization, you need to pick a huge initialization scale sigma in order to scale up the component of x that does not lie in the null space.
    • But after a few steps of training, the situation changes. Gradient descent will cause the input x to align with the non-null space of self.weight. This means that the sigma you chose to control the activations at initialization is now far too large in hindsight, and the activations will blow up! This problem only gets worse with increasing fan_in.
    • The solution to this problem is simple: don’t choose sigma to control variance at initialization! Instead, choose sigma under the assumption that inputs fall in the non-null space. Even if this makes the activations too small at initialization, this is fine as they will quickly β€œwarm up” after a few steps of training.

Detailed

Math

Alignment

  • It’s worth expanding a little on what we mean by alignment here. When we say that an input x aligns with a weight matrix weight, we mean that if we compute U, S, V = torch.linalg.svd(weight), then the input x will tend to have a larger dot product with the rows of V that correspond to larger diagonal entries of the singular value matrix S. When we say that layers align, we mean that the outputs of one layer will align with the next layer.

Spectral norm

  • Spectral Norm
    • The spectral norm is the largest factor by which a matrix can increase the norm of a vector on which it acts.
    • In the case of deep learning, the spectral norm of a weight matrix upper-bounds the activation scale

Feature learning

  • Feature learning regime can be summarized as: both the features and their updates upon a step of gradient descent must be the proper size.
  • Let denote the features of input at layer of a neural network, and let denote their change after a gradient step. We desire that:
  • This amounts to asking that the β€œtypical element size” of vectors and is with respect to width .
    • motivated by the fact that activation functions are designed to take order-one inputs and give order-one outputs (e.g. tanh)
    • the requirement stipulates that feature entries also undergo updates during training. Note that any larger updates would blow up at large width, and any smaller updates would vanish at large width.

Condition 1 (Spectral scaling)

  • The main message is that feature learning in the sense of the above definition may be ensured by the following spectral scaling condition on the weight matrices of a deep network and their gradient updates.

  • Consider applying a gradient update to the th weight matrix . The spectral norms of these matrices should satisfy:

  • We have implicitly assumed that the input has size , which is standard for image data. Language models are an important counterexample, where embedding matrices take one-hot inputs (i.e. not all the width acts on the first layer) and the in Condition 1 should be replaced by 1.

Parametrization 1 (Spectral parametrization) - Eficient implementation of the spectral scaling condition

  • Spectral scaling induces feature learning

  • How to implement it ?

  • They claim that the spectral scaling condition (Condition 1) is satisfied and feature learning is achieved (as per Desideratum 1) if the initialization scale and learning rate of each layer are chosen according to:

Random initialization

  • Common practice, is initialized as , where all elements of are initialized i.i.d. from a normal distribution with mean zero and unit variance. The spectral norm of a matrix thus constructed is roughly:
    • To get the desired scaling , we need merely choose:
      • Simplifying within the , we arrive at scaled as in the spectral parametrization (Parametrization 1). Initializing weights with a prefactor scaling in this manner achieves the correct spectral norm of .
        • We note that the constant factor suppressed by the here will usually be smallβ€”for example, a prefactor of agrees with typical practice for ReLU networks at most layers.
  • If is instead a random semi-orthogonal matrix (spectral norm of 1), then we can simply use a prefactor:

Biases

  • Extend the spectral analysis to biases.
  • Let be a bias vector which enters during forward propagation as . We may choose to view the bias vector as a weight matrix connecting an auxiliary layer with width 1 and output 1 to the th hidden layer, after which we may simply apply our scaling analysis for weight matrices.
  • The spectral scaling condition (Condition 1) prescribes that and , and Parametrization 1 prescribes that the initialization scale and learning rate should be and . In practice, one may usually just take .

Comparison to standard parametrization (SP)

  • β€œKaiming,” β€œXavier,” or β€œLeCun” initialization
  • Notice that SP initialization exceeds Spectral Parametrization in any layer with fan-out smaller than fan-in, (e.g. second MLP in GLU in most cases).