torch.autograd.Function

How it works

  • To define differentiable primitive, one must define a class inheriting from torch.autograd.Function and implement forward(ctx,...) and backward(ctx,grad_output) as @staticmethod that take in ctx as the first argument
    • ctx is instance of torch.autograd.Function.Context
      • It is used to store information that you need to use in the backward pass.
      • Data stored in ctx during the forward pass will be available in the backward pass

Practical example with a linear layer

  • Let our linear layer be
  • The forward signature is forward(ctx, tensor, weight, bias, group, tp_mode)
    • For a linear layer, we save the activations, ctx.save_for_backward(tensor_input, weight) to be able to compute
  • At backward time, we must return a gradient for each parameter in the forward signature i.e. grad_tensor, grad_weight, grad_bias, None, None
  • The backward signature is backward(ctx, grad_output)
    • In our case, grad_output =
  • Then:
    • grad_tensor=grad_output.matmul(weight)
    • grad_weight= grad_output.matmul(tensor_input)
    • grad_bias=grad_output