-
Tensor parallelism, parallelizes the parts of the transformer layer that take the most time during training (attention and MLP) and as a result, it is computationally efficient.
-
However, it leaves the layer-norms as well as the dropouts after attention and MLP blocks intact and as a result, they are replicated across the tensor parallel group.
-
These elements do not require a lot of compute but demand a considerable amount of activation memory
- Quantitatively, the part of the above equation is due to these replicated operations and as a result they are not divided by the tensor parallel size t.
-
The operations are independent along the sequence dimension.
- This characteristic allows us to partition these regions along the sequence dimension s.
- Partitioning along the sequence dimension reduces the memory required for the activations.
- This extra level of parallelism introduces new communication collectives before which will act as converters between sequence and tensor parallel region
- If youβre smart, you can write good communication imperatives and the communication bandwidth used for tensor parallelism and tensor together with sequence parallelism are the same
Combining smartly with Tensor-Parallelism
-
We need to fuse communication primitives used in Tensor Parallelism and use and instead to avoid extra communication
- In a MLP block, the input X is split among the sequence dimension to do the LayerNorm. The result of the LayerNorm is then gather and split along the hidden dimension
-
is an all-gather operation along the sequence dimension in the forward pass
- reduce-scatter in the backward pass
-
is a reduce-scatter in the forward pass
- all-gather in the backward pass
-
By splitting along its columns ( and ) and B along its rows ( and ), we avoid communications (for more details please see [19]) and arrive at and . These two tensors are not parallel anymore and need to be summed as before they are fed into the dropout layer. However, dropout needs its input to be parallel in the sequence dimension s.
-
Instead of summing and then parallelizing in the sequence dimension, we combine these two operations into a reduce-scatter operation. As a result, can be a single reduce-scatter operation in the forward pass.
Communication costs
- Tensor parallelism requires four all-reduces in a single forward and backward pass whereas tensor together with sequence parallelism requires four all-gathers and four reduce-scatters in a single forward and backward pass.
- At the first look, it seems that tensor with sequence parallelism requires more communications compared to tensor parallelism.
- However, we note that a ring all-reduce is composed of two steps: a reduce-scatter followed by an all-gather. As a result, the communication bandwidth used for tensor parallelism and tensor together with sequence parallelism are the same. Therefore, sequence parallelism does not introduce any communication overhead.