- Remember, running a small number of tokens through a model takes about the same amount of time as running a single token, thanks to memory access overhead
- Example:
- Given the sequence “I’m going”, it is very likely that “to” is the next prediction.
- Speculative decoding will speculatively run the forward pass with “I’m going to” as the context ⇒ if “to” is indeed predicted after “going”, we get the token after “to” for free!
- To produce this speculative predicitions (to be ran on a large model), we can use a “draft” model that’s small enough (and therefore quick enough to run) that it will pay for itself by avoiding passes through the larger “oracle” model. A good rule of thumb is for this model to be ~1/10 the size of the oracle model. It should also use the same tokenizer (to avoid needing to detokenize and retokenize the sequence over and over).
- We produce
n_draft
tokens, as the speculative context - Speculative decoding performance can be very context dependent! If the draft model is well-correlated with the oracle model and the text is easy to predict, you’ll get lots of drafted tokens and fast inference. But if the models aren’t correlated, speculative decoding can actually make inference slower, because you’re wasting time generating draft tokens that will just be rejected.
Threshold decoding
- Instead of alway generating
n_draft
tokens, we adjust the speculative context length based on:- cumulative likelihood of the sequence. if prob(context) < threshold, stop generating.
- adjust threshold based on whether base model accepts the generations from the draft model
Staged speculative decoding
Two improvements
- Restructure the draft batch as a tree, instead of a single generation.
- This helps because longer draft batches on complex text can quickly diverge from the base model. It can instead make more sense to do multiple, shorter drafts, branching off from each other, and then verify them all against the oracle model using a specially-crafted attention mask.
- Lets us use batch generation Generating multiple draft sequences lets you reuse prior tokens and sample the draft model in batches, further accelerating the process.
- Speculatively decode the draft model as well—it’s usually a Transformer after all. This could be a yet-smaller Transformer (they recommend 15-20x smaller than the oracle model), or even a simple N-gram model.