DISCLAIMER

Major shout-out to xjdr , doomslide, kalomaze, and Minh Nhat Nguyen. Pretty much everything I’m about to share with you comes from them. Consider me the messenger, passing along their wisdom with my own spin.

Source tweets: 1, 2, 3, 4

Defining entropy

  • Entropy =

  • varentropy =

  • A uniform distribution will have high entropy but low varentropy

  • A perfect peaked distribution i.e. will have low entropy and low varentropy

  • An almost peaked distribution i.e. will have low entropy but high varentropy

  • Consider a distribution over 10 outcomes: p(x₁) = p(x₂) = … = p(x₈) = 0.12 p(x₉) = 0.03 p(x₁₀) = 0.01. This distribution has relatively high entropy because it’s close to uniform for most outcomes. It also has high varentropy because of the presence of the two less likely outcomes.

  • Caveat: varentropy could be defined differently to better work as a heuristic in the below framework

    • e.g. varentropy(logits step-0 … step-1) vs varentropy(logits step)

Entropy computation

  • As seen above, entropy is defined as sum(probs * logprobs)
  • However, given we’re computing in float precision, it’s good to be careful.
  • Below, we define a logit noise threshold under which a prob/logprob should be considered “negligible” and not be taken into account in the computation

Logit noise threshold

  • source: https://x.com/doomslide/status/1840801746307764239

  • The relative error exponentiating a bfloat16 is u = 2^{-7} (mantissa)

  • The relative total error of the sum of vocab size is sqrt(vocab_size) * u (errors are uncorrelated)

  • Hence any prob below sqrt(vocab_size) * u / vocab_size = u / sqrt(vocab_size) should be considered negligible

  • Hence any logprob below ln(u) - ln(vocab_size)/2 should be considered -inf

  • Hence the maximal precision of logits of an llm are :

    • ln(min_p) = ln(u) - ln(vocab_size)/2
  • Example:

    • For vocab_size=128256, ln(min_p) = -10.7329 i.e. min_p ~ 2e-5

Final form/framework

  • Objective is for the sampling to stay within acceptable entropy and varentropy bounds

  • Why?

Using entropy as a measure of confidence or confusion

  • Because entropy and varentropy are good indicators of degenerate logits distribution

    • e.g. low entropy is synonym to a peaked distribution which can occur in good and bad scenarios such as (1) “doom looping”, the model repeats its previous output with high certainty (2) obvious next token learned from text distribution e.g. “The brown” “fox” (3) simply the “correct token” e.g. in a Q/A format
    • Not all scenarios are bad, but the assumption is eliciting more diversity from your model by entropy sampling should not decrease quality in good scenarios and should increase quality in bad scenarios
  • Additionally, when doing tree search, varentropy is suprisingly effective/correlative with model confusion/ungrounding good heuristic to know when the beam search should stop/continue

  • Also if you can drive the varentropy low enough, the model will usually spit out an eot token and let you knows its done.

Summary

  • Let ent_1 and ent_2 be the entropy lower and upper bound

  • Let varent_1 and varent_2 be the varentropy lower and upper bound

  • Let ent and varent be the current entropy and varentropy of my logits distribution at time step t. We have generated x_1, ..., x_(t-1)

  • If ent < ent_1 & varent < varent_1 (low entropy, low varentropy)

    • the model is confident
    • we do normal greedy decoding/argmax
  • If ent < ent_1 & varent_1 < varent (low entropy, high varentropy)

    • we use branching (described below) to inject entropy.
  • If ent_2 < ent & varent_2 < varent (high entropy, high varentropy)

    • we use backspace i.e. we resample the previous token
  • If ent_2 < ent & varent < varent_2 (high entropy, low varentropy)

    • we force CoT by injecting a “Wait…” token to tell the model to re-evaluate (o1 style)
    • important to add cooldown, otherwise the sampling may inject the CoT token over and over
  • Visual summary (source)

Branching

  • If entropy is low but high varentropy, it means the model could be in a “local optimum” e.g. doom loop

  • We use branching to “inject entropy” i.e.

    • adjust temperature
    • noise the logits distribution
    • do beam search to get a “good token”
      • stopping criteria is either end of text, max search length or back within entropy bounds
  • Top-k beam search can kind of suck because not every uncertainty situation should have the same truncation criteria.

    • Adaptive beam search using either top p or min p can make more sense for selecting how many candidates to have

Sample code in Jax to illustrate branching and CoT injection