diff --git a/crypto/stark/src/constraint_ir/builder.rs b/crypto/stark/src/constraint_ir/builder.rs new file mode 100644 index 000000000..29328d9b2 --- /dev/null +++ b/crypto/stark/src/constraint_ir/builder.rs @@ -0,0 +1,203 @@ +//! Explicit-builder capture front-end (Plan B). +//! +//! Where the symbolic-field front-end (Plan A) records IR by running a +//! constraint's generic `evaluate` over recording field types, this front-end +//! builds the same [`ConstraintProgram`] through an explicit [`IrBuilder`]: +//! each constraint implements [`Capture`](super::Capture) and translates its +//! `evaluate` body into builder calls (`main`, `add`, `sub`, `mul`, ...). +//! +//! No fake field, no thread-local arena. The builder hash-conses every node on +//! `(Op, Dim)` and only emits leaves for columns the constraint actually reads, +//! so captured programs are minimal. + +use std::collections::HashMap; + +use math::field::element::FieldElement; +use math::field::goldilocks::GoldilocksField; + +use super::ir::{ConstraintProgram, Dim, Op}; + +/// A handle to a node in an [`IrBuilder`]: its arena id and result dimension. +/// +/// `Copy` so constraint bodies read like ordinary field arithmetic. +#[derive(Clone, Copy, Debug)] +pub struct Expr { + id: u32, + dim: Dim, +} + +impl Expr { + /// The node's result dimension. + pub fn dim(self) -> Dim { + self.dim + } +} + +/// Builds a [`ConstraintProgram`] from explicit node-construction calls. +/// +/// Nodes are appended in topological order (id `i` references only `< i`) and +/// hash-consed on `(Op, Dim)`, so structurally identical subexpressions share a +/// single id. Base-field constants are additionally deduplicated by value via +/// `const_cache`. Node id `0` is reserved for `Op::Const1(0)`, matching the +/// interpreter's convention and Plan A's arena. +pub struct IrBuilder { + nodes: Vec, + dims: Vec, + cse: HashMap<(Op, Dim), u32>, + const_cache: HashMap, + roots: Vec, +} + +impl Default for IrBuilder { + fn default() -> Self { + Self::new() + } +} + +impl IrBuilder { + /// Create a builder with the reserved `Op::Const1(0)` node at id 0. + pub fn new() -> Self { + let mut b = IrBuilder { + nodes: Vec::new(), + dims: Vec::new(), + cse: HashMap::new(), + const_cache: HashMap::new(), + roots: Vec::new(), + }; + // Reserve id 0 = Const1(0). `const_base(0)` will hash-cons to this. + let zero = b.push(Op::Const1(0), Dim::D1); + debug_assert_eq!(zero.id, 0); + b.const_cache.insert(0, 0); + b + } + + /// Append (or reuse) a node with the given op and result dimension. + fn push(&mut self, op: Op, dim: Dim) -> Expr { + if let Some(&id) = self.cse.get(&(op, dim)) { + return Expr { id, dim }; + } + let id = self.nodes.len() as u32; + self.nodes.push(op); + self.dims.push(dim); + self.cse.insert((op, dim), id); + Expr { id, dim } + } + + // --------------------------------------------------------------------- + // Leaves + // --------------------------------------------------------------------- + + /// A main-trace column read at the given frame `offset`, row 0. + pub fn main(&mut self, offset: u8, col: usize) -> Expr { + self.push( + Op::Var { + main: true, + offset, + row: 0, + col: col as u16, + }, + Dim::D1, + ) + } + + /// An aux-trace column read at the given frame `offset`, row 0 (`D3`). + pub fn aux(&mut self, offset: u8, col: usize) -> Expr { + self.push( + Op::Var { + main: false, + offset, + row: 0, + col: col as u16, + }, + Dim::D3, + ) + } + + // --------------------------------------------------------------------- + // Constants + // --------------------------------------------------------------------- + + /// A base-field constant from a `u64`, reduced and deduplicated by value. + pub fn const_base(&mut self, v: u64) -> Expr { + let canon = *FieldElement::::from(v).value(); + if let Some(&id) = self.const_cache.get(&canon) { + return Expr { id, dim: Dim::D1 }; + } + let e = self.push(Op::Const1(canon), Dim::D1); + self.const_cache.insert(canon, e.id); + e + } + + /// A base-field constant from an `i64`; negatives map to `p - |v|`. + pub fn const_signed(&mut self, v: i64) -> Expr { + let canon = *FieldElement::::from(v).value(); + if let Some(&id) = self.const_cache.get(&canon) { + return Expr { id, dim: Dim::D1 }; + } + let e = self.push(Op::Const1(canon), Dim::D1); + self.const_cache.insert(canon, e.id); + e + } + + /// The base-field constant `1`. + pub fn one(&mut self) -> Expr { + self.const_base(1) + } + + // --------------------------------------------------------------------- + // Arithmetic + // --------------------------------------------------------------------- + + /// `a + b`. Result is `D1` only if both operands are `D1`. + pub fn add(&mut self, a: Expr, b: Expr) -> Expr { + let dim = Self::join(a.dim, b.dim); + self.push(Op::Add(a.id, b.id), dim) + } + + /// `a - b`. Result is `D1` only if both operands are `D1`. + pub fn sub(&mut self, a: Expr, b: Expr) -> Expr { + let dim = Self::join(a.dim, b.dim); + self.push(Op::Sub(a.id, b.id), dim) + } + + /// `a * b`. Result is `D1` only if both operands are `D1`. + pub fn mul(&mut self, a: Expr, b: Expr) -> Expr { + let dim = Self::join(a.dim, b.dim); + self.push(Op::Mul(a.id, b.id), dim) + } + + /// `-a`. Preserves the operand's dimension. + pub fn neg(&mut self, a: Expr) -> Expr { + self.push(Op::Neg(a.id), a.dim) + } + + /// Typing join: `(D1, D1) -> D1`; any `D3` operand -> `D3`. + fn join(a: Dim, b: Dim) -> Dim { + match (a, b) { + (Dim::D1, Dim::D1) => Dim::D1, + _ => Dim::D3, + } + } + + // --------------------------------------------------------------------- + // Emit / finish + // --------------------------------------------------------------------- + + /// Record `e` as the root for constraint `constraint_idx`. + /// + /// Roots are stored in emit order; the minimal spike emits exactly one root + /// per program, so `constraint_idx` is accepted for parity with the + /// production design but not used to index `roots` here. + pub fn emit(&mut self, _constraint_idx: usize, e: Expr) { + self.roots.push(e.id); + } + + /// Consume the builder and produce the captured program. + pub fn finish(self) -> ConstraintProgram { + ConstraintProgram { + nodes: self.nodes, + dims: self.dims, + roots: self.roots, + } + } +} diff --git a/crypto/stark/src/constraint_ir/interp.rs b/crypto/stark/src/constraint_ir/interp.rs new file mode 100644 index 000000000..62e502594 --- /dev/null +++ b/crypto/stark/src/constraint_ir/interp.rs @@ -0,0 +1,97 @@ +//! CPU interpreter for a captured [`ConstraintProgram`]. +//! +//! A single forward pass over the topologically ordered nodes evaluates each +//! node into a [`Value`] (base `D1` or extension `D3`), reusing the real +//! `FieldElement` arithmetic so per-op results are bit-identical to the boxed +//! constraint path. Mixed-dimension ops auto-embed the `D1` operand into `D3`, +//! mirroring the field tower's `F: IsSubFieldOf` arithmetic. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField as GoldilocksExtension; +use math::field::goldilocks::GoldilocksField; + +use super::ir::{ConstraintProgram, Dim, Op}; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +/// A node's computed value: base field (`D1`) or degree-3 extension (`D3`). +#[derive(Clone, Copy, Debug)] +enum Value { + D1(Fp), + D3(Fp3), +} + +impl Value { + /// Promote to the extension field, embedding a base value if needed. + fn to_ext(self) -> Fp3 { + match self { + Value::D1(x) => x.to_extension::(), + Value::D3(x) => x, + } + } + + fn as_base(self) -> Fp { + match self { + Value::D1(x) => x, + Value::D3(_) => { + panic!("expected a base (D1) value but found an extension (D3) value") + } + } + } +} + +/// Evaluate the program's single root over a base-field main row. +/// +/// `main_row[col]` resolves `Var { main: true, col, .. }` leaves. The minimal +/// algebraic constraint set only reads main columns at offset 0, row 0 and +/// returns a base-field (`D1`) value, so this returns a `FieldElement`. +pub fn eval_program_base(prog: &ConstraintProgram, main_row: &[Fp]) -> Fp { + let mut values: Vec = Vec::with_capacity(prog.nodes.len()); + + for (i, op) in prog.nodes.iter().enumerate() { + let v = match *op { + Op::Const1(c) => Value::D1(Fp::from(c)), + Op::Const3([c0, c1, c2]) => { + Value::D3(Fp3::from_raw([Fp::from(c0), Fp::from(c1), Fp::from(c2)])) + } + Op::Var { main, row, col, .. } => { + assert!(main, "aux leaves are not part of the minimal algebraic set"); + assert_eq!(row, 0, "minimal set reads row 0 only"); + Value::D1(main_row[col as usize]) + } + Op::Add(a, b) => binop(&values, a, b, prog.dims[i], |x, y| x + y, |x, y| x + y), + Op::Sub(a, b) => binop(&values, a, b, prog.dims[i], |x, y| x - y, |x, y| x - y), + Op::Mul(a, b) => binop(&values, a, b, prog.dims[i], |x, y| x * y, |x, y| x * y), + Op::Neg(a) => match (values[a as usize], prog.dims[i]) { + (Value::D1(x), Dim::D1) => Value::D1(-x), + (val, Dim::D3) => Value::D3(-val.to_ext()), + (Value::D3(x), Dim::D1) => Value::D3(-x), // dim mismatch, keep ext + }, + Op::Embed(a) => Value::D3(values[a as usize].to_ext()), + }; + values.push(v); + } + + let root = prog.roots[0]; + values[root as usize].as_base() +} + +/// Apply a binary op, auto-embedding to the extension field when the result +/// dimension is `D3` (or either operand is already `D3`). +#[inline] +fn binop( + values: &[Value], + a: u32, + b: u32, + result_dim: Dim, + base_op: impl Fn(Fp, Fp) -> Fp, + ext_op: impl Fn(Fp3, Fp3) -> Fp3, +) -> Value { + let va = values[a as usize]; + let vb = values[b as usize]; + match (va, vb, result_dim) { + (Value::D1(x), Value::D1(y), Dim::D1) => Value::D1(base_op(x, y)), + _ => Value::D3(ext_op(va.to_ext(), vb.to_ext())), + } +} diff --git a/crypto/stark/src/constraint_ir/ir.rs b/crypto/stark/src/constraint_ir/ir.rs new file mode 100644 index 000000000..8d0a3c449 --- /dev/null +++ b/crypto/stark/src/constraint_ir/ir.rs @@ -0,0 +1,80 @@ +//! Flat intermediate representation (IR) for captured transition constraints. +//! +//! A [`ConstraintProgram`] is a topologically ordered list of [`Op`] nodes plus +//! a per-constraint root id. It is produced by the builder capture front-end +//! (see [`crate::constraint_ir::builder`]) and consumed by the CPU interpreter +//! (see [`crate::constraint_ir::interp`]). +//! +//! The IR is single-field over Goldilocks, with a [`Dim`] tag distinguishing +//! base (`D1`, one `u64`) from the degree-3 extension (`D3`, three `u64`). + +/// Field-arithmetic dimension of a node's value: base Goldilocks (`D1`) or its +/// degree-3 extension (`D3`). +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Default)] +pub enum Dim { + /// Base field (one Goldilocks `u64`). + #[default] + D1, + /// Degree-3 extension (`[u64; 3]`). + D3, +} + +/// One IR instruction. Operand fields are `u32` ids into the program's `nodes` +/// arena; a node with id `i` only references nodes with id `< i`. +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub enum Op { + /// A base-field literal (already reduced mod the Goldilocks prime). + Const1(u64), + /// An extension-field literal `[c0, c1, c2]` (each component reduced). + Const3([u64; 3]), + /// A leaf read of a main-trace cell. `main` is always `true` for the + /// minimal algebraic set captured by the spike; aux reads would set it + /// `false`. `offset`/`row` select the frame step/row, `col` the column. + Var { + /// `true` for a main-trace column read, `false` for an aux read. + main: bool, + /// Frame step index (0-based). + offset: u8, + /// Row within the step. + row: u8, + /// Column index. + col: u16, + }, + /// `nodes[a] + nodes[b]`. + Add(u32, u32), + /// `nodes[a] - nodes[b]`. + Sub(u32, u32), + /// `nodes[a] * nodes[b]`. + Mul(u32, u32), + /// `-nodes[a]`. + Neg(u32), + /// Embed a `D1` value into `D3` (`>::embed`). + Embed(u32), +} + +/// A captured program for one transition constraint (or a set of them). +/// +/// `nodes` is topologically ordered (id `i` references only `< i`). `dims[i]` +/// is the result dimension of `nodes[i]`. `roots[c]` is the node id of +/// constraint `c`'s value. +#[derive(Clone, Debug)] +pub struct ConstraintProgram { + /// Topologically ordered instruction list. + pub nodes: Vec, + /// Per-node result dimension, parallel to `nodes`. + pub dims: Vec, + /// Per-constraint root node ids. + pub roots: Vec, +} + +impl ConstraintProgram { + /// Number of nodes in the program (an effectiveness measure for hash-consing). + pub fn len(&self) -> usize { + self.nodes.len() + } + + /// Whether the program has no nodes. + pub fn is_empty(&self) -> bool { + self.nodes.is_empty() + } +} diff --git a/crypto/stark/src/constraint_ir/mod.rs b/crypto/stark/src/constraint_ir/mod.rs new file mode 100644 index 000000000..a515ff177 --- /dev/null +++ b/crypto/stark/src/constraint_ir/mod.rs @@ -0,0 +1,39 @@ +//! Explicit-builder constraint capture spike (Plan B). +//! +//! Proof-of-concept that lambda_vm's algebraic transition constraints can be +//! captured into a flat, single-field Goldilocks IR via an explicit +//! [`IrBuilder`] (rather than the recording "symbolic field" of Plan A), and +//! that interpreting that IR on the CPU reproduces the constraint's real +//! `evaluate` bit-for-bit. +//! +//! Both plans produce the SAME IR and use the SAME interpreter; they differ +//! only in the capture front-end. Here each constraint implements [`Capture`] +//! and translates its `evaluate` body into builder calls. This is CPU-only and +//! does not touch the prover hot loop, the LogUp framework, or GPU code. +//! +//! - [`ir`]: the IR data structures ([`ConstraintProgram`], [`Op`], [`Dim`]). +//! - [`builder`]: the [`IrBuilder`] and [`Expr`] capture API. +//! - [`interp`]: a CPU forward-pass interpreter over the IR. +//! +//! [`ConstraintProgram`]: ir::ConstraintProgram +//! [`Op`]: ir::Op +//! [`Dim`]: ir::Dim + +pub mod builder; +pub mod interp; +pub mod ir; + +pub use builder::{Expr, IrBuilder}; +pub use interp::eval_program_base; +pub use ir::{ConstraintProgram, Dim, Op}; + +/// A transition constraint that can record its algebra into an [`IrBuilder`]. +/// +/// Object-safe: `capture` is non-generic (it takes `&mut IrBuilder`), so a +/// constraint can be captured behind a `&dyn Capture`, mirroring the production +/// design where the capture method is not generic over the field tower. +pub trait Capture { + /// Translate this constraint's algebra into builder nodes, finishing with a + /// single `b.emit(constraint_idx, root)` call. + fn capture(&self, b: &mut IrBuilder); +} diff --git a/crypto/stark/src/lib.rs b/crypto/stark/src/lib.rs index e9f6a1cda..5ec372c23 100644 --- a/crypto/stark/src/lib.rs +++ b/crypto/stark/src/lib.rs @@ -5,6 +5,7 @@ compile_error!("the `disk-spill` feature requires memmap2, which does not compil #[cfg(feature = "debug-checks")] pub mod bus_debug; +pub mod constraint_ir; pub mod constraints; pub mod context; pub mod debug; diff --git a/prover/src/constraints/cpu.rs b/prover/src/constraints/cpu.rs index facc9e16d..1c811471b 100644 --- a/prover/src/constraints/cpu.rs +++ b/prover/src/constraints/cpu.rs @@ -17,6 +17,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, IrBuilder}; use stark::constraints::transition::{TransitionConstraint, TransitionConstraintEvaluator}; use stark::table::TableView; @@ -112,6 +113,16 @@ impl TransitionConstraint for ProductZeroC } } +impl Capture for ProductZeroConstraint { + fn capture(&self, b: &mut IrBuilder) { + // col_a * col_b + let a = b.main(0, self.col_a); + let b_col = b.main(0, self.col_b); + let root = b.mul(a, b_col); + b.emit(self.constraint_idx, root); + } +} + /// `(1 - MEMORY - BRANCH) · read_register2 · imm[i] = 0`: when neither MEMORY nor /// BRANCH is set, the `arg2` multiplex needs at most one of `rv2`/`imm` nonzero. /// Decoding already guarantees this; a spec defense-in-depth assumption. diff --git a/prover/src/constraints/templates.rs b/prover/src/constraints/templates.rs index ef5b6c036..daf25ae6d 100644 --- a/prover/src/constraints/templates.rs +++ b/prover/src/constraints/templates.rs @@ -13,6 +13,7 @@ use math::field::element::FieldElement; use math::field::traits::{IsField, IsSubFieldOf}; +use stark::constraint_ir::{Capture, Expr, IrBuilder}; use stark::{constraints::transition::TransitionConstraint, table::TableView}; use crate::tables::types::{GoldilocksExtension, GoldilocksField}; @@ -107,6 +108,28 @@ impl TransitionConstraint for IsBitConstra } } +impl Capture for IsBitConstraint { + fn capture(&self, b: &mut IrBuilder) { + // Mirrors `evaluate`: x = main(value_col), one - x, then the product. + let x = b.main(0, self.value_col); + let one = b.one(); + let one_minus_x = b.sub(one, x); + + let root = match self.cond_col { + Some(cond_col) => { + // cond * x * (1 - x), left-associated like `&cond * &x * (one - x)`. + let cond = b.main(0, cond_col); + let cond_x = b.mul(cond, x); + b.mul(cond_x, one_minus_x) + } + // x * (1 - x) + None => b.mul(x, one_minus_x), + }; + + b.emit(self.constraint_idx, root); + } +} + // ========================================================================= // ADD Template (Embedded Carry Approach) // ========================================================================= @@ -177,6 +200,22 @@ impl AddLinearTerm { AddLinearTerm::Constant(value) => FieldElement::::from(*value), } } + + /// Capture this term into builder nodes, mirroring [`Self::eval`]. + fn capture(&self, b: &mut IrBuilder) -> Expr { + match self { + AddLinearTerm::Column { + coefficient, + column, + } => { + // `col_val * FieldElement::from(coeff)`: column on the left. + let col = b.main(0, *column); + let coeff = b.const_signed(*coefficient); + b.mul(col, coeff) + } + AddLinearTerm::Constant(value) => b.const_signed(*value), + } + } } /// Evaluate a slice of terms as a sum. @@ -195,6 +234,24 @@ where } } +/// Capture a slice of terms as a sum, mirroring [`eval_terms`]. +/// +/// Empty -> `0`; otherwise `0 + t0 + t1 + ...` (same fold seed and order as +/// `eval_terms`, so the captured node tree matches bit-for-bit). +fn capture_terms(terms: &[AddLinearTerm], b: &mut IrBuilder) -> Expr { + let zero = b.const_base(0); + if terms.is_empty() { + zero + } else { + let mut acc = zero; + for t in terms { + let term = t.capture(b); + acc = b.add(acc, term); + } + acc + } +} + impl AddOperand { /// Get the low word value from the trace. pub fn eval_lo(&self, step: &TableView) -> FieldElement @@ -224,6 +281,22 @@ impl AddOperand { } } + /// Capture the low word, mirroring [`Self::eval_lo`]. + pub fn capture_lo(&self, b: &mut IrBuilder) -> Expr { + match self { + AddOperand::DWordWL { start_column } => b.main(0, *start_column), + AddOperand::Linear { lo, .. } => capture_terms(lo, b), + } + } + + /// Capture the high word, mirroring [`Self::eval_hi`]. + pub fn capture_hi(&self, b: &mut IrBuilder) -> Expr { + match self { + AddOperand::DWordWL { start_column } => b.main(0, *start_column + 1), + AddOperand::Linear { hi, .. } => capture_terms(hi, b), + } + } + // ------------------------------------------------------------------------- // Convenience constructors for common cast types // ------------------------------------------------------------------------- @@ -485,6 +558,68 @@ impl TransitionConstraint for AddConstrain } } +impl AddConstraint { + /// Capture carry_0, mirroring [`Self::compute_carry_0`]. + fn capture_carry_0(&self, b: &mut IrBuilder) -> Expr { + let lhs_lo = self.lhs.capture_lo(b); + let rhs_lo = self.rhs.capture_lo(b); + let sum_lo = self.sum.capture_lo(b); + let inv = b.const_base(INV_SHIFT_32); + + // ((lhs_lo + rhs_lo) - sum_lo) * inv_2_32 + let s = b.add(lhs_lo, rhs_lo); + let s = b.sub(s, sum_lo); + b.mul(s, inv) + } + + /// Capture carry_1, mirroring [`Self::compute_carry_1`]. + fn capture_carry_1(&self, b: &mut IrBuilder) -> Expr { + let lhs_hi = self.lhs.capture_hi(b); + let rhs_hi = self.rhs.capture_hi(b); + let sum_hi = self.sum.capture_hi(b); + let carry_0 = self.capture_carry_0(b); + let inv = b.const_base(INV_SHIFT_32); + + // (((lhs_hi + rhs_hi) + carry_0) - sum_hi) * inv_2_32 + let s = b.add(lhs_hi, rhs_hi); + let s = b.add(s, carry_0); + let s = b.sub(s, sum_hi); + b.mul(s, inv) + } +} + +impl Capture for AddConstraint { + fn capture(&self, b: &mut IrBuilder) { + let one = b.one(); + + let carry = match self.carry_idx { + 0 => self.capture_carry_0(b), + 1 => self.capture_carry_1(b), + _ => unreachable!("carry_idx validated <= 1 at construction"), + }; + + let root = if self.cond_cols.is_empty() { + // Unconditional: carry * (1 - carry) + let one_minus_carry = b.sub(one, carry); + b.mul(carry, one_minus_carry) + } else { + // Conditional: cond * carry * (1 - carry), left-associated like + // `cond * &carry * (one - carry)`. + // cond = fold over cond_cols starting from zero: 0 + col0 + col1 + ... + let mut cond = b.const_base(0); + for &col in &self.cond_cols { + let c = b.main(0, col); + cond = b.add(cond, c); + } + let one_minus_carry = b.sub(one, carry); + let cond_carry = b.mul(cond, carry); + b.mul(cond_carry, one_minus_carry) + }; + + b.emit(self.constraint_idx, root); + } +} + // ========================================================================= // Helper Functions // ========================================================================= diff --git a/prover/src/tests/constraint_ir_tests.rs b/prover/src/tests/constraint_ir_tests.rs new file mode 100644 index 000000000..86bf51d81 --- /dev/null +++ b/prover/src/tests/constraint_ir_tests.rs @@ -0,0 +1,113 @@ +//! Differential tests for the explicit-builder constraint capture spike (Plan B). +//! +//! For each algebraic transition constraint, capture it into a flat IR via its +//! `Capture::capture` method (an explicit `IrBuilder`), then assert that +//! interpreting the IR reproduces the constraint's real +//! `evaluate::` bit-for-bit over many +//! random main rows. + +use crate::constraints::cpu::ProductZeroConstraint; +use crate::constraints::templates::{AddConstraint, AddOperand, IsBitConstraint}; +use crate::tables::types::{FE, GoldilocksExtension, GoldilocksField}; + +use math::field::element::FieldElement; +use stark::constraint_ir::{Capture, IrBuilder, eval_program_base}; +use stark::constraints::transition::TransitionConstraint; +use stark::table::TableView; + +/// Number of random trials per constraint. +const TRIALS: usize = 1000; + +/// Column count for the random frame; larger than any column index read by the +/// constraints under test (CPU columns go up to 37). +const NUM_COLS: usize = 64; + +/// A tiny deterministic SplitMix64 PRNG so the test needs no `rand` dependency +/// and is fully reproducible. +struct SplitMix64 { + state: u64, +} + +impl SplitMix64 { + fn new(seed: u64) -> Self { + Self { state: seed } + } + + fn next_u64(&mut self) -> u64 { + self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15); + let mut z = self.state; + z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9); + z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB); + z ^ (z >> 31) + } +} + +/// Run the differential check: capture `c` via the builder, then for `TRIALS` +/// random rows compare the real `evaluate` against the IR interpreter, +/// bit-for-bit. +fn assert_ir_matches_evaluate(c: &T, label: &str) +where + T: TransitionConstraint + Capture, +{ + let mut b = IrBuilder::new(); + c.capture(&mut b); + let prog = b.finish(); + eprintln!("[{label}] captured {} IR nodes", prog.len()); + + let mut rng = SplitMix64::new(0xDEAD_BEEF_CAFE_F00D ^ (label.len() as u64)); + + for trial in 0..TRIALS { + // Build a random main row. + let row: Vec = (0..NUM_COLS).map(|_| FE::from(rng.next_u64())).collect(); + + // Real evaluate: wrap the row in a base/ext TableView (1 row, no aux). + let real_step: TableView = + TableView::new(vec![row.clone()], vec![Vec::new()]); + let real: FieldElement = + c.evaluate::(&real_step); + + // IR interpreter over the same row. + let got = eval_program_base(&prog, &row); + + assert_eq!( + real, got, + "[{label}] mismatch at trial {trial}: real={real:?} got={got:?}" + ); + } +} + +#[test] +fn test_ir_matches_is_bit_unconditional() { + // X * (1 - X), X at column 7. + let c = IsBitConstraint::unconditional(7, 0); + assert_ir_matches_evaluate(&c, "is_bit_unconditional"); +} + +#[test] +fn test_ir_matches_is_bit_conditional() { + // cond * X * (1 - X), cond at column 3, X at column 5. + let c = IsBitConstraint::new(3, 5, 0); + assert_ir_matches_evaluate(&c, "is_bit_conditional"); +} + +#[test] +fn test_ir_matches_add_constraint_carries() { + // 64-bit ADD with embedded carries, DWordWL operands. + // cond at col 0; lhs=[1,2], rhs=[3,4], sum=[5,6]. + let (carry0, carry1) = AddConstraint::new_pair( + vec![0], + AddOperand::dword(1), + AddOperand::dword(3), + AddOperand::dword(5), + 0, + ); + assert_ir_matches_evaluate(&carry0, "add_carry_0"); + assert_ir_matches_evaluate(&carry1, "add_carry_1"); +} + +#[test] +fn test_ir_matches_product_zero() { + // col_a * col_b, columns 12 and 17. + let c = ProductZeroConstraint::new(12, 17, 0); + assert_ir_matches_evaluate(&c, "product_zero"); +} diff --git a/prover/src/tests/mod.rs b/prover/src/tests/mod.rs index 4d0ac4477..d7c5824ea 100644 --- a/prover/src/tests/mod.rs +++ b/prover/src/tests/mod.rs @@ -15,6 +15,8 @@ pub mod commit_tests; #[cfg(test)] pub mod compute_commit_bus_offset_tests; #[cfg(test)] +pub mod constraint_ir_tests; +#[cfg(test)] pub mod constraints_tests; #[cfg(all(test, feature = "disk-spill"))] pub mod count_table_lengths_drift_tests; diff --git a/thoughts/gpu-constraint-eval/README.md b/thoughts/gpu-constraint-eval/README.md new file mode 100644 index 000000000..bfefac626 --- /dev/null +++ b/thoughts/gpu-constraint-eval/README.md @@ -0,0 +1,96 @@ +# GPU-ready constraint evaluation — design + +Design docs for moving STARK transition-constraint evaluation onto the GPU. + +## Why (the real motivation) + +Not to make constraint evaluation faster — it isn't the CPU bottleneck. The goal is +**data residency**: keep the whole prove pipeline on-device so we never round-trip the +LDE trace across the PCIe bus. + +``` +LDE (GPU) → constraint eval / composition poly → Merkle commit (GPU) → FRI (GPU) +``` + +The LDE trace (main + aux columns × blowup factor) is the largest array in the pipeline. +If constraint eval stays on the CPU, every proof must D2H-copy the entire LDE and push +results back — that transfer dominates. `gpu_lde.rs` already keeps columns resident +(`GpuLdeBase`/`GpuLdeExt3` keep-handles); on-GPU constraint eval consumes them in place. + +Consequence: the GPU kernel must also do the accumulation (`Σ αⁱ·Cᵢ·Zᵢ⁻¹`, Horner form ++ ÷Z) so it emits **composition-polynomial evaluations on-device**, not a raw `Cᵢ` +matrix (which would itself be a large D2H copy). + +## The blocker + +The prover is Rust; you can't run arbitrary Rust on a GPU. Constraints today are +`Vec>` evaluated via a generic +`evaluate` — two layers of dynamic dispatch, scalar, CPU-only. The logic has to +be re-expressed in a GPU-executable form. + +## The decided architecture + +**Capture each table's constraints once into a flat, single-field Goldilocks IR (a +typed `Dim1`/`Dim3` op-DAG), then interpret that IR** — on CPU (verifier/optional +prover) and on GPU (one universal Goldilocks kernel). Single source of truth → CPU and +GPU can't diverge. Not codegen, not a DSL, not hand-written per-table kernels. + +- Field: Goldilocks base (`Dim1`, `u64`) + degree-3 extension (`Dim3`, `[u64;3]`). +- IR ops: `Add/Sub/Mul/Neg` + leaves (`Main/Aux{offset,col}`, `Const`, `Periodic`, + `RapChallenge`, `AlphaPow`, `TableOffset`, `Shift`). +- Boundary: the zerofier/coefficient machinery stays in + `ConstraintEvaluator::evaluate_transitions`; the IR replaces only the per-row, + per-constraint step that produces each `Cᵢ` (on GPU, fused with the accumulation). + +This is the same family SP1, OpenVM, and zisk converged on. zisk is the closest match +(Goldilocks, FRI-STARK LDE quotient). + +## The two plans (the only open decision) + +Both produce the **same IR** and feed the **same interpreter + GPU kernel + validation**. +They differ *only* in how the IR is captured. + +| | [Plan A — symbolic field](./plan-symbolic-field.md) | [Plan B — builder rewrite](./plan-builder-rewrite.md) | +|---|---|---| +| Constraint edits | ~0 (record existing `evaluate` via a `SymField`) | ~600–800 LOC across 33 structs rewritten to `capture()` | +| Feasibility | HIGH — `SymField` needs only `IsField`+`IsSubFieldOf` (capture never builds an `AIR`); unreachable methods stubbed | No doubt; just labor | +| Risk shape | Concentrated in `SymField` — spike-able in 1–2 days | Spread across 33 transcriptions (Dvrm 11 / Cpu32 8 / Shift 7 kinds) | +| CPU path | can stay unchanged (IR GPU-only) | forced onto the interpreter (old `evaluate` deleted) | +| End state | recording field + arena + stubs retained | cleanest; generic `evaluate`+adapter deleted; ecosystem-idiomatic | +| Effort (CPU validated) | ~10–14 d | ~12–18 d | +| Effort (GPU) | ~6–10 d | ~5–7 d (identical, shared) | + +Both lose the per-row LogUp zero-skip (value-identical; recover via a static +const-fold peephole). Neither AVX nor monomorphization differentiates them (AVX lives +in the shared interpreter; monomorphization is a third thing neither plan does). + +## Reference implementation + +`others/openvm-stark-backend` (cloned `openvm-org/stark-backend@v1.4.0`) is a working +implementation of this exact approach for a FRI-STARK LDE quotient. Key files: + +- `crates/cuda-backend/src/transpiler/mod.rs` — lowers the symbolic DAG to three-address + code + liveness/linear-scan register allocation (the **IR processing** — the most + portable, field-agnostic piece; ~200 lines; the reg-alloc is optional for v1). +- `crates/cuda-backend/src/transpiler/codec.rs` + `cuda/include/codec.cuh` — encode rules + to a 128-bit packed word. +- `crates/cuda-backend/cuda/src/quotient.cu` — the interpreter kernel: per-row loop over + rules, fused Horner accumulation + ÷Z, per-thread intermediate buffer (local for small + programs, global spill for large — solves GPU scratch pressure). + +BabyBear→Goldilocks deltas to be aware of: the codec packs constants in 32 bits (fits +BabyBear's 31-bit modulus, **not** Goldilocks' 64 bits) → needs a side constant table; +extension is BabyBear's vs our degree-3; OpenVM evaluates everything in `FpExt` (no +base/ext split). It's a blueprint to port, not a crate to depend on (it's tied to +OpenVM's symbolic-DAG type, `PrimeField32` bound, and trace/bus conventions). + +SP1's `sp1-gpu` is the same pattern via an SSA register-machine bytecode (~60+ opcodes, +operand types in the opcode); OpenVM puts operand types in the source tag (~6 ops) — +the latter is the better template for single-field Goldilocks (even fewer source types). + +## Recommendation / next step + +Spike **Plan A** first (1–2 days): implement `SymField`, capture the CPU table, diff the +interpreted IR bit-for-bit against the current evaluator, and dump per-table node counts. +The IR/interpreter/GPU kernel are shared, so switching to Plan B later costs almost +nothing. If `SymField` fights the trait tower, fall back to the Plan B rewrite. diff --git a/thoughts/gpu-constraint-eval/plan-builder-rewrite.md b/thoughts/gpu-constraint-eval/plan-builder-rewrite.md new file mode 100644 index 000000000..e6e3d2bd2 --- /dev/null +++ b/thoughts/gpu-constraint-eval/plan-builder-rewrite.md @@ -0,0 +1,940 @@ +# Plan: Capture-Method Rewrite ("Change the constraints") for GPU-ready STARK constraint evaluation + +> Approach: rewrite each transition constraint so it *emits* its polynomial into a +> builder/capture abstraction once at setup, producing a flat single-field +> Goldilocks IR, then interpret that IR on CPU (prover over the LDE coset; verifier +> at the OOD point) and later on GPU. This is the head-to-head sibling of the +> "wrap the field type / shadow `IsField`" approach. + +All file/line references below were read and verified against the working tree +(branch `main`) unless explicitly marked `? INFERRED` or `✗ UNCERTAIN`. + +--- + +## 1. Overview & end-state + +After this change, every table's transition constraints are *captured once* into a +flat per-table IR program (`TableProgram`) at AIR-construction time. The +per-row/per-OOD hot path no longer dispatches through +`Vec>` calling a generic `evaluate`; +instead `air.compute_transition_prover` (prover) and `air.compute_transition` +(verifier) call a single **interpreter** that walks the IR against the current +`Frame`/`TableView`, writing each constraint's scalar `Cᵢ` into the existing +`base_evals`/`ext_evals` buffers. The accumulation `Σ βᵢ·Cᵢ·Zᵢ⁻¹ + boundary`, +zerofiers, and `ZerofierEvaluations` machinery in +`ConstraintEvaluator::evaluate_transitions` are **untouched** — the IR/interpreter +only replaces the step that produces each `Cᵢ`. Because the IR is a flat array of +Goldilocks-typed ops, the same bytes feed a single universal Goldilocks +interpreter CUDA kernel, dispatched through the existing TypeId+transmute seam used +by `gpu_lde.rs`. + +``` + ┌─────────────── setup (once per AIR) ──────────────┐ + constraint structs ──► capture(&mut IrBuilder) ──► TableProgram (flat IR) │ + (IsBit, Add, Mul…) (column reads/+/-/* → IR nodes) { ops, consts, │ + emits, n_dim1 } │ + └───────────────────────────────────────────────────┘ + │ stored in the AIR + ┌───────────────────────────┼───────────────────────────────┐ + ▼ prover, per LDE row ▼ verifier, at OOD z ▼ GPU + interpret(program, frame_prover) interpret(program, frame_verifier) cuda kernel(program, lde) + → base_evals[], ext_evals[] → ext_evals[] → Cᵢ per row, per table + │ │ │ + └────── feeds unchanged ───────────┴── ConstraintEvaluator ─────┘ + Σ βᵢ·Cᵢ·Zᵢ⁻¹ + boundary (evaluator.rs:102-134, untouched) +``` + +--- + +## 2. The IR (concrete Rust data structures) + +The IR is **single-field over Goldilocks** with explicit base (`dim1`) vs cubic-ext +(`dim3`) typing on each node, so the interpreter knows the storage width and which +arithmetic routine to use. It lives in a new module `crypto/stark/src/ir.rs`. + +### 2.1 Node typing + +`Dim` records the field width of a value. Goldilocks base = `Dim1` (`[u64;1]`), +`Degree3GoldilocksExtensionField` = `Dim3` (`[u64;3]`). (Verified: `IsField for +Degree3GoldilocksExtensionField { type BaseType = [FpE;3] }`, +`extensions_goldilocks.rs:277`; base = `repr(transparent)` `u64`.) + +```rust +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum Dim { D1, D3 } + +/// Index into the program's node arena. Nodes are in topological (emission) order, +/// so an interpreter can evaluate left-to-right into a value stack/arena. +pub type NodeId = u32; +``` + +### 2.2 Op / node enum + +```rust +#[derive(Clone, Copy, Debug)] +pub enum Op { + // ---- leaves (inputs) ---- + /// main trace column read: step offset (0 = this row, 1 = next row), column idx. + Main { offset: u8, col: u16 }, // always Dim1 on prover; Dim3 on verifier (see §6) + /// aux trace column read. + Aux { offset: u8, col: u16 }, // Dim3 (aux is extension-valued) + /// constant base-field element, index into `consts_d1`. + ConstD1 { k: u32 }, // Dim1 + /// constant ext element, index into `consts_d3`. + ConstD3 { k: u32 }, // Dim3 (rarely needed for algebraic; see §5) + /// periodic column j at this row (uniform per-row input). + Periodic { j: u16 }, // Dim1 + /// LogUp challenge i (z=0, alpha=1 by convention), uniform per-proof. + Challenge { i: u16 }, // Dim3 + /// alpha power k (precomputed Σ over the proof), uniform per-proof. + AlphaPow { k: u16 }, // Dim3 + /// logup_table_offset uniform per-proof. + TableOffset, // Dim3 + /// packing shift constant (8,16,24) — small base consts, can also be ConstD1. + // (shifts are just ConstD1 entries; no dedicated op needed.) + + // ---- arithmetic (operands are NodeIds already emitted) ---- + Add { a: NodeId, b: NodeId }, + Sub { a: NodeId, b: NodeId }, + Mul { a: NodeId, b: NodeId }, + Neg { a: NodeId }, +} + +#[derive(Clone, Copy, Debug)] +pub struct Node { pub op: Op, pub dim: Dim } +``` + +The interpreter's typing rule for `Add/Sub/Mul`: `dim = max(dim(a), dim(b))` +(D3 > D1). A `Mul` of `D1×D3` is the cheap subfield mul (componentwise scalar, +`GoldilocksField: IsSubFieldOf::mul` = `[a*b0, a*b1, a*b2]`, verified +`extensions_goldilocks.rs:413`); `D3×D3` is the full cubic mul +`c0=a0b0+2(a1b2+a2b1)`, … (verified `extensions_goldilocks.rs:298-306`); `D1×D1` +is a plain `GoldilocksField::mul`. This single rule subsumes every "mixed base×ext" +case the current code handles via `IsSubFieldOf`. + +### 2.3 Per-table program + +```rust +pub struct TableProgram { + pub nodes: Vec, // topological arena + pub consts_d1: Vec, // deduplicated base constants + pub consts_d3: Vec<[u64; 3]>, // deduplicated ext constants (usually empty) + /// emit[c] = NodeId of constraint c's root value. Length = num_transition_constraints. + pub emits: Vec, + /// First `num_base` constraints are D1 (base-field), matching + /// `num_base_transition_constraints()`. Used to split base/ext eval buffers. + pub num_base: usize, + /// metadata for input plumbing / GPU upload + pub num_main_cols: u16, + pub num_aux_cols: u16, + pub max_offset: u8, // 1 (next-row) for LogUp accumulated; else 0 +} +``` + +### 2.4 Serialization for GPU + +`nodes` is `Vec`; `Node`/`Op` are `Copy` plain-old-data. For GPU we lower +`Op` to a fixed-width tagged record `struct GpuOp { tag: u32, dim: u32, a: u32, b: u32 }` +(16 bytes) — leaves pack their immediates into `a`/`b` (e.g. `Main`: a=offset, +b=col). `consts_d1`/`consts_d3` upload as `&[u64]`. This is a flat `Vec`/`Vec` +blob: H2D once per table at setup, reused for every LDE row (the kernel runs the +program per row). No per-row host work crosses the boundary — only the device- +resident LDE columns (already kept on device by R1, see `gpu_lde.rs` `gpu_main()`/ +`gpu_aux()` handles) plus the uniforms (challenges, alpha powers, periodic, table +offset). `Op`'s representation is internal; we do **not** need `serde` on it unless +we want to cache programs to disk (out of scope). + +--- + +## 3. Capture front-end — builder/capture API & object-safety (distinguishing section) + +### 3.1 Object-safety decision (Question 1) — **RECOMMENDATION: non-generic `capture(&self, &mut IrBuilder)`** + +The constraints are stored heterogeneously as +`Vec>>` +(verified `traits.rs:316`, `lookup.rs:813`). A method generic over a builder type +`fn capture(&self, &mut AB)` is **NOT object-safe** (generic methods +can't go through a vtable), so it could not be called on `Box`. Two ways out: + +- **(a) non-generic `capture(&self, builder: &mut IrBuilder)` with a CONCRETE + builder.** Object-safe. Runs **once at setup** (not in the hot loop), so the + concrete builder costs nothing at steady state. The builder is a struct, not a + trait. This is the minimal, lowest-risk change: the existing + `TransitionConstraint` trait gains one object-safe method. +- (b) builder-generic `eval` (Plonky3/SP1 `AirBuilder` style). To call it + through a boxed trait object you must either (i) monomorphize per concrete AIR + (de-box: store constraints as a concrete `enum`/typed vec per table, a much + bigger refactor touching every table's assembly fn and `AirWithBuses`), or + (ii) add a non-generic shim per constraint anyway (which is just (a) again). + +**Recommendation: (a).** Reasoning: +1. It is object-safe, so it drops straight into the existing + `Box` storage with zero changes to how tables assemble constraints + (`create_all_cpu_constraints`, `mul_constraints`, `dvrm_constraints`, …). +2. Capture is a one-time setup cost; there is no monomorphization win to be had at + runtime because the runtime work is the interpreter, not the constraint body. +3. The interpreter is the single hot path; we want exactly one concrete builder so + the IR is canonical and identical for CPU and GPU. A generic builder would let + callers instantiate it with an "eval-directly" builder, re-introducing the very + `IsField` trait-tower fight this approach exists to avoid. + +The one real cost of (a): the builder is monomorphic on Goldilocks, so a +constraint can't be captured for a non-Goldilocks field. That is exactly the +project's constraint (base = Goldilocks, ext = degree-3 Goldilocks), so it's a +non-issue here. The generic `evaluate` is retained transitionally for +migration/validation (see §6, §9) and deleted at the end. + +### 3.2 The `IrBuilder` surface (Question 2) + +```rust +pub struct IrBuilder { + nodes: Vec, + consts_d1: Vec, + consts_d3: Vec<[u64; 3]>, + emits: Vec, // indexed by constraint_idx + const_d1_cache: HashMap, // dedupe constants + const_d3_cache: HashMap<[u64;3], u32>, + num_main_cols: u16, + num_aux_cols: u16, + max_offset: u8, + // CSE cache: (Op canonicalized) -> NodeId, to coalesce repeated subexpressions + // (e.g. `one`, `1 - x`, shift consts). Optional but cheap and shrinks the IR a lot. + cse: HashMap, +} + +/// Typed handle so `+ - *` compose with compile-time dim tracking and a tiny op set. +#[derive(Clone, Copy)] +pub struct Expr { id: NodeId, dim: Dim } + +impl IrBuilder { + // ---- leaves ---- + pub fn main(&mut self, offset: u8, col: usize) -> Expr; // Dim1 + pub fn aux(&mut self, offset: u8, col: usize) -> Expr; // Dim3 + pub fn const_base(&mut self, v: u64) -> Expr; // Dim1 (dedup) + pub fn const_signed(&mut self, v: i64) -> Expr; // Dim1, maps i64→field + pub fn const_ext(&mut self, v: [u64;3]) -> Expr; // Dim3 (dedup) + pub fn one(&mut self) -> Expr; // = const_base(1) + pub fn periodic(&mut self, j: usize) -> Expr; // Dim1 + pub fn challenge(&mut self, i: usize) -> Expr; // Dim3 + pub fn alpha_power(&mut self, k: usize) -> Expr; // Dim3 + pub fn table_offset(&mut self) -> Expr; // Dim3 + pub fn bus_id(&mut self, id: u64) -> Expr; // = const_base(id) (α⁰ term) + + // ---- arithmetic (auto dim = max) ---- + pub fn add(&mut self, a: Expr, b: Expr) -> Expr; + pub fn sub(&mut self, a: Expr, b: Expr) -> Expr; + pub fn mul(&mut self, a: Expr, b: Expr) -> Expr; + pub fn neg(&mut self, a: Expr) -> Expr; + + // ---- output ---- + /// Record that constraint `constraint_idx` evaluates to `e`. + pub fn emit(&mut self, constraint_idx: usize, e: Expr); + + pub fn finish(self) -> TableProgram; +} +``` + +Notes on the surface vs the prompt's sketch: +- No `table_offset()` for periodic *exemption offsets* — those stay in the + zerofier machinery (`transition.rs:160`), which is outside the boundary. +- `Expr` carries `dim`, so `mul(d1, d3)` is legal and lowers to the cheap subfield + mul; `Expr` makes the constraint bodies read almost identically to today. +- CSE + constant dedup are pure size optimizations; correctness doesn't depend on + them. (`one`, `shift_16`, `INV_SHIFT_32` recur across most bodies.) + +### 3.3 Trait change + +`TransitionConstraint` (`transition.rs:332`) gains: + +```rust +/// Emit this constraint's polynomial into the builder. Called once at setup. +/// `builder.emit(self.constraint_idx(), root)` records the result. +fn capture(&self, builder: &mut IrBuilder); +``` + +`TransitionConstraintEvaluator` (`transition.rs:10`, object-safe) gains a forwarding +non-generic method: + +```rust +fn capture(&self, builder: &mut IrBuilder); +``` + +The adapter `TransitionConstraintAdapter` (`transition.rs:395`) forwards +`capture` to `self.0.capture(builder)`. During migration the adapter keeps its +`evaluate_verifier`/`evaluate_prover` too (used by the parallel old path for +bit-for-bit validation, §12). + +--- + +## 4. Rewriting the algebraic constraints (Question 3 + full scope) + +### 4.1 Before/after: `IsBitConstraint` (`templates.rs:80-108`) + +**Before** (`evaluate`): +```rust +let x = step.get_main_evaluation_element(0, self.value_col).clone(); +let one = FieldElement::::one(); +match self.cond_col { + Some(cond_col) => { let cond = step.get_main_evaluation_element(0, cond_col).clone(); + &cond * &x * (one - x) } + None => &x * (one - &x), +} +``` +**After** (`capture`): +```rust +fn capture(&self, b: &mut IrBuilder) { + let x = b.main(0, self.value_col); + let one = b.one(); + let omx = b.sub(one, x); + let root = match self.cond_col { + Some(c) => { let cond = b.main(0, c); let xm = b.mul(x, omx); b.mul(cond, xm) } + None => b.mul(x, omx), + }; + b.emit(self.constraint_idx, root); +} +``` + +### 4.2 Before/after: `AddConstraint` — `AddOperand`/`AddLinearTerm` mapping (`templates.rs:359-486`) + +The lo/hi-limb abstraction with i64 coefficients maps cleanly. `AddLinearTerm::eval` +(`templates.rs:164`) becomes `capture`: +```rust +impl AddLinearTerm { + fn capture(&self, b: &mut IrBuilder) -> Expr { + match self { + AddLinearTerm::Column { coefficient, column } => { + let col = b.main(0, *column); + let k = b.const_signed(*coefficient); // i64 → field, was FieldElement::from(*coefficient) + b.mul(col, k) + } + AddLinearTerm::Constant(v) => b.const_signed(*v), + } + } +} +fn eval_terms_capture(terms: &[AddLinearTerm], b: &mut IrBuilder) -> Expr { + // empty → zero + let mut acc = b.const_base(0); + for t in terms { let e = t.capture(b); acc = b.add(acc, e); } + acc +} +``` +`AddOperand::eval_lo/eval_hi` → `capture_lo/capture_hi` (DWordWL reads +`main(0,start)` / `main(0,start+1)`; Linear → `eval_terms_capture`). Then +`compute_carry_0` (`templates.rs:414`): +```rust +// carry_0 = (lhs_lo + rhs_lo - sum_lo) * 2^(-32) +let inv = b.const_base(INV_SHIFT_32); // templates.rs:30, precomputed 2^-32 +let s = b.sub(b.add(lhs_lo, rhs_lo), sum_lo); +let c0 = b.mul(s, inv); +``` +`compute_carry_1` adds `carry_0` then the same `*inv`. `compute` then folds the cond +columns (`fold(zero, +)` → chain of `add`) and emits +`cond * carry * (one - carry)` (or unconditional). **The i64 coefficients are the +only subtlety** and they vanish because `const_signed(i64)` reproduces +`FieldElement::::from(i64)` exactly (the field's `From` already canonicalizes +negatives mod p). The lo/hi limb logic is pure compile-time structure; the captured +IR is a flat add/mul chain identical in value to the current `evaluate`. + +### 4.3 Before/after: `ProductZeroConstraint` (`cpu.rs:96-113`) + +**Before:** `step.get_main(0,col_a) * step.get_main(0,col_b)`. +**After:** +```rust +fn capture(&self, b: &mut IrBuilder) { + let a = b.main(0, self.col_a); let c = b.main(0, self.col_b); + let r = b.mul(a, c); b.emit(self.constraint_idx, r); +} +``` + +### 4.4 Before/after: a more complex algebraic constraint — `MulConstraint::RawProduct` (`mul.rs:766-844`) + +This is the representative "mega-constraint": a `kind` enum dispatched in +`compute()` (`mul.rs:721`), with a convolution body whose `for k` / `for j` loops +are bounded by compile-time `i` (not data). Capturing it **runs those loops once**, +unrolling them into a flat IR chain: + +```rust +// raw_product[i] - Σ_k 2^(16k) Σ_j lhs_ext[j]·rhs_ext[idx-j] +fn capture_raw_product(&self, i: usize, b: &mut IrBuilder) -> Expr { + let lhs = [cols::LHS_0, cols::LHS_1, cols::LHS_2, cols::LHS_3].map(|c| b.main(0, c)); + let rhs = [cols::RHS_0, cols::RHS_1, cols::RHS_2, cols::RHS_3].map(|c| b.main(0, c)); + let ln = b.main(0, cols::LHS_IS_NEGATIVE); + let rn = b.main(0, cols::RHS_IS_NEGATIVE); + let sf = b.const_base(SIGN_FILL); + let mut lhs_ext = [b.const_base(0); 8]; + let mut rhs_ext = [b.const_base(0); 8]; + lhs_ext[..4].copy_from_slice(&lhs); rhs_ext[..4].copy_from_slice(&rhs); + for j in 4..8 { lhs_ext[j] = b.mul(sf, ln); rhs_ext[j] = b.mul(sf, rn); } + let shift_16 = b.const_base(SHIFT_16); + let mut sum = b.const_base(0); + for k in 0..=1usize { + let idx = 2*i + k; + if idx < 8 { + let mut inner = b.const_base(0); + for j in 0..=idx { if j < 8 && idx-j < 8 { + inner = b.add(inner, b.mul(lhs_ext[j], rhs_ext[idx-j])); } } + sum = if k==0 { b.add(sum, inner) } else { b.add(sum, b.mul(inner, shift_16)) }; + } + } + let raw = b.main(0, raw_col_for(i)); + b.sub(raw, sum) +} +``` +**This is the central churn-reducing insight: no algebraic body has data-dependent +control flow.** Every loop bound, conditional, and column index is a function of +`self` only. So `capture` is a *mechanical mirror* of the existing body: swap +`FieldElement` constructors for builder leaves and `+ - *` for `b.add/sub/mul`. The +`kind`-enum dispatch in `compute` becomes a `kind`-enum dispatch in `capture`. + +### 4.5 Full rewrite scope (counts verified by grep + reads) + +`grep -rn "impl TransitionConstraint"` across `prover/src/` yields the following +**distinct constraint structs** (each implements the user trait once; structs with a +`kind` enum produce many constraint *instances* but are ONE body to rewrite): + +**`prover/src/constraints/` (11 structs):** +- `templates.rs`: `IsBitConstraint`, `AddConstraint` (+ `AddOperand`/`AddLinearTerm` + helper enums — these get `capture` helpers, not trait impls). +- `cpu.rs`: `ProductZeroConstraint`, `Arg2ExclusiveConstraint`, `MemFlagsBitConstraint`, + `RegNotReadIsZeroConstraint`, `Arg2Constraint`, `RvdEqResConstraint`, + `BranchRvdConstraint`, `BranchCondConstraint`, `NextPcAddConstraint` (+ helper + `res_word`). All small (≤ ~30-line bodies). + +**`prover/src/tables/` (per verified grep, 21 impl sites across 17 files; some files +hold several structs):** +- `mul.rs` `MulConstraint` (kind enum, ~250 lines incl. helpers; convolution). +- `dvrm.rs` `DvrmConstraint` (kind enum, 11 variants; the biggest, ~1300-line file — + body+helpers the largest single rewrite). +- `shift.rs` `ShiftConstraint` (kind enum; ~1100-line file). +- `cpu32.rs` `Cpu32Constraint` (kind enum; ~845-line file). +- `memw.rs` `MemwConstraint`; `memw_aligned.rs` `MemwAlignedConstraint`; + `memw_register.rs` `MemwRegisterMuSumIsBit`. +- `load.rs` `LoadConstraint`; `store.rs` `StoreConstraint`. +- `lt.rs` `LtConstraint`; `eq.rs` `EqXorConstraint`. +- `branch.rs` `BranchConstraint`; `commit.rs` `CommitConstraint`. +- `keccak.rs` `KeccakAddressNoOverflowConstraint` (one small struct). NOTE: keccak's + 51 constraints are mostly **reused `AddConstraint` instances** (`from_dword_bl` + + `constant` + `from_dword_hl`, verified `keccak.rs:545-557`) — so keccak adds almost + no rewrite cost once `AddConstraint::capture` exists, and its program is small (no + GPU register-pressure risk). +- `ec_scalar.rs` `MulZeroConstraint`. +- `ecsm.rs`: `ConvCarry`, `ColIsZero`, `CarryBit`, `OverflowRequired` (4 structs). +- `ecdas.rs`: `ConvCarry`, `ColIsZero`, `MulZero` (3 structs). + +**Authoritative count (enumeration-verified): 33 algebraic +`impl TransitionConstraint` structs across 19 files + 2 framework LogUp +`TransitionConstraintEvaluator` structs (§5).** Breakdown: +- `prover/src/constraints/cpu.rs` (9): ProductZero, Arg2Exclusive, MemFlagsBit, + RegNotReadIsZero, Arg2, RvdEqRes, BranchRvd, BranchCond, NextPcAdd. +- `prover/src/constraints/templates.rs` (2): IsBit, Add (Add carries the + AddOperand/AddLinearTerm combinators — the trickiest single rewrite, §4.2). +- `prover/src/tables/` (22): Branch, Commit, Cpu32, Dvrm, EqXor, MulZero(ec_scalar), + ConvCarry+ColIsZero+MulZero(ecdas, 3), ConvCarry+ColIsZero+CarryBit+OverflowRequired + (ecsm, 4), KeccakAddressNoOverflow, Load, Lt, Memw, MemwAligned, MemwRegisterMuSumIsBit, + Mul, Shift, Store. + +**Scope driver — multi-kind dispatch structs** (one struct, a `kind` enum + a +`compute()` helper; each kind is a separate constraint *instance* needing its own +capture path). Verified kind counts: Dvrm(11), Cpu32(8), Shift(7), Lt(6), Load(6), +Mul(6), Branch(5), Memw(3), MemwAligned(3), Store(2). Dvrm/Cpu32/Shift dominate. +**Rough total evaluate/compute body LOC ≈ 600-800 across the 19 files** — far less +than the raw file sizes suggest, because the kind-enum bodies are short matches that +delegate to `compute()`, and the heavy loops (carry chains, raw-product convolution, +shift formulas) are **statically bounded / metadata-driven**, so they unroll into +builder calls at capture time without per-kind hand-coding of each iteration. +(I read `mul.rs`/`dvrm.rs` bodies in full; the rest share the single-`evaluate` +→`compute()` pattern, kind counts enumeration-verified.) + +--- + +## 5. Rewriting the LogUp / extension framework constraints (Question 4 — the crux) + +These two live in `crypto/stark/src/lookup.rs` and are the only constraints that use +extension arithmetic, challenges, alpha powers, and (for the accumulated one) +next-row reads. They are **not** `TransitionConstraint` impls — they directly +implement the object-safe `TransitionConstraintEvaluator` (`lookup.rs:1741`, +`lookup.rs:1868`). So for these we write `capture` directly on the evaluator impl. +I read both bodies in full; here is how each maps. + +### 5.1 Fingerprint, multiplicity, sign — the shared pieces + +`compute_fingerprint_from_step` (`lookup.rs:1689-1709`) builds +`z − (bus_id + Σ α^k · vₖ)` where `vₖ` are the packed bus elements. In IR: + +```rust +// fingerprint(interaction) -> Expr (Dim3, because z and alpha powers are Dim3) +fn capture_fingerprint(b: &mut IrBuilder, bi: &BusInteraction) -> Expr { + let z = b.challenge(0); // rap_challenges[0] + // α⁰ term: bus_id is a base const, added directly (matches lookup.rs:1697) + let mut lc = b.bus_id(bi.bus_id); // Dim1 const, promoted on first add to Dim3 + let mut alpha_idx = 1usize; // α⁰ handled, start at α¹ (lookup.rs:1698) + for bv in &bi.values { + alpha_idx += capture_busvalue_fingerprint(b, bv, alpha_idx, &mut lc); + } + b.sub(z, lc) // z - lc (Dim3) +} +``` + +`BusValue::accumulate_fingerprint_from_step` (`lookup.rs:738-796`) and +`Packing::accumulate_fingerprint_with` (`lookup.rs:272-369`) are the packing +formulas. They are pure compile-time structure (the `match self { Packing::Word2L => +h0 + 2^16·h1, … }`), so capturing them unrolls the same way as §4.4: + +```rust +fn capture_busvalue_fingerprint(b: &mut IrBuilder, bv: &BusValue, + alpha_off: usize, lc: &mut Expr) -> usize { + match bv { + BusValue::Packed { start_column, packing } => { + // mirror accumulate_fingerprint_with: e.g. Word2L: + // combined = col[start] + col[start+1]·shift_16 (Dim1) + // *lc += combined · alpha_powers[alpha_off] (Dim1 · Dim3 -> Dim3) + let elems = capture_packing(b, *packing, *start_column); // Vec (Dim1) + for (i, e) in elems.iter().enumerate() { + let ap = b.alpha_power(alpha_off + i); // Dim3 + let t = b.mul(*e, ap); // Dim1·Dim3 -> Dim3 + *lc = b.add(*lc, t); + } + packing.num_bus_elements() + } + BusValue::Linear(terms) => { + // result = Σ coeff·col + const (Dim1), then *lc += result·α^alpha_off + let mut r = b.const_base(0); + for t in terms { match t { + LinearTerm::Column{coefficient, column} + | LinearTerm::ColumnUnsigned{coefficient, column}=> { + let col=b.main(0,*column); let k=b.const_signed(*coefficient as i64); + r=b.add(r, b.mul(col,k)); } + LinearTerm::Constant(v)=> r=b.add(r, b.const_signed(*v)), + }} + let ap=b.alpha_power(alpha_off); *lc=b.add(*lc, b.mul(r, ap)); + 1 + } + } +} +``` +> **Honesty note on the runtime zero-skip:** the current code skips the +> `result · α` multiply when `result == 0` *on that row* (`lookup.rs:675-677`, +> `790-792`). That is a *data-dependent* optimization the IR **cannot** reproduce — +> the IR is row-agnostic. The IR always emits the multiply. This is the one place +> the capture approach is strictly less optimal than the current per-row code: a +> few extra D1×D3 muls per row for bus elements that happen to be zero. It does +> **not** change the result (adding `0·α` is a no-op), only cost. Quantify in +> validation; likely negligible vs. the dispatch savings. (✓ VERIFIED the skip +> exists and is value-preserving.) + +`compute_multiplicity_from_step` (`lookup.rs:1679-1684`) = `Multiplicity::evaluate_with` +(`lookup.rs:1252-1282`): `One→1`, `Column→col`, `Sum→a+b`, `Negated→1-col`, +`Diff→a-b`, `Sum3→a+b+c`, `Linear→Σ`. All Dim1, captured as add/sub/mul chains. + +**The sign** (`is_sender`) is a **compile-time bool on the interaction**, so it is +resolved during capture by choosing `add` vs `sub` (or wrapping in `neg`) — never an +IR value. This matches the current "conditional negation instead of E×E sign +multiplication" (`lookup.rs:1779-1790`). + +### 5.2 `LookupBatchedTermConstraint::capture` (was `lookup.rs:1754-1831`) + +Formula (verified `lookup.rs:1791`): +`c·fp_a·fp_b − sign_a·m_a·fp_b − sign_b·m_b·fp_a`. + +```rust +fn capture(&self, b: &mut IrBuilder) { + let c = b.aux(0, self.term_column_idx); // Dim3 + let fp_a = capture_fingerprint(b, &self.interaction_a); + let fp_b = capture_fingerprint(b, &self.interaction_b); + let m_a = capture_multiplicity(b, &self.interaction_a.multiplicity); // Dim1 + let m_b = capture_multiplicity(b, &self.interaction_b.multiplicity); + let term_a = b.mul(m_a, fp_b); // Dim1·Dim3 -> Dim3 + let term_a = if self.interaction_a.is_sender { term_a } else { b.neg(term_a) }; + let term_b = b.mul(m_b, fp_a); + let term_b = if self.interaction_b.is_sender { term_b } else { b.neg(term_b) }; + let main = b.mul(b.mul(c, fp_a), fp_b); + let root = b.sub(b.sub(main, term_a), term_b); + b.emit(self.constraint_idx, root); +} +``` +Clean. Degree 3, all Dim3 at the top, exactly mirrors the read body. + +### 5.3 `LookupAccumulatedConstraint::capture` (was `lookup.rs:1881-2005`) — the messy one + +This is the only constraint that reads **two row offsets** (`acc_curr` at offset 0, +`acc_next` and the term columns at offset 1) — verified +`first_step.get_aux(0, acc)` / `second_step.get_aux(0, …)` where `first_step = +frame.get_evaluation_step(0)` and `second_step = frame.get_evaluation_step(1)` +(`lookup.rs:1971-1972`, `1899-1905`). The IR addresses next-row values with +`b.aux(1, col)` — this is exactly why `Op::Main/Aux` carry an `offset: u8` and why +the program records `max_offset` (the interpreter must fill a 2-step frame for these +tables; the prover already builds frames with `offsets = [0,1]`, see +`AirWithBuses` context `transition_offsets: vec![0,1]`, `lookup.rs:909`). + +```rust +fn capture(&self, b: &mut IrBuilder) { + let acc_curr = b.aux(0, self.acc_column_idx); // offset 0 + let acc_next = b.aux(1, self.acc_column_idx); // offset 1 <-- next row + // terms_sum over committed term columns at offset 1 (lookup.rs:1903) + let mut terms = b.const_base(0); + for i in 0..self.num_term_columns { terms = b.add(terms, b.aux(1, i)); } + // delta = acc_next - acc_curr - terms_sum + L/N + let off = b.table_offset(); // logup_table_offset (Dim3) + let delta = b.add(b.sub(b.sub(acc_next, acc_curr), terms), off); + match self.absorbed.len() { + 1 => { // delta·f - sign·m (lookup.rs:1932) + let f = capture_fingerprint_at(b, &self.absorbed[0], /*offset*/1); + let m = capture_multiplicity_at(b, &self.absorbed[0].multiplicity, 1); + let mt = if self.absorbed[0].is_sender { m } else { b.neg(m) }; + let root = b.sub(b.mul(delta, f), mt); + b.emit(self.constraint_idx, root); + } + 2 => { // delta·f1·f2 - sign1·m1·f2 - sign2·m2·f1 (lookup.rs:1957) + let f1=capture_fingerprint_at(b,&self.absorbed[0],1); + let f2=capture_fingerprint_at(b,&self.absorbed[1],1); + let m1=capture_multiplicity_at(b,&self.absorbed[0].multiplicity,1); + let m2=capture_multiplicity_at(b,&self.absorbed[1].multiplicity,1); + let t1=b.mul(m1,f2); let t1=if self.absorbed[0].is_sender{t1}else{b.neg(t1)}; + let t2=b.mul(m2,f1); let t2=if self.absorbed[1].is_sender{t2}else{b.neg(t2)}; + let root=b.sub(b.sub(b.mul(b.mul(delta,f1),f2),t1),t2); + b.emit(self.constraint_idx, root); + } + _ => unreachable!(), + } +} +``` +> **The messiness, stated honestly:** +> 1. `capture_fingerprint`/`capture_multiplicity` need an **offset parameter** because +> the absorbed interactions read columns at the *next* row (`second_step`, +> `lookup.rs:1919-1946`), whereas the batched-term constraint reads the *current* +> row. The fingerprint/packing capture helpers (§5.1) must thread `offset: u8` +> through to every `b.main(offset, …)`/`b.aux(offset, …)`. This is a real but +> mechanical generalization (one extra arg). +> 2. The `1` vs `2` absorbed cases have different degree (2 vs 3) and different +> formulas; both must be captured (matches the existing `match absorbed.len()`). +> 3. `logup_table_offset` becomes the `TableOffset` uniform leaf (§8). It is `L/N`, +> a single Dim3 value computed in `ConstraintEvaluator::new` (`evaluator.rs:199`) +> and passed via the context — already a per-proof uniform. +> +> **Verdict:** LogUp maps to the builder *cleanly but with one wart* — the per-row +> zero-skip (§5.1) is lost, and the fingerprint helpers must be offset-parameterized. +> Neither blocks the approach; both are mechanical. This is materially **less messy** +> than fighting `IsField` to make a shadow-field type carry the same z/α/alpha-power +> uniforms through `compute_fingerprint_from_step`'s generic `>` +> signature (the sibling approach's burden). The deciding factor leans toward this +> approach because the capture is a near-verbatim transcription of the existing, +> already-factored helpers. + +--- + +## 6. CPU interpreter & the boundary (Question 5) + +### 6.1 Where it slots in + +The boundary is exactly `air.compute_transition_prover` (prover, `traits.rs:254`) +and `air.compute_transition` (verifier, `traits.rs:223`). Today both loop over +`transition_constraints()` calling `evaluate_prover`/`evaluate_verifier`. After the +rewrite, `AirWithBuses` (the only production AIR, `lookup.rs:964`) overrides both to +call the interpreter against its stored `TableProgram`: + +```rust +fn compute_transition_prover(&self, ctx, base_evals, ext_evals) { + interpret_prover(&self.program, ctx, base_evals, ext_evals); +} +fn compute_transition(&self, ctx) -> Vec> { + let mut ext = vec![FieldElement::zero(); self.num_transition_constraints()]; + interpret_verifier(&self.program, ctx, &mut ext); + ext +} +``` + +`ConstraintEvaluator::evaluate_transitions` (`evaluator.rs:79-135`) is **unchanged**: +it still calls `air.compute_transition_prover(&ctx, base_buf, transition_buf)` +(`evaluator.rs:100`) and accumulates with zerofiers (`evaluator.rs:102-134`). The IR +sits entirely inside the AIR's override. + +### 6.2 Base vs ext handling — two interpreters, shared walk + +- **Prover** frame is `Frame`: `main` reads are **Dim1** + (base), `aux` reads are **Dim3**. So `interpret_prover` evaluates each node into + either a `u64` (D1) or `[u64;3]` (D3) slot. The first `num_base` constraints are + D1-rooted and written into `base_evals: &mut [FieldElement]`; the rest are + D3-rooted into `ext_evals[num_base..]`. This reproduces the existing F×E split + (`evaluator.rs:104-114`, `transition.rs:439-458`). Verified: base constraints + must be the first `num_base_transition_constraints()` and the LogUp constraints + are appended last (`lookup.rs:857`, `traits.rs:244`). +- **Verifier** frame is `Frame`: there is no base field; every value + is Dim3 (the verifier "works with a frame that contains only elements from the + extension", `traits.rs:69-71`). So `interpret_verifier` runs the *same* node walk + but treats `Main` reads as Dim3 (the column value is already an ext element) and + every op as Dim3. The IR's per-node `dim` is the prover's typing; the verifier + simply promotes D1 leaves to D3. One IR, two interpreters differing only in leaf + loading and whether D1 storage is used. + +Implementation: a value arena `Vec` where `enum Val { D1(u64), D3([u64;3]) }`, +or two parallel arenas (`Vec` for D1 ids, `Vec<[u64;3]>` for D3 ids) keyed by +node dim. Arithmetic dispatches on `(dim(a),dim(b))` using the raw Goldilocks ops +(`GoldilocksField::add/mul`, the cubic-ext formulas). Reuse the per-thread buffer +pattern already in `evaluate_transitions` (`map_init`, `evaluator.rs:142`): the value +arena is a per-thread scratch `Vec` sized to `program.nodes.len()`. + +### 6.3 Fate of `TransitionConstraintAdapter` (Question 5) + +**End state:** `TransitionConstraint::evaluate` and the adapter's +`evaluate_prover`/`evaluate_verifier` are **deleted**. The user trait keeps +`degree/constraint_idx/period/offset/exemptions/end_exemptions` + the new `capture`. +`TransitionConstraintEvaluator` keeps the zerofier/degree/index methods + `capture`, +and **drops** `evaluate_prover`/`evaluate_verifier` (the per-row eval path no longer +goes through the trait object — it goes through the interpreter). The adapter shrinks +to a forwarder for `capture` and the metadata methods. + +**Transitional:** during migration we keep both `evaluate*` and `capture` so the old +per-row path and the new interpreter can run in parallel and be diff'd +(§9, §12). Only after every table validates bit-for-bit do we delete the old methods. + +--- + +## 7. GPU interpreter sketch (Question 7) + +Model on the `gpu_lde.rs` seam: TypeId checks gate entry, `repr(transparent)`/`[u64;3]` +layout lets us reinterpret `FieldElement` slices as raw `u64`, and a `_keep` device +handle holds the LDE columns resident from R1. + +- **Entry/dispatch.** A new `try_compute_transition_gpu(program, lde_trace, uniforms)` + guarded by `TypeId::of::()==Goldilocks && TypeId::of::()==Ext3` and an + lde-size threshold (mirror `check_base_layout`, `gpu_lde.rs:106`). Returns + `Option>>` of length `num_transition · lde_size` (the per-row + `Cᵢ` values), or `None` to fall back to the CPU interpreter. It is called from the + AIR's `compute_transition_prover` analog — but note the current + `evaluate_transitions` calls `compute_transition_prover` *per row*; for GPU we add a + batched override that produces all rows at once and feeds the accumulation loop + (this is a small refactor of `evaluate_transitions` to optionally accept a + precomputed `Cᵢ` matrix; the accumulation stays on whichever side is cheaper). + `✗ UNCERTAIN`: exact placement of the batched call (per-row vs whole-table) needs a + design pass — the cleanest is a new `air.compute_transitions_batched(lde) -> + Option` that `evaluate_transitions` tries before the per-row loop. +- **What crosses the boundary (once per table).** The program blob (`GpuOp[]` + + `consts_d1` + `consts_d3`), the uniforms (challenges, alpha_powers, periodic + columns, table_offset, packing shifts-as-consts). The LDE main/aux columns are + already on device (`lde_trace.gpu_main()`/`gpu_aux()`, `gpu_lde.rs:832,915`). No + per-row H2D. +- **Kernel.** One `interpret_transition_ext3` kernel, one thread per LDE row + (strided like `barycentric_*_strided`). Each thread walks `nodes` left-to-right + into a small per-thread register/local array indexed by NodeId (program is tiny — + hundreds of nodes — fits in local/shared memory), loading `Main/Aux` from the + resident LDE at `(row + offset·stride)`, doing D1/D3 ops with the existing device + primitives (`gl_add/gl_mul/gl_sub` and `ext3_add/ext3_mul/ext3_sub`, verified + present `device.rs:124-131`). Writes `Cᵢ` for each emit. Because the program is + uniform across rows, this is an embarrassingly parallel single-field kernel — the + whole point of the IR. New `.cu` file `transition_interp.cu` + `Backend` field + + `load_function` (mirror `device.rs:227-229`). +- **Fallback.** Any unsupported op/dim, sub-threshold size, or non-Goldilocks → CPU + interpreter (identical IR, identical result). Same `Option`-returning contract as + every `try_*` in `gpu_lde.rs`. + +--- + +## 8. Inputs plumbing (Question 6) + +The interpreter needs the per-proof/per-row uniforms that today live in +`TransitionEvaluationContext` (`traits.rs:72-93`). They become **leaf opcodes** read +from a uniform table the interpreter is handed alongside the program: + +| Current source (verified) | IR leaf | Const-vs-varies | +|---|---|---| +| `periodic_values[j]` (`evaluator.rs:88-90`, filled per row) | `Op::Periodic{j}` | varies per row (Dim1) | +| `rap_challenges[i]` (`ctx`, `traits.rs:80`) | `Op::Challenge{i}` | per proof (Dim3) | +| `logup_alpha_powers[k]` (precomputed `evaluator.rs:53`) | `Op::AlphaPow{k}` | per proof (Dim3) | +| `logup_table_offset` (`evaluator.rs:199`, `traits.rs:82`) | `Op::TableOffset` | per proof (Dim3) | +| `packing_shifts` (8/16/24, `lookup.rs:53`) | `Op::ConstD1` | program constant | + +The interpreter signature: +```rust +fn interpret_prover(prog: &TableProgram, ctx: &TransitionEvaluationContext, + base: &mut [FieldElement], ext: &mut [FieldElement]); +``` +pulls `frame`, `periodic_values`, `rap_challenges`, `logup_alpha_powers`, +`logup_table_offset` straight out of `ctx` (already plumbed through +`evaluate_transitions`, `evaluator.rs:92-99`). **No new plumbing into the +evaluator** — the context already carries everything; we only change what *consumes* +it. For GPU, these uniforms upload once per table (challenges/alpha/offset are +per-proof; periodic is `num_periodic · lde_size` Dim1, uploaded once). + +--- + +## 9. Coexistence & migration (Question 9) + +- **Table-by-table migration is fully supported.** The interpreter dispatch is on the + AIR. We add `capture` to all constraints up front (it can default to a `todo!()` + or, better, a generic auto-capture, see below), but flip an AIR to *use* the + interpreter independently. Concretely, `AirWithBuses` gets an `Option`: + when `Some`, `compute_transition_prover` interprets; when `None`, it falls back to + the existing `transition_constraints().iter()…evaluate_prover` loop (the current + `traits.rs:267-269` default). So a table is "migrated" by building its program in + `AirWithBuses::new`; unmigrated tables keep the old path verbatim. +- **Auto-capture bridge (optional but valuable):** because every algebraic body is + data-independent, we *could* provide a blanket `capture` that runs the existing + generic `evaluate` against a recording `TableView` whose elements are IR nodes — + i.e. a `TableView` where `IrField` is a field-like type whose + `add/mul` push IR nodes. **However** that is precisely the "shadow IsField" trick + the sibling approach owns, and making `IrField: IsField` is the trait-tower fight + we're avoiding. So for *this* approach we hand-write `capture` per struct and do + **not** rely on an auto-bridge. (Mentioned for completeness; explicitly rejected + here to keep the approaches distinct.) +- **Feature/TypeId gating:** GPU path behind the existing `cuda` feature + TypeId + guard (no new feature). CPU interpreter is unconditional. A `LAMBDA_VM_USE_IR` + env/feature can force the old path for A/B benchmarking during migration. + +--- + +## 10. Exhaustive file-by-file change list + +**New files:** +- `crypto/stark/src/ir.rs` — `Dim`, `NodeId`, `Op`, `Node`, `Expr`, `IrBuilder` + (full API §3.2), `TableProgram`, const/CSE dedup. `~400 LOC`. +- `crypto/stark/src/interpreter.rs` — `interpret_prover`, `interpret_verifier`, + `Val` arena, op dispatch, D1/D3 raw arithmetic helpers. `~300 LOC`. +- `crypto/math-cuda/src/transition_interp.rs` + `cuda/transition_interp.cu` — GPU + kernel + host wrapper `compute_transition_ext3`. `~400 LOC + kernel`. +- `crypto/stark/src/gpu_transition.rs` — `try_compute_transition_gpu` dispatch + (TypeId guard, blob upload, fallback). `~250 LOC`. (Or fold into `gpu_lde.rs`.) + +**Modified — framework:** +- `crypto/stark/src/constraints/transition.rs`: + - `TransitionConstraint`: add `fn capture(&self, &mut IrBuilder)`; delete + `evaluate` (end state). + - `TransitionConstraintEvaluator`: add object-safe `fn capture(&self, &mut + IrBuilder)`; delete `evaluate_prover`/`evaluate_verifier` (end state). + - `TransitionConstraintAdapter`: forward `capture`; drop `evaluate_*`. +- `crypto/stark/src/lookup.rs`: + - `LookupBatchedTermConstraint`: replace `evaluate_verifier` body with `capture` + (§5.2). `LookupAccumulatedConstraint`: replace with `capture` (§5.3). + - Add offset-parameterized capture helpers mirroring + `compute_fingerprint_from_step` (1689), `compute_multiplicity_from_step` (1679), + `BusValue::accumulate_fingerprint_from_step` (738), `Packing::accumulate_*` (272), + `Multiplicity::evaluate_with` (1252). + - `AirWithBuses`: add `program: Option`; build it in `new` + (`lookup.rs:848`) by `capture`-ing every constraint after assembly; override + `compute_transition_prover`/`compute_transition` to interpret. +- `crypto/stark/src/traits.rs`: optionally add + `fn compute_transitions_batched(&self, lde) -> Option>` default `None` + (GPU batched hook for `evaluate_transitions`). +- `crypto/stark/src/constraints/evaluator.rs`: (optional) try the batched GPU hook + before the per-row loop; otherwise **unchanged**. +- `crypto/stark/src/lib.rs` / `crypto/math-cuda/src/lib.rs`: module decls. + +**Modified — every constraint struct (`capture` body, delete `evaluate`):** +- `prover/src/constraints/templates.rs`: `IsBitConstraint`, `AddConstraint`, + `AddOperand::capture_lo/hi`, `AddLinearTerm::capture`, `eval_terms`→`capture`. +- `prover/src/constraints/cpu.rs`: `ProductZeroConstraint`, `Arg2ExclusiveConstraint`, + `MemFlagsBitConstraint`, `RegNotReadIsZeroConstraint`, `Arg2Constraint`, + `RvdEqResConstraint`, `BranchRvdConstraint`, `BranchCondConstraint`, + `NextPcAddConstraint`, `res_word`→capture helper. +- `prover/src/tables/`: `mul.rs (MulConstraint+compute helpers)`, `dvrm.rs + (DvrmConstraint)`, `shift.rs (ShiftConstraint)`, `cpu32.rs (Cpu32Constraint)`, + `memw.rs (MemwConstraint)`, `memw_aligned.rs (MemwAlignedConstraint)`, + `memw_register.rs (MemwRegisterMuSumIsBit)`, `load.rs (LoadConstraint)`, + `store.rs (StoreConstraint)`, `lt.rs (LtConstraint)`, `eq.rs (EqXorConstraint)`, + `branch.rs (BranchConstraint)`, `commit.rs (CommitConstraint)`, + `keccak.rs (one struct)`, `ec_scalar.rs (MulZeroConstraint)`, + `ecsm.rs (ConvCarry, ColIsZero, CarryBit, OverflowRequired)`, + `ecdas.rs (ConvCarry, ColIsZero, MulZero)`. + +**Key new type/function signatures (summary):** +```rust +pub struct TableProgram { nodes, consts_d1, consts_d3, emits, num_base, … } +pub struct IrBuilder { … } impl IrBuilder { main/aux/const_*/periodic/challenge/alpha_power/table_offset/add/sub/mul/neg/emit/finish } +pub fn interpret_prover(&TableProgram, &TransitionEvaluationContext, &mut[FE], &mut[FE]); +pub fn interpret_verifier(&TableProgram, &TransitionEvaluationContext, &mut[FE]); +trait TransitionConstraint { fn capture(&self, &mut IrBuilder); … } // generic evaluate removed +trait TransitionConstraintEvaluator { fn capture(&self, &mut IrBuilder); … } // evaluate_* removed +pub(crate) fn try_compute_transition_gpu(&TableProgram, &LDETraceTable, …) -> Option>>; +``` + +--- + +## 11. Risks & unknowns, ranked (brutally honest) + +1. **Breadth of the manual rewrite (33 structs / 19 files, ~600-800 LOC of bodies).** + This is the dominant cost and risk. Every body is mechanical but the multi-kind + mega-constraints (`dvrm` 11 kinds, `cpu32` 8, `shift` 7) have many capture paths + that are easy to transcribe subtly wrong. *Mitigation:* the bit-for-bit + parallel-path validation (§12) catches any divergence immediately; migrate one + table at a time behind the `Option` flag. +2. **LogUp `LookupAccumulatedConstraint` offset handling + lost per-row zero-skip.** + The fingerprint helpers must thread `offset` (next-row reads, §5.3) and the IR + cannot do the data-dependent `result==0` multiply-skip (§5.1). Correctness is + safe (value-preserving); the cost is a few extra D1×D3 muls/row. *Risk:* the + skip might matter more than expected on wide-bus tables; measure before deleting + the old path. `? INFERRED` it's negligible vs. dispatch savings — not yet + measured. +3. **Verifier-side typing (`Main` reads are Dim3 in the verifier).** The IR's + per-node `dim` is the prover's; the verifier interpreter must promote D1 leaves + to D3 and run everything as D3. If any constraint body relied on F-specific + behavior (e.g. `inv()` in base field) this would break — but I verified the + algebraic bodies only use `+ - * ` and `const` (the only "division" is + multiply-by-precomputed-`INV_SHIFT_32` const, `templates.rs:30`, which is just a + `Mul` by a constant — safe in any field). ✓ VERIFIED no body calls `inv()` at + eval time. +4. **GPU kernel program-size / divergence.** Programs are small (hundreds of nodes) + and uniform across rows (no divergence), but the per-thread value arena must fit + in registers/local mem; a large mega-constraint program (`dvrm`/`shift` are the + biggest) could spill. Keccak is NOT a concern (mostly reused `AddConstraint`s, + verified). *Mitigation:* per-thread arena lives in shared/local mem indexed by + NodeId; CPU fallback always available; GPU is opt-in per table above a threshold. +5. **Refactor of `evaluate_transitions` for the batched GPU hook.** The current loop + is per-row (`evaluator.rs:79`); a whole-table GPU call needs either a batched + path or accepting a precomputed `Cᵢ` matrix. `✗ UNCERTAIN` on the cleanest seam; + CPU interpreter needs none of this (it slots into the existing per-row call). +6. **CSE/const-dedup correctness.** Optional, but if the CSE key mis-merges two ops + with the same shape but different dim, results corrupt. *Mitigation:* key on + `(Op, Dim)`; or ship without CSE first (correctness independent of it). + +--- + +## 12. Effort estimate & validation strategy + +### Effort (by workstream) +- **IR + builder + CPU interpreter (framework):** `ir.rs` + `interpreter.rs` + + trait changes + `AirWithBuses` wiring. **~4-5 days.** Highest design value; + unblocks everything. +- **Rewrite algebraic constraints (33 structs, ~600-800 LOC):** 11 small structs in + `constraints/` (~1 day) + 22 in `tables/`, of which ~10 are multi-kind dispatch + structs. Budget the small ones at ~10-15/day; the multi-kind ones at ~0.5-1.5 each + (dvrm/cpu32/shift the costliest): **~4-6 days** total. +- **LogUp framework (2 constraints + offset-parameterized helpers):** **~2-3 days** + — small count but the highest per-line care (the crux). +- **Validation harness (parallel old/new diff):** **~1-2 days.** +- **GPU interpreter (kernel + dispatch + batched hook):** **~5-7 days** incl. the + `evaluate_transitions` batched seam and parity tests. Can land *after* the CPU + path is fully migrated and validated. +- **Total: ~2.5-3.5 weeks** for CPU-complete + validated; +1-1.5 weeks for GPU. + +### Validation (bit-for-bit, real tables, parallel paths) +1. **Keep the old generic `evaluate*` alongside `capture` during migration.** In a + `#[cfg(test)]` / debug harness, for each table and each LDE row, run BOTH: + the old `compute_transition_prover` (current trait-object loop) and + `interpret_prover(program, …)`, then `assert_eq!` the full `base_evals` and + `ext_evals` arrays. This is exactly the existing `validate_trace`-style + debug-assert pattern referenced in project memory; here it asserts + *evaluator equality* not trace validity. +2. **Drive it with the existing prove test** (`cargo test --release -p + lambda-vm-prover test_prove_elfs_test_sb_sh_8`) and the per-table bus tests + (`prover/src/tests/*_bus_tests.rs`, `*_tests.rs`) — these already exercise every + table's full constraint set on real traces. A mismatch pinpoints the exact + constraint_idx and row. +3. **Verifier parity:** at the OOD point, diff `air.compute_transition` (old) vs + `interpret_verifier` for the same frame — small (one frame), cheap, catches the + D1→D3 promotion bugs (Risk 3). +4. **GPU parity:** standard `gpu_lde.rs` pattern — compute on GPU and on CPU + interpreter, assert equal (the math-cuda test suite already does this per kernel; + add a `transition_interp` parity test). +5. Because the old path *coexists* (Option flag), CI can run both and assert equality + on every prove until we delete the old methods — zero-risk cutover. + +### What I could not confirm +- Struct count (33 algebraic + 2 LogUp / 19 files) and per-struct kind counts are + enumeration-verified; the ~600-800 LOC body total is an aggregate estimate (I read + `mul.rs`/`dvrm.rs` in full; the rest share the `kind`-enum→`compute()` pattern). +- Whether any table reads periodic columns (none of the bodies I read did; the + `Periodic` leaf is provided for completeness — `get_periodic_column_values` + defaults to empty, `traits.rs:290`). `? INFERRED` periodic is unused by current + tables. +- The cleanest `evaluate_transitions` seam for the batched GPU call (Risk 5). +- Keccak constraint body size/shape (didn't read it) — flagged for GPU register + pressure (Risk 4). diff --git a/thoughts/gpu-constraint-eval/plan-symbolic-field.md b/thoughts/gpu-constraint-eval/plan-symbolic-field.md new file mode 100644 index 000000000..1b4c4abad --- /dev/null +++ b/thoughts/gpu-constraint-eval/plan-symbolic-field.md @@ -0,0 +1,893 @@ +# Plan: GPU-ready constraint evaluation via a "Symbolic field" capture + +> **Status:** the CPU spike from this plan is **implemented** (PR #737, branch +> `spike/constraint-ir-symfield`). For the as-built state and the detailed, +> checkbox continuation plan, see **[`roadmap.md`](./roadmap.md)** — that is the +> execution / handoff doc. This file remains the full design rationale. + +**Approach:** keep the ~29 constraint bodies UNCHANGED; introduce a recording +field type `SymField`/`SymExt` whose field operations build an expression graph +instead of computing. Run each constraint's existing generic +`evaluate::(...)` (and the LogUp helpers) ONCE at setup to +capture a flat single-field Goldilocks IR, then INTERPRET that IR on CPU (prover +over the LDE coset + verifier at the OOD point) and on GPU (one universal +Goldilocks interpreter kernel). + +All file/line references below were read directly from the current tree. + +--- + +## 1. Overview & end-state + +After this change, each `AIR` (per table) owns, in addition to its existing +`Vec>`, a captured **constraint program**: +a flat list of typed Goldilocks IR ops plus a per-constraint root id. The program +is built once, at AIR construction, by running every constraint through a +recording field (`SymField`/`SymExt`) and recording the LogUp framework +constraints (`LookupBatchedTermConstraint`, `LookupAccumulatedConstraint`) via +the same recording field. At evaluation time, an **interpreter** walks the IR: +on CPU it replaces the per-row `air.compute_transition_prover(...)` call inside +`ConstraintEvaluator::evaluate_transitions` (crypto/stark/src/constraints/evaluator.rs:100) +and the verifier's `air.compute_transition(...)` call +(crypto/stark/src/verifier.rs:209); on GPU it is one Goldilocks kernel that +reads the serialized IR plus the device-resident LDE columns and produces the +per-constraint `Cᵢ` values. The accumulation `Σ βᵢ·Cᵢ·Zᵢ⁻¹ + boundary` and all +zerofier/coefficient machinery stay exactly where they are in +`evaluate_transitions` — the IR only replaces the step that produces each +constraint's scalar `Cᵢ`. + +``` + ┌─ capture (ONCE, at AIR::new, concrete types known) ─┐ +constraint structs ──► run evaluate::(sym_step) │ +LogUp framework ──► run evaluate_batched/accumulated::(...) │ + records into thread-local arena ──► ConstraintProgram │ + └────────────────────────────────────────────────────┘ + │ (serialize) + ┌─────────────────────────────────┼───────────────────────────────┐ + CPU prover (per LDE row) CPU verifier (1 OOD point) GPU kernel + interp(program, frame) ─► Cᵢ interp(program, ood_frame) ─► Cᵢ interp over device cols + │ │ │ + └─► Σ βᵢ·Cᵢ·Zᵢ⁻¹ (unchanged accumulation in evaluate_transitions / verifier) +``` + +The boxed `dyn TransitionConstraintEvaluator` path is retained verbatim as a +fallback and as the differential-test oracle (Section 9, 12). + +--- + +## 2. The IR (concrete Rust data structures) + +The IR is **single-field over Goldilocks**, with a dimension tag distinguishing +base (`dim1`, one u64) from extension (`dim3`, three u64). New crate module: +`crypto/stark/src/symbolic/ir.rs`. + +```rust +/// Field-arithmetic dimension of a node's value. +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum Dim { D1, D3 } // base Goldilocks, or its degree-3 extension + +/// A leaf input slot, resolved by the interpreter against the current frame +/// and the per-proof uniform inputs. +#[derive(Clone, Copy, Debug)] +pub enum Leaf { + /// Main trace column read: step.data[row][col], offset selects frame step. + Main { step: u8, row: u8, col: u16 }, // dim1 (base) for prover, dim3 for verifier + /// Aux trace column read: step.aux_data[row][col]. + Aux { step: u8, row: u8, col: u16 }, // always dim3 + /// Periodic column value at this row. + Periodic { idx: u16 }, // dim1 + /// rap_challenges[idx] (z, alpha, ...) + Rap { idx: u16 }, // dim3 + /// logup_alpha_powers[idx] + AlphaPow { idx: u16 }, // dim3 + /// logup_table_offset + TableOffset, // dim3 + /// One of the three precomputed packing shift constants (2^8, 2^16, 2^24) + Shift { which: u8 }, // dim1 (prover) / dim3 (verifier) +} + +/// One IR instruction. Indices are u32 ids into the program's `nodes` arena. +#[derive(Clone, Copy, Debug)] +pub enum Op { + Const1(u64), // dim1 literal (from FieldElement::from(u64/i64), one(), zero()) + Const3([u64; 3]), // dim3 literal (rare: produced by to_extension / from(u64) in E) + Leaf(Leaf), + Add(u32, u32), + Sub(u32, u32), + Mul(u32, u32), + Neg(u32), + // Embed a dim1 value into dim3 (the to_extension() / IsSubFieldOf::embed step, + // and the implicit base→ext promotion that F×E ops perform). + Embed(u32), +} + +/// A captured program for one table's transition constraints. +pub struct ConstraintProgram { + pub nodes: Vec, // topologically ordered (id i only references < i) + pub dims: Vec, // dims[i] = result dimension of nodes[i] + pub roots: Vec, // roots[c] = node id of constraint c's value Cᵢ + pub num_base: usize, // first num_base roots are dim1 (base-field) constraints + // metadata needed to size interpreter input arrays: + pub max_step: u8, pub max_main_col: u16, pub max_aux_col: u16, +} +``` + +**Typing rule.** Every node carries a `Dim`. `Add/Sub/Mul` of (D1,D1)→D1; +any operand D3 ⇒ result D3 (the interpreter auto-`Embed`s the D1 operand, +matching the `F: IsSubFieldOf` mixed-arithmetic the field tower performs at +crypto/math/src/field/element.rs:344). `Embed(D1)→D3`. This mirrors the real +arithmetic exactly: a base×ext multiply is 3 Goldilocks muls (the +`IsSubFieldOf::mul` at crypto/math/src/field/extensions_goldilocks.rs:413), an +ext×ext multiply is one `dot_product_3` schoolbook (extensions_goldilocks.rs:297). + +**Serialization for GPU.** `nodes` is encoded as a packed `Vec` opcode +stream: `[opcode_tag, operand_a, operand_b]` (3×u32 per node; `Const1`/`Const3` +store their literal in a side `Vec` indexed by operand_a). `dims` is a +`Vec`. `roots` is a `Vec`. This is a flat POD layout that copies to the +device as three buffers (`ops: &[u32]`, `consts: &[u64]`, `roots: &[u32]`), +following the same "reinterpret as `&[u64]`/`&[u32]`, transmute-free POD" +discipline used by the GPU LDE bridge in crypto/stark/src/gpu_lde.rs. + +--- + +## 3. Capture front-end — `SymField` design (the distinguishing section) + +`SymField` is a **marker type** that implements `IsField`, exactly like +`GoldilocksField` is a zero-sized marker whose `BaseType = u64` +(crypto/math/src/field/goldilocks.rs:70-73). The constraint bodies are generic +over the *field marker* `F` and operate on `FieldElement`, whose data is +`F::BaseType` (crypto/math/src/field/element.rs:50-52). So we choose: + +```rust +pub struct SymField; // base-field recorder (dim1) +pub struct SymExt; // extension recorder (dim3) +impl IsField for SymField { type BaseType = SymId; ... } +impl IsField for SymExt { type BaseType = SymId; ... } +``` + +where `SymId` wraps a `u32` node id (plus the `Dim` it denotes, see arena +decision). Because `BaseType` is just an id, every `IsField::add/mul/...` call +*records* a node into a thread-local arena and returns a fresh id. + +### Q1 — ARENA PROBLEM: thread-local arena returning u32 ids (chosen) + +`IsField` ops are static, contextless `fn mul(a: &BaseType, b: &BaseType) -> BaseType` +(crypto/math/src/field/traits.rs:104-112). There is no `&self`/arena parameter +to thread. Two options: + +* **`BaseType = Arc` (tree, hash-consed).** Each op allocates an `Arc` + node holding its children `Arc`s. Dedup requires hash-consing through a + thread-local `HashMap>`. *Downsides:* Arc clone/drop traffic + during capture, recursion in `Drop` for deep trees, and we *still* need a + thread-local for the hash-cons table — so it buys nothing over ids while + costing pointer-chasing and an `Arc` per node. Rejected. + +* **Thread-local arena returning `u32` ids (CHOSEN).** A `thread_local!` arena: + + ```rust + thread_local! { + static ARENA: RefCell> = const { RefCell::new(None) }; + } + struct Arena { + nodes: Vec, + dims: Vec, + cse: HashMap, // hash-consing: (opcode + operand ids) → id + } + ``` + + `BaseType` is a small `Copy` struct: + + ```rust + #[derive(Clone, Copy, Debug, Default)] + pub struct SymId { id: u32, dim: Dim } // Default = id 0 ... see Q2 Default note + ``` + + Each op does `ARENA.with(|a| { let a = a.borrow_mut().as_mut().unwrap(); + a.push(Op::Mul(x.id, y.id)) })` where `push` consults `cse` for dedup + (hash-consing gives a DAG, not a tree, for free). Capture is wrapped in + `with_arena(|| { ... run constraints ...; arena.take() })` which installs a + fresh `Arena`, runs the closure, and extracts `(nodes, dims, roots)`. + + This avoids `Arc` entirely, gives DAG dedup via the `cse` map, is `Copy` + (so `.clone()` in constraint bodies — used heavily, e.g. templates.rs:97, + cpu.rs:147 — is free and correct), and the only state lives in one + `thread_local`. Capture runs single-threaded per program (it's a setup-time + one-shot per table), so the thread-local is uncontended. **This is the + pick.** + + Hash-consing is mandatory, not optional: without it the ADD-carry templates + (templates.rs:414-440, `compute_carry_1` recomputes `compute_carry_0`) and + the LogUp fingerprints (each `compute_fingerprint_from_step` re-reads the same + columns) would blow up the node count. With `cse`, `compute_carry_0`'s subtree + is shared. + +### Q2 — TRAIT-METHOD COVERAGE (exhaustive) + +`SymField` must satisfy `IsField` and the `BaseType` bounds. `SymExt` must +satisfy `IsField`. The `IsSubFieldOf for SymField` impl is also needed +because constraint bodies are bounded `F: IsSubFieldOf` and `evaluate` +returns `FieldElement` (transition.rs:352-355). Below, every required method +with its symbolic implementation or a flag. + +**`IsField for SymField` (BaseType = SymId, dim D1):** + +| Method | Symbolic impl | +|---|---| +| `type BaseType = SymId` | id+dim, Copy | +| `add(a,b)` | record `Add(a,b)` → D1 | +| `sub(a,b)` | record `Sub(a,b)` → D1 | +| `mul(a,b)` | record `Mul(a,b)` → D1 | +| `neg(a)` | record `Neg(a)` → D1 | +| `double(a)` | default `add(a,a)` works; or record `Add(a,a)` | +| `square(a)` | default `mul(a,a)` works | +| `zero()` | record/return `Const1(0)` id (default `BaseType::default()`; see note) | +| `one()` | record/return `Const1(1)` id | +| `from_u64(x)` | record `Const1(GoldilocksField::from_u64(x))` id | +| `from_base_type(x)` | identity (return x) | +| `inv(a)` | **PROBLEM if ever called** — emit `Op::Inv` only if needed; NOT used by any algebraic constraint nor by the LogUp framework constraints (verified: no `.inv()`/`.div()`/`.pow()` in prover/src/constraints/; LogUp clears denominators so the constraint bodies never invert — fingerprints are *subtracted/multiplied*, not divided, in `evaluate_batched_term_constraint` lookup.rs:1759 and `evaluate_accumulated_constraint` lookup.rs:1887). **Make `inv` `unimplemented!("symbolic inv")`** — if capture ever hits it we want a loud failure, not silent wrong IR. | +| `div(a,b)` | same: `unimplemented!()` (not reached) | +| `eq(a,b)` | **SUBTLE** — returns `bool`, can't be symbolic. Used by `result != FieldElement::zero()` short-circuits in lookup.rs:675, lookup.rs:790. Must return a **conservative `false`** so the "skip zero term" optimization is *not* taken during capture (we always record the multiply). See Q5; this is correct because the skip is a runtime data optimization, and the captured IR must be data-independent. | +| `pow`, `sqrt`, `legendre_symbol` | not reached; default impls call `mul`/`square` and would work but should never run. | + +**`BaseType: Clone + Debug + ByteConversion + Default + Send + Sync` +(traits.rs:101):** + +| Bound | For `SymId` | +|---|---| +| `Clone + Copy` | derive (it's `{u32, Dim}`) | +| `Debug` | derive | +| `Default` | derive — **but** `Default` is used by `FieldElement::default()` → `value: F::zero()` (element.rs:488) and by `Frame::preallocate` (frame.rs:90-95). During capture we *don't* call `preallocate`; we build a symbolic frame by hand (Q4). A derived `SymId::default()` = `{id:0,dim:D1}` is fine as long as id 0 is a valid node — we reserve node id 0 = `Const1(0)` so a stray default is the zero element. **Resolved, no problem.** | +| `Send + Sync` | `SymId` is `Copy` POD ⇒ auto. The thread-local arena is not part of `SymId`, so no `Send` issue. | +| `ByteConversion` | **FLAG — must implement but never call.** `write_bytes_be/to_bytes_be/from_bytes_be/from_bytes_le` ⇒ `unimplemented!()`. ByteConversion is only exercised by transcript/serialization paths (goldilocks.rs:436), which capture never touches. Acceptable: it's a trait-bound satisfier, not a live method. | + +**`IsField for SymExt` (BaseType = SymId, dim D3):** identical table, but every +recorded node is tagged D3, and `from_u64(x)` records `Const3([from_u64(x),0,0])` +(matching `Degree3...::from_u64` extensions_goldilocks.rs:399). `one()`→ +`Const3([1,0,0])`, `zero()`→`Const3([0,0,0])`. `inv`/`div` `unimplemented!()`. + +**`IsSubFieldOf for SymField` (traits.rs:17-25):** this is the mixed +base×ext arithmetic surface the field-element operators dispatch through +(element.rs:223,295,346). Each must record the correct mixed node: + +| Method | Symbolic impl | +|---|---| +| `mul(a: &SymId/*D1*/, b: &SymId/*D3*/) -> SymId/*D3*/` | record `Mul(a,b)` tagged D3 (the interpreter sees a D1×D3 mul and does the 3-mul base×ext path) | +| `add(a,b) -> D3` | record `Add(a,b)` D3 | +| `sub(a,b) -> D3` | record `Sub(a,b)` D3 | +| `div(a,b)` | `unimplemented!()` (not reached) | +| `embed(a: SymId/*D1*/) -> SymId/*D3*/` | record `Embed(a)` → D3 | +| `to_subfield_vec(b)` | `unimplemented!()` (not reached; only serialization uses it) | + +Note the blanket `impl IsSubFieldOf for F` (traits.rs:27-60) automatically +gives us `IsSubFieldOf for SymField` and `IsSubFieldOf for +SymExt` (the prover-frame `evaluate` with FF=F and the verifier-frame with +FF=E both rely on these reflexive impls). Those route to `SymField::mul` etc., +so no extra code. + +**`IsFFTField for SymField`?** The `AIR` trait bounds `Field: IsFFTField` +(traits.rs:139) and `AirWithBuses` further bounds `Field: IsPrimeField` +(lookup.rs:805). **But capture does NOT instantiate any `AIR`.** +Capture only calls the *constraint object's* generic `evaluate::(step)` and the LogUp helper fns `::` — those are +bounded only `FF: IsSubFieldOf, EE: IsField` (transition.rs:352-355, +lookup.rs:1759, lookup.rs:1887). So `SymField` needs **only** `IsField + +IsSubFieldOf`, NOT `IsFFTField`/`IsPrimeField`. This is the single most +important feasibility fact: it sidesteps `IsFFTField::{TWO_ADICITY, root, +field_name}` and `IsPrimeField::{canonical, from_hex, field_bit_size}` entirely +(none are reachable from `evaluate`). *Verified:* `evaluate`'s only bound is +`FF: IsSubFieldOf` (transition.rs:354), the LogUp inner fns' +only bound is `A: IsSubFieldOf, B: IsField` (lookup.rs:1759, lookup.rs:1887, +lookup.rs:1679, lookup.rs:1689). Capture never builds the AIR with sym types. + +### Q3 — Constants & `to_extension` / `one()` / `zero()` + +* `FieldElement::::from(i64/u64)` → `From`/`From` (element.rs:136,149) + → `F::from_u64(value)`. For `F = SymField` this records `Const1(c)` with + `c = GoldilocksField::from_u64(value)` (we *fold the real Goldilocks reduction* + at capture time so the literal stored is canonical). `i64` negatives go through + `-Self::from(abs)` (element.rs:157) → records `Neg(Const1(abs))`; or we can + constant-fold to `Const1(p - abs)`. Either is correct; constant-folding negatives + keeps the IR smaller. Examples captured this way: `inv_2_32` (templates.rs:30-36, + a `from(INV_SHIFT_32)`), `SHIFT_16` (cpu.rs:69), `AddLinearTerm` coefficients + `1<<16`, `1<<8`, `1<<24` (templates.rs:266-326), bus `LinearTerm` coefficients + (lookup.rs:656,772). +* `FieldElement::one()`/`zero()` (element.rs:550,556) → `F::one()`/`F::zero()`. + For `SymField` → `Const1(1)`/`Const1(0)`; for `SymExt` → `Const3([1,0,0])`/ + `Const3([0,0,0])`. The literal `FieldElement::::one()` appears all over + (templates.rs:98, cpu.rs:146). +* `.to_extension::()` (element.rs:566) → `>::embed(value)`. + Used by the adapter's verifier path `...evaluate(...).to_extension()` + (transition.rs:431). For `F=SymField, L=SymExt` this records `Embed(child)`. + **However** — see Section 4: in the *prover* capture we run the adapter with + FF=F (base), and in the *verifier* capture we run FF=E (already D3); we will + capture the constraint's **base value** (the `evaluate` result, dim D1) and let + the interpreter/accumulator handle the embed, mirroring how the real prover + keeps base constraints in `base_evals: &mut [FieldElement]` + (evaluator.rs:106-110). So `to_extension` is mostly *not* in the captured graph + for base constraints; it only appears if a constraint body itself calls + `to_extension`, which none of the algebraic ones do (they return D1). + +### Q4 — SYMBOLIC FRAME + +Capture needs a `TableView` (and `Frame`) +whose column reads return `Leaf` nodes. `TableView` is +`{ data: Vec>>, aux_data: Vec>> }` +(table.rs:397-399) and reads go through `get_main_evaluation_element(row, col)` +(table.rs:410) / `get_aux_evaluation_element` (table.rs:414). So we build a +symbolic frame by filling each cell with a `FieldElement::from_raw(SymId)` whose +id is a recorded `Leaf::Main { step, row, col }` / `Leaf::Aux { ... }`: + +```rust +fn symbolic_frame(num_steps, rows_per_step, num_main, num_aux) -> Frame { + let steps = (0..num_steps).map(|step| { + let data = (0..rows_per_step).map(|r| + (0..num_main).map(|c| + FieldElement::::from_raw(record_leaf(Leaf::Main{step,row:r,col:c})) + ).collect()).collect(); + let aux_data = (0..rows_per_step).map(|r| + (0..num_aux).map(|c| + FieldElement::::from_raw(record_leaf(Leaf::Aux{step,row:r,col:c})) + ).collect()).collect(); + TableView::new(data, aux_data) + }).collect(); + Frame::new(steps) +} +``` + +`num_steps` = `offsets.len()` (= 2 for LogUp tables, `transition_offsets: +vec![0,1]` lookup.rs:909). `rows_per_step` = step_size/blowup (1 for these +tables). The two `TransitionEvaluationContext` variants needed for capture +(Q5/Q6): a `Prover { frame: &Frame, periodic_values: +&[FieldElement], rap_challenges: &[FieldElement], +logup_alpha_powers, logup_table_offset, packing_shifts: &PackingShifts }`. +Each uniform input (periodic, rap, alpha pow, table offset, shifts) is also a +recorded `Leaf` (`Periodic`, `Rap`, `AlphaPow`, `TableOffset`, `Shift`). The +shift constants `PackingShifts::::new()` (lookup.rs:54) call +`FieldElement::::from(SHIFT_8/16)` and `&shift_8 * &shift_16` — those +record `Const1` + `Mul` automatically; but to keep the IR clean we instead +construct `PackingShifts` with `Leaf::Shift{0/1/2}` ids so the interpreter +injects the real precomputed constants at eval time (they're loop-invariant and +the existing code precomputes them once, lookup.rs:64). Both are correct; the +`Leaf::Shift` version matches the existing precompute and keeps shifts uniform. + +--- + +## 4. Capturing the algebraic constraints (the ~29 structs, via the adapter) + +The ~29 algebraic constraints implement the user-facing `TransitionConstraint` +trait and are wrapped by `TransitionConstraintAdapter` (transition.rs:393). +Their bodies are generic `evaluate(&self, step: &TableView) -> +FieldElement` (transition.rs:352). **We do NOT touch any body.** Capture +calls each constraint's `evaluate::(sym_step)` directly and +reads the returned `SymId` (the root for that constraint). + +**Count, verified by grep: there are 33 (not ~29) algebraic +`impl TransitionConstraint` structs**, not +just the CPU ones the team-lead memo listed. Beyond templates.rs/cpu.rs they +span prover/src/tables/: branch.rs:519, commit.rs:837, cpu32.rs:645, +dvrm.rs:1219, ec_scalar.rs:291, ecdas.rs:{363,402,426}, ecsm.rs:{663,698,791,816}, +eq.rs:262, keccak.rs:503, load.rs:572, lt.rs:536, memw_aligned.rs:708, +memw_register.rs:388, memw.rs:921, mul.rs:847, shift.rs:914, store.rs:282 +(plus the 11 in templates.rs/cpu.rs). The "zero body edits" win therefore +applies to **all 33**, including the large ones (keccak, ecsm, dvrm, mul) — a +bigger payoff than the memo implied, but those large bodies also drive risk 5/6 +(node count / GPU scratch). + +These constraints all return a base-field (`FF=F`) value, so we capture them as +**dim-D1 roots** placed in `roots[0..num_base]`, matching the prover's base +split (evaluator.rs:50, evaluator.rs:106). **Safe-op audit (first-hand + grep, +load-bearing for feasibility):** every body uses only `clone`, `+`, `-`, `*`, +`neg` (via `-x`), `FieldElement::from`, `one()`, `get_main_evaluation_element` — +e.g. `IsBitConstraint` (templates.rs:92-107: `&cond * &x * (one - x)`), +`AddConstraint` (templates.rs:442-467 + the carry helpers, which multiply by the +constant `inv_2_32`), `ProductZeroConstraint` (cpu.rs:105-112), `Arg2Constraint` +(cpu.rs:277-303), `BranchRvdConstraint`/`NextPcAddConstraint` +(cpu.rs:394-446, cpu.rs:518-571). Crucially, a grep over the **entire** +`prover/src` (non-test) finds **zero** `.inv()`/`.pow()`/`.div()`/`.sqrt()`/ +`.legendre_symbol()` calls and **zero** field-value conditionals (`== FieldElement`, +`.is_zero()`, `if …value()…`) across all 17 table files — so no body, and no +helper any body transitively calls, performs division/inversion/exponentiation +or branches on a field value. (The per-struct degree + body summary from the +enumeration sub-agent is appended at the end.) + +**Framework glue changes** (minimal, additive): + +1. New trait method on `TransitionConstraint` with a default that **panics**, and + override it for the adapter is *not* the route — instead add a free function + `capture_user_constraint>(c: &T, step: &TableView) -> SymId` + that just calls `c.evaluate::(step)`. Because the adapter + stores `T` (transition.rs:393, `TransitionConstraintAdapter(pub T)`), but + the AIR only keeps `Box` (lookup.rs:813), + we cannot recover `T` from the boxed object. **Therefore capture must run at + the point where concrete constraint types still exist — i.e. inside each + table's constraint-list builder** (e.g. `create_all_cpu_constraints` + cpu.rs:619), *before* `.boxed()`. See Section 9. + +2. Add a capture entry point to the `TransitionConstraintEvaluator` trait: + `fn capture(&self, ctx: &SymCaptureCtx) -> SymId;` with a default that calls + `evaluate_verifier` against a symbolic context... **but** `evaluate_verifier` + needs `&mut [FieldElement]` slots, and for the adapter it calls + `self.0.evaluate(...).to_extension()` (transition.rs:431). Running *that* + under sym types records the constraint plus a trailing `Embed`, giving a D3 + root. That is acceptable for capture purposes (the embed is a no-op cost on + D1→D3 and the accumulator can treat the root as D3). **This is the cleaner, + object-safe route:** add `fn capture(&self, ctx, &mut [SymId])` to + `TransitionConstraintEvaluator`, default-implemented by calling a sym version + of `evaluate_verifier`. The adapter's `capture` runs + `self.0.evaluate::(frame.step).to_extension()` → D3 root, OR + (better, to keep base/ext split) runs `evaluate` and stores the D1 root for + `idx < num_base`. We implement the latter: `capture` mirrors `evaluate_prover` + (transition.rs:439) — D1 root into base slot for base constraints, D3 for the + LogUp ones. This keeps the captured program's `num_base` aligned with + `air.num_base_transition_constraints()` (lookup.rs:1025). + +The recommended concrete design: add **one** method +`fn capture(&self, ctx: &SymCaptureContext, base_roots: &mut Vec, +ext_roots: &mut Vec)` to `TransitionConstraintEvaluator` +(crypto/stark/src/constraints/transition.rs). Default impl: run the verifier- +style body symbolically and push a D3 root. Adapter override +(transition.rs:395): run `self.0.evaluate::` and push a D1 root +when `idx < base_roots.capacity-marker`, else D3. The two LogUp framework +structs override `capture` to run their `evaluate_*_constraint` inner fns under +sym types (Section 5). + +--- + +## 5. Capturing the LogUp / extension framework constraints (Q5 — the crux) + +The two LogUp constraints do **not** go through the adapter; they +`impl TransitionConstraintEvaluator` directly and branch on the +`TransitionEvaluationContext` enum (lookup.rs:1741, lookup.rs:1868). The decisive +question: **are their helpers generic enough to run under SymField/SymExt?** + +**Verdict: YES — they are fully capturable, no hand-emit needed.** Evidence: + +* `compute_multiplicity_from_step, B: IsField>` (lookup.rs:1679) + — generic; body is `multiplicity.evaluate_with(|col| step.get_main_evaluation_element(0,col).clone())` + → `Multiplicity::evaluate_with` (lookup.rs:1252) uses only `one()`, `+`, `-`, + `*`, `FieldElement::from(coeff)`. All recordable. ✓ +* `compute_fingerprint_from_step, B: IsField>` (lookup.rs:1689) + — generic; body builds `FieldElement::::from(bus_id)` then loops + `bv.accumulate_fingerprint_from_step(...)` (lookup.rs:738) which uses + `get_main_evaluation_element`, `Packing::accumulate_fingerprint_with` + (lookup.rs:272: only `+`,`*`, shift consts), and `z - &linear_combination`. + All recordable. ✓ +* `evaluate_batched_term_constraint, B: IsField>` + (lookup.rs:1759) — generic inner fn; computes `c * fp_a * fp_b - term_a - + term_b`. ✓ +* `evaluate_accumulated_constraint, B: IsField>` + (lookup.rs:1887) — generic; `delta * f - m*sign` etc. ✓ + +**The two sign/branch subtleties, and why they're still capturable:** + +1. **`is_sender` sign logic** (lookup.rs:1780-1790, lookup.rs:1927-1932, + lookup.rs:1954-1956): these are `if interaction.is_sender { term } else { + -term }` — branching on a **compile-time-known `bool` field of the + interaction struct**, NOT on a field *value*. During capture `is_sender` is a + concrete `bool`, so the branch is resolved at capture time and we record + either `term` or `Neg(term)`. ✓ No data dependence. + +2. **`result != FieldElement::::zero()` short-circuit** in + `accumulate_fingerprint_from_step` (lookup.rs:790) and the column-major + variant (lookup.rs:675): this branches on a *field value* via `PartialEq` → + `F::eq`. For `SymField` we make `eq` return **`false` always** (Q2), so the + capture path *always records the multiply* (`*acc += result * alpha_powers[..]`). + This is **correct and conservative**: the skip is a runtime optimization for + rows where the value happens to be zero; the IR must be valid for *all* rows, + so it must include the multiply unconditionally. The interpreter then always + does the multiply — slightly more work than the optimized CPU path on + all-zero rows, but bit-identical results. ✓ (If we wanted to preserve the + optimization we could detect "operand is a `Const1(0)` node" at capture time + and constant-fold, recovering the bus-id-padding skip statically. Recommended + as a cheap IR peephole.) + +**Building the capture context.** We construct +`TransitionEvaluationContext::Prover { frame: &Frame, +rap_challenges: &[FieldElement], logup_alpha_powers: +&[FieldElement], logup_table_offset: &FieldElement, +packing_shifts: &PackingShifts, periodic_values: +&[FieldElement] }` (the enum at traits.rs:77-84). Every slice element +is a `Leaf` node (`Rap{idx}`, `AlphaPow{idx}`, `TableOffset`, `Shift{}`). The +frame has 2 steps (acc uses `frame.get_evaluation_step(0)` and `(1)`, +lookup.rs:1972-1973). We call the constraint's `evaluate_verifier` (or the new +`capture`) with this Prover context; the matching `match` arm +(lookup.rs:1794, lookup.rs:1963) fires the generic inner fn under sym types and +returns a D3 root. **No fallback hand-emit is required** — this is the key win +over a hand-written LogUp IR. + +One caveat to call out: `evaluate_verifier` writes into `transition_evaluations: +&mut [FieldElement]` (lookup.rs:1827). Under sym types `E=SymExt`, so the +result is a `FieldElement` whose value is the root `SymId` — we read it +back from the slot. The slice must be pre-filled with a sentinel; we size it to +`num_transition_constraints` and read `slot[constraint_idx]` after the call. ✓ + +--- + +## 6. CPU interpreter + +New module `crypto/stark/src/symbolic/interp.rs`. Two entry points, one shared +core. + +**Core:** `fn eval_program(prog: &ConstraintProgram, inputs: &Inputs, out: &mut Outputs)` +walks `prog.nodes` in id order, computing each node into a value array. Because +nodes are topologically ordered (id i references < i) a single forward pass with +a `Vec` (len = nodes.len()) suffices; `Value` is an enum +`{ D1(FieldElement), D3(FieldElement) }` with auto-embed on mixed ops. +`inputs` resolves `Leaf`s: `Main`/`Aux` from the current frame step/row/col, +`Periodic/Rap/AlphaPow/TableOffset/Shift` from the per-proof uniform arrays +(Section 8). Final: `out.base[c] = values[roots[c]]` for `c` (verifier.rs:198, `into_frame`) so *all* +reads are D3; we run `eval_program` with an `Inputs` whose `Main` leaves resolve +to the OOD frame's D3 cells (interpreter reads them as D3 directly — the program +is the same, only the leaf-resolution dimension differs). Output is the +`Vec>` consumed by the zerofier fold (verifier.rs:218-225), +untouched. + +**Base/ext handling.** The interpreter must do D1×D3 the cheap way (3 muls, +matching `IsSubFieldOf::mul` extensions_goldilocks.rs:413) and D3×D3 via +`dot_product_3` (one `Degree3...::mul` extensions_goldilocks.rs:297). We reuse +the real `FieldElement` / `FieldElement` arithmetic +inside `Value`, so the interpreter's per-node cost equals the boxed path's — the +IR overhead is just the opcode dispatch (a `match` per node), which is cheap +relative to a Goldilocks mul. For the prover the program is run with +`F=GoldilocksField, E=Degree3...`; for the verifier with `F=E=Degree3...`. + +--- + +## 7. GPU interpreter sketch + +One universal Goldilocks kernel, modeled on the gpu_lde TypeId+transmute seam. + +**Host seam** (`crypto/stark/src/symbolic/gpu_interp.rs`): a +`try_eval_program_gpu(prog, lde_trace, uniforms, out) -> Option<()>` that, +exactly like check_base_layout (gpu_lde.rs:106) / the barycentric dispatchers +(gpu_lde.rs:811), gates on `TypeId::of::() == GoldilocksField` and +`TypeId::of::() == Degree3...` (gpu_lde.rs:826-831), a size threshold, and a +device-resident main/aux LDE handle (`lde_trace.gpu_main()`/`gpu_aux()`, +gpu_lde.rs:832,915). On mismatch → `None` → CPU interpreter fallback. The +program's three POD buffers (`ops: &[u32]`, `consts: &[u64]`, `roots: &[u32]`, +Section 2) plus the uniform arrays (rap challenges, alpha powers, table offset, +periodic columns, shift consts — all reinterpreted to `&[u64]` via the same +`#[repr(transparent)]` pattern as weights_to_u64 gpu_lde.rs:196) are H2D-copied +**once** (they're constant across all LDE rows). The columns are already on +device from the R1 LDE keep-handles (gpu_lde.rs:459, `GpuLdeBase`/`GpuLdeExt3`). + +**Device kernel** (new file under crypto/math-cuda/src/, e.g. +`symbolic_interp.cu` + a `math_cuda::symbolic` Rust wrapper): one thread per LDE +row. Each thread allocates a small per-node scratch in registers/shared/local +memory (`nodes.len()` Goldilocks values — programs are small, ~hundreds of nodes +per table) and runs the same forward pass as the CPU core, using the existing +math-cuda Goldilocks device primitives: base mul/add/sub (the same reduce128 +identities as goldilocks.rs:197), and ext3 mul as device `dot_product_3` +(mirroring goldilocks.rs:290). The kernel writes `out[row*num_constraints + c]` +for each root. **What crosses the host/device boundary:** program buffers + uniforms +(small, once); columns (already resident); output = `num_constraints × lde_size` +ext3 values (or, with the base/ext split, `num_base × lde_size` base + the rest +ext3) — D2H once. The accumulation `Σ βᵢ·Cᵢ·Zᵢ⁻¹` can stay on host (cheap) or be +fused into the kernel later; for v1 keep it on host to minimize surface, matching +how `apply_ext3_scalar` post-processes on host (gpu_lde.rs:694). + +The single-field design means **one kernel** handles every table — the per-table +difference is entirely in the data buffers (`ops/consts/roots`), so there is no +per-table CUDA codegen. This is the whole point of the interpreter approach. + +--- + +## 8. Inputs plumbing (Q6) + +Periodic values, rap_challenges, logup_alpha_powers, logup_table_offset, and +packing_shifts vary **per proof** but are **constant across all rows** of one +table's evaluation. They are already computed once per `evaluate_transitions` +call: `logup_alpha_powers` (evaluator.rs:53), `packing_shifts` (evaluator.rs:64), +`rap_challenges` (passed in), `logup_table_offset` (evaluator.rs:47), +`lde_periodic_columns` (evaluator.rs:251 — note periodic is **per-row**, indexed +by `col[i]`, so it is a row-varying leaf resolved like a column). They become IR +**leaf inputs** with these resolutions in the interpreter's `Inputs`: + +| Leaf | CPU resolution | GPU resolution | +|---|---|---| +| `Main{step,row,col}` | `frame.get_evaluation_step(step).get_main_evaluation_element(row,col)` | device LDE main column, strided by step·lde_step_size (frame.fill_from_lde logic, frame.rs:117) | +| `Aux{...}` | `...get_aux_evaluation_element` | device LDE aux column | +| `Periodic{idx}` | `periodic_buf[idx]` (= `lde_periodic_columns[idx][i]`) | device periodic column | +| `Rap{idx}` | `rap_challenges[idx]` | uniform buffer slot | +| `AlphaPow{idx}` | `logup_alpha_powers[idx]` | uniform buffer slot | +| `TableOffset` | `logup_table_offset` | uniform buffer slot | +| `Shift{which}` | `packing_shifts.{shift_8,16,24}` | uniform buffer slot | + +At capture time, the leaf *indices* (which rap challenge, which alpha power) are +fixed by how the constraint reads them (`rap_challenges[0]` = z, lookup.rs:1769; +`alpha_powers[alpha_offset]` walked in packing, lookup.rs:294). So the program +encodes the exact indices; the interpreter just gathers from the per-proof arrays. +The arrays' lengths are known at eval time (`max_bus_elements` → +`compute_alpha_powers` count, evaluator.rs:55). No re-capture per proof. + +--- + +## 9. Coexistence & object-safety + +* **Where capture runs.** Because the AIR only stores `Box` (lookup.rs:813) and the adapter erases the + concrete `T` (transition.rs:393), the cleanest object-safe route is to add a + `capture` method to the **`TransitionConstraintEvaluator` trait** (which the + boxed objects *do* expose). The adapter's `capture` (transition.rs:395) calls + `self.0.evaluate::` — concrete `T` is in scope there. The two + LogUp structs override `capture` to run their generic inner fns. Then a single + pass over `air.transition_constraints()` (the existing `Vec>`, + traits.rs:314) captures the whole program. This means **the AIR builds its + `ConstraintProgram` once in a new default method** + `AIR::constraint_program(&self) -> ConstraintProgram` that iterates the boxed + constraints and calls `capture` on each within a `with_arena` scope. No table + builder needs editing. + +* **`capture` and object safety.** Adding `fn capture(&self, ctx: + &SymCaptureContext, base: &mut Vec, ext: &mut Vec)` to the + trait keeps it object-safe (no generics in the method signature; `SymField`/ + `SymExt` are concrete). The default impl runs the verifier-shaped body + symbolically. ✓ + +* **Generic boxed path retained as fallback.** `compute_transition_prover` + (traits.rs:254) and `compute_transition` (traits.rs:223) stay. A feature flag + `symbolic-interp` (or a runtime toggle) selects, inside + `evaluate_transitions` (evaluator.rs:100) and the verifier (verifier.rs:209), + whether to call the IR interpreter or the boxed path. Default off until the + differential test (Section 12) is green; then default on. + +* **TypeId gating for GPU.** The GPU interpreter only engages for the real + `GoldilocksField`/`Degree3...` instantiation (Section 7), identical to + gpu_lde.rs:119-152. For any other field the host code transparently uses the + CPU interpreter or the boxed path. + +* **Cache the program.** `ConstraintProgram` is built once per AIR and stored in + the AIR (or in `ConstraintEvaluator::new`, evaluator.rs:188, alongside + `boundary_constraints`). It is immutable and `Send + Sync` (POD), so it's + shared across all Rayon workers and reused across proofs of the same table + shape. + +--- + +## 10. Exhaustive file-by-file change list + +**New files:** + +* `crypto/stark/src/symbolic/mod.rs` — module root, re-exports. +* `crypto/stark/src/symbolic/sym_field.rs` — + `pub struct SymField; pub struct SymExt; #[derive(Clone,Copy,Default,Debug)] + pub struct SymId{id:u32,dim:Dim}`; `impl IsField for SymField/SymExt`; + `impl IsSubFieldOf for SymField`; `impl ByteConversion for SymId` + (unimplemented stubs); the `thread_local! ARENA` + `with_arena` + + `record(Op)->SymId` (hash-consing) + `record_leaf(Leaf)->SymId`. +* `crypto/stark/src/symbolic/ir.rs` — `Dim`, `Leaf`, `Op`, `ConstraintProgram`, + serialization (`to_pod()` → `(Vec, Vec, Vec)`). +* `crypto/stark/src/symbolic/capture.rs` — `SymCaptureContext` + (builds the symbolic `Frame`/`TableView`/uniform leaves, Q4), + `fn capture_program(constraints: &[Box>], + layout, num_base, ...) -> ConstraintProgram`. +* `crypto/stark/src/symbolic/interp.rs` — `Value`, `Inputs`, `Outputs`, + `fn eval_program(prog,&Inputs,&mut Outputs)` (CPU core + prover & verifier + adapters). +* `crypto/stark/src/symbolic/gpu_interp.rs` — `try_eval_program_gpu(...) + -> Option<()>` (TypeId gate + H2D uniforms + kernel launch + D2H), guarded by + the cuda feature. +* `crypto/math-cuda/src/symbolic_interp.rs` (+ `.cu`) — `math_cuda::symbolic:: + eval_program_*` device wrapper and the one universal Goldilocks/ext3 kernel. + +**Modified files:** + +* `crypto/stark/src/constraints/transition.rs` — add + `fn capture(&self, ctx: &SymCaptureContext, base: &mut Vec, ext: &mut + Vec)` to `TransitionConstraintEvaluator` (default impl runs verifier- + shaped body symbolically); override in `TransitionConstraintAdapter` + (transition.rs:395) to run `self.0.evaluate::`. +* `crypto/stark/src/lookup.rs` — override `capture` for + `LookupBatchedTermConstraint` (lookup.rs:1741) and + `LookupAccumulatedConstraint` (lookup.rs:1868) to run their generic inner fns + under sym types. The inner fns are **unchanged** (already generic). +* `crypto/stark/src/traits.rs` — add default method + `fn constraint_program(&self) -> ConstraintProgram` (iterates + `self.transition_constraints()` + `with_arena`). +* `crypto/stark/src/constraints/evaluator.rs` — in `evaluate` (evaluator.rs:216) + build/fetch the cached `ConstraintProgram`; in `evaluate_transitions` + (evaluator.rs:100) replace `air.compute_transition_prover(&ctx, base_buf, + transition_buf)` with `eval_program(...)` (behind the feature/toggle), with the + GPU dispatch tried first (gpu_interp `try_eval_program_gpu`, else CPU). +* `crypto/stark/src/verifier.rs` — at verifier.rs:209 replace + `air.compute_transition(&ctx)` with the verifier interpreter (same toggle). +* `crypto/stark/src/lib.rs` (or `mod.rs`) — `pub mod symbolic;`. +* `crypto/math/src/field/...` — **no change** (SymField lives in the stark + crate; it only needs the public `IsField`/`IsSubFieldOf` traits, which are + already public). If `ByteConversion` for `SymId` must be impl'd where the + trait is defined due to orphan rules, add a thin impl in math; otherwise keep + in stark (SymId is a stark type, ByteConversion is a math trait — the impl is + allowed in stark since SymId is local). ✓ orphan-rule-safe. + +**Key new signatures:** +```rust +impl IsField for SymField { type BaseType = SymId; fn mul(a:&SymId,b:&SymId)->SymId {record(Op::Mul(a.id,b.id),Dim::D1)} ... } +impl IsSubFieldOf for SymField { fn mul(a:&SymId,b:&SymId)->SymId {record(Op::Mul(a.id,b.id),Dim::D3)} fn embed(a:SymId)->SymId{record(Op::Embed(a.id),Dim::D3)} ... } +pub fn capture_program(cs: &[Box>], layout:(usize,usize), num_base:usize, offsets:&[usize], step_size:usize) -> ConstraintProgram; +pub fn eval_program(prog:&ConstraintProgram, inp:&Inputs<'_>, out:&mut Outputs<'_>); +pub(crate) fn try_eval_program_gpu(prog:&ConstraintProgram, lde:&LDETraceTable, uni:&Uniforms, out:&mut [FieldElement]) -> Option<()>; +``` + +--- + +## 11. Risks & unknowns, ranked + +1. **IsField-contract friction is LOW — feasibility CONFIRMED.** The decisive + finding: capture never instantiates `AIR`, only calls + `evaluate::` and the LogUp inner fns, whose bounds are only + `IsSubFieldOf + IsField` (transition.rs:354, lookup.rs:1759/1887/1679/1689). + So `SymField` needs **no** `IsFFTField`/`IsPrimeField` — the + `TWO_ADICITY`/root/canonical/from_hex methods are unreachable. The remaining + `IsField` methods that can't be symbolic (`inv`, `div`, `eq`) are either never + reached (`inv`/`div`: no division in any constraint body, verified by grep + + reading templates.rs/cpu.rs/lookup.rs) or handled by a conservative `eq=false` + (the only `eq` use is a runtime zero-skip optimization, lookup.rs:675/790, + which capture must *not* take). `ByteConversion`/`to_subfield_vec` are + bound-satisfier stubs that never run. **Residual risk:** a future constraint + body that calls `.inv()`/`.pow()`/branches on a field value would panic at + capture; mitigate with the loud `unimplemented!()` + a CI check. + +2. **LogUp capturability is HIGH-confidence YES.** The helpers are already + generic over `A: IsSubFieldOf, B` (lookup.rs:1679/1689) and the constraint + inner fns too (lookup.rs:1759/1887); `is_sender` is a compile-time bool, not a + field value (lookup.rs:1780); the only field-value branch is the zero-skip, + handled by `eq=false`. So **no hand-emit of LogUp IR is needed** — this is the + approach's biggest advantage over hand-writing. **Residual risk:** the `eq` + short-circuit means the captured IR always multiplies even by `Const1(0)` + bus-padding; mitigate with a constant-fold peephole (detect `Mul(_,Const(0))`/ + `Add(x,Const(0))` at capture) so the IR matches the optimized path's node + count and the GPU kernel doesn't waste lanes. Low effort, high value. + +3. **Bit-for-bit equivalence of the interpreter vs the boxed path.** The + interpreter reuses the real `FieldElement` arithmetic, so per-op results are + identical; the risk is in *order of operations* (field add/mul are + associative/commutative in value but the existing code's specific fold order + is what the OOD/LDE evaluations must match for the proof to verify). Since we + capture the *exact* call sequence the body executes (recording in evaluation + order), the IR's forward-pass order equals the body's order. **Residual + risk:** the zero-skip fold (lookup.rs:672) changes the *additive grouping* on + zero rows; with `eq=false` we always add, which is value-identical (adding 0). + So equivalence holds. Validate empirically (Section 12). + +4. **Capture-time arena correctness with hash-consing.** A wrong `cse` key + (e.g. not distinguishing D1 vs D3 nodes with the same operands) would alias + nodes of different dimension. Mitigate: include `Dim` in the `cse` key, or + never CSE across dims. Low risk, but must be tested. + +5. **GPU per-thread scratch pressure.** Each thread needs `nodes.len()` + Goldilocks values live. If a table's program is large (hundreds of nodes × + ext3 = hundreds × 24 bytes), register/shared pressure could limit occupancy. + Mitigate: liveness analysis to reuse scratch slots (a node's value is dead + after its last use), or spill to local memory. This is a perf risk, not a + correctness risk, and v1 can keep the accumulation on host. Medium. + +6. **Unknown: exact node counts per table.** Not yet measured, and there are + **33** algebraic constraints across many tables — the largest bodies + (keccak.rs, ecsm.rs, dvrm.rs, mul.rs, commit.rs) are big polynomial circuits + and will dominate node count. ADD/LogUp with hash-consing should be small (low + hundreds), but keccak/ecsm could be thousands of nodes, directly amplifying + risk 5 (GPU per-thread scratch). Resolve by instrumenting `capture_program` + to print `nodes.len()` per table during the differential test, and prioritize + liveness-based scratch reuse for the large tables. + +**Overall feasibility verdict: HIGH.** The SymField approach is sound; the +IsField-contract friction is manageable (the unreachable-method insight is the +crux) and LogUp captures cleanly with zero hand-emit. + +--- + +## 12. Effort estimate & validation strategy + +**Effort (engineer-days, by workstream):** + +* W1 — `SymField`/`SymExt`/`SymId` + arena + hash-consing + IsField/IsSubFieldOf + impls + stubs: **2–3 d**. (Mechanical once the unreachable-method set is fixed.) +* W2 — IR types + serialization + capture context (symbolic frame/uniforms) + + `capture` trait method + adapter/LogUp overrides + `AIR::constraint_program`: + **3–4 d**. (LogUp override is the fiddly part but the inner fns are unchanged.) +* W3 — CPU interpreter (core + prover slot in evaluator.rs + verifier slot) + + feature toggle: **3–4 d**. +* W4 — Differential test harness + peephole constant-fold + fix discrepancies: + **2–3 d**. +* W5 — GPU host seam + universal kernel + math-cuda wrapper + D2H wiring: + **6–10 d** (the largest and riskiest; v1 keeps accumulation on host). + +**Total: ~16–24 engineer-days**, with W1–W4 (~10–14 d) delivering a working, +validated CPU IR interpreter and W5 the GPU kernel. + +**Validation strategy (bit-for-bit, on a real table):** + +1. **Per-row prover diff.** In `evaluate_transitions` (evaluator.rs:79), for each + LDE row run BOTH `air.compute_transition_prover(&ctx, base_a, ext_a)` and + `eval_program(prog, ..., base_b, ext_b)`, and `assert_eq!` the base and ext + buffers element-by-element. Gate behind a `debug-checks`-style feature so it's + on in tests, off in production. Run against the existing test + `cargo test --release -p lambda-vm-prover test_prove_elfs_test_sb_sh_8` + (from project memory) for the CPU table (which exercises ADD/IS_BIT/LogUp). +2. **Per-constraint verifier diff.** At verifier.rs:209 compare + `air.compute_transition(&ctx)` vs the verifier interpreter at the single OOD + point; `assert_eq!` the full `Vec>`. Cheapest oracle (one + point). +3. **End-to-end.** With the interpreter as the live path, run the full + prove→verify test suite; a passing verify is the strongest equivalence check + (the composition poly and FRI depend on every `Cᵢ`). Run across all tables + (CPU, MEMW, LOAD, DECODE, MUL, BRANCH, REGISTER, PAGE, BITWISE, LT, HALT) so + every constraint shape and every `Packing`/`Multiplicity` variant is covered. +4. **GPU diff.** Compare `try_eval_program_gpu` output against the CPU + interpreter output (not the boxed path) element-wise, reusing the + `test-cuda-faults` style harness (gpu_lde.rs:1001) to also exercise the + CPU-fallback path. +5. **Node-count instrumentation.** Log `prog.nodes.len()` per table to size GPU + scratch and confirm hash-consing is effective (risk 4/5/6). + +--- + +## Appendix — full constraint enumeration (verified by reading every body) + +**33 algebraic `impl TransitionConstraint` +structs + 2 framework LogUp `TransitionConstraintEvaluator` structs.** Every one +uses ONLY capturable ops: field `+ - *` and negation, `FieldElement::from(u64/ +i64)`, `one()`/`zero()`, `.clone()`, `get_main_evaluation_element`/ +`get_aux_evaluation_element`, and `to_extension()`. **Zero** uses of `.inv()`, +`.pow()`, `/`, `.sqrt()`, field-value conditionals, or data-dependent loops. +Every conditional branches on **metadata** (`carry_idx`, `is_sender`, kind +enums), never on a field value. Helper fns (`carry_chain`, +`compute_multiplicity_from_step`, `compute_fingerprint_from_step`) contain only +statically-bounded loops. → For SymField this confirms the `IsField` impl needs +only add/sub/mul/neg/from(u64)/one/zero (+ `to_extension`/`embed`) to be +functional; `inv`/`pow`/`div`/real-`ByteConversion` can be `unimplemented!()` +stubs because no body invokes them. + +Algebraic structs (file:line): +- prover/src/constraints/cpu.rs: ProductZeroConstraint:96, Arg2ExclusiveConstraint:132, + MemFlagsBitConstraint:168, RegNotReadIsZeroConstraint:211, Arg2Constraint:266, + RvdEqResConstraint:331, BranchRvdConstraint:422, BranchCondConstraint:464, + NextPcAddConstraint:546 +- prover/src/constraints/templates.rs: IsBitConstraint:80, AddConstraint:470 + (AddOperand/AddLinearTerm with i64 coeffs → `from(i64)` Const nodes) +- prover/src/tables/: BranchConstraint(branch.rs:519), CommitConstraint(commit.rs:837), + Cpu32Constraint(cpu32.rs:645), DvrmConstraint(dvrm.rs:1219), EqXorConstraint(eq.rs:262), + MulZeroConstraint(ec_scalar.rs:291), ConvCarry(ecdas.rs:363), ColIsZero(ecdas.rs:402), + MulZero(ecdas.rs:426), ConvCarry(ecsm.rs:663), ColIsZero(ecsm.rs:698), + CarryBit(ecsm.rs:791), OverflowRequired(ecsm.rs:816), + KeccakAddressNoOverflowConstraint(keccak.rs:503), LoadConstraint(load.rs:572), + LtConstraint(lt.rs:536), MemwConstraint(memw.rs:921), MemwAlignedConstraint(memw_aligned.rs:708), + MemwRegisterMuSumIsBit(memw_register.rs:388), MulConstraint(mul.rs:847), + ShiftConstraint(shift.rs:914), StoreConstraint(store.rs:282) + +LogUp framework (lookup.rs): LookupBatchedTermConstraint:1741 +(`c·fp_a·fp_b − sign_a·m_a·fp_b − sign_b·m_b·fp_a`, degree 3), +LookupAccumulatedConstraint:1868 (running sum over acc col at row 0 AND row 1, +1–2 absorbed interactions, degree 2–3). + +**Multi-kind dispatch structs — IMPORTANT for the IR.** Several "structs" are +really one type that, via a kind-enum matched on **metadata at capture time**, +implements many distinct constraint kinds; each kind must capture to **its own +IR root** (the capture pass iterates the boxed objects, and each boxed object's +`constraint_idx()` already gives it a distinct root slot, so this falls out +naturally — but the plan's `roots` count is driven by `num_transition_constraints`, +not by the 33 struct count): ShiftConstraint(7 kinds), Cpu32Constraint(8), +LtConstraint(6), LoadConstraint(6), MulConstraint(6), DvrmConstraint(11), +BranchConstraint(5), MemwConstraint(3), MemwAlignedConstraint(3), +StoreConstraint(2). The total transition-constraint *count* (and thus root +count) is therefore well above 33; the IR's `roots` vector is sized by +`air.num_transition_constraints()` (traits.rs:286), which the capture pass +already respects by writing `roots[constraint_idx()]`. diff --git a/thoughts/gpu-constraint-eval/roadmap.md b/thoughts/gpu-constraint-eval/roadmap.md new file mode 100644 index 000000000..a59b6dc0c --- /dev/null +++ b/thoughts/gpu-constraint-eval/roadmap.md @@ -0,0 +1,164 @@ +# GPU constraint evaluation — implementation status & execution plan + +**Handoff doc.** Self-contained enough to continue without the originating discussion. +Describes the code as currently built, the decisions already made, and a detailed +checkbox plan to take it to a working, GPU-validated constraint evaluator. + +> **Chosen capture front-end: Plan B (explicit `IrBuilder` + per-constraint `capture()`).** +> Two spikes were built to compare: Plan A (symbolic field) = PR #737 / branch +> `spike/constraint-ir-symfield`; Plan B (builder) = PR #739 / branch +> `spike/constraint-ir-builder`. Both pass the same bit-for-bit diff test and reuse the +> same IR + interpreter. **Plan B is the production direction** (cleaner end-state — no fake +> `IsField`, no thread-local arena, explicit/auditable). Plan A remains as PR #737 for +> reference / comparison. + +--- + +## Goal & motivation + +Evaluate STARK **transition constraints on the GPU**, end-to-end, producing the +composition-polynomial evaluations **on-device**. The point is **data residency**, not +constraint-eval speed (constraints are not the prover bottleneck): once LDE/Merkle/FRI run +on the GPU, evaluating constraints on the CPU forces a D2H round-trip of the (large) LDE +trace, which dominates. Keeping eval on-device removes that transfer. + +## Architecture (decided) + +Capture each table's constraints **once** into a flat, single-field **Goldilocks IR** +(typed `Dim1`=base `u64` / `Dim3`=degree-3 extension `[u64;3]` op-DAG), then **interpret** +that IR on CPU and GPU. One universal kernel; the per-table difference is data. Modeled on +OpenVM's `cuda-backend` (cloned at `others/openvm-stark-backend`, the closest reference — +same FRI-STARK / LDE-quotient protocol; better-matched than SP1). + +### Decisions already made (don't relitigate without reason) +- **Capture front-end = Plan B (explicit builder).** Each constraint implements an + object-safe `Capture { fn capture(&self, &mut IrBuilder) }`, translating its `evaluate` + body into builder calls (`main`/`aux`/`add`/`sub`/`mul`/`neg`/`const_*`/`emit`). No fake + field, no arena, explicit and auditable. (Plan A — a recording "symbolic field" that + captures with zero body edits — was spiked first to validate the IR/interpreter cheaply; + kept as PR #737 + `plan-symbolic-field.md` for reference. We chose B for the cleaner + production end-state.) +- **Backend = interpreter, not codegen** for v1. Codegen stays available later from the same IR. +- **GPU value array = global memory, no register allocation** to start (simplest, works for + all program sizes). Add register allocation only if profiling needs it (Phase 6). +- **Keep the existing boxed CPU path** as the default + differential oracle behind a toggle + (the `capture()` methods are added alongside `evaluate`, which stays). +- **Device field arithmetic already exists** — reuse `crypto/math-cuda/kernels/ext3.cuh` + (`ext3::{add,sub,mul,mul_base}`, where `mul_base` = base×ext) and `kernels/goldilocks.cuh`. + Do **not** build new field math. + +--- + +## Phase 0 — CPU spikes ✅ DONE (two draft PRs; Plan B is the production base) + +Both spikes build, are fmt/clippy clean, and pass a bit-for-bit diff test (capture → +interpret == real `evaluate`, 1000 random rows) for `IsBit`/`Add`/`ProductZero`. They cover +**base-field algebraic constraints only**, single step (offset 0, row 0), main columns only +— no aux, no next-row, no LogUp, no uniforms, not wired into the prover, no GPU. + +**Shared (identical in both):** the IR and the CPU interpreter. +- `ir.rs` — `enum Dim { D1, D3 }`; `enum Op { Const1(u64), Const3([u64;3]), Var { main: bool, offset: u8, row: u8, col: u16 }, Add(u32,u32), Sub(u32,u32), Mul(u32,u32), Neg(u32), Embed(u32) }`; `struct ConstraintProgram { nodes: Vec, dims: Vec, roots: Vec }`. Typing: `(D1,D1)->D1`, any `D3` operand -> `D3` (auto-embed); `Embed: D1->D3`. +- `interp.rs` — `eval_program_base(prog, main_row) -> FieldElement`: forward pass over nodes into a `Value { D1 | D3 }` array, reusing real `FieldElement` arithmetic; resolves `Var{col}` from the row. + +**Plan B — the production base (PR #739, branch `spike/constraint-ir-builder`).** Module +`crypto/stark/src/constraint_ir/`: +- `ir.rs`, `interp.rs` — the shared IR + interpreter above (reused verbatim). +- `builder.rs` — `IrBuilder` (hash-conses nodes on `(Op, Dim)`, dedups base constants by value, dim-join `(D1,D1)->D1` else `D3`, reserves id 0 = `Const1(0)`) + `Expr { id, dim }`. Methods: `main(offset,col)`/`aux(offset,col)`, `const_base`/`const_signed`/`one`, `add`/`sub`/`mul`/`neg`, `emit(constraint_idx, e)`, `finish() -> ConstraintProgram`. +- `mod.rs` — object-safe `pub trait Capture { fn capture(&self, &mut IrBuilder); }`. +- Constraint impls (added **alongside** the unchanged `evaluate`, non-destructive): `IsBitConstraint`, `AddConstraint` (incl. `AddOperand`/`AddLinearTerm` lo/hi-limb mapping with i64 coeffs + the `inv_2_32` constant), `ProductZeroConstraint`. +- `prover/src/tests/constraint_ir_tests.rs` — the diff test. Node counts: product_zero **4**, is_bit_uncond **5**, is_bit_cond **7**, add_carry_0 **14**, add_carry_1 **21** (minimal — the builder only emits leaves for columns actually read). +- Run: `cargo test -p lambda-vm-prover constraint_ir_tests -- --nocapture` + +**Plan A — reference only (PR #737, branch `spike/constraint-ir-symfield`).** Module +`crypto/stark/src/symbolic/` (`sym_field.rs` recording field + capture). Retired the +"can a symbolic type satisfy `IsField`?" question (yes — needs only `IsField` + +`IsSubFieldOf`; capture never builds `AIR`). Not the production path. + +--- + +## Phase 1 — Full Plan-B capture coverage (all constraints, prover + verifier) + +Goal: implement `Capture` for **every** constraint of a real table (all ~33 algebraic + the +2 LogUp), for both prover and verifier, validated on CPU. The GPU runs this same IR, so +completeness/correctness must be nailed here first. + +- [ ] **Extend the IR** (`constraint_ir/ir.rs`): add leaf `Op` variants for the per-proof/per-row uniforms — `Periodic { idx }` (D1), `RapChallenge { idx }` (D3), `AlphaPow { idx }` (D3), `TableOffset` (D3), `Shift { which: u8 }` (D1). `Op::Var` already carries `offset`/`row`/`main` for next-row + aux reads. +- [ ] **Extend `IrBuilder`** (`constraint_ir/builder.rs`): add leaf constructors for the uniforms (`challenge`, `alpha_power`, `periodic`, `table_offset`, `shift`) and `const_ext([u64;3])`; ensure `aux(offset, col)` supports `offset=1` (next row). Make `emit` index `roots` by `constraint_idx` (the spike stores in emit order — switch to indexed for the full per-table program). +- [ ] **`Capture` for the remaining ~30 algebraic constraints** — mechanical translation of each `evaluate`/`compute` body to builder calls. Files: `prover/src/constraints/cpu.rs` (Arg2Exclusive, MemFlagsBit, RegNotReadIsZero, Arg2, RvdEqRes, BranchRvd, BranchCond, NextPcAdd), `prover/src/tables/{branch,commit,cpu32,dvrm,eq,ec_scalar,ecdas,ecsm,keccak,load,lt,memw,memw_aligned,memw_register,mul,shift,store}.rs`. The multi-kind ones (Dvrm 11 / Cpu32 8 / Shift 7 / Lt·Load·Mul 6) are the bulk — their `compute()` loops are statically bounded, so they unroll into builder calls at capture time. +- [ ] **`Capture` for the 2 LogUp constraints (the crux)** — `LookupBatchedTermConstraint` and `LookupAccumulatedConstraint` (`crypto/stark/src/lookup.rs`). Translate their bodies to builder calls: fingerprint = `challenge − Σ alpha_power·col` (mirror `compute_fingerprint_from_step`/`Packing::accumulate_fingerprint_with`), multiplicity (mirror `Multiplicity::evaluate_with`), `is_sender` as a compile-time `add` vs `neg`, the `c·fp_a·fp_b − …` / accumulated formulas. The accumulated one reads aux at offset 0 **and** 1 → use `aux(1, col)`. This is more work than Plan A's auto-capture (Plan A's inner fns were already field-generic) but is explicit and lives in one place. +- [ ] **Per-table program** (`crypto/stark/src/traits.rs`): `fn constraint_program(&self) -> ConstraintProgram` — iterate `self.transition_constraints()`, call `capture` on each into one `IrBuilder`, `roots[constraint_idx()]`, `num_base = num_base_transition_constraints()`. (Requires the object-safe `Capture` to be reachable from the boxed `TransitionConstraintEvaluator` — add `capture` to that trait, which is object-safe and matches the production design.) +- [ ] **Full interpreter** (`constraint_ir/interp.rs`): generalize to `eval_program(prog, inputs) -> (base: Vec>, ext: Vec>)` matching the `compute_transition_prover` contract — resolve all leaf kinds + offsets + aux; Dim1/Dim3 with auto-embed; add a verifier entry (all-D3 frame at the OOD point). +- [ ] **Acceptance test:** for the CPU table + ≥1 LogUp-heavy table, capture the full program, interpret per-row over a real LDE, and `assert_eq!` against `air.compute_transition_prover(...)` bit-for-bit; same for the verifier vs `air.compute_transition(...)` at the OOD point. + +--- + +## Phase 2 — Wire interpreter into prover/verifier (CPU), behind a toggle + +- [ ] Add a `constraint-ir` Cargo feature (or runtime env toggle) in `crypto/stark/Cargo.toml`. +- [ ] Cache the `ConstraintProgram` once in `ConstraintEvaluator::new` (`crypto/stark/src/constraints/evaluator.rs`). +- [ ] In `evaluate_transitions` (same file), behind the toggle, replace the `air.compute_transition_prover(&ctx, base_buf, transition_buf)` call (~line 100) with the IR interpreter; keep the boxed path as default + oracle. Leave the `Σ βᵢ·Cᵢ·Zᵢ⁻¹ + boundary` accumulation untouched. +- [ ] Verifier: same swap at `crypto/stark/src/verifier.rs` (`air.compute_transition`). +- [ ] **Acceptance:** full prove→verify suite passes with the toggle ON, across all tables — `cargo test --release -p lambda-vm-prover` (incl. `test_prove_elfs_*`). This is the **CPU end-to-end** checkpoint; the IR is now proven complete and correct independent of GPU. + +--- + +## Phase 3 — Device field primitives ✅ ALREADY EXIST + +Reuse `crypto/math-cuda/kernels/ext3.cuh` (`ext3::Fe3`, `ext3::{add,sub,mul,mul_base}`) and +`kernels/goldilocks.cuh`. Already used by the GPU FRI/inverse/barycentric/deep kernels. +Only remaining: confirm a `neg` (else `ext3::sub(zero, x)`) and include the header — do it +as part of Phase 4. + +--- + +## Phase 4 — GPU interpreter kernel + +Start **stripped** (mirror OpenVM's `GLOBAL=true` kernel: global-memory value array, no +register allocation, no bit-packed codec). Reference: `others/openvm-stark-backend/crates/cuda-backend/cuda/src/quotient.cu` (`cukernel_quotient`) and `cuda/include/codec.cuh`. + +- [ ] **Device IR layout** (`crypto/stark/src/constraint_ir/` + `crypto/math-cuda`): serialize `ConstraintProgram` to a `#[repr(C)]` flat node array (`{ op_tag: u32, a: u32, b: u32, dim: u32 }`) + a constants table + `roots` + `num_base`. Plus per-proof uniform device buffers (rap challenges, alpha powers, table_offset, periodic columns, shift consts). +- [ ] **Kernel** (`crypto/math-cuda/kernels/constraint_interp.cu` + Rust wrapper in `crypto/math-cuda/src/`): one thread per LDE row (tiled). Forward pass over the node array into a **per-thread value array in global memory** (one slot per node, strided per thread). Resolve `Var{main/aux, offset, col}` from the device-resident LDE columns (`GpuLdeBase`/`GpuLdeExt3` keep-handles from `trace.gpu_main()`/`gpu_aux()`). Dim1 ops via `goldilocks.cuh`, Dim3 via `ext3.cuh` (`mul_base` for D1×D3). **Fused accumulation:** Horner `acc = acc*alpha + Cᵢ` over the constraint roots, then `acc *= inv_zeroifier[row]` → write the composition-poly evaluation. Output stays on device. +- [ ] **Host dispatch** (`crypto/stark/src/constraint_ir/gpu_interp.rs`): `try_eval_program_gpu(...) -> Option<...>` gated on `TypeId::of::() == GoldilocksField && TypeId::of::() == Degree3GoldilocksExtensionField` + a size threshold (mirror `crypto/stark/src/gpu_lde.rs:119-152`). Upload program + uniforms once; launch; leave output device-resident. Fall back to the CPU interpreter / boxed path otherwise. +- [ ] **Pipeline integration:** add a whole-table GPU entry (e.g. `AIR::compute_transitions_batched(lde) -> Option` tried by `evaluate_transitions` before the per-row loop) so the composition-poly evals are produced on-device and feed the existing GPU Merkle commit with **no D2H of the `Cᵢ` matrix**. Reconcile zerofier/boundary accounting with the CPU semantics. +- [ ] **Acceptance:** compiles under `cargo build -p lambda-vm-prover --features cuda`. + +--- + +## Phase 5 — "Working on GPU" (the deliverable) — runs on the CUDA machine + +- [ ] **GPU↔CPU parity test** (extend `prover/tests/cuda_path_integration.rs` / `cuda_fallback_tests.rs`): composition-poly evals on GPU == CPU interpreter == boxed path, per table, on real traces. +- [ ] **End-to-end GPU prove→verify** on a real ELF with `--features cuda`. A passing verify is the goal. +- [ ] **Benchmark** (bench server): prove time with GPU constraints vs CPU constraints — confirm the data-residency win (no LDE D2H for constraint eval). + +--- + +## Phase 6 — Optimizations (only if a profile demands) + +- [ ] **Register allocation** — port OpenVM's transpiler liveness + linear-scan (`others/openvm-stark-backend/crates/cuda-backend/src/transpiler/mod.rs`) to shrink the per-thread value array (local `FpExt[N]` for small programs, smaller global buffer for large ones like Dvrm/Shift/ecsm). +- [ ] **DCE / const-fold peephole** — fold `×Const(0)`/`+Const(0)`; drop dead nodes. +- [ ] **Bit-packed codec** — only if H2D bandwidth shows up (unlikely; the rule stream is tiny and uploaded once). +- [ ] **Selective codegen** — given few-but-large tables, codegen the 1–3 hottest tables (nvcc does register allocation, no per-op dispatch) if interpreter overhead is material. Hybrid: interpreter baseline + codegen the hot ones. + +--- + +## Gotchas / invariants + +- **Single field:** Goldilocks base + degree-3 extension only. The IR's `Dim1`/`Dim3` and + the `ext3.cuh` primitives cover everything. +- **Object safety:** generic methods can't live on `Box`. + Plan B's `Capture` trait is **non-generic** (concrete `IrBuilder`), so it's object-safe; + capture runs once at setup, the per-row hot path only interprets the (data) IR. +- **Don't D2H the `Cᵢ` matrix:** fuse the accumulation in the GPU kernel so only the + (small) composition-poly evaluation crosses on-device into Merkle. +- **LDE columns are already device-resident** (`GpuLdeBase`/`GpuLdeExt3`); read them in place. +- *(Plan A only, not B:)* the symbolic-field path needed `eq → false` to defeat the runtime + zero-skip during capture. Plan B has no such hack — it emits exactly what `capture` writes. + +## Reference material (in-repo) + +- `others/openvm-stark-backend/crates/cuda-backend/` — `src/transpiler/{mod.rs,codec.rs}`, `src/quotient/`, `cuda/src/quotient.cu`, `cuda/include/codec.cuh`. The closest working reference (BabyBear; for Goldilocks the only deltas are 64-bit constants needing a side table, degree-3 ext, and they run all-FpExt with no base/ext split). +- `crypto/stark/src/gpu_lde.rs` — the TypeId+transmute generic→concrete-Goldilocks GPU seam to mirror. +- `thoughts/gpu-constraint-eval/plan-builder-rewrite.md` — the full Plan B design (the chosen approach; Phases 1+ detail its remaining sections). +- `thoughts/gpu-constraint-eval/plan-symbolic-field.md` — Plan A (the reference/comparison spike, PR #737). +- `thoughts/gpu-constraint-eval/README.md` — motivation + the SP1/OpenVM/zisk survey. +- PRs: **#739** (Plan B, production base) · **#737** (Plan A, reference).