diff --git a/crypto/math-cuda/kernels/keccak.cu b/crypto/math-cuda/kernels/keccak.cu index c22bc4d05..86c0e519e 100644 --- a/crypto/math-cuda/kernels/keccak.cu +++ b/crypto/math-cuda/kernels/keccak.cu @@ -317,6 +317,40 @@ extern "C" __global__ void keccak_fri_leaves_ext3( // children: nodes[parent_begin + n_pairs .. parent_begin + 3 * n_pairs] // parents: nodes[parent_begin .. parent_begin + n_pairs] // +// --------------------------------------------------------------------------- +// Row-major base leaf hashing. +// +// Input layout: data[row * m + col] for `num_rows` rows and `m` columns. +// For leaf `tid`, reads the bit-reversed row `br(tid)` — a contiguous slice +// of `m` elements starting at data[br * m]. Coalesced when multiple threads +// in the same warp process consecutive `tid` values (they read non-overlapping +// rows, each a contiguous block of m u64s in order). +// --------------------------------------------------------------------------- +extern "C" __global__ void keccak256_leaves_base_row_major( + const uint64_t *data, + uint64_t m, + uint64_t num_rows, + uint64_t log_num_rows, + uint8_t *hashed_leaves_out) +{ + uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= num_rows) return; + uint64_t br = __brevll(tid) >> (64 - log_num_rows); + const uint64_t *row = data + br * m; + + uint64_t st[25]; + #pragma unroll + for (int i = 0; i < 25; ++i) st[i] = 0; + + uint32_t rate_pos = 0; + for (uint64_t c = 0; c < m; ++c) { + uint64_t canon = goldilocks::canonical(row[c]); + uint64_t lane = bswap64(canon); + absorb_lane(st, rate_pos, lane); + } + finalize_keccak256(st, rate_pos, hashed_leaves_out + tid * 32); +} + // Each thread hashes one child pair → one parent. Keccak-256 of the // concatenation of two 32-byte siblings, identical to // `FieldElementVectorBackend::hash_new_parent` on host. diff --git a/crypto/math-cuda/kernels/ntt.cu b/crypto/math-cuda/kernels/ntt.cu index cf5e1df2c..13c1af688 100644 --- a/crypto/math-cuda/kernels/ntt.cu +++ b/crypto/math-cuda/kernels/ntt.cu @@ -285,3 +285,129 @@ extern "C" __global__ void ntt_dit_8_levels(uint64_t *x, // Store back to the remapped row. x[row] = tile[threadIdx.x]; } + +// ============================================================================ +// ROW-MAJOR BATCHED KERNELS +// +// Data layout: data[row * m + col] for n rows and m columns. +// threadIdx.x = column index → consecutive threads access consecutive columns +// of the same row → coalesced global memory access. +// Twiddle factors depend only on the butterfly position, not the column → +// one twiddle load is broadcast across the entire warp. +// ============================================================================ + +// Bit-reverse permute rows: swap row `row` with row `br(row)`. +// Grid: gridDim.x = ceil(m / 256), gridDim.y = min(n, 65535). +// Grid-stride loop over rows so a capped gridDim.y covers all n rows. +extern "C" __global__ void bit_reverse_row_major(uint64_t *data, + uint64_t n, + uint64_t log_n, + uint64_t m) +{ + uint64_t col = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (col >= m) return; + for (uint64_t row = blockIdx.y; row < n; row += gridDim.y) { + uint64_t rev = __brevll(row) >> (64 - log_n); + if (row < rev) { + uint64_t tmp = data[row * m + col]; + data[row * m + col] = data[rev * m + col]; + data[rev * m + col] = tmp; + } + } +} + +// One DIT butterfly level on row-major data. +// Grid: gridDim.x = ceil(m / blockDim.x), gridDim.y = min(ceil(n/2 / blockDim.y), 65535). +// blockDim.x covers columns (coalescing), blockDim.y covers butterfly pairs. +// Grid-stride loop over butterfly-pair tiles so capped gridDim.y covers all n/2 pairs. +extern "C" __global__ void ntt_dit_level_row_major(uint64_t *data, + const uint64_t *tw, + uint64_t n, + uint64_t log_n, + uint64_t level, + uint64_t m) +{ + uint64_t col = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + uint64_t n_half = n >> 1; + if (col >= m) return; + + uint64_t half = 1ULL << level; + uint64_t block_size = half << 1; + + for (uint64_t bfly_base = blockIdx.y * blockDim.y; + bfly_base < n_half; + bfly_base += (uint64_t)gridDim.y * blockDim.y) { + uint64_t butterfly = bfly_base + threadIdx.y; + if (butterfly >= n_half) break; + + uint64_t block_idx = butterfly >> level; + uint64_t k = butterfly & (half - 1); + uint64_t i0 = block_idx * block_size + k; + uint64_t i1 = i0 + half; + + // Same twiddle for all columns at this butterfly position (broadcast). + uint64_t w = tw[k << (log_n - level - 1)]; + + uint64_t u = data[i0 * m + col]; + uint64_t v = mul(w, data[i1 * m + col]); + data[i0 * m + col] = add(u, v); + data[i1 * m + col] = sub(u, v); + } +} + +// Pointwise multiply row-major: data[row * m + col] *= weights[row]. +// One weight per row, broadcast across all m columns. +// Grid: gridDim.x = ceil(m / 256), gridDim.y = min(n, 65535). +// Grid-stride loop over rows. +extern "C" __global__ void pointwise_mul_row_major(uint64_t *data, + const uint64_t *weights, + uint64_t n, + uint64_t m) +{ + uint64_t col = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (col >= m) return; + for (uint64_t row = blockIdx.y; row < n; row += gridDim.y) + data[row * m + col] = mul(data[row * m + col], weights[row]); +} + +// ── Row-major → column-major transpose (for GpuLdeBase handle) ─────────────── +// +// Converts the row-major LDE output to the column-major layout that downstream +// GPU kernels (DEEP, barycentric) require for the device handle. +// +// src[r * cols + c] → dst[c * out_stride + r] +// +// Grid: gridDim.x = ceil(cols/32), gridDim.y = min(ceil(rows/32), 65535). +// Grid-strides over row tiles so all rows are covered when rows > 65535*32. + +#define MTILE 32 +#define MTILE_P (MTILE + 1) + +extern "C" __global__ void matrix_transpose_strided( + const uint64_t *__restrict__ src, + uint64_t *__restrict__ dst, + uint32_t rows, + uint32_t cols, + uint64_t out_stride) +{ + __shared__ uint64_t tile[MTILE][MTILE_P]; + + for (uint32_t row_base = blockIdx.y * MTILE; row_base < rows; + row_base += gridDim.y * MTILE) { + uint32_t x = blockIdx.x * MTILE + threadIdx.x; + uint32_t y = row_base + threadIdx.y; + + if (x < cols && y < rows) + tile[threadIdx.y][threadIdx.x] = src[(uint64_t)y * cols + x]; + + __syncthreads(); + + uint32_t tx = row_base + threadIdx.x; + uint32_t ty = blockIdx.x * MTILE + threadIdx.y; + + if (tx < rows && ty < cols) + dst[(uint64_t)ty * out_stride + tx] = tile[threadIdx.x][threadIdx.y]; + + __syncthreads(); + } +} diff --git a/crypto/math-cuda/src/device.rs b/crypto/math-cuda/src/device.rs index 17e2f9f82..e9db7657e 100644 --- a/crypto/math-cuda/src/device.rs +++ b/crypto/math-cuda/src/device.rs @@ -140,8 +140,14 @@ pub struct Backend { pub ntt_dit_8_levels_batched: CudaFunction, pub pointwise_mul_batched: CudaFunction, pub scalar_mul_batched: CudaFunction, + // row-major NTT kernels + pub bit_reverse_row_major: CudaFunction, + pub ntt_dit_level_row_major: CudaFunction, + pub pointwise_mul_row_major: CudaFunction, + pub matrix_transpose_strided: CudaFunction, // keccak.ptx + pub keccak256_leaves_base_row_major: CudaFunction, pub keccak256_leaves_base_batched: CudaFunction, pub keccak256_leaves_ext3_batched: CudaFunction, pub keccak_comp_poly_leaves_ext3: CudaFunction, @@ -237,6 +243,12 @@ impl Backend { ntt_dit_8_levels_batched: ntt.load_function("ntt_dit_8_levels_batched")?, pointwise_mul_batched: ntt.load_function("pointwise_mul_batched")?, scalar_mul_batched: ntt.load_function("scalar_mul_batched")?, + bit_reverse_row_major: ntt.load_function("bit_reverse_row_major")?, + ntt_dit_level_row_major: ntt.load_function("ntt_dit_level_row_major")?, + pointwise_mul_row_major: ntt.load_function("pointwise_mul_row_major")?, + matrix_transpose_strided: ntt.load_function("matrix_transpose_strided")?, + keccak256_leaves_base_row_major: keccak + .load_function("keccak256_leaves_base_row_major")?, keccak256_leaves_base_batched: keccak.load_function("keccak256_leaves_base_batched")?, keccak256_leaves_ext3_batched: keccak.load_function("keccak256_leaves_ext3_batched")?, keccak_comp_poly_leaves_ext3: keccak.load_function("keccak_comp_poly_leaves_ext3")?, diff --git a/crypto/math-cuda/src/lde.rs b/crypto/math-cuda/src/lde.rs index ee5dc3fce..8c319f764 100644 --- a/crypto/math-cuda/src/lde.rs +++ b/crypto/math-cuda/src/lde.rs @@ -216,6 +216,384 @@ fn launch_pointwise_mul_batched( Ok(()) } +// ── Row-major NTT helpers ──────────────────────────────────────────────────── + +fn launch_bit_reverse_row_major( + stream: &CudaStream, + be: &Backend, + buf: &mut CudaSlice, + n: u64, + log_n: u64, + m: u64, +) -> Result<()> { + let cfg = LaunchConfig { + grid_dim: ((m as u32).div_ceil(256), (n as u32).min(65535), 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.bit_reverse_row_major) + .arg(buf) + .arg(&n) + .arg(&log_n) + .arg(&m) + .launch(cfg)?; + } + Ok(()) +} + +fn launch_pointwise_mul_row_major( + stream: &CudaStream, + be: &Backend, + buf: &mut CudaSlice, + weights: &CudaSlice, + n: u64, + m: u64, +) -> Result<()> { + let cfg = LaunchConfig { + grid_dim: ((m as u32).div_ceil(256), (n as u32).min(65535), 1), + block_dim: (256, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.pointwise_mul_row_major) + .arg(buf) + .arg(weights) + .arg(&n) + .arg(&m) + .launch(cfg)?; + } + Ok(()) +} + +fn run_row_major_ntt_body( + stream: &CudaStream, + be: &Backend, + buf: &mut CudaSlice, + tw: &CudaSlice, + n: u64, + log_n: u64, + m: u64, +) -> Result<()> { + let col_tile: u32 = 32.min(m as u32); + let row_tile: u32 = (256 / col_tile).max(1); + for level in 0..log_n { + let cfg = LaunchConfig { + grid_dim: ( + (m as u32).div_ceil(col_tile), + ((n >> 1) as u32).div_ceil(row_tile).min(65535), + 1, + ), + block_dim: (col_tile, row_tile, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.ntt_dit_level_row_major) + .arg(&mut *buf) + .arg(tw) + .arg(&n) + .arg(&log_n) + .arg(&level) + .arg(&m) + .launch(cfg)?; + } + } + Ok(()) +} + +fn launch_keccak_base_row_major( + stream: &CudaStream, + be: &Backend, + buf: &CudaSlice, + m: u64, + num_rows: u64, + log_num_rows: u64, + leaves_out: &mut cudarc::driver::CudaViewMut<'_, u8>, +) -> Result<()> { + // The keccak kernel is register-heavy (Keccak state `uint64_t st[25]`), so it + // must launch with the keccak-tuned block dim (128). `for_num_elems` uses 1024 + // threads/block, which exceeds the per-block register budget and fails the + // launch with CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES — silently dropping the whole + // R1 GPU path to the CPU fallback (no device handle for rounds 2-4). + let cfg = keccak_launch_cfg(num_rows); + unsafe { + stream + .launch_builder(&be.keccak256_leaves_base_row_major) + .arg(buf) + .arg(&m) + .arg(&num_rows) + .arg(&log_num_rows) + .arg(leaves_out) + .launch(cfg)?; + } + Ok(()) +} + +/// Transpose row-major `lde_size × cols` → column-major with stride `lde_size`, +/// returning the new device buffer. Used to convert the row-major LDE output to +/// the column-major layout expected by downstream GPU kernels (DEEP, barycentric). +/// No synchronize — callers on the same stream are ordered; other streams must +/// synchronize themselves. +fn launch_row_to_col_major( + stream: &Arc, + be: &Backend, + src: &CudaSlice, + lde_size: usize, + cols: usize, + lde_u64: u64, +) -> Result> { + let mut dst = stream.alloc_zeros::(lde_size * cols)?; + let cfg = LaunchConfig { + grid_dim: ( + (cols as u32).div_ceil(32), + (lde_size as u32).div_ceil(32).min(65535), + 1, + ), + block_dim: (32, 32, 1), + shared_mem_bytes: 0, + }; + unsafe { + stream + .launch_builder(&be.matrix_transpose_strided) + .arg(src) + .arg(&mut dst) + .arg(&(lde_size as u32)) + .arg(&(cols as u32)) + .arg(&lde_u64) + .launch(cfg)?; + } + Ok(dst) +} + +/// Row-major LDE + Keccak + Merkle, all on-device. +/// +/// Input: `row_major` is a flat `n * m` slice in row-major order. +/// Returns (merkle_nodes, GpuLdeBase handle, row-major LDE Vec). +/// Single H2D, row-major NTT, single D2H — no CPU-side extract or transpose. +/// The returned handle is column-major (as required by downstream GPU kernels): +/// after D2H, `buf` is transposed on-device to column-major for the handle. +pub fn coset_lde_row_major_with_merkle_tree_keep( + row_major: &[u64], + n: usize, + m: usize, + blowup_factor: usize, + weights: &[u64], +) -> Result<(Vec, GpuLdeBase, Vec)> { + assert_eq!(row_major.len(), n * m); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + let lde_size = n * blowup_factor; + assert_u32_domain(lde_size, "coset_lde_row_major lde_size"); + + let nodes_bytes = KeccakCommit::FullTree.total_nodes_bytes(lde_size); + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let m_u64 = m as u64; + + let be = backend()?; + let stream = be.next_stream(); + + // H2D into a zeroed lde_size*m buffer; only the first n*m rows carry data, + // the remainder are already zero (zero-padding for LDE expansion). + let mut buf = stream.alloc_zeros::(lde_size * m)?; + stream.memcpy_htod(row_major, &mut buf.slice_mut(0..n * m))?; + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + // iNTT: bit-reverse rows → per-level DIT. + launch_bit_reverse_row_major(stream.as_ref(), be, &mut buf, n_u64, log_n, m_u64)?; + run_row_major_ntt_body( + stream.as_ref(), + be, + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + m_u64, + )?; + + // Coset weights: one weight per row, broadcast across all m columns. + launch_pointwise_mul_row_major(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, m_u64)?; + + // Forward NTT at lde_size. + launch_bit_reverse_row_major(stream.as_ref(), be, &mut buf, lde_u64, log_lde, m_u64)?; + run_row_major_ntt_body( + stream.as_ref(), + be, + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + m_u64, + )?; + + // Keccak + Merkle on-device. + let mut nodes_dev = unsafe { stream.alloc::(nodes_bytes) }?; + let leaves_offset = KeccakCommit::FullTree.leaves_offset_bytes(lde_size); + { + let mut leaves_view = nodes_dev.slice_mut(leaves_offset..leaves_offset + lde_size * 32); + launch_keccak_base_row_major( + stream.as_ref(), + be, + &buf, + m_u64, + lde_u64, + log_lde, + &mut leaves_view, + )?; + } + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; + + // D2H the row-major LDE first (before the handle transpose). Release the + // staging lock before the Merkle nodes transfer to minimise lock contention. + let lde_out = { + let staging_slot = be.pinned_staging(); + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(lde_size * m, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(lde_size * m) }; + stream.memcpy_dtoh(&buf, pinned)?; + stream.synchronize()?; + let out = pinned[..lde_size * m].to_vec(); + drop(staging); + out + }; + + let mut nodes_out = vec![0u8; nodes_bytes]; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, &mut nodes_out)?; + + // Transpose row-major buf → column-major for the handle. Downstream kernels + // (DEEP, barycentric) expect buf[c * lde_size + r] (column-major). + let col_major_dev = launch_row_to_col_major(&stream, be, &buf, lde_size, m, lde_u64)?; + // Synchronize before returning: the handle crosses stream boundaries — downstream + // consumers call be.next_stream() and read handle.buf on a different stream. + // Without this, a barycentric or DEEP kernel can start before the transpose finishes. + stream.synchronize()?; + + let handle = GpuLdeBase { + buf: Arc::new(col_major_dev), + m, + lde_size, + }; + Ok((nodes_out, handle, lde_out)) +} + +/// Row-major ext3 LDE + Keccak + Merkle, all on-device. +/// +/// `Fp3` is `[u64; 3]` in memory, so row-major ext3 with `m` ext3 columns is +/// identical to row-major base-field with `m3 = m * 3`. The same row-major NTT +/// and Keccak kernels handle all three components simultaneously — no extra +/// de-interleave step. +/// +/// Input: `row_major` is `n * m` ext3 elements as flat `n * m * 3` u64s +/// (element [row][col] components k=0,1,2 at `row_major[(row*m + col)*3 + k]`). +/// Returns (merkle_nodes, GpuLdeExt3 handle, row-major ext3 LDE Vec). +pub fn coset_lde_ext3_row_major_with_merkle_tree_keep( + row_major: &[u64], + n: usize, + m: usize, + blowup_factor: usize, + weights: &[u64], +) -> Result<(Vec, GpuLdeExt3, Vec)> { + let m3 = m * 3; + assert_eq!(row_major.len(), n * m3); + assert!(n.is_power_of_two()); + assert_eq!(weights.len(), n); + assert!(blowup_factor.is_power_of_two()); + let lde_size = n * blowup_factor; + assert_u32_domain(lde_size, "coset_lde_ext3_row_major lde_size"); + + let nodes_bytes = KeccakCommit::FullTree.total_nodes_bytes(lde_size); + let log_n = n.trailing_zeros() as u64; + let log_lde = lde_size.trailing_zeros() as u64; + let n_u64 = n as u64; + let lde_u64 = lde_size as u64; + let m3_u64 = m3 as u64; + + let be = backend()?; + let stream = be.next_stream(); + + let mut buf = stream.alloc_zeros::(lde_size * m3)?; + stream.memcpy_htod(row_major, &mut buf.slice_mut(0..n * m3))?; + + let inv_tw = be.inv_twiddles_for(log_n)?; + let fwd_tw = be.fwd_twiddles_for(log_lde)?; + let weights_dev = stream.clone_htod(weights)?; + + // iNTT + coset weights + forward NTT — same row-major kernels as base-field + // but with m3 = m*3 (all 3 components processed simultaneously). + launch_bit_reverse_row_major(stream.as_ref(), be, &mut buf, n_u64, log_n, m3_u64)?; + run_row_major_ntt_body( + stream.as_ref(), + be, + &mut buf, + inv_tw.as_ref(), + n_u64, + log_n, + m3_u64, + )?; + launch_pointwise_mul_row_major(stream.as_ref(), be, &mut buf, &weights_dev, n_u64, m3_u64)?; + launch_bit_reverse_row_major(stream.as_ref(), be, &mut buf, lde_u64, log_lde, m3_u64)?; + run_row_major_ntt_body( + stream.as_ref(), + be, + &mut buf, + fwd_tw.as_ref(), + lde_u64, + log_lde, + m3_u64, + )?; + + // Keccak: same row-major kernel — each leaf reads m3 consecutive u64s (= m ext3 elements). + let mut nodes_dev = unsafe { stream.alloc::(nodes_bytes) }?; + let leaves_offset = KeccakCommit::FullTree.leaves_offset_bytes(lde_size); + { + let mut leaves_view = nodes_dev.slice_mut(leaves_offset..leaves_offset + lde_size * 32); + launch_keccak_base_row_major( + stream.as_ref(), + be, + &buf, + m3_u64, + lde_u64, + log_lde, + &mut leaves_view, + )?; + } + crate::merkle::build_inner_tree_levels(stream.as_ref(), be, &mut nodes_dev, lde_size)?; + + let lde_out = { + let staging_slot = be.pinned_staging(); + let mut staging = staging_slot.lock().unwrap(); + staging.ensure_capacity(lde_size * m3, &be.ctx)?; + let pinned = unsafe { staging.as_mut_slice(lde_size * m3) }; + stream.memcpy_dtoh(&buf, pinned)?; + stream.synchronize()?; + let out = pinned[..lde_size * m3].to_vec(); + drop(staging); + out + }; + + let mut nodes_out = vec![0u8; nodes_bytes]; + d2h_bytes_via_pinned_hashes(&stream, be, &nodes_dev, &mut nodes_out)?; + + let col_major_dev = launch_row_to_col_major(&stream, be, &buf, lde_size, m3, lde_u64)?; + stream.synchronize()?; + + let handle = GpuLdeExt3 { + buf: Arc::new(col_major_dev), + m, + lde_size, + }; + Ok((nodes_out, handle, lde_out)) +} + /// Handle to a base-field LDE kept live on device after R1 commit. /// Layout: `m` columns, each `lde_size` u64s, column `c` at byte offset /// `c * lde_size * 8` within `buf`. Freed when `buf` Arc drops. @@ -644,29 +1022,6 @@ pub fn coset_lde_batch_base_into_with_merkle_tree( .map(|_| ()) } -/// Fused LDE + leaf-hash + Merkle tree build. If `keep_device_buf` is true, -/// returns an `Arc>` wrapping the LDE device buffer so callers -/// (R2–R4 GPU paths) can reuse the LDE without a re-H2D. -pub fn coset_lde_batch_base_into_with_merkle_tree_keep( - columns: &[&[u64]], - blowup_factor: usize, - weights: &[u64], - outputs: &mut [&mut [u64]], - merkle_nodes_out: &mut [u8], -) -> Result { - let opt = coset_lde_batch_base_into_with_merkle_tree_inner( - columns, - blowup_factor, - weights, - outputs, - merkle_nodes_out, - KeccakCommit::FullTree, - true, - )?; - let handle = opt.expect("keep_device_buf=true must return Some"); - Ok(handle) -} - fn coset_lde_batch_base_into_with_merkle_tree_inner( columns: &[&[u64]], blowup_factor: usize, @@ -876,30 +1231,6 @@ pub fn coset_lde_batch_ext3_into_with_merkle_tree( .map(|_| ()) } -/// Ext3 variant of [`coset_lde_batch_base_into_with_merkle_tree_keep`] — -/// returns an `Arc>` handle to the de-interleaved LDE device -/// buffer. -pub fn coset_lde_batch_ext3_into_with_merkle_tree_keep( - columns: &[&[u64]], - n: usize, - blowup_factor: usize, - weights: &[u64], - outputs: &mut [&mut [u64]], - merkle_nodes_out: &mut [u8], -) -> Result { - let opt = coset_lde_batch_ext3_into_with_merkle_tree_inner( - columns, - n, - blowup_factor, - weights, - outputs, - merkle_nodes_out, - KeccakCommit::FullTree, - true, - )?; - Ok(opt.expect("keep_device_buf=true must return Some")) -} - #[allow(clippy::too_many_arguments)] fn coset_lde_batch_ext3_into_with_merkle_tree_inner( columns: &[&[u64]], diff --git a/crypto/math-cuda/tests/barycentric_cpu_gpu_parity.rs b/crypto/math-cuda/tests/barycentric_cpu_gpu_parity.rs new file mode 100644 index 000000000..1b85494bb --- /dev/null +++ b/crypto/math-cuda/tests/barycentric_cpu_gpu_parity.rs @@ -0,0 +1,285 @@ +//! GPU barycentric kernels (`barycentric_base` / `barycentric_ext3`) must produce +//! the same OOD evaluation as the CPU formula in `get_trace_evaluations_from_lde` +//! (`interpolate_coset_eval_ext_with_g_n_inv`). Covers base field and ext3. +//! +//! Note: `barycentric_ext3` expects the pre-strided input in component-major layout +//! (`[all-a, all-b, all-c]`), not interleaved. Passing interleaved data produces +//! wrong results without any error — the test catches this silently. + +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::field::traits::{IsFFTField, IsPrimeField}; +use math::polynomial::barycentric_inv_denoms; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; + +type Fp = FieldElement; +type Fp3 = FieldElement; + +fn rand_fp(rng: &mut ChaCha8Rng) -> Fp { + Fp::from_raw(rng.r#gen::()) +} +fn rand_fp3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([rand_fp(rng), rand_fp(rng), rand_fp(rng)]) +} + +/// Build coset points `[g * ω^0, g * ω^1, ..., g * ω^{n-1}]` from a +/// coset offset `g` and the primitive root `ω` of the trace domain. +fn coset_points(n: usize, coset_offset: u64) -> Vec { + let log_n = n.trailing_zeros() as u64; + let omega = GoldilocksField::get_primitive_root_of_unity(log_n).unwrap(); + let g = Fp::from_raw(coset_offset); + let mut pts = Vec::with_capacity(n); + let mut cur = g; + for _ in 0..n { + pts.push(cur); + cur = &cur * ω + } + pts +} + +/// CPU barycentric eval for a single base-field column. +/// Mirrors the prover's `get_trace_evaluations_from_lde` inner loop: +/// col_scale[i] = point[i] * inv_denom[i] +/// sum = Σ lde[i*blowup] * col_scale[i] (Fp × Fp3 → Fp3) +/// result = (n_inv * g_n_inv) * (z^N - g^N) * sum +fn cpu_barycentric_base( + lde_col: &[Fp], + blowup: usize, + coset_pts: &[Fp], + z: &Fp3, + coset_offset: &Fp, +) -> Fp3 { + let n = coset_pts.len(); + let n_inv = Fp::from(n as u64).inv().unwrap(); + let g_n = coset_offset.pow(n as u64); + let g_n_inv = g_n.inv().unwrap(); + let z_pow_n = z.pow(n as u64); + + let inv_denoms = + barycentric_inv_denoms::(z, coset_pts); + + let col_scale: Vec = coset_pts + .iter() + .zip(inv_denoms.iter()) + .map(|(pt, inv_d)| pt * inv_d) + .collect(); + + let sum = col_scale + .iter() + .enumerate() + .fold(Fp3::from(0u64), |acc, (i, scale)| { + acc + &lde_col[i * blowup] * scale + }); + + let vanishing = z_pow_n.sub_subfield(&g_n); + let scalar = &n_inv * &g_n_inv; + &scalar * &(&vanishing * &sum) +} + +/// GPU barycentric eval for a single column via `barycentric_base` kernel, +/// followed by the host-side vanishing scaling that the prover applies. +fn gpu_barycentric_base( + lde_col: &[Fp], + blowup: usize, + coset_pts: &[Fp], + z: &Fp3, + coset_offset: &Fp, +) -> Fp3 { + let n = coset_pts.len(); + + let n_inv = Fp::from(n as u64).inv().unwrap(); + let g_n = coset_offset.pow(n as u64); + let g_n_inv = g_n.inv().unwrap(); + let z_pow_n = z.pow(n as u64); + + let inv_denoms_fp3 = + barycentric_inv_denoms::(z, coset_pts); + + // Pack for GPU: coset_points as u64, inv_denoms interleaved ext3 u64. + let pts_u64: Vec = coset_pts.iter().map(|p| *p.value()).collect(); + let inv_u64: Vec = inv_denoms_fp3 + .iter() + .flat_map(|e| { + [ + *e.value()[0].value(), + *e.value()[1].value(), + *e.value()[2].value(), + ] + }) + .collect(); + + // Pre-strided column (trace points at stride blowup). + let pre_strided: Vec = (0..n).map(|i| *lde_col[i * blowup].value()).collect(); + + let raw = math_cuda::barycentric::barycentric_base(&pre_strided, n, &pts_u64, &inv_u64, n, 1) + .expect("GPU barycentric_base"); + + // raw is 3 u64s (ext3 interleaved): the unscaled sum S. + // The prover then applies: result = scalar * (vanishing * S) + // where scalar = n_inv * g_n_inv, vanishing = z^N - g^N. + let s = Fp3::new([ + Fp::from_raw(raw[0]), + Fp::from_raw(raw[1]), + Fp::from_raw(raw[2]), + ]); + let vanishing = z_pow_n.sub_subfield(&g_n); + let scalar = &n_inv * &g_n_inv; + &scalar * &(&vanishing * &s) +} + +#[test] +fn gpu_barycentric_base_matches_cpu() { + const COSET_OFFSET: u64 = 7; + + for log_n in [4usize, 6, 8] { + for blowup in [2usize, 4] { + let n = 1usize << log_n; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64((log_n * 100 + blowup) as u64); + + let lde_col: Vec = (0..lde_size).map(|_| rand_fp(&mut rng)).collect(); + let z = rand_fp3(&mut rng); + let coset_offset = Fp::from_raw(COSET_OFFSET); + let pts = coset_points(n, COSET_OFFSET); + + let cpu = cpu_barycentric_base(&lde_col, blowup, &pts, &z, &coset_offset); + let gpu = gpu_barycentric_base(&lde_col, blowup, &pts, &z, &coset_offset); + + for k in 0..3 { + let cpu_k = *cpu.value()[k].value(); + let gpu_k = *gpu.value()[k].value(); + let cpu_c = GoldilocksField::canonical(&cpu_k); + let gpu_c = GoldilocksField::canonical(&gpu_k); + assert_eq!( + cpu_c, gpu_c, + "component {k} mismatch: log_n={log_n} blowup={blowup} \ + cpu={cpu_c} gpu={gpu_c}" + ); + } + } + } +} + +// ── Ext3 aux path ───────────────────────────────────────────────────────────── + +/// CPU barycentric for a single ext3 column (aux trace path). +fn cpu_barycentric_ext3( + lde_col: &[Fp3], + blowup: usize, + coset_pts: &[Fp], + z: &Fp3, + coset_offset: &Fp, +) -> Fp3 { + let n = coset_pts.len(); + let n_inv = Fp::from(n as u64).inv().unwrap(); + let g_n = coset_offset.pow(n as u64); + let g_n_inv = g_n.inv().unwrap(); + let z_pow_n = z.pow(n as u64); + + let inv_denoms = + barycentric_inv_denoms::(z, coset_pts); + + let col_scale: Vec = coset_pts + .iter() + .zip(inv_denoms.iter()) + .map(|(pt, inv_d)| pt * inv_d) + .collect(); + + let sum = col_scale + .iter() + .enumerate() + .fold(Fp3::from(0u64), |acc, (i, scale)| { + acc + scale * &lde_col[i * blowup] + }); + + let vanishing = z_pow_n.sub_subfield(&g_n); + let scalar = &n_inv * &g_n_inv; + &scalar * &(&vanishing * &sum) +} + +/// GPU barycentric for a single ext3 column via `barycentric_ext3` kernel. +fn gpu_barycentric_ext3( + lde_col: &[Fp3], + blowup: usize, + coset_pts: &[Fp], + z: &Fp3, + coset_offset: &Fp, +) -> Fp3 { + let n = coset_pts.len(); + let n_inv = Fp::from(n as u64).inv().unwrap(); + let g_n = coset_offset.pow(n as u64); + let g_n_inv = g_n.inv().unwrap(); + let z_pow_n = z.pow(n as u64); + + let inv_denoms_fp3 = + barycentric_inv_denoms::(z, coset_pts); + + let pts_u64: Vec = coset_pts.iter().map(|p| *p.value()).collect(); + let inv_u64: Vec = inv_denoms_fp3 + .iter() + .flat_map(|e| { + [ + *e.value()[0].value(), + *e.value()[1].value(), + *e.value()[2].value(), + ] + }) + .collect(); + + // Pre-strided ext3 in the de-interleaved (component-major) layout the + // kernel expects: slab k at offset k*n holds component k of all n points. + let mut pre_strided: Vec = vec![0u64; 3 * n]; + for i in 0..n { + let e = &lde_col[i * blowup]; + pre_strided[i] = *e.value()[0].value(); + pre_strided[n + i] = *e.value()[1].value(); + pre_strided[2 * n + i] = *e.value()[2].value(); + } + + let raw = math_cuda::barycentric::barycentric_ext3(&pre_strided, n, &pts_u64, &inv_u64, n, 1) + .expect("GPU barycentric_ext3"); + + let s = Fp3::new([ + Fp::from_raw(raw[0]), + Fp::from_raw(raw[1]), + Fp::from_raw(raw[2]), + ]); + let vanishing = z_pow_n.sub_subfield(&g_n); + let scalar = &n_inv * &g_n_inv; + &scalar * &(&vanishing * &s) +} + +#[test] +fn gpu_barycentric_ext3_matches_cpu() { + const COSET_OFFSET: u64 = 7; + + for log_n in [4usize, 6, 8] { + for blowup in [2usize, 4] { + let n = 1usize << log_n; + let lde_size = n * blowup; + let mut rng = ChaCha8Rng::seed_from_u64((log_n * 100 + blowup + 5000) as u64); + + let lde_col: Vec = (0..lde_size).map(|_| rand_fp3(&mut rng)).collect(); + let z = rand_fp3(&mut rng); + let coset_offset = Fp::from_raw(COSET_OFFSET); + let pts = coset_points(n, COSET_OFFSET); + + let cpu = cpu_barycentric_ext3(&lde_col, blowup, &pts, &z, &coset_offset); + let gpu = gpu_barycentric_ext3(&lde_col, blowup, &pts, &z, &coset_offset); + + for k in 0..3 { + let cpu_k = *cpu.value()[k].value(); + let gpu_k = *gpu.value()[k].value(); + let cpu_c = GoldilocksField::canonical(&cpu_k); + let gpu_c = GoldilocksField::canonical(&gpu_k); + assert_eq!( + cpu_c, gpu_c, + "ext3 component {k} mismatch: log_n={log_n} blowup={blowup} \ + cpu={cpu_c} gpu={gpu_c}" + ); + } + } + } +} diff --git a/crypto/math-cuda/tests/merkle_root_parity.rs b/crypto/math-cuda/tests/merkle_root_parity.rs new file mode 100644 index 000000000..72e2aaea4 --- /dev/null +++ b/crypto/math-cuda/tests/merkle_root_parity.rs @@ -0,0 +1,382 @@ +//! GPU LDE + GPU Keccak leaf hash + GPU Merkle tree must produce the same root +//! as the CPU row-major LDE path (`coset_lde_full_expand_row_major` + +//! `commit_rows_bit_reversed`). Covers base field (main trace) and ext3 (aux trace). +//! +//! Two non-obvious layout details caught while writing these tests: +//! - `build_merkle_tree_on_device` stores the tree top-down: root at `nodes[0..32]`, +//! leaves in the tail (not the end). +//! - `keccak_leaves_ext3` expects component-major layout `[all-a, all-b, all-c]`, +//! not the interleaved `[a,b,c per element]` that `coset_lde_batch_ext3_into` produces. + +use math::fft::two_half_fft::TwoHalfTwiddles; +use math::field::element::FieldElement; +use math::field::extensions_goldilocks::Degree3GoldilocksExtensionField; +use math::field::goldilocks::GoldilocksField; +use math::polynomial::Polynomial; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use stark::prover::{IsStarkProver, Prover}; + +type Fp3 = FieldElement; + +type Fp = FieldElement; + +fn coset_weights(n: usize, g: u64) -> Vec { + let inv_n = Fp::from(n as u64).inv().unwrap(); + let g_fp = Fp::from_raw(g); + let mut w = Vec::with_capacity(n); + let mut cur = inv_n; + for _ in 0..n { + w.push(cur); + cur = &cur * &g_fp; + } + w +} + +fn coset_weights_u64(n: usize, g: u64) -> Vec { + coset_weights(n, g).iter().map(|w| *w.value()).collect() +} + +/// Run GPU batch LDE + GPU Keccak leaf hashing + GPU Merkle tree build. +/// Returns the 32-byte root extracted from the node array. +fn gpu_merkle_root(columns: &[Vec], blowup: usize, weights: &[u64]) -> [u8; 32] { + let col_slices: Vec<&[u64]> = columns.iter().map(|c| c.as_slice()).collect(); + let lde_columns = + math_cuda::lde::coset_lde_batch_base(&col_slices, blowup, weights).expect("GPU batch LDE"); + + let n_lde = lde_columns[0].len(); + let num_cols = lde_columns.len(); + + // Pack into column-major flat layout: [col * stride + row]. + let mut flat = vec![0u64; num_cols * n_lde]; + for (c, col) in lde_columns.iter().enumerate() { + for (r, &v) in col.iter().enumerate() { + flat[c * n_lde + r] = v; + } + } + + let gpu_leaves = math_cuda::merkle::keccak_leaves_base(&flat, n_lde, num_cols, n_lde) + .expect("GPU keccak leaves"); + let nodes = + math_cuda::merkle::build_merkle_tree_on_device(&gpu_leaves).expect("GPU Merkle tree"); + + // `build_merkle_tree_on_device` places the root at index 0 (the leaves + // live in the tail), so the root is the first 32 bytes of the node array. + let mut root = [0u8; 32]; + root.copy_from_slice(&nodes[0..32]); + root +} + +/// Run the new CPU row-major LDE (`coset_lde_full_expand_row_major`) + +/// `commit_rows_bit_reversed` and return the Merkle root. +fn cpu_row_major_merkle_root( + columns: &[Vec], + blowup: usize, + weights: &[Fp], + inv_tw: &TwoHalfTwiddles, + fwd_tw: &TwoHalfTwiddles, +) -> [u8; 32] { + let n = columns[0].len(); + let num_cols = columns.len(); + + // Build row-major buffer: data[row * num_cols + col] = columns[col][row]. + let mut buf: Vec = vec![Fp::from(0u64); n * num_cols]; + for (c, col) in columns.iter().enumerate() { + for (r, &v) in col.iter().enumerate() { + buf[r * num_cols + c] = Fp::from_raw(v); + } + } + + Polynomial::::coset_lde_full_expand_row_major::( + &mut buf, num_cols, blowup, weights, inv_tw, fwd_tw, + ) + .expect("CPU row-major LDE"); + + let (_, root) = + Prover::::commit_rows_bit_reversed(&buf, num_cols) + .expect("CPU commit"); + + root +} + +#[test] +fn gpu_and_cpu_row_major_merkle_roots_match() { + const COSET_OFFSET: u64 = 7; + + for log_n in [4usize, 6, 8, 10] { + for blowup in [2usize, 4] { + for num_cols in [1usize, 3, 8] { + let n = 1usize << log_n; + let log_lde = (n * blowup).trailing_zeros() as usize; + let mut rng = + ChaCha8Rng::seed_from_u64((log_n * 1000 + blowup * 100 + num_cols) as u64); + + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rng.r#gen::()).collect()) + .collect(); + + let weights_u64 = coset_weights_u64(n, COSET_OFFSET); + let weights_fp = coset_weights(n, COSET_OFFSET); + let inv_tw = + TwoHalfTwiddles::::new(log_n, true).expect("inv twiddles"); + let fwd_tw = + TwoHalfTwiddles::::new(log_lde, false).expect("fwd twiddles"); + + let gpu_root = gpu_merkle_root(&columns, blowup, &weights_u64); + let cpu_root = + cpu_row_major_merkle_root(&columns, blowup, &weights_fp, &inv_tw, &fwd_tw); + + assert_eq!( + gpu_root, cpu_root, + "root mismatch: log_n={log_n} blowup={blowup} num_cols={num_cols}" + ); + } + } + } +} + +// ── Ext3 helpers ───────────────────────────────────────────────────────────── + +fn rand_ext3(rng: &mut ChaCha8Rng) -> Fp3 { + Fp3::new([ + FieldElement::::from_raw(rng.r#gen::()), + FieldElement::::from_raw(rng.r#gen::()), + FieldElement::::from_raw(rng.r#gen::()), + ]) +} + +fn ext3_to_u64s(col: &[Fp3]) -> Vec { + let mut out = Vec::with_capacity(col.len() * 3); + for e in col { + out.push(*e.value()[0].value()); + out.push(*e.value()[1].value()); + out.push(*e.value()[2].value()); + } + out +} + +/// GPU ext3 LDE + Keccak leaf hash + Merkle tree → root. +fn gpu_ext3_merkle_root(columns: &[Vec], blowup: usize, weights: &[u64]) -> [u8; 32] { + let n = columns[0].len(); + let lde_size = n * blowup; + let num_cols = columns.len(); + + let flat_inputs: Vec> = columns.iter().map(|c| ext3_to_u64s(c)).collect(); + let input_slices: Vec<&[u64]> = flat_inputs.iter().map(|v| v.as_slice()).collect(); + + let mut flat_outputs: Vec> = (0..num_cols).map(|_| vec![0u64; 3 * lde_size]).collect(); + { + let mut out_slices: Vec<&mut [u64]> = + flat_outputs.iter_mut().map(|v| v.as_mut_slice()).collect(); + math_cuda::lde::coset_lde_batch_ext3_into( + &input_slices, + n, + blowup, + weights, + &mut out_slices, + ) + .expect("GPU ext3 LDE"); + } + + // Repack from interleaved [a,b,c per element] to component-major + // [all-a, all-b, all-c] as keccak_leaves_ext3 expects. + let mut flat_for_keccak = vec![0u64; num_cols * 3 * lde_size]; + for (c, out) in flat_outputs.iter().enumerate() { + for r in 0..lde_size { + flat_for_keccak[(c * 3) * lde_size + r] = out[r * 3]; + flat_for_keccak[(c * 3 + 1) * lde_size + r] = out[r * 3 + 1]; + flat_for_keccak[(c * 3 + 2) * lde_size + r] = out[r * 3 + 2]; + } + } + + let gpu_leaves = + math_cuda::merkle::keccak_leaves_ext3(&flat_for_keccak, lde_size, num_cols, lde_size) + .expect("GPU ext3 keccak leaves"); + let nodes = + math_cuda::merkle::build_merkle_tree_on_device(&gpu_leaves).expect("GPU Merkle tree"); + + let mut root = [0u8; 32]; + root.copy_from_slice(&nodes[0..32]); + root +} + +/// CPU row-major ext3 LDE + `commit_rows_bit_reversed` → root. +fn cpu_ext3_row_major_merkle_root( + columns: &[Vec], + blowup: usize, + weights: &[FieldElement], + inv_tw: &TwoHalfTwiddles, + fwd_tw: &TwoHalfTwiddles, +) -> [u8; 32] { + let n = columns[0].len(); + let num_cols = columns.len(); + + let mut buf: Vec = vec![Fp3::from(0u64); n * num_cols]; + for (c, col) in columns.iter().enumerate() { + for (r, v) in col.iter().enumerate() { + buf[r * num_cols + c] = *v; + } + } + + Polynomial::::coset_lde_full_expand_row_major::( + &mut buf, num_cols, blowup, weights, inv_tw, fwd_tw, + ) + .expect("CPU ext3 row-major LDE"); + + let (_, root) = + Prover::::commit_rows_bit_reversed( + &buf, num_cols, + ) + .expect("CPU ext3 commit"); + + root +} + +#[test] +fn gpu_and_cpu_ext3_merkle_roots_match() { + const COSET_OFFSET: u64 = 7; + + for log_n in [4usize, 6, 8] { + for blowup in [2usize, 4] { + for num_cols in [1usize, 3, 5] { + let n = 1usize << log_n; + let log_lde = (n * blowup).trailing_zeros() as usize; + let mut rng = ChaCha8Rng::seed_from_u64( + (log_n * 1000 + blowup * 100 + num_cols) as u64 + 9999, + ); + + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rand_ext3(&mut rng)).collect()) + .collect(); + + let weights_u64 = coset_weights_u64(n, COSET_OFFSET); + let weights_fp = coset_weights(n, COSET_OFFSET); + let inv_tw = + TwoHalfTwiddles::::new(log_n, true).expect("inv twiddles"); + let fwd_tw = + TwoHalfTwiddles::::new(log_lde, false).expect("fwd twiddles"); + + let gpu_root = gpu_ext3_merkle_root(&columns, blowup, &weights_u64); + let cpu_root = + cpu_ext3_row_major_merkle_root(&columns, blowup, &weights_fp, &inv_tw, &fwd_tw); + + assert_eq!( + gpu_root, cpu_root, + "ext3 root mismatch: log_n={log_n} blowup={blowup} num_cols={num_cols}" + ); + } + } + } +} + +// ── New row-major pipeline tests ───────────────────────────────────────────── + +#[test] +fn new_row_major_pipeline_base_root_matches_cpu() { + const COSET_OFFSET: u64 = 7; + + for log_n in [4usize, 6, 8, 10] { + for blowup in [2usize, 4] { + for num_cols in [1usize, 3, 8] { + let n = 1usize << log_n; + let log_lde = (n * blowup).trailing_zeros() as usize; + let mut rng = ChaCha8Rng::seed_from_u64( + (log_n * 1000 + blowup * 100 + num_cols) as u64 + 10000, + ); + + let row_major: Vec = (0..n * num_cols).map(|_| rng.r#gen::()).collect(); + + let weights_u64 = coset_weights_u64(n, COSET_OFFSET); + let weights_fp = coset_weights(n, COSET_OFFSET); + let inv_tw = + TwoHalfTwiddles::::new(log_n, true).expect("inv twiddles"); + let fwd_tw = + TwoHalfTwiddles::::new(log_lde, false).expect("fwd twiddles"); + + let (nodes, _handle, _lde) = + math_cuda::lde::coset_lde_row_major_with_merkle_tree_keep( + &row_major, + n, + num_cols, + blowup, + &weights_u64, + ) + .expect("new row-major GPU pipeline"); + let mut gpu_root = [0u8; 32]; + gpu_root.copy_from_slice(&nodes[0..32]); + + let cpu_root = cpu_row_major_merkle_root( + &(0..num_cols) + .map(|c| (0..n).map(|r| row_major[r * num_cols + c]).collect()) + .collect::>>(), + blowup, + &weights_fp, + &inv_tw, + &fwd_tw, + ); + + assert_eq!( + gpu_root, cpu_root, + "new row-major pipeline root mismatch: log_n={log_n} blowup={blowup} num_cols={num_cols}" + ); + } + } + } +} + +#[test] +fn new_row_major_pipeline_ext3_root_matches_cpu() { + const COSET_OFFSET: u64 = 7; + + for log_n in [4usize, 6, 8] { + for blowup in [2usize, 4] { + for num_cols in [1usize, 3, 5] { + let n = 1usize << log_n; + let log_lde = (n * blowup).trailing_zeros() as usize; + let mut rng = ChaCha8Rng::seed_from_u64( + (log_n * 1000 + blowup * 100 + num_cols) as u64 + 20000, + ); + + let columns: Vec> = (0..num_cols) + .map(|_| (0..n).map(|_| rand_ext3(&mut rng)).collect()) + .collect(); + + let mut row_major: Vec = Vec::with_capacity(n * num_cols * 3); + for r in 0..n { + for col in &columns { + row_major.push(*col[r].value()[0].value()); + row_major.push(*col[r].value()[1].value()); + row_major.push(*col[r].value()[2].value()); + } + } + + let weights_u64 = coset_weights_u64(n, COSET_OFFSET); + let weights_fp = coset_weights(n, COSET_OFFSET); + let inv_tw = + TwoHalfTwiddles::::new(log_n, true).expect("inv twiddles"); + let fwd_tw = + TwoHalfTwiddles::::new(log_lde, false).expect("fwd twiddles"); + + let (nodes, _handle, _lde) = + math_cuda::lde::coset_lde_ext3_row_major_with_merkle_tree_keep( + &row_major, + n, + num_cols, + blowup, + &weights_u64, + ) + .expect("new ext3 row-major GPU pipeline"); + let mut gpu_root = [0u8; 32]; + gpu_root.copy_from_slice(&nodes[0..32]); + + let cpu_root = + cpu_ext3_row_major_merkle_root(&columns, blowup, &weights_fp, &inv_tw, &fwd_tw); + + assert_eq!( + gpu_root, cpu_root, + "new ext3 row-major pipeline root mismatch: log_n={log_n} blowup={blowup} num_cols={num_cols}" + ); + } + } + } +} diff --git a/crypto/stark/src/gpu_lde.rs b/crypto/stark/src/gpu_lde.rs index 36756b40b..a9925e252 100644 --- a/crypto/stark/src/gpu_lde.rs +++ b/crypto/stark/src/gpu_lde.rs @@ -41,13 +41,25 @@ use crate::trace::LDETraceTable; const DEFAULT_GPU_LDE_THRESHOLD: usize = 1 << 19; fn gpu_lde_threshold() -> usize { - static CACHED: OnceLock = OnceLock::new(); - *CACHED.get_or_init(|| { + // In test builds re-read the env var on every call so tests can switch + // between GPU and CPU paths in the same process (OnceLock can't be reset). + #[cfg(test)] + { std::env::var("LAMBDA_VM_GPU_LDE_THRESHOLD") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(DEFAULT_GPU_LDE_THRESHOLD) - }) + } + #[cfg(not(test))] + { + static CACHED: OnceLock = OnceLock::new(); + *CACHED.get_or_init(|| { + std::env::var("LAMBDA_VM_GPU_LDE_THRESHOLD") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(DEFAULT_GPU_LDE_THRESHOLD) + }) + } } /// Incremented by the `try_expand_*` functions per base-field column handed to @@ -451,120 +463,140 @@ pub fn gpu_leaf_hash_calls() -> u64 { GPU_LEAF_HASH_CALLS.load(Ordering::Relaxed) } -/// Fused base-field path: LDE + Keccak-256 leaf hash + Merkle tree build, -/// all on device, with the LDE buffer retained for R2–R4 GPU reuse. On -/// success: `columns[c]` is resized to `lde_size` with the LDE output, and -/// the returned `(tree, GpuLdeBase)` pair is the host-side tree plus a -/// device-resident handle to the LDE buffer. -pub(crate) fn try_expand_leaf_and_tree_batched_keep( - columns: &mut [Vec>], +/// Row-major GPU path: single H2D → row-major NTT → row-major Keccak → +/// Merkle → single D2H. No column extraction or CPU-side transpose. +pub(crate) fn try_expand_leaf_and_tree_row_major_keep( + row_major: &[FieldElement], + n: usize, + m: usize, blowup_factor: usize, weights: &[FieldElement], -) -> Option<(MerkleTree, math_cuda::lde::GpuLdeBase)> +) -> Option<( + MerkleTree, + math_cuda::lde::GpuLdeBase, + Vec>, +)> where F: IsField + 'static, E: IsField + 'static, B: IsMerkleTreeBackend, { - let (n, lde_size) = match check_base_layout::(columns, blowup_factor) { - LayoutDispatch::Empty | LayoutDispatch::Skip => return None, - LayoutDispatch::Run { n, lde_size } => (n, lde_size), - }; - let num_columns = columns.len(); - let (mut nodes, total_nodes) = alloc_merkle_nodes(lde_size)?; - let node_byte_len = total_nodes - .checked_mul(32) - .expect("node byte length overflow"); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + if row_major.len() != n * m || m == 0 || n == 0 { + return None; + } - // SAFETY: layout-checked above. - let raw_columns = unsafe { columns_to_u64_base::(columns) }; + let raw: &[u64] = unsafe { from_raw_parts(row_major.as_ptr() as *const u64, n * m) }; let weights_u64 = unsafe { weights_to_u64::(weights) }; - let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - GPU_LDE_CALLS.fetch_add(num_columns as u64, Ordering::Relaxed); + GPU_LDE_CALLS.fetch_add(m as u64, Ordering::Relaxed); GPU_LEAF_HASH_CALLS.fetch_add(1, Ordering::Relaxed); GPU_MERKLE_TREE_CALLS.fetch_add(1, Ordering::Relaxed); - let handle_result = { - let mut raw_outputs = unsafe { presize_and_view_base::(columns, lde_size) }; - let nodes_bytes: &mut [u8] = - unsafe { from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) }; - math_cuda::lde::coset_lde_batch_base_into_with_merkle_tree_keep( - &slices, - blowup_factor, - &weights_u64, - &mut raw_outputs, - nodes_bytes, + let (nodes_bytes, handle, lde_u64) = math_cuda::lde::coset_lde_row_major_with_merkle_tree_keep( + raw, + n, + m, + blowup_factor, + &weights_u64, + ) + .ok()?; + + // Transmute Vec → Vec> (zero-copy, E == GoldilocksField). + let lde_out: Vec> = unsafe { + let mut v = std::mem::ManuallyDrop::new(lde_u64); + Vec::from_raw_parts( + v.as_mut_ptr() as *mut FieldElement, + v.len(), + v.capacity(), ) }; - let handle = match handle_result { - Ok(h) => h, - Err(_) => { - restore_columns_on_err(columns, n); - return None; - } - }; + let nodes: Vec<[u8; 32]> = nodes_bytes + .chunks_exact(32) + .map(|c| c.try_into().expect("32-byte chunk")) + .collect(); let tree = MerkleTree::::from_precomputed_nodes(nodes)?; - Some((tree, handle)) + Some((tree, handle, lde_out)) } -/// Fused ext3 path: LDE + Keccak-256 leaf hash + Merkle tree build over -/// ext3 columns via the three-slab decomposition, with the ext3 LDE device -/// buffer (de-interleaved 3-slab layout) retained for downstream GPU rounds. -/// `B::Node = [u8; 32]` by construction for `BatchKeccak256Backend`. -pub(crate) fn try_expand_leaf_and_tree_batched_ext3_keep( - columns: &mut [Vec>], +/// Row-major ext3 GPU path: single H2D → row-major NTT (m*3 base-field cols) → +/// row-major Keccak → Merkle → single D2H → transpose to GpuLdeExt3 handle. +/// Same optimization as the base-field path: no extract_columns, no CPU transpose. +pub(crate) fn try_expand_leaf_and_tree_ext3_row_major_keep( + row_major: &[FieldElement], + n: usize, + m: usize, blowup_factor: usize, weights: &[FieldElement], -) -> Option<(MerkleTree, math_cuda::lde::GpuLdeExt3)> +) -> Option<( + MerkleTree, + math_cuda::lde::GpuLdeExt3, + Vec>, +)> where F: IsField + 'static, E: IsField + 'static, B: IsMerkleTreeBackend, { - let (n, lde_size) = match check_ext3_layout::(columns, blowup_factor) { - LayoutDispatch::Empty | LayoutDispatch::Skip => return None, - LayoutDispatch::Run { n, lde_size } => (n, lde_size), - }; - let num_columns = columns.len(); - let (mut nodes, total_nodes) = alloc_merkle_nodes(lde_size)?; - let node_byte_len = total_nodes - .checked_mul(32) - .expect("node byte length overflow"); + let lde_size = n.saturating_mul(blowup_factor); + if lde_size < gpu_lde_threshold() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + if TypeId::of::() != TypeId::of::() { + return None; + } + if row_major.len() != n * m || m == 0 || n == 0 { + return None; + } - // SAFETY: layout-checked above. - let raw_columns = unsafe { columns_to_u64_ext3::(columns) }; + // Fp3 = [u64; 3] in memory — reinterpret as flat u64 slice (m3 = m*3). + let m3 = m * 3; + let raw: &[u64] = unsafe { from_raw_parts(row_major.as_ptr() as *const u64, n * m3) }; let weights_u64 = unsafe { weights_to_u64::(weights) }; - let slices: Vec<&[u64]> = raw_columns.iter().map(|c| c.as_slice()).collect(); - GPU_LDE_CALLS.fetch_add((num_columns * 3) as u64, Ordering::Relaxed); + GPU_LDE_CALLS.fetch_add((m * 3) as u64, Ordering::Relaxed); GPU_LEAF_HASH_CALLS.fetch_add(1, Ordering::Relaxed); GPU_MERKLE_TREE_CALLS.fetch_add(1, Ordering::Relaxed); - let handle_result = { - let mut raw_outputs = unsafe { presize_and_view_ext3::(columns, lde_size) }; - let nodes_bytes: &mut [u8] = - unsafe { from_raw_parts_mut(nodes.as_mut_ptr() as *mut u8, node_byte_len) }; - math_cuda::lde::coset_lde_batch_ext3_into_with_merkle_tree_keep( - &slices, + let (nodes_bytes, handle, lde_u64) = + math_cuda::lde::coset_lde_ext3_row_major_with_merkle_tree_keep( + raw, n, + m, blowup_factor, &weights_u64, - &mut raw_outputs, - nodes_bytes, ) - }; - let handle = match handle_result { - Ok(h) => h, - Err(_) => { - restore_columns_on_err(columns, n); - return None; - } + .ok()?; + + // Transmute Vec → Vec> (zero-copy, E == Fp3 = [u64;3]). + let lde_out: Vec> = unsafe { + let mut v = std::mem::ManuallyDrop::new(lde_u64); + Vec::from_raw_parts( + v.as_mut_ptr() as *mut FieldElement, + v.len() / 3, + v.capacity() / 3, + ) }; + let nodes: Vec<[u8; 32]> = nodes_bytes + .chunks_exact(32) + .map(|c| c.try_into().expect("32-byte chunk")) + .collect(); let tree = MerkleTree::::from_precomputed_nodes(nodes)?; - Some((tree, handle)) + Some((tree, handle, lde_out)) } /// Ext3 specialisation of [`try_expand_columns_batched`]. `E` is known to be diff --git a/crypto/stark/src/prover.rs b/crypto/stark/src/prover.rs index eed0e512a..9b716ff13 100644 --- a/crypto/stark/src/prover.rs +++ b/crypto/stark/src/prover.rs @@ -892,29 +892,40 @@ pub trait IsStarkProver< { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; - // Fused GPU path (cuda only): extract columns and try the on-device - // pipeline; on success it returns the LDE + tree directly. + // Fused GPU path (cuda only): row-major NTT — single H2D from the + // already-row-major trace, no column extraction, no transpose. + // Falls back to CPU if GPU path returns None. #[cfg(feature = "cuda")] if precomputed.is_none() { - let mut columns = trace.extract_columns_main(lde_size); + let (trace_slice, num_cols) = trace.main_data_row_major(); + let n = if num_cols > 0 { + trace_slice.len() / num_cols + } else { + 0 + }; #[cfg(feature = "instruments")] let t_sub = Instant::now(); - if let Some((tree, handle)) = - crate::gpu_lde::try_expand_leaf_and_tree_batched_keep::< + if let Some((tree, handle, main_data)) = + crate::gpu_lde::try_expand_leaf_and_tree_row_major_keep::< Field, Field, BatchedMerkleTreeBackend, - >(&mut columns, domain.blowup_factor, &twiddles.coset_weights) + >( + trace_slice, + n, + num_cols, + domain.blowup_factor, + &twiddles.coset_weights, + ) { #[cfg(feature = "instruments")] let main_lde_dur = t_sub.elapsed(); let root = tree.root; #[cfg(feature = "instruments")] - crate::instruments::accum_r1_main(main_lde_dur, main_lde_dur); - let (main_data, total_cols) = columns_to_row_major(&columns); + crate::instruments::accum_r1_main(main_lde_dur, std::time::Duration::ZERO); return Ok(( TableCommit::plain(tree, root), - (main_data, total_cols), + (main_data, num_cols), Some(handle), )); } @@ -2210,20 +2221,22 @@ pub trait IsStarkProver< if air.has_aux_trace() { let lde_size = domain.interpolation_domain_size * domain.blowup_factor; - // Fused GPU path (cuda only): extract columns and try the - // on-device ext3 pipeline; on success it returns directly. + // Fused GPU path (cuda only): row-major ext3 NTT — single + // H2D, no column extraction, no CPU transpose. #[cfg(feature = "cuda")] { - let mut columns = trace.extract_columns_aux(lde_size); + let (trace_slice, num_cols) = trace.aux_data_row_major(); + let n = if num_cols > 0 { trace_slice.len() / num_cols } else { 0 }; #[cfg(feature = "instruments")] let t_sub = Instant::now(); - if let Some((tree, handle)) = - crate::gpu_lde::try_expand_leaf_and_tree_batched_ext3_keep::< + if let Some((tree, handle, aux_data)) = + crate::gpu_lde::try_expand_leaf_and_tree_ext3_row_major_keep::< Field, FieldExtension, BatchedMerkleTreeBackend, >( - &mut columns, domain.blowup_factor, &twiddles.coset_weights + trace_slice, n, num_cols, domain.blowup_factor, + &twiddles.coset_weights, ) { #[cfg(feature = "instruments")] @@ -2231,10 +2244,9 @@ pub trait IsStarkProver< let root = tree.root; #[cfg(feature = "instruments")] crate::instruments::accum_r1_aux(aux_lde_dur, Duration::ZERO); - let (aux_data, total_cols) = columns_to_row_major(&columns); return Ok(( Some(TableCommit::plain(tree, root)), - (aux_data, total_cols), + (aux_data, num_cols), Some(handle), )); } diff --git a/crypto/stark/src/tests/prover_tests.rs b/crypto/stark/src/tests/prover_tests.rs index 318dacb81..ab3589702 100644 --- a/crypto/stark/src/tests/prover_tests.rs +++ b/crypto/stark/src/tests/prover_tests.rs @@ -562,3 +562,41 @@ fn test_deep_poly_direct_2n_matches_interpolate_fft_extend() { ); } } + +#[test] +fn commit_rows_bit_reversed_matches_commit_columns_bit_reversed() { + type F = GoldilocksField; + type FE = FieldElement; + + for num_cols in [1usize, 3, 7] { + for log_rows in [4usize, 6, 8] { + let num_rows = 1usize << log_rows; + + let columns: Vec> = (0..num_cols) + .map(|c| { + (0..num_rows) + .map(|r| FE::from((c * num_rows + r) as u64 * 6700417 + 1)) + .collect() + }) + .collect(); + + // Row-major interleaving: data[row * num_cols + col] = columns[col][row]. + let mut row_major: Vec = Vec::with_capacity(num_rows * num_cols); + for r in 0..num_rows { + for col in &columns { + row_major.push(col[r]); + } + } + + let (_, root_col) = Prover::::commit_columns_bit_reversed(&columns) + .expect("column-major commit must succeed"); + let (_, root_row) = Prover::::commit_rows_bit_reversed(&row_major, num_cols) + .expect("row-major commit must succeed"); + + assert_eq!( + root_col, root_row, + "commit root mismatch: num_cols={num_cols} log_rows={log_rows}" + ); + } + } +} diff --git a/prover/src/instruments.rs b/prover/src/instruments.rs index f15223e18..a33fd3dad 100644 --- a/prover/src/instruments.rs +++ b/prover/src/instruments.rs @@ -77,19 +77,27 @@ pub fn print_report( row_top("Round 1", round1, total); row_sub(" Main trace commits", mp.main_commits, total); row_sub( - " Main expand_columns_to_lde", + " Main LDE (fused GPU: LDE+Keccak+Merkle / CPU: LDE only)", mp.round1_sub.main_lde, total, ); - row_sub(" Main commit (Merkle)", mp.round1_sub.main_merkle, total); + row_sub( + " Main commit (Merkle, CPU only)", + mp.round1_sub.main_merkle, + total, + ); row_sub(" Aux trace build (parallel)", mp.aux_build, total); row_sub(" Aux trace commit", mp.aux_commit, total); row_sub( - " Aux expand_columns_to_lde", + " Aux LDE (fused GPU: LDE+Keccak+Merkle / CPU: LDE only)", mp.round1_sub.aux_lde, total, ); - row_sub(" Aux commit (Merkle)", mp.round1_sub.aux_merkle, total); + row_sub( + " Aux commit (Merkle, CPU only)", + mp.round1_sub.aux_merkle, + total, + ); row_top("Rounds 2\u{2013}4", mp.rounds_2_4, total); // Merge split tables: MEMW[0..4] → MEMW x5 diff --git a/prover/src/tables/branch.rs b/prover/src/tables/branch.rs index 9443a81a1..6ecd42e78 100644 --- a/prover/src/tables/branch.rs +++ b/prover/src/tables/branch.rs @@ -33,9 +33,9 @@ use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing} use stark::table::TableView; use stark::trace::TraceTable; -use std::collections::HashMap; - -use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField, SHIFT_16, VmTable, alu_op}; +use super::types::{ + BusId, FE, FxHashMap, GoldilocksExtension, GoldilocksField, SHIFT_16, VmTable, alu_op, +}; // ========================================================================= // Column indices for BRANCH table @@ -161,7 +161,7 @@ pub fn generate_branch_trace( operations: &[BranchOperation], ) -> TraceTable { // Deduplicate operations: (pc, offset, register, jalr) -> multiplicity - let mut op_map: HashMap = HashMap::new(); + let mut op_map: FxHashMap = FxHashMap::default(); for op in operations { *op_map.entry(op.clone()).or_insert(0) += 1; } diff --git a/prover/src/tables/dvrm.rs b/prover/src/tables/dvrm.rs index 3da78dff5..2dcaae453 100644 --- a/prover/src/tables/dvrm.rs +++ b/prover/src/tables/dvrm.rs @@ -36,11 +36,9 @@ use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing} use stark::table::TableView; use stark::trace::TraceTable; -use std::collections::HashMap; - use super::types::{ - BusId, FE, GoldilocksExtension, GoldilocksField, NEG_INV_2_16, NEG_INV_2_32, NEG_INV_2_48, - NEG_INV_2_64, SHIFT_16, VmTable, alu_op, + BusId, FE, FxHashMap, GoldilocksExtension, GoldilocksField, NEG_INV_2_16, NEG_INV_2_32, + NEG_INV_2_48, NEG_INV_2_64, SHIFT_16, VmTable, alu_op, }; // ========================================================================= @@ -288,7 +286,7 @@ pub fn generate_dvrm_trace( operations: &[(DvrmOperation, bool)], ) -> TraceTable { // Deduplicate: (n, d, signed) -> (mu_q, mu_r) - let mut op_map: HashMap = HashMap::new(); + let mut op_map: FxHashMap = FxHashMap::default(); for (op, wants_remainder) in operations { let entry = op_map.entry(op.clone()).or_default(); diff --git a/prover/src/tables/lt.rs b/prover/src/tables/lt.rs index 02ed029bd..be0b60773 100644 --- a/prover/src/tables/lt.rs +++ b/prover/src/tables/lt.rs @@ -33,9 +33,9 @@ use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing} use stark::table::TableView; use stark::trace::TraceTable; -use std::collections::HashMap; - -use super::types::{BusId, FE, GoldilocksExtension, GoldilocksField, SHIFT_16, VmTable, alu_op}; +use super::types::{ + BusId, FE, FxHashMap, GoldilocksExtension, GoldilocksField, SHIFT_16, VmTable, alu_op, +}; // ========================================================================= // Column indices for LT table @@ -164,7 +164,7 @@ pub fn generate_lt_trace( operations: &[LtOperation], ) -> TraceTable { // Deduplicate operations: (lhs, rhs, signed) -> multiplicity - let mut op_map: HashMap = HashMap::new(); + let mut op_map: FxHashMap = FxHashMap::default(); for op in operations { *op_map.entry(op.clone()).or_insert(0) += 1; } diff --git a/prover/src/tables/mul.rs b/prover/src/tables/mul.rs index 33679211c..197a0d334 100644 --- a/prover/src/tables/mul.rs +++ b/prover/src/tables/mul.rs @@ -37,11 +37,9 @@ use stark::lookup::{BusInteraction, BusValue, LinearTerm, Multiplicity, Packing} use stark::table::TableView; use stark::trace::TraceTable; -use std::collections::HashMap; - use super::types::{ - BusId, FE, GoldilocksExtension, GoldilocksField, INV_2_32, INV_2_64, INV_2_96, INV_2_128, - NEG_INV_2_16, NEG_INV_2_32, NEG_INV_2_48, NEG_INV_2_64, NEG_INV_2_80, NEG_INV_2_96, + BusId, FE, FxHashMap, GoldilocksExtension, GoldilocksField, INV_2_32, INV_2_64, INV_2_96, + INV_2_128, NEG_INV_2_16, NEG_INV_2_32, NEG_INV_2_48, NEG_INV_2_64, NEG_INV_2_80, NEG_INV_2_96, NEG_INV_2_112, NEG_INV_2_128, SHIFT_16, VmTable, alu_op, }; @@ -296,7 +294,7 @@ pub fn generate_mul_trace( operations: &[(MulOperation, bool)], ) -> TraceTable { // Deduplicate: (lhs, lhs_signed, rhs, rhs_signed) -> (mu_lo, mu_hi) - let mut op_map: HashMap = HashMap::new(); + let mut op_map: FxHashMap = FxHashMap::default(); for (op, wants_hi) in operations { let entry = op_map.entry(op.clone()).or_default(); diff --git a/prover/src/tables/types.rs b/prover/src/tables/types.rs index d6091d0fd..2cd5db0f0 100644 --- a/prover/src/tables/types.rs +++ b/prover/src/tables/types.rs @@ -968,6 +968,61 @@ impl DecodeEntry { } } +// ========================================================================= +// Fast hashing for op-dedup multiplicity maps +// ========================================================================= + +/// Fast non-cryptographic hash for the op-dedup hot path. Skipping SipHash is +/// safe here: the maps are per-chunk (bounded ≤ `max_rows`), keyed by the +/// prover's own trace, and collisions only cost probes, never soundness. +#[derive(Default)] +pub struct FxHasher(u64); + +impl FxHasher { + const SEED: u64 = 0x51_7c_c1_b7_27_22_0a_95; + + #[inline] + fn add(&mut self, word: u64) { + self.0 = (self.0.rotate_left(5) ^ word).wrapping_mul(Self::SEED); + } +} + +impl std::hash::Hasher for FxHasher { + #[inline] + fn write(&mut self, bytes: &[u8]) { + for &b in bytes { + self.add(b as u64); + } + } + #[inline] + fn write_u8(&mut self, i: u8) { + self.add(i as u64); + } + #[inline] + fn write_u16(&mut self, i: u16) { + self.add(i as u64); + } + #[inline] + fn write_u32(&mut self, i: u32) { + self.add(i as u64); + } + #[inline] + fn write_u64(&mut self, i: u64) { + self.add(i); + } + #[inline] + fn write_usize(&mut self, i: usize) { + self.add(i as u64); + } + #[inline] + fn finish(&self) -> u64 { + self.0 + } +} + +/// `HashMap` keyed with [`FxHasher`]. +pub type FxHashMap = std::collections::HashMap>; + /// The fully sign-extended 64-bit immediate for an instruction (0 when none). fn imm_from_instruction(instruction: Instruction) -> u64 { match instruction { diff --git a/prover/tests/cuda_path_integration.rs b/prover/tests/cuda_path_integration.rs index 0f7c1f3c7..6de671cf1 100644 --- a/prover/tests/cuda_path_integration.rs +++ b/prover/tests/cuda_path_integration.rs @@ -66,3 +66,25 @@ fn gpu_path_fires_end_to_end() { let ok = verify(&proof, &elf).expect("verify"); assert!(ok, "GPU-produced proof failed verification"); } + +#[test] +#[ignore = "requires GPU; run with --ignored --nocapture"] +fn gpu_and_cpu_proofs_both_verify() { + let elf = asm_elf_bytes("fib_iterative_1M"); + + let proof_gpu = prove(&elf).expect("GPU prove"); + assert!( + verify(&proof_gpu, &elf).expect("GPU verify"), + "GPU proof failed" + ); + + // Force CPU path by pushing the GPU threshold above any real table size. + // SAFETY: no other thread reads this env var during the test. + unsafe { std::env::set_var("LAMBDA_VM_GPU_LDE_THRESHOLD", "999999999") }; + let proof_cpu = prove(&elf).expect("CPU prove"); + unsafe { std::env::remove_var("LAMBDA_VM_GPU_LDE_THRESHOLD") }; + assert!( + verify(&proof_cpu, &elf).expect("CPU verify"), + "CPU proof failed" + ); +}