- SplitConfig defines over which dimension we shard/split
contiguous_chunks
can customize the sharding by explictly designing boundaries for the sharding- helpful if you have an operation of this sort.
- let’s say we want a linear proj with gating i.e. f(x) = linear(x) * gate(x)
- In the non-distributed case, we write:
- we define
gate_proj=nn.Linear(in,2*out)
and dodef forward(x):
hidden, gate = torch.split(gate_proj(x), out, dim=-1)
return hidden*gate
- we define
- For this to work in the
TensorParallelColumnLinear
case,- we need to define
contiguous_chunks=(out,out)
- such that process_0 has the and slice of the weight matrix
- and process_1 has has the and slice of the weight matrix
- such that
hidden, gate = torch.split(gate_proj(x), out, dim=-1)
is correct when each process applies it - and also that the operation is still correct when un-sharding, and merging things back (important for inference)
- we need to define
- helpful if you have an operation of this sort.
-
The
param
you pass increate_sharded_parameter_from_config(parameter=param, pg=pg, split_config=split_config)
is actually already “sharded”- e.g. if you want to shard a nn.Linear(in,out) by the columns over the process group
pg
- you create a class
ShardedLinear(nn.Linear)
where each process will callsuper.__init__(in_features=in, out_features=out // pg.size())
- Thus, the sharding is done implictly at construction time, not when calling
create_sharded()
.
- Thus, the sharding is done implictly at construction time, not when calling
- and then call
mark_all_parameters_in_module_as_sharded(self, pg=self.pg, split_config=SplitConfig(split_dim=0))
- which will take care of explictly explaining where
ShardedLinear
is actually sharded.
- which will take care of explictly explaining where
- e.g. if you want to shard a nn.Linear(in,out) by the columns over the process group
-
A sharded parameter is defined as
global_ranks
i.e. list of the global ranks of the processed holding the shards-
unsharded_shape
i.e. true shape of the parameter if we merged everything back
local_global_slices_pairs
i.e.-
mostly important when contiguous_chunks is not None
-
let’s say we have two contiguous chunks (100,100) over which we shard with pg.size() = 2
- for process_0
- local_slice=slice(0,50,None) ←> global_slice=slice(0,50,None)
- local_slice=slice(50,100,None) ←> global_slice= slice(100,150,None)
- for process_1
- local_slice=slice(0,50,None) ←> global_slice=slice(50,100,None)
- local_slice=slice(50,100,None) ←> global_slice= slice(150,200,None)
- for process_0
-
list with one element when
contiguous_chunks=None
- for a given process in the process group,
global_slice = slice(current_rank * shard_length, (current_rank + 1) * shard_length)
- where
current_rank = dist.get_rank(pg)
, the process local rank within the process group andshard_length = parameter.shape[split_dim]
- where
local_slices = tuple(slice(None) for _ in range(param_num_dims))
slice(None)
does nothing
-
global_slices = tuple(global_slice if dim_id == split_dim else slice(None) for dim_id in range(param_num_dims))
local_global_slices_pairs =(SlicesPair(local_slices=local_slices, global_slices=global_slices),)
- for a given process in the process group,
-
- unsharding/merging back
true_param = init()
for process, sharded_param in zip(pg, params):
curr_local_global_slices_pairs = get(pg)
- `for local_slice, global_slice in curr_local_global_slices_pairs:
-
- `true_param[global_slice].copy_(sharded_param[local_slice])`