diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index bd0852bb4..eed0e512a 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -1,11 +1,10 @@ use std::marker::PhantomData; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[cfg(feature = "instruments")] use std::time::{Duration, Instant}; use crypto::fiat_shamir::is_transcript::IsStarkTranscript; use math::fft::bit_reversing::{in_place_bit_reverse_permute, reverse_index}; -#[cfg(any(test, feature = "test-utils", feature = "debug-checks"))] use math::fft::bowers_fft::LayerTwiddles; use math::fft::errors::FFTError; use math::fft::two_half_fft::TwoHalfTwiddles; @@ -292,11 +291,55 @@ pub(crate) struct LdeTwiddles { two_half_inv: TwoHalfTwiddles, two_half_fwd: TwoHalfTwiddles, coset_weights: Vec>, + /// Composition half-extension cache, initialized only when the degree-2 + /// decomposition path actually runs on CPU. + composition: OnceLock>, +} + +pub(crate) struct CompositionLdeTwiddles { + /// Inverse twiddles for the g²-coset halves of size `lde_size/2`. + inv: LayerTwiddles, + /// Forward twiddles for the full g-coset of size `lde_size`. + fwd: LayerTwiddles, + /// Weights `g⁻ʲ/(lde_size/2)` for the composition half-extension. + weights: Vec>, +} + +impl CompositionLdeTwiddles { + fn new(half_size: usize, offset: &FieldElement) -> Self { + // Composition half-extension weights: g⁻ʲ / half_size. The constraint- + // quotient halves live on the g²-coset of size `half_size`; the unnormalized + // iFFT yields `n·cⱼ·(g²)ʲ` and these weights turn that into `cⱼ·gʲ` for the + // forward FFT onto the g-coset. + let half_size_fe = FieldElement::::from(half_size as u64); + let inv_half_size_offset = (&half_size_fe * offset) + .inv() + .expect("half_size and coset offset are non-zero"); + let half_size_inv = offset * &inv_half_size_offset; + let offset_inv = &half_size_fe * &inv_half_size_offset; + let weights = { + let mut w = Vec::with_capacity(half_size); + let mut cur = half_size_inv; + for _ in 0..half_size { + w.push(cur.clone()); + cur = &cur * &offset_inv; + } + w + }; + + Self { + inv: LayerTwiddles::::new_inverse(half_size.trailing_zeros() as u64) + .expect("valid composition inverse twiddles"), + fwd: LayerTwiddles::::new((half_size * 2).trailing_zeros() as u64) + .expect("valid composition forward twiddles"), + weights, + } + } } impl LdeTwiddles { /// Construct twiddles and coset weights for a domain of the given size and blowup factor. - fn new(domain: &Domain) -> Self { + pub(crate) fn new(domain: &Domain) -> Self { let domain_size = domain.interpolation_domain_size; let lde_size = domain_size * domain.blowup_factor; @@ -326,8 +369,22 @@ impl LdeTwiddles { two_half_fwd: TwoHalfTwiddles::::new(lde_size.trailing_zeros() as usize, false) .expect("valid forward two-half twiddles"), coset_weights, + composition: OnceLock::new(), } } + + fn composition(&self, domain: &Domain) -> &CompositionLdeTwiddles { + let lde_size = domain.interpolation_domain_size * domain.blowup_factor; + let half_size = lde_size / 2; + debug_assert_eq!(self.coset_weights.len(), domain.interpolation_domain_size); + self.composition + .get_or_init(|| CompositionLdeTwiddles::new(half_size, &domain.coset_offset)) + } + + #[cfg(test)] + pub(crate) fn has_composition_cache(&self) -> bool { + self.composition.get().is_some() + } } /// Number of tables to process concurrently in `multi_prove`. @@ -1120,6 +1177,7 @@ pub trait IsStarkProver< fn decompose_and_extend_d2( constraint_evaluations: &[FieldElement], domain: &Domain, + twiddles: &LdeTwiddles, ) -> Vec>> where FieldElement: AsBytes + Sync + Send, @@ -1150,9 +1208,8 @@ pub trait IsStarkProver< (&two_inv * &sum, &inv_2x[i] * &diff) }); - // Step 3: Extend each part from N evals on g²-coset to 2N evals on g-coset. - // The squared coset offset is g² (= coset_offset²). - let coset_offset_squared = &domain.coset_offset * &domain.coset_offset; + // Step 3: Extend each part from n evals on the g²-coset to 2n evals on the + // g-coset (the full LDE domain). // GPU fast path: batch both halves into one ext3 LDE call. Requires // `cuda` feature and a qualifying size. Falls through to CPU when not. @@ -1163,36 +1220,38 @@ pub trait IsStarkProver< return vec![lde_h0, lde_h1]; } + let composition_twiddles = twiddles.composition(domain); let (lde_h0, lde_h1) = crate::par::join( - || Self::extend_half_to_lde(&h0_evals, &coset_offset_squared, domain), - || Self::extend_half_to_lde(&h1_evals, &coset_offset_squared, domain), + || Self::extend_half_to_lde(&h0_evals, composition_twiddles), + || Self::extend_half_to_lde(&h1_evals, composition_twiddles), ); vec![lde_h0, lde_h1] } - /// Given N evaluations of a degree-], - squared_offset: &FieldElement, - domain: &Domain, + twiddles: &CompositionLdeTwiddles, ) -> Vec> where FieldElement: AsBytes, FieldElement: AsBytes, { - // iFFT on the N-point squared coset to get coefficients - let poly = Polynomial::interpolate_offset_fft(half_evals, squared_offset) - .expect("iFFT should succeed"); - // Evaluate on the full LDE domain (2N points on the g-coset) - evaluate_polynomial_on_lde_domain( - &poly, - domain.blowup_factor, - domain.interpolation_domain_size, - &domain.coset_offset, + debug_assert_eq!(half_evals.len(), twiddles.weights.len()); + Polynomial::coset_lde_full::( + half_evals, + 2, + &twiddles.weights, + &twiddles.inv, + &twiddles.fwd, ) - .expect("LDE evaluation should succeed") + .expect("coset extension") } /// Returns the result of the second round of the STARK Prove protocol. @@ -1200,6 +1259,7 @@ pub trait IsStarkProver< air: &dyn AIR, pub_inputs: &PI, domain: &Domain, + twiddles: &LdeTwiddles, round_1_result: &Round1, transition_coefficients: &[FieldElement], boundary_coefficients: &[FieldElement], @@ -1242,7 +1302,7 @@ pub trait IsStarkProver< // H₀(x²) = (H(x) + H(-x)) / 2 // H₁(x²) = (H(x) - H(-x)) / (2x) // On the LDE coset {g·ω^i}, we have -g·ω^i = g·ω^{i+N} since ω^N = -1. - Self::decompose_and_extend_d2(&constraint_evaluations, domain) + Self::decompose_and_extend_d2(&constraint_evaluations, domain, twiddles) } else if number_of_parts == 1 { // Degree bound equals trace length: constraint evals are the LDE directly. vec![constraint_evaluations] @@ -2373,6 +2433,7 @@ pub trait IsStarkProver< &round_1_result, table_transcript, domain, + &twiddle_caches[idx], )?; #[cfg(feature = "instruments")] @@ -2460,6 +2521,7 @@ pub trait IsStarkProver< round_1_result: &Round1, transcript: &mut (impl IsStarkTranscript + Clone), domain: &Domain, + twiddles: &LdeTwiddles, ) -> Result, ProvingError> where FieldElement: AsBytes, @@ -2500,6 +2562,7 @@ pub trait IsStarkProver< air, pub_inputs, domain, + twiddles, round_1_result, &transition_coefficients, &boundary_coefficients, diff --git a/crypto/stark/src/tests/prover_tests.rs b/crypto/stark/src/tests/prover_tests.rs index 7c8972eeb..318dacb81 100644 --- a/crypto/stark/src/tests/prover_tests.rs +++ b/crypto/stark/src/tests/prover_tests.rs @@ -7,7 +7,7 @@ use crate::{ simple_fibonacci::{self, FibonacciAIR, FibonacciPublicInputs}, }, proof::options::ProofOptions, - prover::{IsStarkProver, Prover, evaluate_polynomial_on_lde_domain}, + prover::{IsStarkProver, LdeTwiddles, Prover, evaluate_polynomial_on_lde_domain}, test_utils::multi_prove_ram, tests::domain_cache_stats, tests::trace_test_helpers::get_trace_evaluations, @@ -22,6 +22,42 @@ use math::{ type Felt = FieldElement; +/// The fused composition half-extension (`extend_half_to_lde`) must produce exactly +/// the same g-coset evaluations as the reference it replaces: iFFT on the g²-coset → +/// coefficients → evaluate on the g-coset LDE. Both yield the unique degree-` = (0..n).map(|i| Felt::from((i as u64) * 7 + 1)).collect(); + + // Reference: iFFT(g²) → coeffs → evaluate on the g-coset of size 2n. + let poly = Polynomial::interpolate_offset_fft(&half, &g2).unwrap(); + let reference = evaluate_polynomial_on_lde_domain(&poly, 2, n, &g).unwrap(); + + // Fused: coset_lde_full with weights wⱼ = g⁻ʲ / n. + let n_inv = Felt::from(n as u64).inv().unwrap(); + let g_inv = g.inv().unwrap(); + let mut weights = Vec::with_capacity(n); + let mut w = n_inv; + for _ in 0..n { + weights.push(w); + w = &w * &g_inv; + } + let inv = LayerTwiddles::::new_inverse(n.trailing_zeros() as u64).unwrap(); + let fwd = LayerTwiddles::::new((2 * n).trailing_zeros() as u64).unwrap(); + let fused = Polynomial::coset_lde_full::(&half, 2, &weights, &inv, &fwd).unwrap(); + + assert_eq!(reference, fused, "mismatch at n={n}"); + } +} + #[test] fn test_domain_constructor() { let trace = simple_fibonacci::fibonacci_trace([Felt::from(1), Felt::from(1)], 8); @@ -232,10 +268,15 @@ fn test_decompose_and_extend_d2_matches_original() { .collect(); // --- New path: algebraic decomposition --- + let twiddles = LdeTwiddles::new(&domain); + assert!(!twiddles.has_composition_cache()); let new_result = Prover::::decompose_and_extend_d2( &constraint_evaluations, &domain, + &twiddles, ); + #[cfg(not(feature = "cuda"))] + assert!(twiddles.has_composition_cache()); assert_eq!(new_result.len(), 2); assert_eq!(new_result[0].len(), original[0].len());