
Defines the distributed primitives used in Tensor Parallelism and Sequence Parallelism such that the forward pass and backward pass are appropriately defined

reduce_scatter
andall_reduce
are equivalent in terms of communcation becauseall_reduce
is implemented using a ringallreduce = reduce + scatter we just need to bucketize by sequence dimension for things to go correctly.
PyTorch specifics
torch.autograd.Function
 To define differentiable operator, you will need to define classes inheriting
torch.autograd.Function
. How it works is explained in torch.autograd
@assert_cuda_max_connections_set_to_1
 functiond decorator that sets
CUDA_DEVICE_MAX_CONNECTIONS=1
ensures that only one connection can be made to a single GPU from any host thread. In
ColumnLinearAsync
, in the backward pass withtp_mode=REDUCE_SCATTER
, they asynchronously fetch
 In
Tensor Parallelism only
Differentiable Identity
 Before a
ColumnLinear
operation (beginning of MLP)  Forward:
f(x) = x
i.e. allColumnLinear
receive the same input $X$  Backward:
b(x) = all_reduce_sum(x)
, because the same input $X$ was replicated over all processes
Differentiable_All_Reduce_Sum
 After a
RowLinear
operation (end of MLP)  Forward:
f(x) = all_reduce_sum(x)
i.e. we need to accumulate the matrix to get the correct results  Backward:
b(x) = x
i.e. the gradient must be replicated over all processes, as they all participated
Tensor + Sequence Parallelism
Differentiable_All_Gather

 Before a
ColumnLinear
operation (beginning of MLP)
 Before a
 Forward:
f(x) = all_gather(x)
i.e. $X$ was split by sequence previously, so we need to gather it back before feeding toColumnLinear
 Backward:
b(x) = reduce_scatter_sum(x)
, because the same input $X$ was replicated over all processes, we need to reduce, and then scatter by splitting by sequence dimension.
How to implement it
 nanotron code assumes it’s sharded by the first dimension
 gets the current tensor and
sharded_batch_size, *rest_size = tensor.shape
 creates an empty tensor of
unsharded_size
i.e.unsharded_tensor= torch.empty((sharded_batch_size*group.size(), *rest_size))
 call
dist.all_gather_into_tensor(unsharded_tensor, tensor, group)
Differentiable_Reduce_Scatter
 After a
RowLinear
operation (end of MLP)  Forward:
f(x) = reduce_scatter_sum(x)
i.e. we need to accumulate the matrix to get the correct results, and then scatter by splitting by sequence dimension  Backward:
b(x) = all_gather(x)
i.e. the gradient must be gathered along the sequence dimension and then replicated over all processes, as they all participated
Async Tensor + Sequence Parallelism
Motivation
 If we’re smart, we can overlap communication and computation
 e.g. In backward
 e.g. for
ColumnLinear
, start the reduce_scatter ofgrad_tensor
, while computing the gradient of the weight and bias i.e.grad_weight
andgrad_bias
.
 e.g. for
 e.g. In backward
 We rely on
CUDA_DEVICE_MAX_CONNECTIONS=1
to ensure that the gather/reduce_scatter is scheduled before the tensor gradient computation in the code.
Details
RowLinear
doesn’t support async if the tp_mode isall_reduce
(tp only) instead ofreduce_scatter
(tp+sequence) Must Define
_{Row,Column}LinearAsyncCommunication(torch.autograd.Function)
classes
Code
ColumnLinear
Forward
 In
def forward(ctx, tensor, weight, bias, group, tp_mode)
tensor
is sharded/split by sequence dimension We define the full input $X$ to be
gathered_tensor
 What you can do is
ctx.save_for_backward(tensor, weight)
(only save the sharded tensor to reduce activation memory, we will gather it back at backward time) start async gather
handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group, async_op=True)
 meanwhile, compute the result matmul only with the sharded
tensor
torch.mm(tensor, weight, out=same_device_shard)
handle.wait()
 compute the rest of the matmul with the rest of the gathered tensor
Backward
def backward(ctx, grad_output)
 What you can do is
tensor, weight = ctx.saved_tensors
 async gather
gathered_tensor
handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
 Compute
grad_tensor = grad_output.matmul(weight)
handle.wait()
 async
reduce_scatter_sum
the grad_tensor
handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)

 meanwhile, compute
grad_weight
andgrad_bias
grad_weight = grad_output.t().matmul(gathered_tensor)
grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return sub_grad_tensor, grad_weight, grad_bias, None, None
Row Linear
Forward
def forward(ctx, tensor, weight, bias, group, tp_mode):
 nothing much tricky going on
 What you do is
ctx.save_for_backward(tensor,weight)
out = F.linear(tensor,weight,bias)
return differentiable_reduce_scatter_sum(out, group)
Backward
def backward(ctx, grad_output)
 What you do is (similar to
ColumnLinear
forward) start async gather
handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True)
 meanwhile, compute the local
grad_output
with the current shardtorch.mm(grad_output,weight,out=same_device_shard_grad_tensor)
handle.wait()
 Compute the rest of the
grad_tensor
usingtotal_grad_output
 compute weight and bias grad
grad_weight = total_grad_output.t().matmul(tensor)
grad_bias = total_grad_output.sum(dim=0) if use_bias else None
return total_grad_tensor, grad_weight, grad_bias, None, None
 start async gather