To define differentiable primitive, one must define a class inheriting from torch.autograd.Function and implement forward(ctx,...) and backward(ctx,grad_output) as @staticmethod that take in ctx as the first argument
ctx is instance of torch.autograd.Function.Context
It is used to store information that you need to use in the backward pass.
Data stored in ctx during the forward pass will be available in the backward pass
Practical example with a linear layer
Let our linear layer be Y=XW+b
The forward signature is forward(ctx, tensor, weight, bias, group, tp_mode)
For a linear layer, we save the activations, ctx.save_for_backward(tensor_input, weight) to be able to compute
At backward time, we must return a gradient for each parameter in the forward signature i.e. grad_tensor, grad_weight, grad_bias, None, None
The backward signature is backward(ctx, grad_output)