• SplitConfig defines over which dimension we shard/split
    • contiguous_chunks can customize the sharding by explictly designing boundaries for the sharding
      • helpful if you have an operation of this sort.
        • let’s say we want a linear proj with gating i.e. f(x) = linear(x) * gate(x)
        • In the non-distributed case, we write:
          • we define gate_proj=nn.Linear(in,2*out) and do
            • def forward(x):
            • hidden, gate = torch.split(gate_proj(x), out, dim=-1)
            • return hidden*gate
        • For this to work in the TensorParallelColumnLinear case,
          • we need to define contiguous_chunks=(out,out)
            • such that process_0 has the and slice of the weight matrix
            • and process_1 has has the and slice of the weight matrix
            • such that hidden, gate = torch.split(gate_proj(x), out, dim=-1) is correct when each process applies it
            • and also that the operation is still correct when un-sharding, and merging things back (important for inference)
@dataclasses.dataclass
class SplitConfig:
    split_dim: int
    # contiguous_chunks is a tuple of chunk sizes along the split_dim
    # sharding happens inside each chunk
    # if None, by default contiguous_chunks = (len(unsharded_param.shape[split_dim]),)
    contiguous_chunks: Optional[Tuple[int, ...]] = None
  • The param you pass in create_sharded_parameter_from_config(parameter=param, pg=pg, split_config=split_config) is actually already “sharded”

    • e.g. if you want to shard a nn.Linear(in,out) by the columns over the process group pg
    • you create a class ShardedLinear(nn.Linear) where each process will call super.__init__(in_features=in, out_features=out // pg.size())
      • Thus, the sharding is done implictly at construction time, not when calling create_sharded().
    • and then call mark_all_parameters_in_module_as_sharded(self, pg=self.pg, split_config=SplitConfig(split_dim=0))
      • which will take care of explictly explaining where ShardedLinear is actually sharded.
  • A sharded parameter is defined as

    • global_ranks i.e. list of the global ranks of the processed holding the shards
      • unsharded_shape i.e. true shape of the parameter if we merged everything back
    • local_global_slices_pairs i.e.
      • mostly important when contiguous_chunks is not None

      • let’s say we have two contiguous chunks (100,100) over which we shard with pg.size() = 2

        • for process_0
          • local_slice=slice(0,50,None) > global_slice=slice(0,50,None)
          • local_slice=slice(50,100,None) > global_slice= slice(100,150,None)
        • for process_1
          • local_slice=slice(0,50,None) > global_slice=slice(50,100,None)
          • local_slice=slice(50,100,None) > global_slice= slice(150,200,None)
      • list with one element when contiguous_chunks=None

        • for a given process in the process group, global_slice = slice(current_rank * shard_length, (current_rank + 1) * shard_length)
          • where current_rank = dist.get_rank(pg), the process local rank within the process group and shard_length = parameter.shape[split_dim]
        • local_slices = tuple(slice(None) for _ in range(param_num_dims))
          • slice(None) does nothing
        •   global_slices = tuple(global_slice if dim_id == split_dim else slice(None) for dim_id in range(param_num_dims))
        • local_global_slices_pairs =(SlicesPair(local_slices=local_slices, global_slices=global_slices),)
    • unsharding/merging back
      • true_param = init()
      • for process, sharded_param in zip(pg, params):
        • curr_local_global_slices_pairs = get(pg)
        • `for local_slice, global_slice in curr_local_global_slices_pairs:
        •   			- `true_param[global_slice].copy_(sharded_param[local_slice])`