Vanilla Softmax

Linear Attention

  • While removing softmax alone doesn’t immediately reduce computational complexity, it enables a crucial mathematical property: linearity.
  • Linearity gives us associativity β‡’ chunkwise parallel form for prefill

Linear Attention as a linear RNN

  • For inference, we can rearrange things

  • Let’s define a state matrix

    • Then we have a clear recurrent relationship
  • Linear attention is essentially a linear RNN with a matrix-valued state that accumulates key-value outer-products, and keeps track of a compressed view of size of the history.

  • This is called a state size expansion from to

  • By casting linear attention as an RNN,

    • we reduce inference cost from to
    • we reduce the space complexity to to O()

Limitations of Linear Attention - defining retrieval error

  • The fixed-size state matrix in linear attention means it cannot perfectly preserve all historical information, making exact retrieval particularly challenging.

Retrieval error

  • Linear attention implements a key-value associative memory, which is the sum of outer products between keys and values

  • Assuming all keys are normalized to unit length, when we try to retrieve a value associated with a specific key , i.e. the dot-product between the state matrix and the key should ideally give us back (because )

  • To minimize the retrieval error term, we need all they key vectors to be orthogonal ()

    • However, in a space of dimension , we can not define more than vectors that are all orthogonal to each other.
    • is the head-dimension, and it has been shown that increasing head-dimension improves performance, as it gives more space for storing distinct-key value pairs
    • There is a tradeoff between increasing head-dim, and allowing chunked linear attention to be hardware friendly (we want tiles that are small enough to keep in registers)

Gating or forgetting as a mechanism to improve retrieval

  • In this key-value associative memory system, we can only add new key-value associations without the ability to erase existing information. As sequences grow longer, this leads to accumulating β€œretrieval errors” that degrade performance.
  • We can narrow the performance gap with standard attention in language modeling tasks by incorporating a forgetting mechanism
    • , where
    • There are multiple different structured parameterization (for parameter efficiency), often with outer product structure.
      • (decaying fast weight)
      • (GLA)
      • (Mamba)
      • (Mamba 2)

DeltaNet: Linear Attention with Delta Rule

What is the Delta Rule

  • Very simple error-correction learning principle
    • principle: adjust the model’s parameters based on the difference (delta) between what we want (target) and what we actually get (prediction).
  • Imagine teaching a child to aim at a target. If they shoot too far to the left, you’d tell them to adjust right; too far right, adjust left. T
  • The size of the adjustment depends on
    • the delta size
    • the magnitude of the input itself (in the linear regression case)

Pseudocode

import numpy as np
 
def delta_rule(x, y, epochs=100, lr=0.1):
    """
    Simple delta rule implementation
    x: input features (N samples by D features)
    y: target values (N samples)
    """
    # Initialize weights
    w = np.zeros(x.shape[1])
    
    # Train
    for _ in range(epochs):
        for i in range(len(x)):
            # Forward pass
            pred = np.dot(x[i], w)
            
            # Compute error
            error = y[i] - pred
            
            # Update weights
            w += lr * error * x[i]
            
    return w

DeltaNet

DeltaNet applies this error-correction principle to linear attention. Instead of simply accumulating key-value outer product, it updates its state based on prediction errors:

\begin{align*} \mathbf{S}_{t} &= \mathbf{S}_{t-1} - \beta_t(\mathbf{S}_{t-1} \mathbf{k}_t - \mathbf{v}_t)\mathbf{k}_t^\top \\ &= \mathbf{S}_{t-1} - \beta_t \mathbf{S}_{t-1} \mathbf{k}_t \mathbf{k}_t^\top + \beta_t \mathbf{v}_t \mathbf{k}_t^\top \end{align*}

  • The parallel to the Delta Rule becomes clear when we break down the components:

    • acts as the learning rate
    • is the input data
    • is the target
    • is our current prediction (trying to retrieve from the state matrix)
  • Think of as retrieving the β€œold value” associated with the current key from memory. When we encounter a newly associated value for the same key, rather than blindly overwriting, we make a careful update: \begin{align*} \mathbf{v}_t^{\text{new}} &= (1-\beta_t) \mathbf{v}_t^{\text{old}} + \beta_t \mathbf{v}_t, \\ \mathbf{S}_t &= \mathbf{S}_{t-1} - \underbrace{\mathbf{v}_t^{\text{old}} \mathbf{k}_t^\top}_{\text{erase}} + \underbrace{\mathbf{v}_t^{\text{new}} \mathbf{k}_t^\top}_{\text{write}} \end{align*}

  • where is a learned combination of the old and current values, controlled by a dynamic : when , the memory content remains intact, and when , we completely replace the old associated value with the new one.

DeltaNet as the gradient update for MSE

  • DeltaNet’s update rule can be derived by sequentially minimizing the mean squared error (MSE) between the desired output and the predicted output at each time step using gradient descent:
  • Applying gradient descent to minimize this MSE loss gives:

  • When the learning rate is set to , we recover DeltaNet
  • In contrast, vanilla linear attention employs a linear loss function: