• Used in VQ-VAE and Training a model with N M sparsity from scratch

  • The straight-through estimator is used for training neural networks with discrete or non-differentiable operations (e.g. quantization).

    • It’s a method to approximate gradients for non-differentiable functions, allowing backpropagation to work through these otherwise problematic layers.
  • How it Works:

    • Forward Pass: Use the actual non-differentiable function.
    • Backward Pass: Pretend the non-differentiable function was the identity function.
  1. Example with Binarization: Consider a binarization function that outputs 1 if input > 0, and 0 otherwise.
    • Forward: y = sign(x)
    • Backward: dy/dx = 1 (pretending y = x)
  • Often works well in practice, despite the mismatch between forward and backward passes.
    • Can lead to instability in training if not carefully managed.