Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
cc48f43
add row major batched lde fft primitives
jotabulacios Jun 8, 2026
46c0080
Make LDETraceTable row-major
jotabulacios Jun 8, 2026
a8db4ee
Wire prover to row-major batched LDE
jotabulacios Jun 8, 2026
c55dc05
read trace row major
jotabulacios Jun 8, 2026
59f7c0d
Move the batched-FFT and row-major-LDE unit tests into corresponding …
jotabulacios Jun 8, 2026
4bd73d4
resolve conflicts
jotabulacios Jun 8, 2026
ab786f5
Merge branch 'main' into perf/cpu-lde-rework
jotabulacios Jun 8, 2026
08eaa8b
fix disk-spill EmptyCommitment in row-major LDE
jotabulacios Jun 9, 2026
8dc41fb
Merge branch 'main' into perf/cpu-lde-rework
diegokingston Jun 10, 2026
bd1b9fd
Parallelize trace build and speed up op-dedup bookkeeping
jotabulacios Jun 9, 2026
e6a7297
Skip the identity multiply by alpha_powers[0] in LogUp fingerprints
jotabulacios Jun 10, 2026
26a2bcd
solve conflicts
jotabulacios Jun 16, 2026
665d82a
Remove dead FFT module and gate legacy twiddles
jotabulacios Jun 17, 2026
b69edd3
Harden parallel row-major bit-reverse permute
jotabulacios Jun 18, 2026
47ded68
Guard columns_to_row_major; clarify hasher doc
jotabulacios Jun 18, 2026
14bac6b
solve conflicts
jotabulacios Jun 18, 2026
18d1058
Merge branch 'main' into perf/cpu-lde-rework
diegokingston Jun 18, 2026
56fe834
Merge branch 'main' into perf/cpu-lde-rework
diegokingston Jun 18, 2026
3770047
Merge branch 'main' into perf/cpu-lde-rework
MauroToscano Jun 18, 2026
2191f65
Deduplicate commit_rows_bit_reversed and bit_reverse_vec
jotabulacios Jun 24, 2026
94b3f29
add gpu tests
jotabulacios Jun 24, 2026
8dc297f
Add GPU/CPU Merkle root parity test
jotabulacios Jun 24, 2026
b3d915d
Add GPU/CPU Merkle root parity tests for base and ext3 aux trace
jotabulacios Jun 24, 2026
fd5da8b
Add GPU/CPU barycentric OOD parity tests
jotabulacios Jun 24, 2026
c795edc
Fix ext3 pre-strided layout in barycentric parity test
jotabulacios Jun 24, 2026
b3d7366
Fix instruments double-billing GPU fused pipeline in R1
jotabulacios Jun 24, 2026
cb84e23
Add test verifying GPU and CPU proofs both pass verification
jotabulacios Jun 25, 2026
627856b
Clean up verbose comments in parity tests
jotabulacios Jun 25, 2026
38f5600
GPU R1 GPU: eliminate extract_columns + columns_to_row_major via on-d…
jotabulacios Jun 25, 2026
85db0e9
Revert "GPU R1 GPU: eliminate extract_columns + columns_to_row_major …
jotabulacios Jun 25, 2026
74f97db
GPU R1: row-major NTT kernel — no transpose, coalesced column access
jotabulacios Jun 25, 2026
71c310b
Fix GPU R1 row-major: transpose buf to col-major for device handle
jotabulacios Jun 25, 2026
a52f737
Fix keccak row-major launch config: use 128-thread block, not 1024
jotabulacios Jun 25, 2026
1fd3ad9
GPU R1 aux: row-major ext3 NTT reusing base-field kernels with m*3
jotabulacios Jun 25, 2026
1b24c66
Clean up GPU row-major LDE: extract transpose helper, fix zero-pad al…
jotabulacios Jun 25, 2026
ff24aae
Clean up GPU row-major LDE: extract helper, fix alloc, trim comments
jotabulacios Jun 25, 2026
ff71a65
Fix gpu_lde_threshold OnceLock: re-read env var in test builds
jotabulacios Jun 26, 2026
09bc2dc
Fix cross-stream race: synchronize after transpose before returning h…
jotabulacios Jun 26, 2026
53af5d9
Add parity tests for new row-major GPU pipeline
jotabulacios Jun 26, 2026
2613718
Remove dead batched-keep GPU LDE functions
jotabulacios Jun 26, 2026
e5d5952
fix lint
jotabulacios Jun 26, 2026
0e1c694
fix conflicts
jotabulacios Jun 26, 2026
8c87f74
fix conflicts
jotabulacios Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions crypto/math-cuda/kernels/keccak.cu
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,40 @@ extern "C" __global__ void keccak_fri_leaves_ext3(
// children: nodes[parent_begin + n_pairs .. parent_begin + 3 * n_pairs]
// parents: nodes[parent_begin .. parent_begin + n_pairs]
//
// ---------------------------------------------------------------------------
// Row-major base leaf hashing.
//
// Input layout: data[row * m + col] for `num_rows` rows and `m` columns.
// For leaf `tid`, reads the bit-reversed row `br(tid)` — a contiguous slice
// of `m` elements starting at data[br * m]. Coalesced when multiple threads
// in the same warp process consecutive `tid` values (they read non-overlapping
// rows, each a contiguous block of m u64s in order).
// ---------------------------------------------------------------------------
extern "C" __global__ void keccak256_leaves_base_row_major(
const uint64_t *data,
uint64_t m,
uint64_t num_rows,
uint64_t log_num_rows,
uint8_t *hashed_leaves_out)
{
uint64_t tid = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= num_rows) return;
uint64_t br = __brevll(tid) >> (64 - log_num_rows);
const uint64_t *row = data + br * m;

uint64_t st[25];
#pragma unroll
for (int i = 0; i < 25; ++i) st[i] = 0;

uint32_t rate_pos = 0;
for (uint64_t c = 0; c < m; ++c) {
uint64_t canon = goldilocks::canonical(row[c]);
uint64_t lane = bswap64(canon);
absorb_lane(st, rate_pos, lane);
}
finalize_keccak256(st, rate_pos, hashed_leaves_out + tid * 32);
}

// Each thread hashes one child pair → one parent. Keccak-256 of the
// concatenation of two 32-byte siblings, identical to
// `FieldElementVectorBackend::hash_new_parent` on host.
Expand Down
126 changes: 126 additions & 0 deletions crypto/math-cuda/kernels/ntt.cu
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,129 @@ extern "C" __global__ void ntt_dit_8_levels(uint64_t *x,
// Store back to the remapped row.
x[row] = tile[threadIdx.x];
}

// ============================================================================
// ROW-MAJOR BATCHED KERNELS
//
// Data layout: data[row * m + col] for n rows and m columns.
// threadIdx.x = column index → consecutive threads access consecutive columns
// of the same row → coalesced global memory access.
// Twiddle factors depend only on the butterfly position, not the column →
// one twiddle load is broadcast across the entire warp.
// ============================================================================

// Bit-reverse permute rows: swap row `row` with row `br(row)`.
// Grid: gridDim.x = ceil(m / 256), gridDim.y = min(n, 65535).
// Grid-stride loop over rows so a capped gridDim.y covers all n rows.
extern "C" __global__ void bit_reverse_row_major(uint64_t *data,
uint64_t n,
uint64_t log_n,
uint64_t m)
{
uint64_t col = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x;
if (col >= m) return;
for (uint64_t row = blockIdx.y; row < n; row += gridDim.y) {
uint64_t rev = __brevll(row) >> (64 - log_n);
if (row < rev) {
uint64_t tmp = data[row * m + col];
data[row * m + col] = data[rev * m + col];
data[rev * m + col] = tmp;
}
}
}

// One DIT butterfly level on row-major data.
// Grid: gridDim.x = ceil(m / blockDim.x), gridDim.y = min(ceil(n/2 / blockDim.y), 65535).
// blockDim.x covers columns (coalescing), blockDim.y covers butterfly pairs.
// Grid-stride loop over butterfly-pair tiles so capped gridDim.y covers all n/2 pairs.
extern "C" __global__ void ntt_dit_level_row_major(uint64_t *data,
const uint64_t *tw,
uint64_t n,
uint64_t log_n,
uint64_t level,
uint64_t m)
{
uint64_t col = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x;
uint64_t n_half = n >> 1;
if (col >= m) return;

uint64_t half = 1ULL << level;
uint64_t block_size = half << 1;

for (uint64_t bfly_base = blockIdx.y * blockDim.y;
bfly_base < n_half;
bfly_base += (uint64_t)gridDim.y * blockDim.y) {
uint64_t butterfly = bfly_base + threadIdx.y;
if (butterfly >= n_half) break;

uint64_t block_idx = butterfly >> level;
uint64_t k = butterfly & (half - 1);
uint64_t i0 = block_idx * block_size + k;
uint64_t i1 = i0 + half;

// Same twiddle for all columns at this butterfly position (broadcast).
uint64_t w = tw[k << (log_n - level - 1)];

uint64_t u = data[i0 * m + col];
uint64_t v = mul(w, data[i1 * m + col]);
data[i0 * m + col] = add(u, v);
data[i1 * m + col] = sub(u, v);
}
}

// Pointwise multiply row-major: data[row * m + col] *= weights[row].
// One weight per row, broadcast across all m columns.
// Grid: gridDim.x = ceil(m / 256), gridDim.y = min(n, 65535).
// Grid-stride loop over rows.
extern "C" __global__ void pointwise_mul_row_major(uint64_t *data,
const uint64_t *weights,
uint64_t n,
uint64_t m)
{
uint64_t col = (uint64_t)blockIdx.x * blockDim.x + threadIdx.x;
if (col >= m) return;
for (uint64_t row = blockIdx.y; row < n; row += gridDim.y)
data[row * m + col] = mul(data[row * m + col], weights[row]);
}

// ── Row-major → column-major transpose (for GpuLdeBase handle) ───────────────
//
// Converts the row-major LDE output to the column-major layout that downstream
// GPU kernels (DEEP, barycentric) require for the device handle.
//
// src[r * cols + c] → dst[c * out_stride + r]
//
// Grid: gridDim.x = ceil(cols/32), gridDim.y = min(ceil(rows/32), 65535).
// Grid-strides over row tiles so all rows are covered when rows > 65535*32.

#define MTILE 32
#define MTILE_P (MTILE + 1)

extern "C" __global__ void matrix_transpose_strided(
const uint64_t *__restrict__ src,
uint64_t *__restrict__ dst,
uint32_t rows,
uint32_t cols,
uint64_t out_stride)
{
__shared__ uint64_t tile[MTILE][MTILE_P];

for (uint32_t row_base = blockIdx.y * MTILE; row_base < rows;
row_base += gridDim.y * MTILE) {
uint32_t x = blockIdx.x * MTILE + threadIdx.x;
uint32_t y = row_base + threadIdx.y;

if (x < cols && y < rows)
tile[threadIdx.y][threadIdx.x] = src[(uint64_t)y * cols + x];

__syncthreads();

uint32_t tx = row_base + threadIdx.x;
uint32_t ty = blockIdx.x * MTILE + threadIdx.y;

if (tx < rows && ty < cols)
dst[(uint64_t)ty * out_stride + tx] = tile[threadIdx.x][threadIdx.y];

__syncthreads();
}
}
12 changes: 12 additions & 0 deletions crypto/math-cuda/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,14 @@ pub struct Backend {
pub ntt_dit_8_levels_batched: CudaFunction,
pub pointwise_mul_batched: CudaFunction,
pub scalar_mul_batched: CudaFunction,
// row-major NTT kernels
pub bit_reverse_row_major: CudaFunction,
pub ntt_dit_level_row_major: CudaFunction,
pub pointwise_mul_row_major: CudaFunction,
pub matrix_transpose_strided: CudaFunction,

// keccak.ptx
pub keccak256_leaves_base_row_major: CudaFunction,
pub keccak256_leaves_base_batched: CudaFunction,
pub keccak256_leaves_ext3_batched: CudaFunction,
pub keccak_comp_poly_leaves_ext3: CudaFunction,
Expand Down Expand Up @@ -237,6 +243,12 @@ impl Backend {
ntt_dit_8_levels_batched: ntt.load_function("ntt_dit_8_levels_batched")?,
pointwise_mul_batched: ntt.load_function("pointwise_mul_batched")?,
scalar_mul_batched: ntt.load_function("scalar_mul_batched")?,
bit_reverse_row_major: ntt.load_function("bit_reverse_row_major")?,
ntt_dit_level_row_major: ntt.load_function("ntt_dit_level_row_major")?,
pointwise_mul_row_major: ntt.load_function("pointwise_mul_row_major")?,
matrix_transpose_strided: ntt.load_function("matrix_transpose_strided")?,
keccak256_leaves_base_row_major: keccak
.load_function("keccak256_leaves_base_row_major")?,
keccak256_leaves_base_batched: keccak.load_function("keccak256_leaves_base_batched")?,
keccak256_leaves_ext3_batched: keccak.load_function("keccak256_leaves_ext3_batched")?,
keccak_comp_poly_leaves_ext3: keccak.load_function("keccak_comp_poly_leaves_ext3")?,
Expand Down
Loading
Loading