https://rockt.github.io/2018/04/30/einsum Let’s say you want to multiply matrix A 3x2 and matrix B 2x5. The result is 3x5. Then you write

torch.einsum("ij,jk -> ik", [a,b])

The dimension that is not present in the output gets “collapsed” i.e. summed over. You can also think that a non present dimension is equal to dimension of 1.