Skip to content
111 changes: 87 additions & 24 deletions crypto/stark/src/prover.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::marker::PhantomData;
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
#[cfg(feature = "instruments")]
use std::time::{Duration, Instant};

use crypto::fiat_shamir::is_transcript::IsStarkTranscript;
use math::fft::bit_reversing::{in_place_bit_reverse_permute, reverse_index};
#[cfg(any(test, feature = "test-utils", feature = "debug-checks"))]
use math::fft::bowers_fft::LayerTwiddles;
use math::fft::errors::FFTError;
use math::fft::two_half_fft::TwoHalfTwiddles;
Expand Down Expand Up @@ -292,11 +291,55 @@ pub(crate) struct LdeTwiddles<F: IsFFTField> {
two_half_inv: TwoHalfTwiddles<F>,
two_half_fwd: TwoHalfTwiddles<F>,
coset_weights: Vec<FieldElement<F>>,
/// Composition half-extension cache, initialized only when the degree-2
/// decomposition path actually runs on CPU.
composition: OnceLock<CompositionLdeTwiddles<F>>,
}

pub(crate) struct CompositionLdeTwiddles<F: IsFFTField> {
/// Inverse twiddles for the g²-coset halves of size `lde_size/2`.
inv: LayerTwiddles<F>,
/// Forward twiddles for the full g-coset of size `lde_size`.
fwd: LayerTwiddles<F>,
/// Weights `g⁻ʲ/(lde_size/2)` for the composition half-extension.
weights: Vec<FieldElement<F>>,
}

impl<F: IsFFTField> CompositionLdeTwiddles<F> {
fn new(half_size: usize, offset: &FieldElement<F>) -> Self {
// Composition half-extension weights: g⁻ʲ / half_size. The constraint-
// quotient halves live on the g²-coset of size `half_size`; the unnormalized
// iFFT yields `n·cⱼ·(g²)ʲ` and these weights turn that into `cⱼ·gʲ` for the
// forward FFT onto the g-coset.
let half_size_fe = FieldElement::<F>::from(half_size as u64);
let inv_half_size_offset = (&half_size_fe * offset)
.inv()
.expect("half_size and coset offset are non-zero");
let half_size_inv = offset * &inv_half_size_offset;
let offset_inv = &half_size_fe * &inv_half_size_offset;
let weights = {
let mut w = Vec::with_capacity(half_size);
let mut cur = half_size_inv;
for _ in 0..half_size {
w.push(cur.clone());
cur = &cur * &offset_inv;
}
w
};

Self {
inv: LayerTwiddles::<F>::new_inverse(half_size.trailing_zeros() as u64)
.expect("valid composition inverse twiddles"),
fwd: LayerTwiddles::<F>::new((half_size * 2).trailing_zeros() as u64)
.expect("valid composition forward twiddles"),
weights,
}
}
}

impl<F: IsFFTField> LdeTwiddles<F> {
/// Construct twiddles and coset weights for a domain of the given size and blowup factor.
fn new(domain: &Domain<F>) -> Self {
pub(crate) fn new(domain: &Domain<F>) -> Self {
let domain_size = domain.interpolation_domain_size;
let lde_size = domain_size * domain.blowup_factor;

Expand Down Expand Up @@ -326,8 +369,22 @@ impl<F: IsFFTField> LdeTwiddles<F> {
two_half_fwd: TwoHalfTwiddles::<F>::new(lde_size.trailing_zeros() as usize, false)
.expect("valid forward two-half twiddles"),
coset_weights,
composition: OnceLock::new(),
}
}

fn composition(&self, domain: &Domain<F>) -> &CompositionLdeTwiddles<F> {
let lde_size = domain.interpolation_domain_size * domain.blowup_factor;
let half_size = lde_size / 2;
debug_assert_eq!(self.coset_weights.len(), domain.interpolation_domain_size);
self.composition
.get_or_init(|| CompositionLdeTwiddles::new(half_size, &domain.coset_offset))
}

#[cfg(test)]
pub(crate) fn has_composition_cache(&self) -> bool {
self.composition.get().is_some()
}
}

/// Number of tables to process concurrently in `multi_prove`.
Expand Down Expand Up @@ -1120,6 +1177,7 @@ pub trait IsStarkProver<
fn decompose_and_extend_d2(
constraint_evaluations: &[FieldElement<FieldExtension>],
domain: &Domain<Field>,
twiddles: &LdeTwiddles<Field>,
) -> Vec<Vec<FieldElement<FieldExtension>>>
where
FieldElement<Field>: AsBytes + Sync + Send,
Expand Down Expand Up @@ -1150,9 +1208,8 @@ pub trait IsStarkProver<
(&two_inv * &sum, &inv_2x[i] * &diff)
});

// Step 3: Extend each part from N evals on g²-coset to 2N evals on g-coset.
// The squared coset offset is g² (= coset_offset²).
let coset_offset_squared = &domain.coset_offset * &domain.coset_offset;
// Step 3: Extend each part from n evals on the g²-coset to 2n evals on the
// g-coset (the full LDE domain).

// GPU fast path: batch both halves into one ext3 LDE call. Requires
// `cuda` feature and a qualifying size. Falls through to CPU when not.
Expand All @@ -1163,43 +1220,46 @@ pub trait IsStarkProver<
return vec![lde_h0, lde_h1];
}

let composition_twiddles = twiddles.composition(domain);
let (lde_h0, lde_h1) = crate::par::join(
|| Self::extend_half_to_lde(&h0_evals, &coset_offset_squared, domain),
|| Self::extend_half_to_lde(&h1_evals, &coset_offset_squared, domain),
|| Self::extend_half_to_lde(&h0_evals, composition_twiddles),
|| Self::extend_half_to_lde(&h1_evals, composition_twiddles),
);
vec![lde_h0, lde_h1]
}

/// Given N evaluations of a degree-<N polynomial on the g²-coset,
/// extend to 2N evaluations on the g-coset (the full LDE domain).
/// This is: iFFT(N, offset=g²) → coefficients → FFT(2N, offset=g).
/// Extend `half_evals` — `n = lde_size/2` evaluations of a degree-`<n` polynomial
/// on the g²-coset — to `2n` evaluations on the g-coset (the full LDE domain).
///
/// Fused: iFFT(n) → coset reshift g²→g → forward FFT(2n) in a single pass with no
/// intermediate coefficient `Polynomial`. The twiddles and the weights `g⁻ʲ/n`
/// (which fold the 1/n normalization and the net g²→g shift) are cached lazily
/// once per domain in [`LdeTwiddles`].
fn extend_half_to_lde(
half_evals: &[FieldElement<FieldExtension>],
squared_offset: &FieldElement<Field>,
domain: &Domain<Field>,
twiddles: &CompositionLdeTwiddles<Field>,
) -> Vec<FieldElement<FieldExtension>>
where
FieldElement<Field>: AsBytes,
FieldElement<FieldExtension>: AsBytes,
{
// iFFT on the N-point squared coset to get coefficients
let poly = Polynomial::interpolate_offset_fft(half_evals, squared_offset)
.expect("iFFT should succeed");
// Evaluate on the full LDE domain (2N points on the g-coset)
evaluate_polynomial_on_lde_domain(
&poly,
domain.blowup_factor,
domain.interpolation_domain_size,
&domain.coset_offset,
debug_assert_eq!(half_evals.len(), twiddles.weights.len());
Polynomial::coset_lde_full::<Field>(
half_evals,
2,
&twiddles.weights,
&twiddles.inv,
&twiddles.fwd,
)
.expect("LDE evaluation should succeed")
.expect("coset extension")
}

/// Returns the result of the second round of the STARK Prove protocol.
fn round_2_compute_composition_polynomial(
air: &dyn AIR<Field = Field, FieldExtension = FieldExtension, PublicInputs = PI>,
pub_inputs: &PI,
domain: &Domain<Field>,
twiddles: &LdeTwiddles<Field>,
round_1_result: &Round1<Field, FieldExtension>,
transition_coefficients: &[FieldElement<FieldExtension>],
boundary_coefficients: &[FieldElement<FieldExtension>],
Expand Down Expand Up @@ -1242,7 +1302,7 @@ pub trait IsStarkProver<
// H₀(x²) = (H(x) + H(-x)) / 2
// H₁(x²) = (H(x) - H(-x)) / (2x)
// On the LDE coset {g·ω^i}, we have -g·ω^i = g·ω^{i+N} since ω^N = -1.
Self::decompose_and_extend_d2(&constraint_evaluations, domain)
Self::decompose_and_extend_d2(&constraint_evaluations, domain, twiddles)
} else if number_of_parts == 1 {
// Degree bound equals trace length: constraint evals are the LDE directly.
vec![constraint_evaluations]
Expand Down Expand Up @@ -2373,6 +2433,7 @@ pub trait IsStarkProver<
&round_1_result,
table_transcript,
domain,
&twiddle_caches[idx],
)?;

#[cfg(feature = "instruments")]
Expand Down Expand Up @@ -2460,6 +2521,7 @@ pub trait IsStarkProver<
round_1_result: &Round1<Field, FieldExtension>,
transcript: &mut (impl IsStarkTranscript<FieldExtension, Field> + Clone),
domain: &Domain<Field>,
twiddles: &LdeTwiddles<Field>,
) -> Result<StarkProof<Field, FieldExtension, PI>, ProvingError>
where
FieldElement<Field>: AsBytes,
Expand Down Expand Up @@ -2500,6 +2562,7 @@ pub trait IsStarkProver<
air,
pub_inputs,
domain,
twiddles,
round_1_result,
&transition_coefficients,
&boundary_coefficients,
Expand Down
43 changes: 42 additions & 1 deletion crypto/stark/src/tests/prover_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
simple_fibonacci::{self, FibonacciAIR, FibonacciPublicInputs},
},
proof::options::ProofOptions,
prover::{IsStarkProver, Prover, evaluate_polynomial_on_lde_domain},
prover::{IsStarkProver, LdeTwiddles, Prover, evaluate_polynomial_on_lde_domain},
test_utils::multi_prove_ram,
tests::domain_cache_stats,
tests::trace_test_helpers::get_trace_evaluations,
Expand All @@ -22,6 +22,42 @@ use math::{

type Felt = FieldElement<GoldilocksField>;

/// The fused composition half-extension (`extend_half_to_lde`) must produce exactly
/// the same g-coset evaluations as the reference it replaces: iFFT on the g²-coset →
/// coefficients → evaluate on the g-coset LDE. Both yield the unique degree-`<n`
/// polynomial evaluated on the g-coset of size `2n`, so they must be byte-identical.
#[test]
fn composition_extend_half_fused_matches_reference() {
use math::fft::bowers_fft::LayerTwiddles;
type F = GoldilocksField;

let g = Felt::from(3); // any non-zero coset offset
let g2 = &g * &g;

for n in [4usize, 8, 16, 32] {
let half: Vec<Felt> = (0..n).map(|i| Felt::from((i as u64) * 7 + 1)).collect();

// Reference: iFFT(g²) → coeffs → evaluate on the g-coset of size 2n.
let poly = Polynomial::interpolate_offset_fft(&half, &g2).unwrap();
let reference = evaluate_polynomial_on_lde_domain(&poly, 2, n, &g).unwrap();

// Fused: coset_lde_full with weights wⱼ = g⁻ʲ / n.
let n_inv = Felt::from(n as u64).inv().unwrap();
let g_inv = g.inv().unwrap();
let mut weights = Vec::with_capacity(n);
let mut w = n_inv;
for _ in 0..n {
weights.push(w);
w = &w * &g_inv;
}
let inv = LayerTwiddles::<F>::new_inverse(n.trailing_zeros() as u64).unwrap();
let fwd = LayerTwiddles::<F>::new((2 * n).trailing_zeros() as u64).unwrap();
let fused = Polynomial::coset_lde_full::<F>(&half, 2, &weights, &inv, &fwd).unwrap();

assert_eq!(reference, fused, "mismatch at n={n}");
}
}

#[test]
fn test_domain_constructor() {
let trace = simple_fibonacci::fibonacci_trace([Felt::from(1), Felt::from(1)], 8);
Expand Down Expand Up @@ -232,10 +268,15 @@ fn test_decompose_and_extend_d2_matches_original() {
.collect();

// --- New path: algebraic decomposition ---
let twiddles = LdeTwiddles::new(&domain);
assert!(!twiddles.has_composition_cache());
let new_result = Prover::<GoldilocksField, GoldilocksField, ()>::decompose_and_extend_d2(
&constraint_evaluations,
&domain,
&twiddles,
);
#[cfg(not(feature = "cuda"))]
assert!(twiddles.has_composition_cache());

assert_eq!(new_result.len(), 2);
assert_eq!(new_result[0].len(), original[0].len());
Expand Down
Loading