Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/public/playground/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@
}
}
</style>
<script type="module" crossorigin src="./assets/index-BYbIYCc0.js"></script>
<script type="module" crossorigin src="./assets/index-Bl6wtlTo.js"></script>
</head>
<body>
<div class="container">
Expand Down
31 changes: 19 additions & 12 deletions src/benchmark/Benchmark.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,31 @@ export class Benchmark {
const times: number[] = [];
const gpuTimes: number[] = [];

// Preallocate GPU buffers for the target size so iterations measure
// steady-state sort performance (buffer reuse) rather than allocation.
let sorter: BitonicSorter | RadixSorter | null = null;
if (algorithm === 'bitonic') {
if (!this.bitonicSorter) {
this.bitonicSorter = new BitonicSorter(this.context);
}
sorter = this.bitonicSorter;
this.bitonicSorter.preallocate(size);
} else if (algorithm === 'radix') {
if (!this.radixSorter) {
this.radixSorter = new RadixSorter(this.context);
}
sorter = this.radixSorter;
this.radixSorter.preallocate(size);
}

for (let i = 0; i < iterations; i++) {
const data = Benchmark.generateRandomData(size);

if (algorithm === 'js-native') {
const time = this.runNativeSort(data);
times.push(time);
} else if (algorithm === 'bitonic') {
if (!this.bitonicSorter) {
this.bitonicSorter = new BitonicSorter(this.context);
}
const result = await this.bitonicSorter.sort(data);
times.push(result.totalTimeMs);
gpuTimes.push(result.gpuTimeMs);
} else if (algorithm === 'radix') {
if (!this.radixSorter) {
this.radixSorter = new RadixSorter(this.context);
}
const result = await this.radixSorter.sort(data);
} else if (sorter) {
const result = await sorter.sort(data);
times.push(result.totalTimeMs);
gpuTimes.push(result.gpuTimeMs);
}
Expand Down
17 changes: 11 additions & 6 deletions src/shaders/scan.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,28 @@ fn scan_block_sums(

// Add block prefixes to each block's local scan results
// This is the third step of two-level scan
// Each thread handles 2 elements (matching blelloch_scan's 512 elements per workgroup)
@compute @workgroup_size(SCAN_WORKGROUP_SIZE)
fn add_block_prefixes(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>
) {
let tid = local_id.x;
let gid = global_id.x;
let block_id = workgroup_id.x;
let n = scan_uniforms.data_size;

// Get the prefix for this block (sum of all previous blocks)
let block_prefix = block_sums[block_id];

// Add block prefix to each element in this block
let idx = gid;
if (idx < n) {
scan_output[idx] = scan_output[idx] + block_prefix;
// Add block prefix to each element in this block (2 elements per thread)
let block_start = block_id * (SCAN_WORKGROUP_SIZE * 2u);
let idx0 = block_start + tid;
let idx1 = block_start + tid + SCAN_WORKGROUP_SIZE;

if (idx0 < n) {
scan_output[idx0] = scan_output[idx0] + block_prefix;
}
if (idx1 < n) {
scan_output[idx1] = scan_output[idx1] + block_prefix;
}
}
7 changes: 4 additions & 3 deletions src/shared/random.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ export function fillRandomUint32Array(data: Uint32Array): Uint32Array {
if (typeof crypto !== 'undefined' && typeof crypto.getRandomValues === 'function') {
for (let offset = 0; offset < data.length; offset += MAX_CRYPTO_FILL_U32) {
const chunkLength = Math.min(MAX_CRYPTO_FILL_U32, data.length - offset);
const chunk = new Uint32Array(chunkLength);
crypto.getRandomValues(chunk);
data.set(chunk, offset);
// Fill in-place via subarray view — avoids per-chunk allocation + copy
crypto.getRandomValues(
data.subarray(offset, offset + chunkLength) as Uint32Array<ArrayBuffer>
);
}
return data;
}
Expand Down
78 changes: 45 additions & 33 deletions src/sorting/BitonicSorter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -244,49 +244,61 @@ export class BitonicSorter {
const numWorkgroups = Math.ceil(paddedSize / WORKGROUP_SIZE);
// Safe integer log2 - paddedSize is guaranteed to be power of 2
const numStages = Math.trunc(Math.log2(paddedSize));

// First, do local sort within each workgroup
{
const localPipeline = this.localPipeline;
if (!localPipeline) {
throw new ShaderCompilationError('Local pipeline not initialized');
}

const uniformData = new Uint32Array([0, 0, paddedSize, 0]);
this.device.queue.writeBuffer(uniformBuffer, 0, uniformData);

const commandEncoder = this.device.createCommandEncoder();
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(localPipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(numWorkgroups);
passEncoder.end();
this.device.queue.submit([commandEncoder.finish()]);
}

// Then do global merge stages
// Safe integer log2 - WORKGROUP_SIZE is guaranteed to be power of 2
const localStages = Math.trunc(Math.log2(WORKGROUP_SIZE));

const localPipeline = this.localPipeline;
const globalPipeline = this.globalPipeline;
if (!globalPipeline) {
throw new ShaderCompilationError('Global pipeline not initialized');
if (!localPipeline || !globalPipeline) {
throw new ShaderCompilationError('Sort pipelines not initialized');
}

// Pre-compute all uniform values (local pass + all global passes) into a
// single buffer, then batch every dispatch into one command encoder with
// copyBufferToBuffer updating the uniform between passes. This eliminates
// per-pass queue submissions (can be 100+ for large arrays).
const passes: Array<{ stage: number; passNum: number; isLocal: boolean }> = [
{ stage: 0, passNum: 0, isLocal: true },
];
for (let stage = localStages; stage < numStages; stage++) {
for (let passNum = stage; passNum >= 0; passNum--) {
const uniformData = new Uint32Array([stage, passNum, paddedSize, 0]);
this.device.queue.writeBuffer(uniformBuffer, 0, uniformData);

const commandEncoder = this.device.createCommandEncoder();
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(globalPipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(numWorkgroups);
passEncoder.end();
this.device.queue.submit([commandEncoder.finish()]);
passes.push({ stage, passNum, isLocal: false });
}
}

const uniformData = new Uint32Array(passes.length * 4);
for (let i = 0; i < passes.length; i++) {
const p = passes[i];
uniformData[i * 4] = p.stage;
uniformData[i * 4 + 1] = p.passNum;
uniformData[i * 4 + 2] = paddedSize;
uniformData[i * 4 + 3] = 0;
}

const uniformDataBuffer = bufferScope.track(
this.device.createBuffer({
label: 'bitonic-uniform-data',
size: BufferManager.alignSize(uniformData.byteLength, 4),
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST,
})
);
this.device.queue.writeBuffer(uniformDataBuffer, 0, uniformData);

// Single command encoder for all passes — compute passes within an
// encoder are ordered and each sees the writes of previous passes.
const commandEncoder = this.device.createCommandEncoder();
for (let i = 0; i < passes.length; i++) {
// Update uniform for this pass via encoder-level copy (ordered)
commandEncoder.copyBufferToBuffer(uniformDataBuffer, i * 16, uniformBuffer, 0, 16);

const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(passes[i].isLocal ? localPipeline : globalPipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(numWorkgroups);
passEncoder.end();
}
this.device.queue.submit([commandEncoder.finish()]);

// Wait for GPU to finish
await this.device.queue.onSubmittedWorkDone();

Expand Down
4 changes: 3 additions & 1 deletion src/sorting/RadixSorter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,14 @@ export class RadixSorter {

const gpuStartTime = performance.now();

// Reusable zero buffer for histogram clearing (avoids per-pass allocation)
const zeroHistogram = new Uint32Array(histogramSize);

// Perform 8 passes (4 bits each)
for (let pass = 0; pass < NUM_PASSES; pass++) {
const bitOffset = pass * BITS_PER_PASS;

// Clear histogram
const zeroHistogram = new Uint32Array(histogramSize);
this.device.queue.writeBuffer(histogramBuffer, 0, zeroHistogram);

// Update uniforms
Expand Down
Loading
Loading