
ZeRO sits at the Data Parallelism Layer
 superscalable extension of DP

ZeRO removes the memory redundancies across dataparallel processes by partitioning the model states—parameters, gradients, and optimizer state—across data parallel processes instead of replicating them.
Setup
 How memory is organized within a GPU in vanilla DDP
 Let $Ψ$ be the number of parameters of the models.
 ”Naive” memory consumption of mixed precision training.
 fp16 copy of the parameters and the gradients, $2Ψ$ + $2Ψ$
 Optimizer states: $KΨ$
 $K=12$ for Adam, as we need to hold fp32 copies of parameters, momentum, and variance
 Total $2Ψ$ + $2Ψ$ + $KΨ$ = $16Ψ$
ZeRO++
 ZeRO++ reduces communication volume by 4x compared to ZeRO3.
 How?
 Quantized weights (qwZ) : Reduces allgather parameter communication volume by half by quantizing model weights to int8.
 Hierarchical Partitioning (hpZ): Hierarchical partitioning is a hybrid partitioning scheme that can help in multinode settings with DeepSpeed ZeRO 3. In this case, you can have model parameter sharding happening within a node, and then have replication across nodes. This means that you don’t have the same amount of memory savings as classic ZeRO3 running for the full setup, but you avoid expensive internode parameter communication overhead, thereby improving throughput in general. Related to hybrid sharding in FSDP (Fully Sharded Data Parallel)
 Quantized gradients (qgZ): Enables even more savings in communication volume by replacing fp16 with int4 quantized data during gradient reducescatter ops (Recall: this is the gradient gather + averaging step in ZeRO 2/3 with sharded gradients).
Partitioning (ZeRO1,2,3)
ZeRODP (optimizer state, gradients, parameters)
Let the DP degree be $N_{d}$.
$P_{os}$: Optimizer State Partitioning
 For a DP degree of $N_{d}$, we group the optimizer states into $N_{d}$ equal partitions, such that the $i$th data parallel process only updates the optimizer states corresponding to the $i$th partition.
 Thus, each data parallel process only needs to store and update $1/N_{d}$ of the total optimizer states and then only update $1/N_{d}$ of the parameters
 Memory requirements: $4Ψ+N_{d}K Ψ$. If $N_{d}$ is large, we get approx 4x reduction
$P_{os+g}$: Optimizer State and Gradient Partitioning
 For a DP degree of $N_{d}$, we group the optimizer states and the gradients into $N_{d}$ equal partitions, such that the $i$th data parallel process only updates the optimizer states corresponding to the $i$th partition.
 As each gradient of each layer becomes available during the backward propagation, we only reduce them on the data parallel process responsible for updating the corresponding parameters.
 After the reduction, we no longer need the gradients and their memory can be released on all processes expect the relevant one
 Memory requirements: $2Ψ+N_{d}2+K Ψ$. If $N_{d}$ is large, approx. 8x reduction in memory needs.
$P_{os+g+p}$: Parameter + Optimizer State + Gradient Partitioning
 Just as with the optimizer states, and the gradients, each process only stores the parameters corresponding to its partition.
 When the parameters outside of its partition are required for forward and backward propagation, they are received from the appropriate data parallel process through broadcast.
 The approach only increases the total communication volume of a baseline DP system to 1.5x
 Memory requirements: $N_{d}4+K Ψ$. If $N_{d}$ is large, approx. Can fit any model as long as there are sufficient number of devices to share the models states. Reduces consumption by $N_{d}$.
Communication volume analysis
 If two communications of volume $k$ happen in parallel, we only count a volume of $k$
 What matters is sequential communications, as those will impact performance
All reduce communication volume for gradients in typical DP = $2Ψ$
 SOTA implementation of allreduce use a two step approach: reducescatter + allgather
 reducescatter and allgather both require $Ψ$ communication volume. ⇒ $2Ψ$ communication volume for a single process (but everything overlaps in a ring, so same runtime as doing it with a node)
 sending everything to root node, reducing there, and sending back is also $2Ψ$ communication volume
 Total volume = $2Ψ$
Communication Volume with $P_{os+g}=2Ψ$
 With gradient partitioning, each process only stores the portion of the gradients, that is required to update its corresponding parameter partition.
 During the backward
 Each process sends a gradient of volume$N_{d}Ψ $ to the corresponding partition. Each process does so $N_{d}$ times over the process of the backward ⇒ $Ψ$ volume
 Once the parameters are updated on each process
 Each process sends parameters of volume $N_{d}Ψ $ to all the $N_{d}$ processes ⇒ $Ψ$ volume
 During the backward
 Total volume = $2Ψ$
Communication Volume with $P_{os+g+p}=3Ψ$

After parameter partitioning, each data parallel process only stores the parameters that it updates.

Therefore, during the forward propagation it needs to receives the parameters for all the other partitions.

During the forward
 The responsible process broadcasts its parameters of volume $N_{d}Ψ $ to all other processes, when it’s time to process its layer.
 This happens $N_{d}$ times.
 Volume = $Ψ$

Total volume = $2Ψ+Ψ=3Ψ$