-
Defines the distributed primitives used in Tensor Parallelism and Sequence Parallelism such that the forward pass and backward pass are appropriately defined
-
reduce_scatter
andall_reduce
are equivalent in terms of communcation becauseall_reduce
is implemented using a ring-all-reduce = reduce + scatter- we just need to bucketize by sequence dimension for things to go correctly.
PyTorch specifics
torch.autograd.Function
- To define differentiable operator, you will need to define classes inheriting
torch.autograd.Function
. How it works is explained in torch.autograd
@assert_cuda_max_connections_set_to_1
- functiond decorator that sets
CUDA_DEVICE_MAX_CONNECTIONS=1
ensures that only one connection can be made to a single GPU from any host thread.- In
ColumnLinearAsync
, in the backward pass withtp_mode=REDUCE_SCATTER
, they asynchronously fetch
- In
Tensor Parallelism only
Differentiable Identity
- Before a
ColumnLinear
operation (beginning of MLP) - Forward:
f(x) = x
i.e. allColumnLinear
receive the same input - Backward:
b(x) = all_reduce_sum(x)
, because the same input was replicated over all processes
Differentiable_All_Reduce_Sum
- After a
RowLinear
operation (end of MLP) - Forward:
f(x) = all_reduce_sum(x)
i.e. we need to accumulate the matrix to get the correct results - Backward:
b(x) = x
i.e. the gradient must be replicated over all processes, as they all participated
Tensor + Sequence Parallelism
Differentiable_All_Gather
-
- Before a
ColumnLinear
operation (beginning of MLP)
- Before a
- Forward:
f(x) = all_gather(x)
i.e. was split by sequence previously, so we need to gather it back before feeding toColumnLinear
- Backward:
b(x) = reduce_scatter_sum(x)
, because the same input was replicated over all processes, we need to reduce, and then scatter by splitting by sequence dimension.
How to implement it
- nanotron code assumes it’s sharded by the first dimension
- gets the current tensor and
sharded_batch_size, *rest_size = tensor.shape
- creates an empty tensor of
unsharded_size
i.e.unsharded_tensor= torch.empty((sharded_batch_size*group.size(), *rest_size))
- call
dist.all_gather_into_tensor(unsharded_tensor, tensor, group)
Differentiable_Reduce_Scatter
- After a
RowLinear
operation (end of MLP) - Forward:
f(x) = reduce_scatter_sum(x)
i.e. we need to accumulate the matrix to get the correct results, and then scatter by splitting by sequence dimension - Backward:
b(x) = all_gather(x)
i.e. the gradient must be gathered along the sequence dimension and then replicated over all processes, as they all participated
Async Tensor + Sequence Parallelism
Motivation
- If we’re smart, we can overlap communication and computation
- e.g. In backward
- e.g. for
ColumnLinear
, start the reduce_scatter ofgrad_tensor
, while computing the gradient of the weight and bias i.e.grad_weight
andgrad_bias
.
- e.g. for
- e.g. In backward
- We rely on
CUDA_DEVICE_MAX_CONNECTIONS=1
to ensure that the gather/reduce_scatter is scheduled before the tensor gradient computation in the code.
Details
RowLinear
doesn’t support async if the tp_mode isall_reduce
(tp only) instead ofreduce_scatter
(tp+sequence)- Must Define
_{Row,Column}LinearAsyncCommunication(torch.autograd.Function)
classes
Code
ColumnLinear
Forward
- In
def forward(ctx, tensor, weight, bias, group, tp_mode)
tensor
is sharded/split by sequence dimension- We define the full input to be
gathered_tensor
- What you can do is
ctx.save_for_backward(tensor, weight)
(only save the sharded tensor to reduce activation memory, we will gather it back at backward time)- start async gather
handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group, async_op=True)
- meanwhile, compute the result matmul only with the sharded
tensor
torch.mm(tensor, weight, out=same_device_shard)
handle.wait()
- compute the rest of the matmul with the rest of the gathered tensor
Backward
def backward(ctx, grad_output)
- What you can do is
tensor, weight = ctx.saved_tensors
- async gather
gathered_tensor
handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
- Compute
grad_tensor = grad_output.matmul(weight)
handle.wait()
- async
reduce_scatter_sum
the grad_tensor-
handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)
-
- meanwhile, compute
grad_weight
andgrad_bias
grad_weight = grad_output.t().matmul(gathered_tensor)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return sub_grad_tensor, grad_weight, grad_bias, None, None
Row Linear
Forward
def forward(ctx, tensor, weight, bias, group, tp_mode):
- nothing much tricky going on
- What you do is
ctx.save_for_backward(tensor,weight)
out = F.linear(tensor,weight,bias)
return differentiable_reduce_scatter_sum(out, group)
Backward
def backward(ctx, grad_output)
- What you do is (similar to
ColumnLinear
forward)- start async gather
handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True)
- meanwhile, compute the local
grad_output
with the current shardtorch.mm(grad_output,weight,out=same_device_shard_grad_tensor)
handle.wait()
- Compute the rest of the
grad_tensor
usingtotal_grad_output
- compute weight and bias grad
grad_weight = total_grad_output.t().matmul(tensor)
grad_bias = total_grad_output.sum(dim=0) if use_bias else None
return total_grad_tensor, grad_weight, grad_bias, None, None
- start async gather