Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions prover/src/tables/branch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,29 +153,20 @@ impl BranchOperation {

/// Generates the BRANCH trace table from a list of operations.
///
/// Duplicate operations (same pc, offset, register, jalr) are merged into a single row
/// with their multiplicities summed. The table is then padded to the next power of 2.
/// One row per operation with `μ = 1` (the spec types `μ` as a `Bit`). The table
/// is padded to the next power of two.
pub fn generate_branch_trace(
operations: &[BranchOperation],
) -> TraceTable<GoldilocksField, GoldilocksExtension> {
use std::collections::HashMap;

// Deduplicate operations: (pc, offset, register, jalr) -> multiplicity
let mut op_map: HashMap<BranchOperation, u64> = HashMap::new();
for op in operations {
*op_map.entry(op.clone()).or_insert(0) += 1;
}

let unique_ops: Vec<_> = op_map.into_iter().collect();
let num_rows = unique_ops.len().next_power_of_two().max(4);
let num_rows = operations.len().next_power_of_two().max(4);
let mut trace = TraceTable::new_main(
vec![FE::zero(); num_rows * cols::NUM_COLUMNS],
cols::NUM_COLUMNS,
1,
);
let table = &mut trace.main_table;

for (row_idx, (op, multiplicity)) in unique_ops.iter().enumerate() {
for (row_idx, op) in operations.iter().enumerate() {
// Compute next_pc
let next_pc_unmasked = op.compute_next_pc_unmasked();
let next_pc = op.compute_next_pc();
Expand Down Expand Up @@ -209,7 +200,8 @@ pub fn generate_branch_trace(
&[next_pc_low_0, next_pc_low_1],
);
table.set_byte(row_idx, cols::UNMASKED_LOW_BYTE, unmasked_low_byte);
table.set_u64(row_idx, cols::MU, *multiplicity);
// One row per op, so μ is the Bit the spec declares (1 real / 0 padding).
table.set_u64(row_idx, cols::MU, 1);
}

trace
Expand Down
18 changes: 12 additions & 6 deletions prover/src/tables/dvrm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,10 @@ impl DvrmOperation {
pub fn generate_dvrm_trace(
operations: &[(DvrmOperation, bool)],
) -> TraceTable<GoldilocksField, GoldilocksExtension> {
// Deduplicate: (n, d, signed) -> (mu_q, mu_r)
let mut op_map: HashMap<DvrmOperation, DvrmMultiplicities> = HashMap::new();
// Deduplicate: (n, d, signed) -> (mu_q, mu_r).
// Pre-size to the op count (an upper bound on unique ops) to avoid rehashing.
let mut op_map: HashMap<DvrmOperation, DvrmMultiplicities> =
HashMap::with_capacity(operations.len());

for (op, wants_remainder) in operations {
let entry = op_map.entry(op.clone()).or_default();
Expand All @@ -311,8 +313,12 @@ pub fn generate_dvrm_trace(
for (row_idx, (op, multiplicities)) in unique_ops.iter().enumerate() {
let q = op.compute_quotient();
let r = op.compute_remainder();
let n_sub_r = op.n_sub_r();
let abs_r = op.abs_r();
// Derive the rest from the single `r` above instead of the helper methods,
// each of which recomputes the integer division internally (6×/row → 1×).
let sign_r = op.signed && (r >> 63) == 1;
let n_sub_r = op.n.wrapping_sub(r);
let sign_n_sub_r = op.signed && (n_sub_r >> 63) == 1;
let abs_r = DvrmOperation::abs_value(r, sign_r);
let abs_d = op.abs_d();

// Fill n as DWordHL (4 halfwords)
Expand All @@ -339,11 +345,11 @@ pub fn generate_dvrm_trace(
// Fill n_sub_r as DWordHL (4 halfwords)
table.set_dword_hl(row_idx, cols::N_SUB_R_0, n_sub_r);

table.set_bool(row_idx, cols::SIGN_N_SUB_R, op.sign_n_sub_r());
table.set_bool(row_idx, cols::SIGN_N_SUB_R, sign_n_sub_r);
table.set_bool(row_idx, cols::SIGN_N, op.sign_n());
table.set_bool(row_idx, cols::SIGN_D, op.sign_d());
table.set_bool(row_idx, cols::SIGN_Q, op.sign_q());
table.set_bool(row_idx, cols::SIGN_R, op.sign_r());
table.set_bool(row_idx, cols::SIGN_R, sign_r);

// Multiplicities
table.set_u64(row_idx, cols::MU_Q, multiplicities.mu_q);
Expand Down
22 changes: 7 additions & 15 deletions prover/src/tables/lt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,29 +156,20 @@ impl LtOperation {

/// Generates the LT trace table from a list of operations.
///
/// Duplicate operations (same lhs, rhs, signed) are merged into a single row
/// with their multiplicities summed. The table is then padded to the next power of 2.
/// One row per operation with `μ = 1` (the spec types `μ` as a `Bit`). The table
/// is padded to the next power of two.
pub fn generate_lt_trace(
operations: &[LtOperation],
) -> TraceTable<GoldilocksField, GoldilocksExtension> {
use std::collections::HashMap;

// Deduplicate operations: (lhs, rhs, signed) -> multiplicity
let mut op_map: HashMap<LtOperation, u64> = HashMap::new();
for op in operations {
*op_map.entry(op.clone()).or_insert(0) += 1;
}

let unique_ops: Vec<_> = op_map.into_iter().collect();
let num_rows = unique_ops.len().next_power_of_two().max(4);
let num_rows = operations.len().next_power_of_two().max(4);
let mut trace = TraceTable::new_main(
vec![FE::zero(); num_rows * cols::NUM_COLUMNS],
cols::NUM_COLUMNS,
1,
);
let table = &mut trace.main_table;

for (row_idx, (op, multiplicity)) in unique_ops.iter().enumerate() {
for (row_idx, op) in operations.iter().enumerate() {
// Store input columns
table.set_dword_hhw(row_idx, cols::LHS_0, op.lhs);
table.set_dword_hhw(row_idx, cols::RHS_0, op.rhs);
Expand All @@ -205,8 +196,9 @@ pub fn generate_lt_trace(
table.set_bool(row_idx, cols::INVERT, op.invert);
table.set_bool(row_idx, cols::OUT, op.compute_out());

// All LT lookups go through the unified ALU bus → single multiplicity.
table.set_u64(row_idx, cols::MU, *multiplicity);
// All LT lookups go through the unified ALU bus. One row per op, so
// μ is the Bit the spec declares (1 for real rows, 0 for padding).
table.set_u64(row_idx, cols::MU, 1);
}

trace
Expand Down
6 changes: 4 additions & 2 deletions prover/src/tables/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,10 @@ impl MulOperation {
pub fn generate_mul_trace(
operations: &[(MulOperation, bool)],
) -> TraceTable<GoldilocksField, GoldilocksExtension> {
// Deduplicate: (lhs, lhs_signed, rhs, rhs_signed) -> (mu_lo, mu_hi)
let mut op_map: HashMap<MulOperation, MulMultiplicities> = HashMap::new();
// Deduplicate: (lhs, lhs_signed, rhs, rhs_signed) -> (mu_lo, mu_hi).
// Pre-size to the op count (an upper bound on unique ops) to avoid rehashing.
let mut op_map: HashMap<MulOperation, MulMultiplicities> =
HashMap::with_capacity(operations.len());

for (op, wants_hi) in operations {
let entry = op_map.entry(op.clone()).or_default();
Expand Down
144 changes: 96 additions & 48 deletions prover/src/tables/trace_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,71 @@ fn collect_cpu_ops(
///
/// Returns: (memw_ops, load_ops, lt_ops, shift_ops, bitwise_ops, commit_ops, keccak_ops,
/// cpu32_ops, ecsm_ops, ec_scalar_ops, ecdas_ops)
/// Collect the chips that depend only on each `CpuOperation` (no memory/register
/// state): the CPU range-check bitwise lookups plus the CPU32 / LT / SHIFT
/// dispatch. Parallel under the `parallel` feature; results stay in program
/// order, matching the sequential build.
fn collect_state_free_ops(
cpu_ops: &[CpuOperation],
) -> (
Vec<BitwiseOperation>,
Vec<cpu32::Cpu32Operation>,
Vec<LtOperation>,
Vec<ShiftOperation>,
) {
let lt = |op: &CpuOperation| -> Option<LtOperation> {
let f = op.decode.fields;
(!f.word_instr && f.is_lt()).then(|| {
LtOperation::new_with_invert(op.rv1, op.arg2, f.alu_signed(), f.alu_signed2_or_invert())
})
};
let shift = |op: &CpuOperation| -> Option<ShiftOperation> {
let f = op.decode.fields;
(!f.word_instr && f.is_shift()).then(|| {
ShiftOperation::new(
op.rv1,
op.arg2,
f.alu_signed2_or_invert(),
f.alu_signed(),
f.word_instr,
)
})
};
#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
(
cpu_ops
.par_iter()
.flat_map_iter(|op| op.collect_bitwise_ops())
.collect(),
cpu_ops
.par_iter()
.filter(|op| op.decode.fields.word_instr)
.map(build_cpu32_op)
.collect(),
cpu_ops.par_iter().filter_map(lt).collect(),
cpu_ops.par_iter().filter_map(shift).collect(),
)
}
#[cfg(not(feature = "parallel"))]
{
(
cpu_ops
.iter()
.flat_map(|op| op.collect_bitwise_ops())
.collect(),
cpu_ops
.iter()
.filter(|op| op.decode.fields.word_instr)
.map(build_cpu32_op)
.collect(),
cpu_ops.iter().filter_map(lt).collect(),
cpu_ops.iter().filter_map(shift).collect(),
)
}
}

#[allow(clippy::type_complexity)]
fn collect_ops_from_cpu(
cpu_ops: &[CpuOperation],
Expand All @@ -373,27 +438,26 @@ fn collect_ops_from_cpu(
Vec<ec_scalar::EcScalarOperation>,
Vec<ecdas::EcdasOperation>,
) {
// State-free chips (CPU range-check bitwise lookups + CPU32/LT/SHIFT dispatch)
// are collected in parallel; the loop below only does the state-dependent work
// (MEMW/register/commit/keccak/ecsm — which thread memory/register state).
let (cpu_bitwise_ops, cpu32_ops, lt_ops, shift_ops) = collect_state_free_ops(cpu_ops);

let mut memw_ops = Vec::with_capacity(cpu_ops.len() * 3);
let mut load_ops = Vec::with_capacity(cpu_ops.len() / 8 + 1);
let mut lt_ops = Vec::with_capacity(cpu_ops.len() / 10 + 1);
let mut shift_ops = Vec::with_capacity(cpu_ops.len() / 10 + 1);
let mut bitwise_ops = Vec::with_capacity(cpu_ops.len() * 4);
let mut commit_ops = Vec::new();
let mut keccak_ops = Vec::new();
let mut cpu32_ops = Vec::new();
let mut ecsm_ops = Vec::new();
let mut ec_scalar_ops = Vec::new();
let mut ecdas_ops = Vec::new();
let mut current_commit_index = 0u32;
let mut commit_ecall_count = 0u32;

for op in cpu_ops {
// Word (`*W`) instructions delegate to the CPU32 table (built in program
// order; its register accesses are still emitted via the shared register
// collector below so the MEMW table balances).
if op.decode.fields.word_instr {
cpu32_ops.push(build_cpu32_op(op));
}
// CPU32 register accesses are still emitted via the shared register
// collector below so the MEMW table balances; the CPU32 op itself is
// built in the state-free parallel pass.

// --- MEMW and LOAD (require state tracking, order matters) ---

Expand Down Expand Up @@ -474,41 +538,13 @@ fn collect_ops_from_cpu(
ec_scalar_ops.extend(ec_scalar_rows);
ecdas_ops.extend(ecdas_rows);
}

// --- ALU chip dispatch (no state tracking) ---
// Word (`*W`) instructions are delegated to CPU32 (which itself drives
// the ALU chips); the main CPU does not send the ALU bus for them, so we
// must not emit chip ops here. CPU32 op-generation is B5b.
let f = op.decode.fields;
if !f.word_instr {
// LT: SLT / BLT / BGE, dispatched on the unified ALU bus. `invert`
// (BGE/BGEU) is applied inside the LT chip (`out = lt XOR invert`).
if f.is_lt() {
lt_ops.push(LtOperation::new_with_invert(
op.rv1,
op.arg2,
f.alu_signed(),
f.alu_signed2_or_invert(),
));
}
// SHIFT: SLL/SRL/SRA. direction = invert bit (0 = left, 1 = right).
// The full arg2 goes on the ALU bus as in2; the chip uses its low
// byte for the (mod 32/64) computation.
if f.is_shift() {
shift_ops.push(ShiftOperation::new(
op.rv1,
op.arg2,
f.alu_signed2_or_invert(),
f.alu_signed(),
f.word_instr,
));
}
}

// Collect CPU range-check bitwise lookups (ARE_BYTES + IS_HALF).
bitwise_ops.extend(op.collect_bitwise_ops());
}

// CPU range-check lookups (ARE_BYTES + IS_HALF) were collected in the
// state-free pass above; merge them in. Order is irrelevant for the bitwise
// multiplicity accumulation.
bitwise_ops.extend(cpu_bitwise_ops);

// Each ecall generates count+1 operations (count real rows + 1 end row)
debug_assert_eq!(
commit_ops.len(),
Expand Down Expand Up @@ -2520,19 +2556,22 @@ struct CollectedOps {
/// Chunk raw ops and generate one trace table per chunk. When `storage_mode`
/// is `Disk`, each chunk's main table is spilled to mmap before the next chunk
/// is built so peak heap usage stays bounded.
fn chunk_and_generate<T>(
fn chunk_and_generate<T: Sync>(
ops: &[T],
max_rows: usize,
generate: impl Fn(&[T]) -> TraceTable<GoldilocksField, GoldilocksExtension>,
generate: impl Fn(&[T]) -> TraceTable<GoldilocksField, GoldilocksExtension> + Sync,
#[cfg(feature = "disk-spill")] storage_mode: StorageMode,
) -> Result<Vec<TraceTable<GoldilocksField, GoldilocksExtension>>, Error> {
let op_chunks: Vec<&[T]> = if ops.is_empty() {
vec![&[][..]]
} else {
ops.chunks(max_rows).collect()
};
let mut tables = Vec::with_capacity(op_chunks.len());
for chunk in op_chunks {

// Each chunk is independent, so generate them concurrently. `collect` into a
// `Result<Vec<_>>` preserves chunk order, so the output is byte-identical to
// the sequential build.
let gen_one = |chunk: &[T]| -> Result<TraceTable<GoldilocksField, GoldilocksExtension>, Error> {
#[allow(unused_mut)]
let mut t = generate(chunk);
#[cfg(feature = "disk-spill")]
Expand All @@ -2541,9 +2580,18 @@ fn chunk_and_generate<T>(
.spill_to_disk()
.map_err(|e| Error::Prover(format!("disk-spill trace: {e}")))?;
}
tables.push(t);
Ok(t)
};

#[cfg(feature = "parallel")]
{
use rayon::prelude::*;
op_chunks.into_par_iter().map(gen_one).collect()
}
#[cfg(not(feature = "parallel"))]
{
op_chunks.into_iter().map(gen_one).collect()
}
Ok(tables)
}

/// Phase 2: Collect and route all operations from CPU ops.
Expand Down
Loading
Loading