Summary
-
Feature learning is achieved by scaling the spectral norm of weight matrices and their updates like
- i.e. weight matrices should take in vectors
fan_in
of length and spit out weight vectors of lengthfan_out
. - This scaling happens at init time and also at gradient update time.
- i.e. weight matrices should take in vectors
-
An important fact about a matrix
self.weight
withfan_in
much larger thanfan_out
is that the null space is huge, meaning that most of the input space is mapped to zero. The dimension of the null space is at leastfan_in - fan_out
. At initialization, most of a fixed inputx
will lie in this nullspace.- This means that to get the output of
self.forward
to have unit variance at initialization, you need to pick a huge initialization scalesigma
in order to scale up the component ofx
that does not lie in the null space. - But after a few steps of training, the situation changes. Gradient descent will cause the input
x
to align with the non-null space ofself.weight
. This means that thesigma
you chose to control the activations at initialization is now far too large in hindsight, and the activations will blow up! This problem only gets worse with increasingfan_in
. - The solution to this problem is simple: donβt choose
sigma
to control variance at initialization! Instead, choosesigma
under the assumption that inputs fall in the non-null space. Even if this makes the activations too small at initialization, this is fine as they will quickly βwarm upβ after a few steps of training.
- This means that to get the output of
Detailed
Math
Alignment
- Itβs worth expanding a little on what we mean by alignment here. When we say that an input
x
aligns with a weight matrixweight
, we mean that if we computeU, S, V = torch.linalg.svd(weight)
, then the inputx
will tend to have a larger dot product with the rows ofV
that correspond to larger diagonal entries of the singular value matrixS
. When we say that layers align, we mean that the outputs of one layer will align with the next layer.- This happens naturally during gradient descent. https://jeremybernste.in/modula/golden-rules/#three-golden-rules
Spectral norm
- Spectral Norm
- The spectral norm is the largest factor by which a matrix can increase the norm of a vector on which it acts.
- In the case of deep learning, the spectral norm of a weight matrix upper-bounds the activation scale
Feature learning
- Feature learning regime can be summarized as: both the features and their updates upon a step of gradient descent must be the proper size.
- Let denote the features of input at layer of a neural network, and let denote their change after a gradient step. We desire that:
- This amounts to asking that the βtypical element sizeβ of vectors and is with respect to width .
- motivated by the fact that activation functions are designed to take order-one inputs and give order-one outputs (e.g. tanh)
- the requirement stipulates that feature entries also undergo updates during training. Note that any larger updates would blow up at large width, and any smaller updates would vanish at large width.
Condition 1 (Spectral scaling)
-
The main message is that feature learning in the sense of the above definition may be ensured by the following spectral scaling condition on the weight matrices of a deep network and their gradient updates.
-
Consider applying a gradient update to the th weight matrix . The spectral norms of these matrices should satisfy:
-
We have implicitly assumed that the input has size , which is standard for image data. Language models are an important counterexample, where embedding matrices take one-hot inputs (i.e. not all the width acts on the first layer) and the in Condition 1 should be replaced by 1.
Parametrization 1 (Spectral parametrization) - Eficient implementation of the spectral scaling condition
-
Spectral scaling induces feature learning
-
How to implement it ?
-
They claim that the spectral scaling condition (Condition 1) is satisfied and feature learning is achieved (as per Desideratum 1) if the initialization scale and learning rate of each layer are chosen according to:
Random initialization
- Common practice, is initialized as , where all elements of are initialized i.i.d. from a normal distribution with mean zero and unit variance. The spectral norm of a matrix thus constructed is roughly:
- To get the desired scaling , we need merely choose:
- Simplifying within the , we arrive at scaled as in the spectral parametrization (Parametrization 1). Initializing weights with a prefactor scaling in this manner achieves the correct spectral norm of .
- We note that the constant factor suppressed by the here will usually be smallβfor example, a prefactor of agrees with typical practice for ReLU networks at most layers.
- Simplifying within the , we arrive at scaled as in the spectral parametrization (Parametrization 1). Initializing weights with a prefactor scaling in this manner achieves the correct spectral norm of .
- To get the desired scaling , we need merely choose:
- If is instead a random semi-orthogonal matrix (spectral norm of 1), then we can simply use a prefactor:
Biases
- Extend the spectral analysis to biases.
- Let be a bias vector which enters during forward propagation as . We may choose to view the bias vector as a weight matrix connecting an auxiliary layer with width 1 and output 1 to the th hidden layer, after which we may simply apply our scaling analysis for weight matrices.
- The spectral scaling condition (Condition 1) prescribes that and , and Parametrization 1 prescribes that the initialization scale and learning rate should be and . In practice, one may usually just take .
Comparison to standard parametrization (SP)
- βKaiming,β βXavier,β or βLeCunβ initialization
- Notice that SP initialization exceeds Spectral Parametrization in any layer with fan-out smaller than fan-in, (e.g. second MLP in GLU in most cases).