Assume that vectors are column vectors in notation

Computing safe softmax in online fashion

Online algorithm for the normalization factor

  • Let be the max value at step and the running sum at step i

  • At step 0

    • ,
  • At step 1

    • if , ,
  • At step i,

    • By induction, we assume
    • then and

Computing without materializing

  • Using Block Matrix Multiplication, the high-level idea is that you compute first a block of (ideally the shape of the block should be independent of the sequence length e.g. ), and then directly reuse the resulting block to compute with block .

Online block softmax computation

  • For the sake of simplicity, let’s say we have blocks of shape (c,dim) of and shape (dim,c) for

    • In this case, we’re selecting entire rows of Q as blocks (called ), and entire columns of as blocks. (called )
  • We’ll call the score matrix of shape (T,T)

    • Assuming we do block-matrix MM, we get that = of shape (c,c) (note that there’s no accumulation, because the blocks span entire rows/columns)
  • Let’s call , which is softmax without the normalization (still over a row vector)

  • Let’s call , where the softmax is applied over each block independently i.e.

  • Then we do a block MM with with blocks of shape (c,dim), giving us block rows of shape (c,dim) for the output matrix i.e.

  • Now obviously, each in uses a different local maximum. However, we can reuse our Computing safe softmax in online fashion idea, communicate the maximum of each row in a block, and just readjust the current maximum as we accumulate

    • We can compute the normalization factor in the same online fasion
    • Does this require the accumulation to be non-parallel i.e. one-by-one?
      • No, because is an associative operator, we can do a parallel-prefix sum

The pseudo-code with for-loops to make it simpler

For output block o_j = [0,0,...,0)$ # shape (c,dim)
	curr_max = eye(-infinity) # shape (c,1)
	normalization = [0,...0] # shape (c,1)
	for i in range(T/c):
		## compute the block
		S_ji = Q_iK_j^T
		## recomputing the max
		curr_max = max(curr_max, rowmax(S_ji))
		correction_factor = exp(rowmax(S_ji) - curr_max)
		P_ji = exp(S_ji - curr_max)
		## correcting the sums
		normalization = normalization*correction_factor + rowsum(P_ji)
		o_j = o_j*correction_factor + P_ji*V_i
	
	## normalize at the end
	o_j = o_j * normalization
 

Actual FA2 pseudocode (with more comments on loading between SRAM and HBM)