Most testing strategies miss rare edge cases until customers find them. We developed a system that automatically generates targeted unit tests for rare bugs, including the one that would have caught Anthropic’s recent approximate top-K bug.

Comparison of testing approaches
Figure 1: Unit-level PBTs are fast but miss edge cases. Proofs offer exhaustive coverage but require extensive reasoning and code refactoring. End-to-end PBTs have coverage but are not compute efficient. Fractional proofs sit at the intersection, using proof decomposition to generate targeted unit tests that balance compute efficiency, developer accuracy, and speed.

Catching the rare bug in top-K sampling

A bug in the TPU implementation of approximate top-K resulted in the most likely token sometimes being excluded. Rare bugs like this frequently slip through to production because covering every behavior with testing is infeasible in practice. After discovery, Anthropic provided a simple reproducer of the bug, but it is the sort of test you only manage to write after a laborious bug minimization process.

We used fractional proof decomposition to automatically generate the unit test without relying on Anthropic’s bug reproducer code. You can run the unit test on colab. For any code, if testing is done via fractional proof decomposition, bugs can be systematically found without the benefit of hindsight.

Top-K sampling property verification
Figure 2: Top-K sampling should always have some chance of picking the most likely token. We encode this property with a PBT for max(approximate_top_k(arr, k=k)) == max(arr). If the implementation of lax.approx_max_k is correct, we should expect the test to pass because the approximate top-K algorithm is implemented by dividing data points into L bins and computing the true max in each bin. L is chosen based on the desired average recall r as .

Systematically generating tests via fractional proof decomposition

Fractional proof decomposition process
Figure 3: We encode theorems as PBTs, then recursively decompose them into smaller sub-theorems using reasoning, and fuzz the theorems, aka run PBTs, once the decomposition creates compute efficient unit-level PBTs.

Step 1: Identify the theorem, which is the property that your implementation must satisfy. Then, encode it as a PBT using the Hypothesis framework. We call the top-level theorem an end-to-end PBT because it corresponds to the end-to-end behavior of the function.

The theorem for the top-K bug is:

$$\forall\ \text{prompt},\ k,$$ $$LLM_{\text{top-1}}(\text{prompt}) \in LLM_{\text{top-}k}(\text{prompt})$$

For any $\text{prompt}$, $LLM_{\text{top-k}}(\text{prompt})$ represents the model’s prediction of the top-K most likely next tokens. Now, we can encode this as an end-to-end PBT. Since the end-to-end PBT does not need to run on TPU, we set up a different colab.

Top-K theorem encoded as PBT
Figure 4: The theorem stating that the most likely token should always be included in the top k tokens, encoded as a PBT.

Although the end-to-end PBT has comprehensive coverage, to catch rare bugs would take an excessively large number of tokens. The rarer the bug, the more compute is required.

Compute requirements for rare bugs

Step 2: Recursively decompose the theorems into a collection of smaller theorems, which are also encoded as PBTs. These smaller theorems are intermediate results that compose together to establish the original end-to-end PBT.

Decomposing the end-to-end PBT for the top-K bug yields, by construction, three theorems:

  1. max(approximate_top_k(arr, k=k)) == max(arr) (true max always included) 1
  2. On any input tokens, the logits are finite (not ∞ and not NaN)
  3. In vLLM, the token ids are the same as the logprobs dict keys

You can think of PBTs as fractional components of the brute-force proof. Just as the brute-force proof is optimized by decomposing properties into logical sub-properties via reasoning, better known as partial evaluation, we’re applying reasoning to decompose the PBTs, aka fractional proofs. This reasoning bootstraps trust into PBT coverage–so you don’t have to exhaustively check every single input like a formal proof would but you still get systematic understanding of your programs and how you spend your testing compute. We’re calling this sampling technique fractional proofs.

Step 3: Continue decomposing until the input space is small enough to be compute efficient at catching rare bugs. You stop decomposing when:

  1. Each sub-test runs sufficiently quickly,
  2. Each sub-test tests a sufficient fraction of its input distribution so as to catch most bugs that would end up in code of similar complexity to that being tested, and
  3. The sub-properties provably compose to cover the full end-to-end property

While we found the top-K bug in about 10m of sampling, we were able to find the XLA:TPU bug (also discussed in Anthropic’s post), involving an issue with excess-precision, in just a few seconds. 2

Fractional proofs for efficient oversight

Systematic decomposition catches rare bugs without sacrificing developer speed or compute efficiency. Instead of scaling compute proportional to the rarity of the bug, fractional proofs scale compute as the logarithm of rarity.

Compute scaling comparison

We can straightforwardly extend the approximate top-K example in this post to real world codebases. For example, top-K can be decomposed into a sequence of PBTs testing how libtpu implements the algorithm described in its reference paper. Or, we can use this reasoning to establish how single-TPU behavior composes into cluster behavior.

At Theorem, we’re training models that can automatically reason about program correctness. If you want to catch bugs earlier and make your devs happy, talk to us.


  1. More generally, for any k’ <= k, at most floor(k’(L-1)/L) of the true top k’ values are excluded. ↩︎

  2. Because the older version of Anthropic’s code includes more computation around approximate top k, we decompose the theorem max(top_k_computation(arr, k=k)) == max(arr) into: max(arr) >= min(arr) and max(softmax(arr)) >= min(top_k(softmax(arr), k=k)). You can find the work at the bottom of the same colab notebook↩︎