The main loop is summarized as send previous payload from left to right, receive new payload from left, accumulate new payload from left
To be able to store previous payloads, and receive new payloads, the code has two buffers first_buff and second_buff
## send is the tensor to be reduced## recv is just where to store the final result of the accumulationdef allreduce(send, recv): rank = dist.get_rank() size = dist.get_world_size() first_buff = send.clone() second_buff = send.clone() accum = send.clone() left = ((rank - 1) + size) % size right = (rank + 1) % size for i in range(size - 1): if i % 2 == 0: # Send first_buff and receive second_buff send_req = dist.isend(first_buff, right) dist.recv(second_buff, left) accum[:] += second_buff[:] else: # Send second_buff and receive first_buff send_req = dist.isend(second_buff, right) dist.recv(first_buff, left) accum[:] += first_buff[:] send_req.wait() recv[:] = accum[:]
Collective Communication
As opposed to point-to-point communication, collectives allow for communication patterns across all processes in a group.
A group is a subset of all our processes
To create a group, we can pass a list of ranks todist.new_group(group)
By default, collectives are executed on all processes, also known as the world.
By default, all collectives are blocking
The function will not return until the operation has been completed across all the participating processes in the specified group
Can made async by using async_op=True
All-Reduce Example
in order to obtain the sum of all tensors on all processes, we can use the dist.all_reduce(tensor, op, group) collective
""" All-Reduce example."""def run(rank, size): """ Simple collective communication. """ group = dist.new_group([0, 1]) tensor = torch.ones(1) dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group) print('Rank ', rank, ' has data ', tensor[0])
All collectives
dist.broadcast(tensor, src, group): Copies tensor from src to all other processes.
dist.reduce(tensor, dst, op, group): Applies op to every tensor and stores the result in dst.
dist.all_reduce(tensor, op, group): Same as reduce, but the result is stored in all processes.
dist.scatter(tensor, scatter_list, src, group): Copies the ith tensor scatter_list[i] to the ith process.
input is the tensor to be reduced and scattered. Its size should be output tensor size times the group size. can be a concatenation or a stack.
dist.gather(tensor, gather_list, dst, group): Copies tensor from all processes in dst.
dist.all_gather(tensor_list, tensor, group): Copies tensor from all processes to tensor_list, on all processes.
other equivalent dist.all_gather_into_tensor(output_tensor, input_tensor, group)
output_tensor must be correctly sized i.e.
a concatenation of all the input tensors along the primary dimension
OR a stack of all the input tensors along the primary dimension
dist.barrier(group): Blocks all processes in group until each one has entered this function.
Point-to-Point Communication
A transfer of data from one process to another is called a point-to-point communication. These are achieved through the send and recv functions or their immediate counter-parts, isend and irecv.
Example
Two processes communicate a tensor
Both processes start with a zero tensor, then process 0 increments the tensor and sends it to process 1 so that they both end up with 1.0. Notice that process 1 needs to allocate memory in order to store the data it will receive.
Blocking point-to-point communication
"""Blocking point-to-point communication."""def run(rank, size): tensor = torch.zeros(1) if rank == 0: tensor += 1 # Send the tensor to process 1 dist.send(tensor=tensor, dst=1) else: # Receive tensor from process 0 dist.recv(tensor=tensor, src=0) print('Rank ', rank, ' has data ', tensor[0])
Non-Blocking point-to-point communication
"""Non-blocking point-to-point communication."""def run(rank, size): tensor = torch.zeros(1) req = None if rank == 0: tensor += 1 # Send the tensor to process 1 req = dist.isend(tensor=tensor, dst=1) print('Rank 0 started sending') else: # Receive tensor from process 0 req = dist.irecv(tensor=tensor, src=0) print('Rank 1 started receiving') req.wait() print('Rank ', rank, ' has data ', tensor[0])
When using immediates we have to be careful about how we use the sent and received tensors. Since we do not know when the data will be communicated to the other process, we should not modify the sent tensor nor access the received tensor before req.wait() has completed.
In other words:
writing to tensor after dist.isend() will result in undefined behaviour.
reading from tensor after dist.irecv() will result in undefined behaviour.
However, after req.wait() has been executed we are guaranteed that the communication took place, and that the value stored in tensor[0] is 1.0.