

QKV proj

  • qkv_proj is a TensorParallelColumnLinear
    • d_qk= d_model // num_attention_heads
    • in_features=d_model
    • out_features= num_attention_heads*d_qk + 2* num_key_value_heads d_qk
      • support for MQA/GQA Self-Attention (to shrink the kv cache size)
        • mqa num_key_value_heads=1
    • to spit out nicely separated over which we can apply attention easily, we must define contigous_chunks = (num_attention_heads * d_qk, num_key_value_heads * d_qk, num_key_value_heads * d_qk)


  • o_proj is a TensorParallelRowLinear(d_model,d_model)


  • Is checkpointed using decorator @checkpoint_method(attr_name="checkpoint_attention")
  • because it sits in between qkv_proj (ColumnLinear) and o_proj (RowLinear), there’s no need to think about anything distributed
  • just usual attention (code is still not great to look at because we’re using flash_attn_varlen_func)


  • need to pass TP mode (all_reduce or reduce_scatter) and TP process group

  • All complexity is handled by TensorParallel paradigms

  • In Llama, they want to use GLUActivation for the up_proj, which is basically activation & gating.

    • def forward(x):
      • gate_states, up_states = torch.split(x, merged_states.shape[-1] // 2, dim=-1)         return self.act(gate_states) * up_states        
  • Thus you need to define contiguous_chunks = (intermediate_size, intermediate_size) such that the nn.Linear is sharded correctly.

  • gate_up_proj = TensorParallelColumnLinear(hidden_size, 2*intermediate_size, contigous_chunks, pg=tp_pg)

  • down_proj = TensorParallelRowLinear(intermediate_size, hidden_size, pg=tp_pg)

  • forward(x) = down_proj(GLUActivation(gate_up_proj(x)))


  • Diagram
  • Signature __init__(in_features, out_features, pg, mode, ...)
    • init under the hood as nn.Linear(in_features, out_features // tp_pg.size())
    • automatic sharding
  • How we mark the parameters as sharded
    • define split_dim=0 for sharding by the column, because a module = nn.Linear(in, out) has its weight matrix represented as , such that


  • Diagram
  • Signature __init__(in_features, out_features, pg, mode, ...)
    • init under the hood as nn.Linear(in_features // tp_pg.size(), out_features)
    • automatic sharding
  • How we mark the parameters as sharded
    • define split_dim=1 for sharding by the column, because a module = nn.Linear(in, out) has its weight matrix represented as , such that
  • No need to shard the bias term, only rank 0 would have it
    • i.e. bias = dist.get_rank( == 0 and bias when init the linear layer


  • basically a TensorParallelRowLinear in how it functions
  • interestingly, split_dim=0, because nn.Embedding stores its embedding matrix in the natural order row x column
  • at forward time,
    • just need to keep track of which part of the vocab a given process contains, and given the input, mask out the tokens not present in the current vocab
    • The differentiable_reduce_scatter_sum or all_reduce will work well, because each token contains