• Defines the distributed primitives used in Tensor Parallelism and Sequence Parallelism such that the forward pass and backward pass are appropriately defined

  • reduce_scatter and all_reduce are equivalent in terms of communcation because all_reduce is implemented using a ring-all-reduce = reduce + scatter

    • we just need to bucketize by sequence dimension for things to go correctly.

PyTorch specifics

torch.autograd.Function

  • To define differentiable operator, you will need to define classes inheriting torch.autograd.Function. How it works is explained in torch.autograd

@assert_cuda_max_connections_set_to_1

  • functiond decorator that sets CUDA_DEVICE_MAX_CONNECTIONS=1 ensures that only one connection can be made to a single GPU from any host thread.
    • In ColumnLinearAsync, in the backward pass with tp_mode=REDUCE_SCATTER, they asynchronously fetch

Tensor Parallelism only

Differentiable Identity

  • Before a ColumnLinear operation (beginning of MLP)
  • Forward: f(x) = x i.e. all ColumnLinear receive the same input
  • Backward: b(x) = all_reduce_sum(x), because the same input was replicated over all processes

Differentiable_All_Reduce_Sum

  • After a RowLinear operation (end of MLP)
  • Forward: f(x) = all_reduce_sum(x) i.e. we need to accumulate the matrix to get the correct results
  • Backward: b(x) = x i.e. the gradient must be replicated over all processes, as they all participated

Tensor + Sequence Parallelism

Differentiable_All_Gather

    • Before a ColumnLinear operation (beginning of MLP)
  • Forward: f(x) = all_gather(x) i.e. was split by sequence previously, so we need to gather it back before feeding to ColumnLinear
  • Backward: b(x) = reduce_scatter_sum(x), because the same input was replicated over all processes, we need to reduce, and then scatter by splitting by sequence dimension.

How to implement it

  • nanotron code assumes it’s sharded by the first dimension
  • gets the current tensor and sharded_batch_size, *rest_size = tensor.shape
  • creates an empty tensor of unsharded_size i.e. unsharded_tensor= torch.empty((sharded_batch_size*group.size(), *rest_size))
  • call dist.all_gather_into_tensor(unsharded_tensor, tensor, group)

Differentiable_Reduce_Scatter

  • After a RowLinear operation (end of MLP)
  • Forward: f(x) = reduce_scatter_sum(x) i.e. we need to accumulate the matrix to get the correct results, and then scatter by splitting by sequence dimension
  • Backward: b(x) = all_gather(x) i.e. the gradient must be gathered along the sequence dimension and then replicated over all processes, as they all participated

Async Tensor + Sequence Parallelism

Motivation

  • If we’re smart, we can overlap communication and computation
    • e.g. In backward
      • e.g. for ColumnLinear, start the reduce_scatter of grad_tensor, while computing the gradient of the weight and bias i.e. grad_weight and grad_bias.
  • We rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the gather/reduce_scatter is scheduled before the tensor gradient computation in the code.

Details

  • RowLinear doesn’t support async if the tp_mode is all_reduce (tp only) instead of reduce_scatter (tp+sequence)
  • Must Define _{Row,Column}LinearAsyncCommunication(torch.autograd.Function) classes

Code

ColumnLinear

Forward

  • In def forward(ctx, tensor, weight, bias, group, tp_mode)
    • tensor is sharded/split by sequence dimension
    • We define the full input to be gathered_tensor
  • What you can do is
    1. ctx.save_for_backward(tensor, weight) (only save the sharded tensor to reduce activation memory, we will gather it back at backward time)
    2. start async gather handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group, async_op=True)
    3. meanwhile, compute the result matmul only with the sharded tensor
      1. torch.mm(tensor, weight, out=same_device_shard)
    4. handle.wait()
    5. compute the rest of the matmul with the rest of the gathered tensor

Backward

  • def backward(ctx, grad_output)
  • What you can do is
    1. tensor, weight = ctx.saved_tensors
    2. async gather gathered_tensor
      1. handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
    3. Compute grad_tensor = grad_output.matmul(weight)
    4. handle.wait()
    5. async reduce_scatter_sum the grad_tensor
      1.  handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)
    6. meanwhile, compute grad_weight and grad_bias
      1. grad_weight = grad_output.t().matmul(gathered_tensor)
      2. grad_bias = grad_output.sum(dim=0) if use_bias else None
    7. handle.wait()
    8. return sub_grad_tensor, grad_weight, grad_bias, None, None

Row Linear

Forward

  • def forward(ctx, tensor, weight, bias, group, tp_mode):
    • nothing much tricky going on
  • What you do is
    1. ctx.save_for_backward(tensor,weight)
    2. out = F.linear(tensor,weight,bias)
    3. return differentiable_reduce_scatter_sum(out, group)

Backward

  • def backward(ctx, grad_output)
  • What you do is (similar to ColumnLinear forward)
    1. start async gather handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True)
    2. meanwhile, compute the local grad_output with the current shard
      • torch.mm(grad_output,weight,out=same_device_shard_grad_tensor)
    3. handle.wait()
    4. Compute the rest of the grad_tensor using total_grad_output
    5. compute weight and bias grad
      1. grad_weight = total_grad_output.t().matmul(tensor)
      2. grad_bias = total_grad_output.sum(dim=0) if use_bias else None
    6. return total_grad_tensor, grad_weight, grad_bias, None, None