The task

  • Assume for the moment 1 head, no batch dimension
    • This is okay, because they are fully independent, they are “embarrassingly parallel”
  • We have of shape , seq. length , head dimension
  • ,
  • How to parallelize and do this in one go?
    • is an intermediate, can we avoid materializing it?

Thoughts

  • Attention is like 2-layer network, but the head dimension is very small