Vanilla Softmax
Linear Attention
- While removing softmax alone doesnβt immediately reduce computational complexity, it enables a crucial mathematical property: linearity.
- Linearity gives us associativity β chunkwise parallel form for prefill
Linear Attention as a linear RNN
-
For inference, we can rearrange things
-
Letβs define a state matrix
- Then we have a clear recurrent relationship
-
Linear attention is essentially a linear RNN with a matrix-valued state that accumulates key-value outer-products, and keeps track of a compressed view of size of the history.
-
This is called a state size expansion from to
-
By casting linear attention as an RNN,
- we reduce inference cost from to
- we reduce the space complexity to to O()
Limitations of Linear Attention - defining retrieval error
- The fixed-size state matrix in linear attention means it cannot perfectly preserve all historical information, making exact retrieval particularly challenging.
Retrieval error
-
Linear attention implements a key-value associative memory, which is the sum of outer products between keys and values
-
Assuming all keys are normalized to unit length, when we try to retrieve a value associated with a specific key , i.e. the dot-product between the state matrix and the key should ideally give us back (because )
-
To minimize the retrieval error term, we need all they key vectors to be orthogonal ()
- However, in a space of dimension , we can not define more than vectors that are all orthogonal to each other.
- is the head-dimension, and it has been shown that increasing head-dimension improves performance, as it gives more space for storing distinct-key value pairs
- There is a tradeoff between increasing head-dim, and allowing chunked linear attention to be hardware friendly (we want tiles that are small enough to keep in registers)
Gating or forgetting as a mechanism to improve retrieval
- In this key-value associative memory system, we can only add new key-value associations without the ability to erase existing information. As sequences grow longer, this leads to accumulating βretrieval errorsβ that degrade performance.
- We can narrow the performance gap with standard attention in language modeling tasks by incorporating a forgetting mechanism
- , where
- There are multiple different structured parameterization (for parameter efficiency), often with outer product structure.
- (decaying fast weight)
- (GLA)
- (Mamba)
- (Mamba 2)
DeltaNet: Linear Attention with Delta Rule
What is the Delta Rule
- Very simple error-correction learning principle
- principle: adjust the modelβs parameters based on the difference (delta) between what we want (target) and what we actually get (prediction).
- Imagine teaching a child to aim at a target. If they shoot too far to the left, youβd tell them to adjust right; too far right, adjust left. T
- The size of the adjustment depends on
- the delta size
- the magnitude of the input itself (in the linear regression case)
Pseudocode
import numpy as np
def delta_rule(x, y, epochs=100, lr=0.1):
"""
Simple delta rule implementation
x: input features (N samples by D features)
y: target values (N samples)
"""
# Initialize weights
w = np.zeros(x.shape[1])
# Train
for _ in range(epochs):
for i in range(len(x)):
# Forward pass
pred = np.dot(x[i], w)
# Compute error
error = y[i] - pred
# Update weights
w += lr * error * x[i]
return w
DeltaNet
DeltaNet applies this error-correction principle to linear attention. Instead of simply accumulating key-value outer product, it updates its state based on prediction errors:
\begin{align*} \mathbf{S}_{t} &= \mathbf{S}_{t-1} - \beta_t(\mathbf{S}_{t-1} \mathbf{k}_t - \mathbf{v}_t)\mathbf{k}_t^\top \\ &= \mathbf{S}_{t-1} - \beta_t \mathbf{S}_{t-1} \mathbf{k}_t \mathbf{k}_t^\top + \beta_t \mathbf{v}_t \mathbf{k}_t^\top \end{align*}
-
The parallel to the Delta Rule becomes clear when we break down the components:
- acts as the learning rate
- is the input data
- is the target
- is our current prediction (trying to retrieve from the state matrix)
-
Think of as retrieving the βold valueβ associated with the current key from memory. When we encounter a newly associated value for the same key, rather than blindly overwriting, we make a careful update: \begin{align*} \mathbf{v}_t^{\text{new}} &= (1-\beta_t) \mathbf{v}_t^{\text{old}} + \beta_t \mathbf{v}_t, \\ \mathbf{S}_t &= \mathbf{S}_{t-1} - \underbrace{\mathbf{v}_t^{\text{old}} \mathbf{k}_t^\top}_{\text{erase}} + \underbrace{\mathbf{v}_t^{\text{new}} \mathbf{k}_t^\top}_{\text{write}} \end{align*}
-
where is a learned combination of the old and current values, controlled by a dynamic : when , the memory content remains intact, and when , we completely replace the old associated value with the new one.
DeltaNet as the gradient update for MSE
- DeltaNetβs update rule can be derived by sequentially minimizing the mean squared error (MSE) between the desired output and the predicted output at each time step using gradient descent:
- Applying gradient descent to minimize this MSE loss gives:
- When the learning rate is set to , we recover DeltaNet
- In contrast, vanilla linear attention employs a linear loss function: