Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
cc48f43
add row major batched lde fft primitives
jotabulacios Jun 8, 2026
46c0080
Make LDETraceTable row-major
jotabulacios Jun 8, 2026
a8db4ee
Wire prover to row-major batched LDE
jotabulacios Jun 8, 2026
c55dc05
read trace row major
jotabulacios Jun 8, 2026
59f7c0d
Move the batched-FFT and row-major-LDE unit tests into corresponding …
jotabulacios Jun 8, 2026
4bd73d4
resolve conflicts
jotabulacios Jun 8, 2026
ab786f5
Merge branch 'main' into perf/cpu-lde-rework
jotabulacios Jun 8, 2026
08eaa8b
fix disk-spill EmptyCommitment in row-major LDE
jotabulacios Jun 9, 2026
8dc41fb
Merge branch 'main' into perf/cpu-lde-rework
diegokingston Jun 10, 2026
bd1b9fd
Parallelize trace build and speed up op-dedup bookkeeping
jotabulacios Jun 9, 2026
e6a7297
Skip the identity multiply by alpha_powers[0] in LogUp fingerprints
jotabulacios Jun 10, 2026
26a2bcd
solve conflicts
jotabulacios Jun 16, 2026
665d82a
Remove dead FFT module and gate legacy twiddles
jotabulacios Jun 17, 2026
b69edd3
Harden parallel row-major bit-reverse permute
jotabulacios Jun 18, 2026
47ded68
Guard columns_to_row_major; clarify hasher doc
jotabulacios Jun 18, 2026
14bac6b
solve conflicts
jotabulacios Jun 18, 2026
18d1058
Merge branch 'main' into perf/cpu-lde-rework
diegokingston Jun 18, 2026
56fe834
Merge branch 'main' into perf/cpu-lde-rework
diegokingston Jun 18, 2026
3770047
Merge branch 'main' into perf/cpu-lde-rework
MauroToscano Jun 18, 2026
2191f65
Deduplicate commit_rows_bit_reversed and bit_reverse_vec
jotabulacios Jun 24, 2026
ea32589
solve conflicts
jotabulacios Jun 25, 2026
b5eb9d2
Merge branch 'main' into perf/cpu-lde-rework
MauroToscano Jun 25, 2026
d1bf465
Use default hasher for op dedup maps
MauroToscano Jun 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions crypto/math/src/fft/bit_reversing.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#[cfg(all(feature = "alloc", feature = "parallel"))]
use rayon::prelude::*;

/// In-place bit-reverse permutation algorithm. Requires input length to be a power of two.
pub fn in_place_bit_reverse_permute<E>(input: &mut [E]) {
for i in 0..input.len() {
Expand All @@ -16,3 +19,92 @@ pub fn reverse_index(i: usize, size: u64) -> usize {
i.reverse_bits() >> (usize::BITS - size.trailing_zeros())
}
}

/// Row-major variant of [`in_place_bit_reverse_permute`]: permute a flat
/// `n * num_cols` row-major buffer by bit-reversing the row index, swapping
/// whole rows (`num_cols` consecutive elements) at a time.
///
/// `buf.len()` must equal `n * num_cols` for some power-of-two `n`. Row `i` is
/// swapped with row `reverse_index(i, n)` when that index is greater (so each
/// pair is swapped exactly once). Used by the batched row-major FFT/LDE.
///
/// Parallel path: over a power-of-two row count, bit-reverse is an *involution*
/// (`br(br(i)) == i`), so every non-trivial orbit is a 2-cycle `{i, br(i)}`.
/// Filtering on `br(i) > i` selects one representative per orbit, so the swapped
/// pairs are pairwise disjoint; each swap touches two distinct, non-overlapping
/// row slices, so they can be dispatched via raw-pointer indexing without a
/// synchronization barrier.
///
/// The power-of-two row count is the precondition that makes bit-reverse an
/// involution, so it is enforced with a runtime `assert!` (not just a
/// `debug_assert!`): a non-power-of-two `n` would break the disjointness the
/// parallel path relies on, turning a bad caller's input into a data race.
#[cfg(feature = "alloc")]
pub(crate) fn in_place_bit_reverse_permute_row_major<E: Send + Sync>(
buf: &mut [E],
num_cols: usize,
) {
if num_cols == 0 || buf.is_empty() {
return;
}
debug_assert!(
buf.len().is_multiple_of(num_cols),
"buf.len() must be a multiple of num_cols"
);
let n = buf.len() / num_cols;
if n <= 1 {
return;
}
// Safety-critical, not just correctness: the parallel raw-pointer path below
// relies on bit-reverse being an involution, which holds only when `n` is a
// power of two. Enforce at runtime so a bad caller panics here rather than
// triggering a data race in the unsafe block.
assert!(n.is_power_of_two(), "row count must be a power of two");

#[cfg(feature = "parallel")]
{
// No upfront Vec<(usize, usize)> collection (saves ~16 MB at log21 n=64).
if n >= 2048 {
use core::sync::atomic::{AtomicPtr, Ordering};
let raw = AtomicPtr::new(buf.as_mut_ptr());
(0..n).into_par_iter().for_each(|i| {
let j = reverse_index(i, n as u64);
if j > i {
let ptr = raw.load(Ordering::Relaxed);
let lo = i * num_cols;
let hi = j * num_cols;
// SAFETY: (lo..lo+M) and (hi..hi+M) are disjoint, so no two
// threads ever touch overlapping ranges:
// 1. `n` is a power of two (asserted above), so bit-reverse
// is an involution (`br(br(i)) == i`); every non-trivial
// orbit is a 2-cycle `{i, br(i)}`. The `j > i` filter
// keeps one representative per orbit, so the chosen pairs
// are pairwise disjoint and `lo != hi`. (`j = br(i) < n`,
// so both rows are in bounds.)
// 2. Rows are `num_cols` wide and don't overlap, so the two
// M-element ranges are disjoint.
// 3. `Ordering::Relaxed` on the load is sound: the pointer is
// written before `into_par_iter()` starts, and Rayon's
// thread spawn provides the happens-before edge that makes
// every worker observe the initial value.
unsafe {
let lo_row = core::slice::from_raw_parts_mut(ptr.add(lo), num_cols);
let hi_row = core::slice::from_raw_parts_mut(ptr.add(hi), num_cols);
lo_row.swap_with_slice(hi_row);
}
}
});
return;
}
}

for i in 0..n {
let j = reverse_index(i, n as u64);
if j > i {
let lo = i * num_cols;
let hi = j * num_cols;
let (left, right) = buf.split_at_mut(hi);
left[lo..lo + num_cols].swap_with_slice(&mut right[..num_cols]);
}
}
}
2 changes: 2 additions & 0 deletions crypto/math/src/fft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ pub mod bowers_fft;
pub mod errors;
#[cfg(feature = "alloc")]
pub mod roots_of_unity;
#[cfg(feature = "alloc")]
pub mod two_half_fft;

#[cfg(all(test, feature = "alloc"))]
pub(crate) mod test_helpers;
236 changes: 236 additions & 0 deletions crypto/math/src/fft/two_half_fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
//! Cache-blocked, transpose-free batched FFT (port of Plonky3's two-half
//! `Radix2DitParallel::dft_batch`).
//!
//! The flat Bowers DIF streams the whole `n·m` buffer with large strides at the
//! early layers, thrashing cache for large `n`. This kernel keeps every layer
//! cache-resident by interleaving bit-reversals: bit-reverse → first `mid` DIT
//! layers within `2^mid`-row chunks → bit-reverse → remaining layers within
//! `2^(log_n−mid)`-row chunks → bit-reverse. The bit-reversals turn the
//! large-stride butterflies into chunk-local ones — the cache win the flat
//! Bowers misses. Output is natural order, identical to a per-column
//! single-column Bowers FFT followed by `in_place_bit_reverse_permute_row_major`.
//!
//! Twiddles are precomputed once per size in [`TwoHalfTwiddles`] and reused
//! across calls (the trace LDE invokes this once per direction per domain, and
//! the same domain recurs across tables and rounds).

#[cfg(feature = "alloc")]
use crate::fft::bit_reversing::{
in_place_bit_reverse_permute, in_place_bit_reverse_permute_row_major,
};
#[cfg(feature = "alloc")]
use crate::fft::errors::FFTError;
#[cfg(feature = "alloc")]
use crate::field::{
element::FieldElement,
traits::{IsFFTField, IsField, IsSubFieldOf},
};
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(all(feature = "alloc", feature = "parallel"))]
use rayon::prelude::*;

/// Precomputed twiddles for a size-`2^log_n` two-half FFT in one direction.
///
/// `tw` is the flat geometric array `[ω⁰, ω¹, …, ω^(n/2−1)]` (`ω` the forward
/// root for the forward transform, its inverse for the inverse transform);
/// `bitrev_tw` is its bit-reversal permutation, used by the second-half layers.
/// Build once and share across calls of the same size and direction.
#[cfg(feature = "alloc")]
pub struct TwoHalfTwiddles<F: IsField> {
log_n: usize,
tw: Vec<FieldElement<F>>,
bitrev_tw: Vec<FieldElement<F>>,
}

#[cfg(feature = "alloc")]
impl<F: IsFFTField> TwoHalfTwiddles<F> {
/// Precompute twiddles for a size-`2^log_n` transform. `inverse = true`
/// selects the (unscaled) inverse transform (uses `ω⁻¹`); the `1/n`
/// normalization is the caller's responsibility.
pub fn new(log_n: usize, inverse: bool) -> Result<Self, FFTError> {
let n = 1usize << log_n;
let half = n / 2;
// `omega` is unused when half == 0 (log_n == 0), so skip the lookup.
let omega = if half == 0 {
FieldElement::<F>::one()
} else {
let fwd = F::get_primitive_root_of_unity(log_n as u64)
.map_err(|_| FFTError::InputError(n))?;
if inverse {
fwd.inv().map_err(|_| FFTError::InputError(n))?
} else {
fwd
}
};

let mut tw: Vec<FieldElement<F>> = Vec::with_capacity(half);
let mut cur = FieldElement::<F>::one();
for _ in 0..half {
tw.push(cur.clone());
cur = &cur * &omega;
}
let mut bitrev_tw = tw.clone();
in_place_bit_reverse_permute(&mut bitrev_tw);

Ok(Self {
log_n,
tw,
bitrev_tw,
})
}
}

/// DIT butterfly over two equal-length row-slices, one twiddle for all pairs:
/// `a' = a + tw·b`, `b' = a − tw·b` (element-wise; `tw·b` is the F×E multiply).
#[cfg(feature = "alloc")]
#[inline]
fn dit_butterfly_rows<F, E>(
lo: &mut [FieldElement<E>],
hi: &mut [FieldElement<E>],
tw: &FieldElement<F>,
) where
F: IsSubFieldOf<E>,
E: IsField,
{
for (a, b) in lo.iter_mut().zip(hi.iter_mut()) {
let t = tw * &*b; // F × E → E
let new_a = &*a + &t;
*b = &*a - &t;
*a = new_a;
}
}

/// First-half DIT layer (per-pair twiddle), applied within one cache-resident
/// row-chunk. `tw` is the flat `[ω^0..ω^(n/2−1)]` array; pair `j` of layer
/// `layer` uses `tw[j · 2^(log_n−1−layer)]`.
#[cfg(feature = "alloc")]
fn dit_first_half_layer<F, E>(
chunk: &mut [FieldElement<E>],
m: usize,
layer: usize,
log_n: usize,
tw: &[FieldElement<F>],
) where
F: IsSubFieldOf<E>,
E: IsField,
{
let half = 1usize << layer;
let block_rows = half * 2;
let step = 1usize << (log_n - 1 - layer);
for block in chunk.chunks_mut(block_rows * m) {
let (lows, highs) = block.split_at_mut(half * m);
for j in 0..half {
let twj = &tw[j * step];
dit_butterfly_rows(
&mut lows[j * m..j * m + m],
&mut highs[j * m..j * m + m],
twj,
);
}
}
}

/// Second-half DIT layer (one twiddle per block, bit-reversed twiddle order),
/// applied within one cache-resident row-chunk owned by `thread`.
#[cfg(feature = "alloc")]
fn dit_second_half_layer<F, E>(
chunk: &mut [FieldElement<E>],
m: usize,
layer: usize,
log_n: usize,
mid: usize,
thread: usize,
bitrev_tw: &[FieldElement<F>],
) where
F: IsSubFieldOf<E>,
E: IsField,
{
let half_block = 1usize << (log_n - 1 - layer);
let block_rows = half_block * 2;
let first_block = thread << (layer - mid);
for (b, block) in chunk.chunks_mut(block_rows * m).enumerate() {
let twb = &bitrev_tw[first_block + b];
let (lows, highs) = block.split_at_mut(half_block * m);
dit_butterfly_rows(lows, highs, twb);
}
}

/// Cache-blocked, transpose-free batched FFT. `buf` is `n * num_cols` row-major
/// (`n` rows of `num_cols` consecutive elements); `tw` are the precomputed
/// twiddles for size `n` in the desired direction (forward or inverse).
/// Output is the natural-order DFT (matches a per-column single-column Bowers
/// FFT followed by `in_place_bit_reverse_permute_row_major`). Inverse transforms
/// are NOT scaled by `1/n` — that is the caller's responsibility (e.g. folded
/// into the coset-weight pass of the LDE).
#[cfg(feature = "alloc")]
pub fn fft_batch_two_half<F, E>(
buf: &mut [FieldElement<E>],
num_cols: usize,
tw: &TwoHalfTwiddles<F>,
) -> Result<(), FFTError>
where
F: IsFFTField + IsSubFieldOf<E>,
E: IsField,
FieldElement<F>: Sync,
FieldElement<E>: Send + Sync,
{
let m = num_cols;
if m == 0 || buf.is_empty() {
return Ok(());
}
let total = buf.len();
if !total.is_multiple_of(m) {
return Err(FFTError::InputError(total));
}
let n = total / m;
if !n.is_power_of_two() {
return Err(FFTError::InputError(n));
}
let log_n = n.trailing_zeros() as usize;
if log_n != tw.log_n {
return Err(FFTError::InputError(n));
}
if log_n == 0 {
return Ok(());
}

let flat_tw = &tw.tw;
let bitrev_tw = &tw.bitrev_tw;
let mid = log_n.div_ceil(2);

// Step 1: bit-reverse rows.
in_place_bit_reverse_permute_row_major(buf, m);

// Step 2: first half — layers 0..mid within 2^mid-row chunks (all identical).
let first_chunk = (1usize << mid) * m;
#[cfg(feature = "parallel")]
let it = buf.par_chunks_mut(first_chunk);
#[cfg(not(feature = "parallel"))]
let it = buf.chunks_mut(first_chunk);
it.for_each(|chunk| {
for layer in 0..mid {
dit_first_half_layer::<F, E>(chunk, m, layer, log_n, flat_tw);
}
});

// Step 3: bit-reverse rows.
in_place_bit_reverse_permute_row_major(buf, m);

// Step 4: second half — layers mid..log_n within 2^(log_n-mid)-row chunks.
let second_chunk = (1usize << (log_n - mid)) * m;
#[cfg(feature = "parallel")]
let it2 = buf.par_chunks_mut(second_chunk).enumerate();
#[cfg(not(feature = "parallel"))]
let it2 = buf.chunks_mut(second_chunk).enumerate();
it2.for_each(|(thread, chunk)| {
for layer in mid..log_n {
dit_second_half_layer::<F, E>(chunk, m, layer, log_n, mid, thread, bitrev_tw);
}
});

// Step 5: final bit-reverse to natural order.
in_place_bit_reverse_permute_row_major(buf, m);

Ok(())
}
Loading
Loading