-
Used in VQ-VAE and Training a model with N M sparsity from scratch
-
The straight-through estimator is used for training neural networks with discrete or non-differentiable operations (e.g. quantization).
- It’s a method to approximate gradients for non-differentiable functions, allowing backpropagation to work through these otherwise problematic layers.
-
How it Works:
- Forward Pass: Use the actual non-differentiable function.
- Backward Pass: Pretend the non-differentiable function was the identity function.
- Example with Binarization: Consider a binarization function that outputs 1 if input > 0, and 0 otherwise.
- Forward: y = sign(x)
- Backward: dy/dx = 1 (pretending y = x)
- Often works well in practice, despite the mismatch between forward and backward passes.
- Can lead to instability in training if not carefully managed.