Assume that vectors are column vectors in notation
Computing safe softmax in online fashion
Online algorithm for the normalization factor
-
-
Let be the max value at step and the running sum at step i
-
At step 0
- ,
-
At step 1
- if , ,
-
At step i,
- By induction, we assume
- then and
Computing without materializing
- Using Block Matrix Multiplication, the high-level idea is that you compute first a block of (ideally the shape of the block should be independent of the sequence length e.g. ), and then directly reuse the resulting block to compute with block .
Online block softmax computation
-
For the sake of simplicity, letβs say we have blocks of shape (c,dim) of and shape (dim,c) for
- In this case, weβre selecting entire rows of Q as blocks (called ), and entire columns of as blocks. (called )
-
Weβll call the score matrix of shape (T,T)
- Assuming we do block-matrix MM, we get that = of shape (c,c) (note that thereβs no accumulation, because the blocks span entire rows/columns)
-
Letβs call , which is softmax without the normalization (still over a row vector)
-
Letβs call , where the softmax is applied over each block independently i.e.
-
Then we do a block MM with with blocks of shape (c,dim), giving us block rows of shape (c,dim) for the output matrix i.e.
-
Now obviously, each in uses a different local maximum. However, we can reuse our Computing safe softmax in online fashion idea, communicate the maximum of each row in a block, and just readjust the current maximum as we accumulate
- We can compute the normalization factor in the same online fasion
- Does this require the accumulation to be non-parallel i.e. one-by-one?
- No, because is an associative operator, we can do a parallel-prefix sum
The pseudo-code with for-loops to make it simpler
For output block o_j = [0,0,...,0)$ # shape (c,dim)
curr_max = eye(-infinity) # shape (c,1)
normalization = [0,...0] # shape (c,1)
for i in range(T/c):
## compute the block
S_ji = Q_iK_j^T
## recomputing the max
curr_max = max(curr_max, rowmax(S_ji))
correction_factor = exp(rowmax(S_ji) - curr_max)
P_ji = exp(S_ji - curr_max)
## correcting the sums
normalization = normalization*correction_factor + rowsum(P_ji)
o_j = o_j*correction_factor + P_ji*V_i
## normalize at the end
o_j = o_j * normalization