I was looking for a real systems application that would show what TorchLean is actually for: a place where an ML system makes a promise that is easy to say in English and surprisingly hard to guarantee in production.
The Thinking Machines Lab post Defeating Nondeterminism in LLM Inference[1] caught my attention because it explains a failure mode that sounds almost impossible the first time you hear it: temperature zero LLM inference can still return different completions for the same prompt. The model is not sampling. The decoder is supposed to pick the largest logit. And yet the serving system can still leak enough numerical variation into the computation that the final greedy token changes.
My request may be evaluated alone in one run and beside unrelated requests in another. That can change the batch shape, which can change the kernel path, tile size, key/value cache traversal, or partial-sum order. Because floating point addition is not associative, those choices can change bits. In an LLM, the changed bits can move through RMSNorm, matmul, attention, logits, and finally the argmax. Logits are the model’s raw scores for possible next tokens. Argmax is the greedy choice: pick the token with the largest score. If two logits are close enough, a bit level change can become a token level change.[4]
The formal target is simple: the output for my request should depend on my request, not on the accidental batch it was served in. To make that sentence checkable, we have to say what “depend on” means, what a kernel schedule is, what attention chunking is allowed to change, and how a CUDA runtime path can make a checkable promise instead of asking us to trust a black box.
I go through it roughly in the order I figured it out. Start with the serving promise. Name the missing mathematical object: the reduction schedule. Then RMSNorm and matmul become simple instances, attention becomes a scheduling problem over features and KV blocks, and decoding becomes a question about which tokens the server is allowed to release. At the end, the Lean development touches a compiled CUDA value reduction kernel through a checked FMA-chain certificate.
The checked code for this post is in Robertboy18/TorchLean-Verified-Examples, under week-01-batch-invariant-inference. The README there has the exact lake build command and the CUDA certificate regeneration command.
One thing up front, because the page mixes Python and Lean: I do not treat them as two competing “status” columns. I probe in Python, repeat a hosted call, compare logits alone versus in a batch, and see if anything flips. That is how I convince myself the symptom is real. Lean is where I formalize what would have to be true for the symptom to go away. Python is observation; Lean is the checked contract.
Here is the theorem I formalized mathematically: if the selected request has the same local state, if the numerical schedules that affect it are fixed or certified, and if the server only commits verifier accepted tokens, then the output shown to the user agrees with a canonical reference decoder.
argmax can flip.The mathematical shape is not “same input, same batch, different result.” It is:
The formalization below turns that sentence into a theorem. The serving promise becomes something the implementation side can be asked to satisfy: a schedule, a graph, a verifier, and a certificate, all talking about the same selected request.
The observable probe
The original systems post is persuasive because the experiment is simple. Hold the prompt fixed. Set temperature to zero. Ask the model for completions repeatedly. Count how many distinct token sequences come back.
Temperature zero means the sampler is supposed to be greedy. Instead of drawing randomly from a distribution over next tokens, it should repeatedly choose the largest logit. The surprise is that there is no intentional sampling noise left to blame.
The Thinking Machines post ran this at a much larger scale: 1000 temperature zero completions from Qwen/Qwen3-235B-A22B-Instruct-2507 on the prompt “Tell me about Richard Feynman.” They saw 80 unique completions. The first divergence was especially striking: the completions matched for the first 102 tokens, then split at token 103 between “Queens, New York” and “New York City.”[1]
Here I run the same kind of probe through Tinker: a hosted model API, a fixed prompt, fixed sampling parameters, repeated calls, and a count of unique completions. The measurement gives the observable symptom. The formalization asks for the condition underneath it: what would have to be true about the selected request, the numerical schedule, and the serving rule for this kind of split to be impossible?
export TINKER_API_KEY=...
python3 scripts/batch_invariance_demo.py \
--tinker-repeat \
--tinker-base-model meta-llama/Llama-3.2-1B \
--prompt "Tell me about Richard Feynman" \
--trials 10 \
--max-tokens 32I keep the probe in Python on purpose. Hosted inference is something I observe; Lean is where I state and check claims about schedules, batches, and serving. The script creates a sampling client, uses temperature zero, repeats the call, and counts distinct token tuples. It is deliberately close to the original experiment so the empirical symptom and the formal target line up.
params = tinker.SamplingParams(
max_tokens=max_tokens,
temperature=0,
seed=0,
)
outputs = []
for _ in range(trials):
resp = sampler.sample(
prompt=prompt,
num_samples=1,
sampling_params=params,
).result()
outputs.append(tuple(resp.sequences[0].tokens))
counts = Counter(outputs)
print(f"unique_completions={len(counts)}")There is also a local Hugging Face probe for the smaller matrix-style symptom: compare the logits for one prompt alone against the logits for the same prompt as the first row of a padded batch. It may or may not find a mismatch on a given machine. Either outcome is useful. The experiment is not trying to certify the hosted service. It is a diagnostic that points toward the mathematical condition: the selected request must see a request-local numerical schedule.
What Lean is checking
After the Python probe identifies the symptom, Lean is where I try to say exactly what “batch invariant” should mean and prove results under that meaning. Lean is both a programming language and an interactive theorem prover.[2] You write definitions, state theorems about those definitions, and construct proofs. Lean checks them with a small trusted kernel. The proof is not a comment that says “this should be true”; it is an object Lean type checks.
A tiny example: if I define \(f(x)=x+1\), I can ask Lean to check \(f(3)=4\). Here it is the same idea, just with larger objects: batched forward pass, reduction tree, attention schedule, reference decoder. Name the object, state the property, build the proof.
If you do not use theorem provers day to day, the feel is a bit like a very strict compiler, except it checks math. A theorem might say a parser is sound, or that a selected row of a batched computation cannot see unrelated rows. If it compiles without sorry, admit, or new axioms, I treat that as “checked,” not “probably fine.”
TorchLean provides the neural network machinery.[3] It gives shape indexed tensors, an op tagged graph IR, finite precision semantics, autograd/runtime definitions, and certificate checking infrastructure. “IR” means intermediate representation: a graph-shaped object that records operations and dependencies after model code has been lowered into a more inspectable form. Here I use those ingredients in a deliberately narrow way: define specifications for pieces of an inference system, prove theorems about those specifications, and mark the places where compiled CUDA or hardware must provide a refinement certificate.
I use three words in a very literal way. A specification is the mathematical object we prove about. An implementation is a program or kernel that is supposed to behave like that object for the property at hand. A certificate is finite evidence that Lean can check to connect compiled code or generated data back to the specification.
The Lean development defines a vocabulary for the serving problem. Each concept names something production systems usually leave implicit.
Schedules
Reduction order is represented explicitly as a binary tree, which is how floating point non associativity enters the semantics.
Batch invariance
A batched forward pass is correct for serving only if the selected request output depends on the selected request state, not on unrelated rows.
Certificates
Generated CUDA/PTX/SASS facts are packaged as a Lean value; the checker recomputes the contract on that value and produces the semantic theorem.
The first specification: a batched forward pass
First strip away Transformers and keep only the shape of a serving call. A batched forward pass receives a batch size \(B\), the inputs in that batch, and the index of the row we care about. It returns the output for that selected row.
If \(B=3\), think of the batch as three requests sitting next to each other: Alice, me, and Bob. The property is about the middle row. If my request is the same, then my result should be the same whether Alice and Bob are there, whether Carol and Dave are there, or whether I am served alone.
Lean notation can make this look scarier than it is. The type \(\mathrm{Fin}(B)\) means “a valid index less than \(B\).” So if \(B=4\), an element of \(\mathrm{Fin}(B)\) is one of \(0,1,2,3\), together with the fact that the index is in bounds. A function \(\mathrm{Fin}(B)\to X\) is just a batch of \(B\) inputs: give it a legal row index and it returns the \(X\)-value stored at that row.
Give me any batch size, give me the rows in the batch, tell me which row is mine, and I will return my output. The property is not ordinary determinism for one fixed batch. It is a relation between two different batch contexts. If the selected row in the first batch equals the selected row in the second batch, the selected outputs should agree.
-- A batched forward pass receives:
-- B : the batch size,
-- rows : a function mapping each legal row index to an input,
-- selected: the row whose output we observe.
abbrev BatchedForward (X : Type u) (Y : Type v) :=
(B : Nat) -> (Fin B -> X) -> Fin B -> Y
-- Batch invariance compares two different batch contexts.
-- If the selected rows are equal, the selected outputs must be equal.
def BatchInvariantForward (forward : BatchedForward X Y) : Prop :=
forall {B C : Nat}
(xs : Fin B -> X) (ys : Fin C -> X)
(i : Fin B) (j : Fin C),
xs i = ys j -> forward B xs i = forward C ys jThis definition is intentionally simple. It says exactly what the user of an inference endpoint needs: my request’s output should not depend on which other rows happen to be next to it. Everything later in the note exists to make real kernels and real serving paths satisfy this relation.
The finite precision counterexample
Before defining the schedule machinery, it helps to see why we need it. I start with a finite precision counterexample using TorchLean’s executable IEEE32 model. IEEE32 is the usual 32 bit floating point format used all over ML: one sign bit, exponent bits, fraction bits, rounding, infinities, NaNs, and all the small surprises that come with real hardware floats.[4]
This is one place where TorchLean is doing something more useful than ordinary real-number math. If we proved the property over exact real numbers, addition would be associative and this bug would disappear from the model. TorchLean gives us an executable Float32 semantics inside Lean, so the theorem can talk about the same kind of finite precision arithmetic that makes the serving problem possible in the first place.
The chosen numbers are the classic large cancellation pattern:
Those equations are about the executable float model, not real arithmetic. They show that two legal parenthesizations of the same three additions can produce different Float32 results.
Real arithmetic
Addition is associative, so changing parentheses does not change the result.
Float32 arithmetic
Every operation rounds to a finite format, so changing the reduction tree can change bits.
TorchLean
The finite precision behavior is an executable Lean object, so the counterexample is checked rather than merely described.
The warning example builds a deliberately bad schedule selector. If the selected request is served alone, the selector uses the left parenthesization. If the same selected request is served inside a larger batch, the selector uses the right parenthesization. Nothing about the selected row changed. The only thing that changed is the surrounding batch context, and the selector used that context to choose a different arithmetic tree.
A production kernel would choose algorithms through more complex shape heuristics, not this tiny branch. The counterexample isolates the danger: if a heuristic lets unrelated batch context choose the selected request’s reduction order, then Float32 non-associativity can make the selected result change.
-- This theorem says the bad schedule selector is not batch invariant.
--
-- The forward pass below reduces one selected row with IEEE32 addition.
-- `chooseByBatchSize3` is intentionally bad:
-- * batch size 1 -> left tree: ((a + b) + c)
-- * larger batch -> right tree: (a + (b + c))
--
-- The selected row is the same in both runs. Only the surrounding
-- batch size changes, and that is enough to change the reduction tree.
theorem IEEE32_batchDependentSchedule_counterexample :
¬ BatchInvariantForward
(reduceForwardWithScheduleSelector
TorchLean.Floats.IEEE754.IEEE32Exec.add
chooseByBatchSize3) := by
-- The generic lemma says: if a combiner has a non-associative
-- witness, then this batch-dependent selector can break invariance.
exact batchDependentSchedule_counterexample
-- The combiner is TorchLean's executable Float32 addition.
TorchLean.Floats.IEEE754.IEEE32Exec.add
-- These are concrete IEEE32 values: 1e20, -1e20, and 1.0.
IEEE32.big IEEE32.negBig IEEE32.one
-- Lean checks that the two parenthesizations really differ.
IEEE32.add_nonassoc_witnessYou do not need all the Lean names yet. Read chooseByBatchSize3 as the intentionally bad runtime policy: batch size one gets one addition tree, larger batches get another. The next section introduces the missing formal object, the reduction schedule, which lets us state this dependency directly.
The theorem does not say every batching system is broken. It says a specific kind of dependence is unsafe: if the selected request’s arithmetic order is allowed to depend on unrelated batch context, then there are concrete Float32 inputs where the selected output changes. The rest of the article turns that negative example into a positive requirement: schedules must be fixed, request local, or checked by a certificate.
The second specification: a reduction schedule
The counterexample above forces us to name the thing that changed. In ordinary mathematical writing, we usually write a sum and move on. We do not say whether the numbers are added left to right, as a balanced tree, in tiles, or through a split reduction. For real numbers that silence is harmless. For floating point numbers, that silence hides the bug.
For example, a GPU kernel might compute a sum by giving different chunks to different warps, then combining the partial sums. Another kernel path might use a different tile size and combine the same values in another order. Both are reasonable implementations of “sum.” They are not necessarily the same Float32 computation.[4]
Most neural network operations are pointwise or shape movement; they do not mix values across a reduction dimension. The operations that matter here are reductions. In RMSNorm we reduce over the hidden dimension. In matmul we reduce over \(K\). In attention we reduce over features to compute scores and over the KV dimension to compute the weighted value sum.
In real arithmetic, a sum is a sum. In floating point arithmetic, the tree matters:
So the spec does not define a reduction as merely a list. It defines a reduction as a binary tree plus an evaluator. The concrete failure is sharper than “floating point is approximate”: the runtime changes the tree that determines the selected request’s rounded operations.
import NN
import NN.Floats.IEEEExec.Reductions
-- A reduction schedule is a binary tree over leaf indices.
-- Leaves name input positions; internal nodes say which partial
-- results are combined first.
abbrev SumTree (ι : Type u) := TorchLean.Floats.IEEE754.SumTree ι
-- To evaluate the schedule, provide:
-- combine : how to combine two partial results,
-- leaf : how to read each leaf value.
def SumTree.evalWith
(combine : β -> β -> β) (leaf : ι -> β) : SumTree ι -> β
| .leaf i => leaf i
| .node l r =>
combine (evalWith combine leaf l)
(evalWith combine leaf r)Once the schedule is explicit, the first theorem is almost forced: if two selected rows have the same leaves and the same schedule, they have the same reduction result. In Lean this theorem is by induction on the tree. That proof is short, but it is the foundation for everything else.
The theorem gives the systems side a precise target to implement or certify: same request local values, same request local tree.
Schedules are allowed to depend on local shape, not on server load
Here “schedule” means the arithmetic plan used by a kernel: which values are grouped together, which partial sums are formed first, how tiles are combined, and which tree of additions is used to produce the selected output. It is not a wall-clock schedule like “run request A before request B.” It is the numerical schedule inside the operation.
For a dot product, a schedule might say: multiply all \(K\) pairs, add them left to right, and return the final accumulator. A faster kernel might split \(K\) into blocks, sum each block separately, then add the block sums. For attention, a schedule also includes how the logical KV sequence is split into chunks or blocks before the weighted values are combined.
My first attempt was too blunt: require every batch row to use the exact same global schedule. That would make the proof easy, but it is not the right requirement. Real kernels are not written that way, and they should not have to be. A kernel may reasonably choose one strategy for hidden dimension 4096 and another for hidden dimension 8192. It may choose a schedule based on dtype, head dimension, sequence length, KV length, or block size. I had to back up and separate local shape choices from server load choices.
The line I want to draw is different: the schedule can depend on facts local to the selected request, but not on accidental facts about the server batch. If my prompt has the same prefix, same KV length, same head dimension, and same dtype, then the reduction schedule that determines my result should not change merely because another request arrived next to mine.
This is the engineering requirement. Optimized kernels are welcome. The choice of arithmetic path for request \(r\) should be a function of request \(r\)'s own local metadata, rather than a function of whoever happened to share the batch with \(r\).
| Schedule choice | Allowed reason | Suspicious reason |
|---|---|---|
| hidden dimension reduction tree | the selected row has hidden dimension \(d\) | the server batched the row with 31 other requests |
| matmul split over \(K\) | the selected dot product has a fixed \(K\) and dtype | the batch size crosses a kernel heuristic threshold |
| attention KV block tree | the selected request has a fixed logical KV length and block size | prefill was chunked differently because of current load |
So I introduce a request local context:
-- Facts the kernel may legitimately use when choosing a schedule.
-- These are local to the selected request and operation.
structure LocalKernelCtx where
hiddenDim : Nat
seqLen : Nat
kvLen : Nat
dtypeTag : Nat
blockSize : Nat
-- The schedule selector is given the batch size and row index,
-- but this property says those may not matter once the local
-- contexts are equal.
def RequestLocalScheduleInvariant
(choose : (B : Nat) -> Fin B -> LocalKernelCtx -> SumTree ι) : Prop :=
forall {B C : Nat} (i : Fin B) (j : Fin C)
(ctx1 ctx2 : LocalKernelCtx),
ctx1 = ctx2 -> choose B i ctx1 = choose C j ctx2That is the design rule in formal form. The schedule may depend on the work local to the selected request. It may not depend on the accidental batch context generated by load.
That distinction matters because otherwise the property becomes unrealistic. Batch invariance does not mean “use one slow kernel shape for the whole universe.” It means “do not let unrelated requests decide the arithmetic tree for my request.”
RMSNorm and matmul become instances
Once the schedule idea is in place, RMSNorm and matmul are less mysterious. They are not easy because production kernels are easy; they are easy at the semantic layer because each selected output is built from a request local reduction.
RMSNorm is a normalization layer used in many Transformer models. For each token row, it computes a scale from the row’s hidden activations, then rescales each coordinate. The only part that mixes hidden coordinates is the sum of squares. Matmul is the usual matrix multiply behind linear layers and attention projections. Each output entry is a dot product, so the only part that mixes values is the reduction over the inner dimension.
The key question becomes very local: for the row I care about, which values are read, and in what tree are they combined? If that local read pattern and tree are fixed, the unrelated rows have no mathematical route into the answer.
RMSNorm has one essential reduction:
The spec keeps the scalar operations abstract. The theorem does not need to know whether \(\operatorname{rsqrt}\) is real valued, rounded real, or executable IEEE. It only needs to know that the selected row, weights, and hidden dimension schedule are fixed. If those are fixed, the other rows in the batch have nowhere to enter the computation.
import torch
def rms_norm(x, weight, eps=1e-5):
# x: [batch, hidden]
mean_sq = torch.mean(x * x, dim=-1, keepdim=True)
return x * torch.rsqrt(mean_sq + eps) * weight
request = torch.randn(4096)
others = torch.randn(7, 4096)
weight = torch.randn(4096)
single = rms_norm(request[None, :], weight)[0]
batched = rms_norm(torch.cat([request[None, :], others]), weight)[0]
# This checks one model state, one input, one backend, one run.
print(torch.equal(single, batched))-- RMSNorm is modeled as a scheduled hidden-dimension reduction.
-- `sched` fixes the order of the sum of squares.
-- `scaleMean` and `rsqrt` stay abstract, because this theorem
-- is about batch dependence, not about a particular square-root implementation.
def rmsNormRow
(add mul : β -> β -> β)
(scaleMean rsqrt : β -> β)
(sched : SumTree ι)
(x weight : ι -> β) : ι -> β :=
let sqsum :=
SumTree.evalWith add (fun k => mul (x k) (x k)) sched
let invRms := rsqrt (scaleMean sqsum)
fun j => mul (mul (x j) invRms) (weight j)
-- If the selected row, weights, and schedule agree, every output
-- coordinate of the RMSNorm row agrees.
theorem rmsNorm_batchInvariant
(hrow : forall k, x₁ k = x₂ k)
(hweight : forall k, w₁ k = w₂ k) :
rmsNormRow add mul scaleMean rsqrt sched x₁ w₁ =
rmsNormRow add mul scaleMean rsqrt sched x₂ w₂ := by
funext j
simp [rmsNormRow, hrow, hweight]How I use the two sides: Python tests one backend, one input, and one run. Lean states the condition that would make the result independent of batch context for all rows satisfying the hypotheses.
Matmul is the same idea repeated for each output coordinate:
The theorem for matmul is a dot product theorem with the schedule written out. The formalization also includes a TorchLean tensor version, reading selected rows from tensors whose shapes are tracked by Lean. This connects the simple function level proof to the tensor discipline used by TorchLean.
Here is the intuition I would give a kernel engineer: matmul is safe for a selected request if the dot product that computes that request’s output coordinate uses the same \(K\)-reduction schedule in both serving contexts. Split-K can be fast, but if the decision to use it is driven by the surrounding batch shape, then the selected request has inherited nondeterminism from server load.
import torch
def linear(x, W, b):
# x: [batch, in_dim], W: [out_dim, in_dim]
return x @ W.T + b
request = torch.randn(4096)
others = torch.randn(31, 4096)
W = torch.randn(4096, 4096)
b = torch.randn(4096)
alone = linear(request[None, :], W, b)[0]
inside_batch = linear(torch.cat([request[None, :], others]), W, b)[0]
# If the backend changes the reduction schedule with batch shape,
# these can differ even though the selected request is the same.
print(torch.equal(alone, inside_batch))-- A dot product is a reduction over K with an explicit tree.
def scheduledDot
(add mul : β -> β -> β)
(sched : SumTree K)
(x : K -> β) (w : K -> β) : β :=
SumTree.evalWith add (fun k => mul (x k) (w k)) sched
def batchedMatmul
(add mul : β -> β -> β)
(sched : SumTree K)
(x : Fin B -> K -> β)
(w : Out -> K -> β)
(b : Fin B) (o : Out) : β :=
scheduledDot add mul sched (x b) (w o)
-- Same selected row, same weights, same K-tree: same scalar output.
theorem matmul_batchInvariant
(hrow : forall k, x₁ i k = x₂ j k)
(hW : forall o k, W₁ o k = W₂ o k) :
batchedMatmul add mul sched x₁ W₁ i o =
batchedMatmul add mul sched x₂ W₂ j o := by
simp [batchedMatmul, scheduledDot, hrow, hW]Again, Python checks an instance. Lean states the reusable requirement: same selected row, same \(K\)-tree, same weights ⇒ same scalar, no matter who else is in the batch. That is what a CUDA or Triton implementation would need to certify.
| Kernel | Specification | Checked statement |
|---|---|---|
| RMSNorm | hidden dimension tree for \(\sum x_k^2\) | same row + same tree + same weights gives same selected output |
| Matmul | dot product tree over \(K\) | same selected input row + same weight row + same tree gives same scalar output |
| Tensor matmul | TorchLean shape indexed tensor rows | the function level theorem lifts to Spec.Tensor row access |
Attention required a more careful spec
Attention is where the simple reduction picture needs more structure. It is tempting to say “attention is just matmul plus softmax plus matmul,” but inference engines do not execute attention as one clean mathematical line. They prefill prompts in chunks, maintain KV caches, use paged layouts, and split work across blocks for performance.[9]
The letters Q, K, and V stand for query, key, and value. For the selected token, the query vector is compared against key vectors from the context to produce scores. Softmax turns those scores into weights. Then the value vectors are combined with those weights to produce the output. A KV cache stores previously computed keys and values so decoding does not recompute the whole prefix on every new token.
For a single token, attention asks: which keys do I compare against, in what order do I compute the scores, and in what order do I combine the values? A serving system may store those keys and values in pages, reuse old prefixes, or process a long prompt in chunks. The theorem has to see through those implementation choices to the same logical KV sequence: the same ordered list of keys and values belonging to the selected request.
Attention is harder because there are two kinds of reductions. First, for each key position \(t\), a score is computed by reducing over the feature dimension:
Then the output coordinate reduces over the KV dimension:
The Thinking Machines analysis emphasizes two attention sensitivities: how many requests are batched together, and how each request is sliced by the inference engine.[1] Chunked prefill, prefix caching, decode with a KV cache, and Split-KV strategies can all change the reduction order.
import torch
def attention(q, k, v):
# q: [D], k: [T, D], v: [T, O]
scores = q @ k.T
probs = torch.softmax(scores, dim=-1)
return probs @ v
# In a serving engine, the same logical k/v sequence may be
# traversed through different chunks, pages, or split-KV blocks.
out_full = attention(q, k_full, v_full)
out_chunked = attention(q, k_from_cache_pages, v_from_cache_pages)-- A policy maps local attention context and chunk plan
-- to the logical KV reduction tree.
structure KVLayoutPolicy (KV : Type u) where
kvTree : AttentionLocalCtx -> ChunkPlan -> SumTree KV
-- Same local context, same logical KV tree,
-- even if the server chose a different chunk plan.
canonical :
forall {ctx1 ctx2 : AttentionLocalCtx} (chunk1 chunk2 : ChunkPlan),
ctx1 = ctx2 -> kvTree ctx1 chunk1 = kvTree ctx2 chunk2
-- The attention theorem uses this canonicality condition:
-- same selected q/k/v payload + canonical schedules
-- implies same selected attention output.The Python block is the operational picture. The Lean block names the KV traversal tree explicitly. Once that tree is canonical for the selected request, different prefill chunks or cache pages are implementation details rather than reasons the selected token should change.
There is an important caveat here. Real FlashAttention does not literally materialize the full score matrix, run softmax, and then multiply by values. It is an IO aware tiled algorithm: it processes blocks and maintains a running maximum, normalizer, and partial output so the exact attention result can be computed without writing the whole attention matrix to memory.[8] Newer FlashAttention kernels add still more scheduling structure, including asynchrony and low precision paths on Hopper GPUs.[11] So a FlashAttention certificate needs more than a KV summation tree. It also needs the order of max updates, normalizer updates, accumulator rescaling, and any low precision conversion steps. The Lean theorem here is the schedule semantic target; a concrete FlashAttention proof would have to show that its tiled online softmax chain refines that target.
The Lean spec separates the two schedule policies:
Feature schedule
A tree over the head dimension, used to compute \(q\cdot k_t\). It may depend on request local metadata such as head dimension and dtype.
KV schedule
A tree over logical KV blocks, used to reduce the weighted values. It must be canonical with respect to chunking.
The key object is a canonical KV layout policy. It maps the selected request’s local attention context and the server’s chunk plan to a KV reduction tree. Canonicality says that for the same local context, different chunk plans produce the same tree.
The claim is narrower than “attention is deterministic.” Attention becomes batch invariant when the logical KV sequence is traversed through a canonical tree. If the server slices the request differently, the proof obligation is to show that those slices still produce the same logical reduction tree for the selected token.
The current theorem keeps softmax abstract: it is a deterministic function from the complete score vector to KV weights. The checked proof first proves that the two executions have the same score vector. Once the score vector is equal, applying the same abstract softmax gives the same weights, and the remaining KV reduction follows by induction on the KV tree.
-- The actual proof separates score equality from the KV reduction.
-- Softmax is abstract here:
-- softmax : (KV -> β) -> KV -> β
theorem selectedAttention_batchInvariant_of_schedulePolicies
(hctx : ctxs1 i = ctxs2 j)
(hquery : forall f, (inputs1 i).query f = (inputs2 j).query f)
(hkey : forall kv f, (inputs1 i).key kv f = (inputs2 j).key kv f)
(hvalue : forall kv out, (inputs1 i).value kv out = (inputs2 j).value kv out) :
selectedAttentionWithPolicies add mul featurePolicy layout softmax
inputs1 ctxs1 chunks1 i out =
selectedAttentionWithPolicies add mul featurePolicy layout softmax
inputs2 ctxs2 chunks2 j out := by
unfold selectedAttentionWithPolicies
rw [featurePolicy.canonical hctx]
rw [layout.canonical (chunks1 i) (chunks2 j) hctx]
exact scheduledAttentionOut_eq_of_same_inputs
add mul (featurePolicy.featureTree (ctxs2 j))
(layout.kvTree (ctxs2 j) (chunks2 j)) softmax
(inputs1 i) (inputs2 j) out hquery hkey hvalueThe attention theorem then says: if the selected request local context agrees, the selected Q/K/V payloads agree, the feature schedule is request local, and the KV layout is canonical, then the selected attention output is batch invariant.
q · k_t
canonical tree
same under batching
There is also a concrete fixed layout construction. Instead of treating canonicality only as an assumption, the formalization defines a fixed KV block plan and proves that its tree is independent of the chunk plan. This is the Lean analogue of the fixed split size idea: split by a request local block size, not by whatever batch/chunk shape the server happened to use.
This is the specification that FlashAttention, FlashDecoding, or paged attention would have to refine.[8][9] First state the clean mathematical schedule object; then ask a runtime path, certificate, or kernel proof to show that the implementation follows that object.
When bits move but the token cannot
Bitwise equality is the cleanest route to deterministic serving. It is not the only route. Sometimes schedules differ, logits move slightly, but the top token remains safely separated from the rest.
This is the everyday intuition behind margins. If the winning logit is \(12.0\) and the next best is \(3.0\), a tiny numerical perturbation will leave the token alone. If the winning logit is \(12.00001\) and the next best is \(12.00000\), the same perturbation might matter a lot.
The margin is the gap between the winning logit and the best competing logit. A large margin means the greedy choice is robust to small numerical drift. A tiny margin means one rounded addition somewhere upstream might be enough to change the token.
Serving systems often care about tokens, not every intermediate bit. If two logit vectors differ by a tiny amount but the winning token is still far ahead, then greedy decoding is stable. This gives a different kind of certificate: not “the two executions are bitwise equal,” but “the numerical drift is too small to change the committed token.”
The margin theorem is stated over rationals. If two logit vectors are close in \(\ell_\infty\), and the winning logit has margin greater than twice the error, then any strict argmax implementation returns the same token.
The theorem is deliberately stated with an explicit StrictArgmaxSound contract. That prevents the proof from assuming an implementation detail of tie breaking. The implementation only has to satisfy the usual spec: whenever one token strictly beats all others, choose it.
This gives a second certification path. Instead of requiring all kernels to be bitwise batch invariant, a verifier can certify that all admissible schedule variations stay within a logit error budget and that the margin is large enough.
Batch invariant kernels are the clean systems answer. Margin certificates cover cases where a fast path is not bitwise identical to the reference but the committed token is still provably unchanged.
Serving is a refinement problem
A real server has queues, workers, memory allocation, batching policy, speculative candidates, CUDA streams, and many other details. I decided not to model all of that first. The smaller target is the part relevant to observable correctness: which tokens are released to the user.
Refinement is the word I use for “the fast thing behaves like the reference thing for the property we care about.” The reference decoder is the simple, canonical semantics: fixed model graph, fixed finite precision assumptions, fixed token choice rule. The fast server is allowed to use dynamic batching and optimized kernels, but the user-visible output should refine the reference decoder.
The move here is very practical, and it is in the same family of ideas as speculative decoding: use a fast path to propose work, then use a trusted check before committing the visible result.[10] Here I use that pattern as a refinement rule. A fast path can be messy internally as long as the server has a rule for what it is allowed to commit. If a candidate token is checked against a canonical reference decoder, it can be released. If not, the server rolls back to the last verified prefix.
At this point the question shifts from a single kernel to the serving protocol. The fast path can be complicated. The theorem about what the user sees should stay simple: every token released to the user is the token the canonical reference decoder would have released.
The server theorem introduces a canonical reference decoder. Its step function is the fixed finite precision decoder we want to be the ground truth. A fast path is allowed to speculate. It may use dynamic batching. It may produce candidate tokens under a schedule dependent implementation. But committed tokens are appended only through a verifier sound window.
dynamic batches allowed
accepts or rejects
user visible output
The central theorem is a trace theorem. Starting from a state whose committed list matches the empty reference prefix, every legal decode/verify/rollback trace preserves the invariant. Speculation and rollback do not change the user visible prefix. Verification acceptance can extend the prefix, but only by a sequence of tokens equal to the reference decoder’s next window.
This avoids proving the entire server implementation first. Instead, it proves that a protocol with verifier sound acceptance refines the reference decoder. The production server can be fast and messy internally, but the observable stream is pinned to the checked reference semantics.
The TorchLean bridge then specializes the reference decoder to NN.IR.Graph.denote. In TorchLean, Graph.denote is the mathematical evaluator for an IR graph: give it a graph, parameter payload, input value, and output id, and it returns the denotational result or an error. One decoder step turns the request state into an IR input value, evaluates the TorchLean graph with a payload, chooses a token from the denotational result, and updates the request local state.
The CUDA microkernel certificate
To make the runtime bridge concrete, I used a tiny CUDA kernel. CUDA is NVIDIA’s programming model for running many lightweight GPU threads in parallel. A kernel is a function launched over a grid of thread blocks. Inside the kernel, each thread has an id, can read and write GPU memory, and may cooperate with other threads through shared memory or synchronization. That is powerful, but it also creates exactly the kind of question a proof needs to answer: which thread computed which value, which memory locations did it read, and in what arithmetic order?
The kernel I chose is deliberately small. It verifies the value reduction part of attention, after softmax weights have already been computed. In normal attention, after the model computes attention weights, each output coordinate is a weighted sum over value vectors. For one batch row \(b\) and one output coordinate \(o\), the kernel computes:
This is not the full attention kernel. It is the smallest useful bridge from compiled CUDA back to the schedule semantics. There is no online softmax yet, no tensor cores, no shared memory, no warp reductions, and no barriers. That is intentional. Before verifying a larger FlashAttention style kernel, I wanted one end-to-end example where every active thread has simple ownership: thread \(o\) writes only out[b,o], reads weights[b,0..7] and V[b,0..7,o], and performs the same eight-step reduction order.
extern "C" __global__
void tiny_attn_one_row(
const float* __restrict__ weights, // [B, 8]
const float* __restrict__ v, // [B, 8, 4]
float* __restrict__ out, // [B, 4]
unsigned int B) {
const unsigned int b = blockIdx.x;
const unsigned int tid = threadIdx.x;
if (b >= B) return;
if (tid < 4) {
float acc = 0.0f;
// Fixed request-local KV schedule: 0,1,2,3,4,5,6,7.
for (unsigned int t = 0; t < 8; ++t) {
const float w = weights[b * 8 + t];
const float val = v[(b * 8 + t) * 4 + tid];
acc = acc + w * val;
}
out[b * 4 + tid] = acc;
}
}Compiling this source does not immediately give us a proof. The compiler lowers CUDA C++ through several representations. PTX is NVIDIA’s virtual assembly language: a low-level, readable instruction set that still abstracts over the exact GPU machine code. SASS is the actual architecture-specific machine instruction listing produced for a target GPU. In this example, the pipeline compiles the CUDA source with nvcc, extracts PTX and SASS, emits a JSON certificate plus a generated Lean certificate, and then asks Lean to check the facts the certificate claims.[5][6][7]
The verification question is narrow but real: does the compiled fragment have the dataflow we think it has? In particular, does the value stored to out[b,o] come from the left-to-right chain over \(t=0,\ldots,7\)? Are the arithmetic instructions fused multiply-adds? Are there atomics, barriers, or shared-memory effects hiding in the compiled files? The certificate records those facts in a finite object that Lean can check. The source, PTX, SASS, and hashes remain useful for audit and reproducibility; the Lean theorem depends on the accepted certificate and the checker’s soundness lemma.
PTX
SASS
dataflow
hashes
accepts certificate
value reduction
The certificate is not merely instruction counting. A weak checker would say “I saw eight FMA instructions” and stop there. That is not enough, because eight unrelated FMAs do not prove anything about the value stored to memory. The certificate instead records the dataflow chain: the accumulator starts at zero, each step consumes the previous accumulator, and the final accumulator is the value that feeds the store.
This matters because the compiler emits fused multiply-add instructions. FMA has its own floating point meaning: it computes the multiply and add as one rounded operation rather than as a rounded multiply followed by a rounded add.[4] So the Lean specification for this compiled kernel is not “multiply then add.” It is an explicit FMA chain:
Lean checks that the accepted certificate denotes the schedule explicit FMA value reduction spec. After the checker accepts the certificate, the theorem states that evaluating the extracted FMA chain is equal to the Lean specification valueReduceFMA on the same weights and values. Some generated certificate checks close with by rfl. That does not mean “trust me”; it means the checker is a pure Lean function, and on this concrete certificate it computes to true by definitional reduction. The final theorem packages that result as a batch invariant runtime fragment.
-- The generated certificate states an eight step FMA chain.
-- Lean checks that this chain denotes the intended value reduction.
theorem compiledKernelCert_denotes_valueReduceFMA
(fma : β -> β -> β -> β)
(zero : β)
(inputs : TinyPTXSemantics.ValueInputs β) :
TinyPTXSemantics.evalFMAChain8 compiledKernelCert.dataflow fma zero inputs =
TinyAttentionSpec.valueReduceFMA fma zero inputs.weights inputs.values :=
TinyPTXSemantics.kernelCert_denotes_valueReduceFMA
compiledKernelCert (by rfl) fma zero inputsThe boundary matches TorchLean’s native boundary design: external tools and CUDA paths can produce values, traces, or certificates, while Lean checks the object that is brought back into the proof world.[12] In this CUDA example, Lean checks a finite certificate derived from the compiled files and connects the accepted certificate to the FMA-chain specification. Full PTX/SASS operational semantics and NVIDIA hardware correctness sit one layer lower in the stack.[6][7]
So the bridge is: the compiled CUDA files yield a certificate; the certificate checks that the stored value is the intended FMA chain; the FMA chain is a schedule explicit value reduction; and that reduction is one of the pieces needed for batch invariant attention. The next CUDA step is to make the certificate less tiny: parse more PTX/SASS directly, include memory safety and inactive-thread facts more deeply, and then move from this value reduction kernel toward a tiled attention kernel with online softmax.
The checked stack
The result is a sequence of checked contracts. Deterministic serving is not one giant statement; it is several smaller promises that have to line up.
| Layer | Specification | Checked theorem |
|---|---|---|
| Reduction | SumTree plus evalWith | same selected leaves and same tree imply same result |
| Schedule policy | request local metadata selector | same local context implies same selected schedule |
| Finite precision | TorchLean IEEE32Exec addition | batch dependent schedule selector can break invariance |
| RMSNorm / matmul | schedule explicit reductions | selected row batch invariance |
| Attention | feature tree plus canonical KV tree | selected attention is invariant to batching and chunking |
| Token choice | margin certificate and strict argmax contract | bounded logit drift cannot change the greedy token |
| Serving | decode/verify/rollback trace | committed output is a prefix of the reference decoder |
| TorchLean IR | NN.IR.Graph.denote reference step | serving theorem specializes to TorchLean graph semantics |
| CUDA microkernel | FMA chain certificate | accepted certificate denotes value reduction and gives batch invariance |
What I proved: batch invariant inference is formalized as a property in Lean; schedule explicit reductions, RMSNorm, matmul, attention schedules, greedy token stability, serving refinement, and a proof carrying CUDA value reduction microkernel are connected to that property. The CUDA certificate checks a compiled value reduction fragment; full PTX/SASS execution semantics and hardware correctness are separate refinement layers.
Next
Dynamic batching, CUDA kernels, KV caches, and speculative serving do not disappear in this formalization. They become parts of the system that either refine the reference decoder or carry a certificate explaining why they are allowed to affect the computation.
This is the kind of problem TorchLean should make easier to study. The value is not only proving a fact about an idealized neural network. It gives the pipeline a place to meet: Lean definitions for the model, finite precision semantics for the arithmetic, graph semantics for the compiled form, and checked certificates for the runtime path. If the serving system claims “this token is the one the reference decoder would have produced,” the proof should say exactly what has to be checked for that claim to be true.
I stop here before full PTX/SASS operational semantics or hardware verification. The CUDA part is a proof carrying microkernel certificate, not a proof of NVIDIA hardware. The next concrete target is a larger attention certificate: parse more PTX/SASS dataflow directly, model the online softmax recurrence used by FlashAttention, and replace more of the current runtime assumptions with checkable certificates.
This is the direction I expect useful verification work to move: not one language replacing the whole stack, but formal contracts connecting several ecosystems. Python observes the failure mode. Lean states and checks the reference property. CUDA, Triton, or vendor libraries provide fast paths plus certificates. The theorem is the bridge between them.
AI usage
The code is available at Robertboy18/TorchLean-Verified-Examples.
Some of the proof development, refactoring, debugging, and prose editing for this project were assisted by GPT-5.5 Pro. The correctness claims do not come from the model. They come from the Lean files compiling without sorry, admit, new axioms, or unchecked proof holes, plus the explicit certificate checks described above.