Add NVFP4 per-token quantization recipe#3045
Conversation
Rewrites the grouped multi-tensor cast as a K1 fused amax + K2 fused cast
pair and ships pytest correctness + sweep benches against the per-tensor
RHT+SR production baseline.
* common/cast/.../quantize_nvfp4_per_token_group.cu: K1+K2 fused
grouped kernel, reusing the single-tensor 4-stage TMA pipeline.
* common/gemm/nvfp4_per_token_post_scale.cu: row-wise post-scale
kernel for the cuBLASLT NVFP4 dequantize step (maybe updated due
to 2d quant of W).
* pytorch/csrc/extensions/nvfp4_per_token.cpp + pybind.cpp: new C++
grouped bulk binding and per-token GEMM entry; thin pybind layer.
* pytorch/custom_recipes/{gemm_nvfp4_per_token,
quantization_nvfp4_per_token_group}.py: Python wrappers.
* tests/pytorch/nvfp4/test_nvfp4_per_token{,_group}.py: byte-equal
cast tests + bf16-close GEMM tests.
* tests/pytorch/nvfp4/bench_nvfp4_per_token{,_group}.py: 6x3 sweep
over M in {1024..32768} x K in {2048,4096,8192}, eager + CUDA
Graphs columns, ratio against per-tensor RHT+SR baseline.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
6f17fe4 to
928ab1c
Compare
for more information, see https://pre-commit.ci
…uped) Wire `with_rht` / `random_sign_mask_t` through the per-token K1 (amax) and K2 (encode) kernels for both single-tensor and grouped paths. with_rht=False is byte-equal to the pre-RHT code path; when true, applies a 16-pt RHT on the columnwise direction in both K1 and K2 (rowwise stays raw) with outer amax + inner SF self-consistent. Implementation: per-thread fp32 FHT on CUDA cores, branchless fp32 sign-bit XOR for the +/-1 sign diagonal, 0.25 normalization folded into block_amax / block_scale (bit-exact). Tests cover K1, K2, composite + grouped vs a PyTorch fp32 reference and byte-equality regressions. Benches gain a --rht flag (2-way default, 3-way under --rht). Perf vs prod NVFP4Quantizer(rht+sr), Graph mode, 18 shapes M up to 32K: * single tensor : 0.49x-0.77x (no RHT), 0.59x-0.88x (+RHT) * grouped (N=8) : 0.41x-0.77x (no RHT), 0.50x-0.94x (+RHT) Also drops unused THREADS_X_TR / THREADS_Y_TR (nvcc warning NVIDIA#177-D). Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Add an optional fused-swizzle path to the NVFP4 per-token K2 encode
kernel: when with_swizzle=True the rowwise scale_inv is emitted directly
in the cuBLAS LT 128Mx4K swizzled tile layout, skipping the downstream
nvte_swizzle_scaling_factors launch. The colwise scale_inv stays in the
compact M-major layout (rowwise-only fusion for now).
The new code path is gated by a kWithSwizzle template parameter on
per_token_encode_kernel. The scatter epilogue uses thread mapping
b=tid&3, ty=tid>>2 to give each warp a coalesced 128-byte gmem store,
and packs two K-tiles into one uint64_t SMEM load (2-way bank conflict
instead of 4-way). Pre-existing code path is byte-equal.
with_swizzle is threaded through nvte_nvfp4_per_token_{quantize,encode},
their PyTorch bindings, and the nvfp4_per_token_{quantize,encode} Python
recipes. nvfp4_per_token_gemm takes new a_sf_swizzled / b_sf_swizzled
flags so the caller opts into the fast path per operand (mirrors prod
NVFP4 GEMM's per-operand swizzle).
Add tex.nvfp4_per_token_swizzle_rowwise_sf -- a thin wrapper around
nvte_swizzle_scaling_factors that does one standalone per-operand
swizzle launch. Bench-only; lets --qs attribute swizzle cost separately
from K1+K2 and from cuBLAS LT GEMM.
Bench (bench_nvfp4_per_token.py): add --qs mode (K1+K2 + standalone
swizzle, no GEMM) with two modifiers -- --pair (2 operands, matches one
prod GEMM call's quant+swizzle pipeline) and --fuse (adds a per-token
(fuse) column for the K2-fused path). The existing --swizzle end-to-end
mode also gains the fused-swizzle column. --pair / --fuse auto-imply
--qs to avoid silent fall-through to the default --composite table.
Tests (test_nvfp4_per_token.py): byte-equality of the fused-swizzle
rowwise SF vs a pure-Python permutation reference, byte-equality of all
other outputs (FP4 data, colwise SF, row/col amax) vs with_swizzle=False,
and numerical equivalence of the end-to-end GEMM via both code paths.
Perf at K=N=4096, Graph mode: fused-swizzle path is ~7-35% faster than
the unfused per-token pipeline (--qs) and reaches up to ~2.6x faster
than per-tensor at small M.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
The per-token cuBLASLt NVFP4 path needs a trailing post-scale kernel
(D *= alpha_a[i] * alpha_b[j]) that is HBM-bound on the M*N output. This
patch ships a forked-CUTLASS NVFP4 GEMM whose EVT epilogue folds the
per-row * per-col rescale into the in-TMEM accumulator -- a single launch
with no separate post-scale, no M*N HBM round-trip.
New C-API entry points (transformer_engine/common/gemm/nvfp4_cutlass_gemm.cu):
- nvte_nvfp4_cutlass_gemm: scalar (alpha, beta) NVFP4xNVFP4 -> BF16 GEMM
(CUTLASS analog of the cuBLASLt per-tensor path; used as test ground truth).
- nvte_nvfp4_cutlass_per_token_gemm: same mainloop, EVT epilogue
D[i,j] = bf16(NVFP4_DEQUANT_K * alpha_a[i] * alpha_b[j] * acc).
The outer 1/2688^2 factor (NVFP4 spec) is baked into the EVT explicitly,
matching the value cuBLASLt auto-folds via its amax slot.
Python bindings (tex.nvfp4_cutlass_gemm / tex.nvfp4_cutlass_per_token_gemm)
plus a/b_sf_swizzled flags for apples-to-apples --gemm-only benching.
Numerical correctness (tests/pytorch/nvfp4/test_nvfp4_cutlass_per_token_gemm.py):
- fused EVT == cuBLASLt per-token within bf16 ULP (rtol=2e-2), across
M,N,K = 256..1024.
- fused EVT with unity alphas == nvfp4_cutlass_gemm(alpha=1/2688^2) BIT-EXACT
(sanity check that the EVT tree and the baked constant are both correct).
Bench (tests/pytorch/nvfp4/bench_nvfp4_per_token.py --gemm-only) streamlined
to the only comparison that matters for shipping: ct_fused (per-token CUTLASS
fused) vs pten_gemm (prod per-tensor cuBLASLt), with the cf/pten ratio.
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Extends tests/pytorch/nvfp4/{bench,test}_nvfp4_cutlass_per_token_gemm
with end-to-end forward and backward coverage that aligns the prod
baseline with NVFP4BlockScaling real-ship defaults (input RHT-1D,
weight 2D no-RHT, grad RHT-cols + SR), so per-token (no RHT/SR) is
measured against an actually-shippable prod recipe rather than a
toy quantizer.
bench_nvfp4_per_token.py:
* --e2e-fwd: per-token quant (with_swizzle=True) + fused-EVT CUTLASS
GEMM vs NVFP4Quantizer + general_gemm (the real nn.Linear fwd
dispatch). Quant + GEMM inside the timing loop, N = K. Function
docstring carries an ASCII kernel-pipeline diagram for both paths
(per-call launch budget: per-token ~5 vs prod ~10).
* --e2e-bwd: real prod nn.Linear.bwd lifecycle. Timing loop = 1 x dY
quant + dgrad GEMM + wgrad GEMM; X and W are pre-quantized OUTSIDE
the loop (mirrors prod's reuse of fwd-saved QuantizedTensorStorage,
bwd never re-quantizes). pten side uses RHT-cols + SR grad
quantizer + general_gemm NN (dgrad) / NT (wgrad). Function docstring
carries an ASCII kernel-pipeline diagram (per-step launch budget:
per-token ~4 vs prod ~12).
* --gemm-only: 3-way table adds an lt_post column (cuBLASLt NVFP4 +
bf16 per-row*per-col post-scale, "Route 1") next to the existing
ct_fused fused-EVT path ("Route 2") and the prod pten_gemm
baseline. Headline ratio lp/cf decides whether to dispatch
per-token through cuBLASLt + post_scale or fused EVT; current
data shows ct_fused wins or ties at every shape we care about.
test_nvfp4_cutlass_per_token_gemm.py:
* Layer 2 fwd: per-token quant + fused-EVT GEMM vs BF16 fp32 ground
truth (rel_l2 < 0.30, robust to per-shape noise).
* Layer 3 fwd: dual-SNR table comparing per-token vs prod, both
measured against BF16 ground truth, with a per-token-vs-prod ratio.
* Layer 3 bwd: same dual-SNR pattern for dgrad and wgrad. Prod side
uses real-ship NVFP4BlockScaling grad quantizer (RHT cols + SR);
per-token side has no RHT/SR (numerical-floor comparison).
* Sanity micro-test for weight 2D quant plumbing through general_gemm
(catches breakage cheaper than the broader Layer 3 test).
Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Co-authored-by: Jiaxing Qi <jqi@nvidia.com>
for more information, see https://pre-commit.ci
| DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT); | ||
| constexpr int dshmem_size = buff_size_aligned_in + TMA_SHMEM_ALIGNMENT; // + align pad | ||
|
|
||
| dim3 grid(static_cast<unsigned>(K / CHUNK_DIM_X), static_cast<unsigned>(M / CHUNK_DIM_Y), 1); |
There was a problem hiding this comment.
maybe use DIVUP here to handle the remainder case?
There was a problem hiding this comment.
This fast path has a hard precondition that M and K are exact multiples of CHUNK_DIM (128): validate() does NVTE_CHECK(M % CHUNK_DIM_Y == 0) / NVTE_CHECK(K % CHUNK_DIM_X == 0), and is_supported() returns false unless both hold — so any non-multiple shape is rejected / routed to the generic per-token fallback before it ever reaches this launcher.
| // After all 4 stages, emit one atomicMaxFloat per row slot + one per col slot. | ||
| // | ||
| // kWithRht=true: col-wise amax over RHT-rotated 16-row strips (per-thread | ||
| // FHT with random_sign_mask_t). Row direction never sees RHT. |
There was a problem hiding this comment.
typo: Row direction never sees RHT -> Row direction never uses RHT
| } | ||
| } | ||
| #else | ||
| NVTE_DEVICE_ERROR("Per-token amax kernel requires SM 10.0+ (Blackwell)."); |
There was a problem hiding this comment.
For these quantization kernel, TMA only require SM 9.0+ only. Is there any other constraints that limit to sm 10.0+?
There was a problem hiding this comment.
The CUDA_ARCH >= 1000 guard is intentional but not because of a hardware op in this kernel. Two reasons:
- The shared TE PTX wrappers it calls — cp_async_bulk_tensor_2d_global_to_shared and mbarrier_wait_parity_acquire_cta_shared_cta in util/ptx.cuh — are themselves guarded to >= 1000 and emit NVTE_DEVICE_ERROR below that. They were authored/validated only for the Blackwell path.
- The whole NVFP4 quantize path is host-gated to SM100 anyway (NVTE_ERROR("NVFP4 requires SM100 ...")), since NVFP4 is a Blackwell datatype and the downstream FP4 GEMM that consumes these scales only exists on SM100. So the amax kernel is never launched off <SM100; the per-arch guard just yields a clean error instead of an undefined symbol.
Add NN/NT GEMM layout dispatch so the per-token NVFP4 path covers dgrad and wgrad, and let per-token opt into RHT via NVFP4PerTokenBlockScaling(per_token_rht=...) while SR/2D stay disabled (kernels unimplemented at this commit). Extends the per-token CUTLASS GEMM, the torch NVFP4Quantizer, and the NVFP4Tensor plumbing, plus dgrad/wgrad numerical tests and a fwd+bwd module smoke test. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Thread a Philox rng_state and a kWithSr template flag through the per-token encode kernel (rowwise + colwise) and the nvte_nvfp4_per_token_encode/quantize C-API, mirroring the per-tensor SR path. Drop the SR mutex check in the torch NVFP4Quantizer and build the rng_state when stochastic rounding is requested. Add a per_token_sr recipe flag on NVFP4PerTokenBlockScaling wired through the quantizer factory, plus statistical tests (SR unbiasedness -- lower RMSE than RN when averaged -- and RN-determinism / SR-nondeterminism) folded into test_nvfp4_per_token.py. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Wire with_sr + rng_state through the grouped per-token C-API and cast dispatch, implement the SR FP4 cast in the grouped kernel, and drop the "per-token does not support SR" guard. Also fix two comment typos (sees -> uses) in quantize_nvfp4_per_token.cu per review. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
for more information, see https://pre-commit.ci
Introduce NVTE_NVFP4_PER_TOKEN_WEIGHT_2D (recipe.per_token_weight_2d), default off so the per-token path stays byte-equal. When enabled, only the forward WEIGHT switches to the per-tensor 2D cast (16x16 inner tile + scalar outer amax) re-dressed in per-token tensor layout: the scalar outer amax is broadcast across the per-row/col alpha vectors and the inner SF is the same 16-row-replicated 2D tile, so the existing per-token CUTLASS GEMM consumes it unchanged with no kernel modification. Activation/gradient casts stay per-token 1D. Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
Document the user-facing surface of the NVFP4 per-token recipe and add a runnable single-GPU example so the recipe can be exercised end to end. - docs/api/common.rst: list NVFP4PerTokenBlockScaling in the API reference. - docs/envvars.rst: document the NVTE_NVFP4_* knobs -- per-token activation (NVTE_NVFP4_PER_TOKEN) plus the RHT/SR/weight-2D opt-ins, and the per-tensor disable flags. - docs/features/.../nvfp4.rst: add a "Per-token NVFP4" section explaining the per-row/per-col outer-amax cast, its differences from the per-tensor default (RHT/SR off by default, forced-off knobs, unfused-norm requirement), and how to launch it with Megatron-Core. - recipe/__init__.py: document the per_token_rht/per_token_sr/per_token_weight_2d constructor kwargs and drop the stale "stochastic rounding unsupported" note. - pytorch/fp8.py: re-export NVFP4PerTokenBlockScaling. - examples/pytorch/nvfp4_per_token_megatron: single-GPU MoE example (run + sbatch + job-chain scripts and README) comparing per-token vs per-tensor vs BF16 with identical model/data/seed. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
Greptile SummaryThis PR adds an NVFP4 per-token quantization recipe for model pre-training. The default
Confidence Score: 4/5Safe to merge after addressing the columnwise default inconsistency in the diagnostic split functions. The PR is large but well-structured. The only new finding is a P1 inconsistency in default argument values for the diagnostic split functions nvfp4_per_token_amax and nvfp4_per_token_encode (columnwise=True vs composite's columnwise=False), which can cause a confusing ValueError for callers using the split API with defaults. Core production paths (composite quantize, CUTLASS GEMM dispatch, recipe routing) look correct. transformer_engine/pytorch/custom_recipes/quantization_nvfp4_per_token.py — columnwise default mismatch between split and composite APIs. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant Recipe as NVFP4PerTokenBlockScaling
participant Quant as NVFP4Quantizer (Python)
participant QCpp as NVFP4Quantizer (C++)
participant K1 as nvte_nvfp4_per_token_amax (K1 kernel)
participant K2 as nvte_nvfp4_per_token_encode (K2 kernel)
participant Gemm as _nvfp4_per_token_gemm
participant CUTLASS as tex.nvfp4_cutlass_per_token_gemm
Recipe->>Quant: "nvfp4_per_token()=True, per_token_weight_2d"
Quant->>QCpp: create_tensor(M,K,rowwise/columnwise)
QCpp-->>Quant: NVFP4Tensor + row_amax(M,) or col_amax(K,)
Note over Quant,K2: Composite quantize path (production)
Quant->>K1: "nvte_nvfp4_per_token_quantize(input, with_swizzle=0)"
K1-->>Quant: inner_sf (E4M3), row_amax / col_amax vector
Quant->>K2: nvte_nvfp4_per_token_encode(inner_sf, row_amax)
K2-->>Quant: NVFP4 data + outer_sf
Note over Gemm,CUTLASS: TN forward GEMM (A rowwise, B columnwise)
Gemm->>Gemm: _nvfp4_per_token_select(A, rowwise)
Gemm->>Gemm: _nvfp4_per_token_select(B, columnwise)
Gemm->>CUTLASS: nvfp4_cutlass_per_token_gemm(A_data, A_sf, A_row_amax, B_data, B_sf, B_col_amax)
CUTLASS-->>Gemm: output (M, N)
Gemm-->>Gemm: reshape to N-D leading dims + optional bias add
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant Recipe as NVFP4PerTokenBlockScaling
participant Quant as NVFP4Quantizer (Python)
participant QCpp as NVFP4Quantizer (C++)
participant K1 as nvte_nvfp4_per_token_amax (K1 kernel)
participant K2 as nvte_nvfp4_per_token_encode (K2 kernel)
participant Gemm as _nvfp4_per_token_gemm
participant CUTLASS as tex.nvfp4_cutlass_per_token_gemm
Recipe->>Quant: "nvfp4_per_token()=True, per_token_weight_2d"
Quant->>QCpp: create_tensor(M,K,rowwise/columnwise)
QCpp-->>Quant: NVFP4Tensor + row_amax(M,) or col_amax(K,)
Note over Quant,K2: Composite quantize path (production)
Quant->>K1: "nvte_nvfp4_per_token_quantize(input, with_swizzle=0)"
K1-->>Quant: inner_sf (E4M3), row_amax / col_amax vector
Quant->>K2: nvte_nvfp4_per_token_encode(inner_sf, row_amax)
K2-->>Quant: NVFP4 data + outer_sf
Note over Gemm,CUTLASS: TN forward GEMM (A rowwise, B columnwise)
Gemm->>Gemm: _nvfp4_per_token_select(A, rowwise)
Gemm->>Gemm: _nvfp4_per_token_select(B, columnwise)
Gemm->>CUTLASS: nvfp4_cutlass_per_token_gemm(A_data, A_sf, A_row_amax, B_data, B_sf, B_col_amax)
CUTLASS-->>Gemm: output (M, N)
Gemm-->>Gemm: reshape to N-D leading dims + optional bias add
Reviews (4): Last reviewed commit: "Batch per-group metadata H2D in NVFP4 gr..." | Re-trigger Greptile |
| # Per-token NVFP4 dispatches to fused EVT GEMM that consumes per-row | ||
| # (M,) and per-col (N,) outer-amax vectors directly. cuBLASLt cannot, | ||
| # so this MUST short-circuit before the row-scaled-or-generic fork. | ||
| if _is_nvfp4_per_token_tensor(A) or _is_nvfp4_per_token_tensor(B): | ||
| if not (_is_nvfp4_per_token_tensor(A) and _is_nvfp4_per_token_tensor(B)): | ||
| raise NotImplementedError( | ||
| "NVFP4 per-token GEMM requires both A and B to be per-token tensors. " | ||
| "Mixing per-token + prod NVFP4 in one GEMM is not supported." | ||
| ) | ||
| out = _nvfp4_per_token_gemm( | ||
| A, | ||
| B, | ||
| transa=transa, | ||
| transb=transb, | ||
| out=out, | ||
| out_dtype=out_dtype, | ||
| bias=bias, | ||
| grad=grad, | ||
| accumulate=accumulate, | ||
| gelu=gelu, | ||
| quantization_params=quantization_params, | ||
| ub=ub, | ||
| extra_output=extra_output, | ||
| ) |
There was a problem hiding this comment.
alpha scalar silently ignored for per-token GEMM
general_gemm validates and stores alpha in kwargs["alpha"], but the per-token short-circuit path dispatches to _nvfp4_per_token_gemm which has no alpha parameter and never forwards the value. The C++ binding nvfp4_cutlass_per_token_gemm also lacks a global scalar alpha argument — only the per-row/per-col alpha_a/alpha_b vectors are supported. For all current TE module call sites alpha=1.0 is the invariant, so numerical output is unaffected today. If a caller ever passes alpha != 1.0 through general_gemm with per-token tensors, the result will be silently wrong instead of raising an error.
| for i, M_i in enumerate(split_sections): | ||
| if M_i <= 0: | ||
| raise ValueError(f"split_sections[{i}] must be > 0, got {M_i}") | ||
| if M_i % _PER_TOKEN_TILE != 0: | ||
| raise ValueError(f"split_sections[{i}] = {M_i} must be a multiple of {_PER_TOKEN_TILE}") |
There was a problem hiding this comment.
Public grouped-quantize API unconditionally rejects 0-token splits
split_sections[i] <= 0 raises ValueError, but in MoE training with dynamic token routing, experts commonly receive zero tokens in a given micro-batch. The general_grouped_gemm per-token loop already handles this by skipping the launch when m_splits[i] == 0, so the GEMM side is fine. If users call this Python wrapper directly (e.g., from bench scripts or custom MoE quantization pipelines), they must pre-filter empty experts. A comment or guard skipping allocation for empty splits would make the API usable in unbalanced-routing scenarios.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Replace the per-expert Python loop for the plain D = bf16(alpha_a * alpha_b * (A @ B^T)) path with a single ptr-array CUTLASS grouped kernel (SM100). The dispatcher in general_grouped_gemm routes to the native kernel when no accumulate/bias/gelu/output-quant is requested, and otherwise falls back to the per-expert loop (NVTE_NVFP4_PER_TOKEN_GROUPED_FALLBACK=1 forces the fallback). The launcher caches the SM count and reuses persistent device scratch + workspace buffers across launches to avoid per-call cudaMalloc/Free and cudaGetDeviceProperties overhead. Parity tests assert the grouped kernel matches the dense per-token GEMM bit-exact per group. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
Add an fp32-output accumulate variant to the dense and grouped per-token NVFP4 CUTLASS kernels. The EVT computes D = beta * C + dW, where beta=1 accumulates the weight gradient in place into the fp32 main_grad buffer (C aliases D) and beta=0 overwrites. This lets te.Linear / GroupedLinear wgrad accumulate straight into main_grad, mirroring the prod per-tensor cuBLAS LT path (C == D in place, beta = accumulate ? 1 : 0; output quantization disables accumulation). Add dense and grouped parity tests covering fp32 overwrite (matches the bf16 path cast to fp32) and bit-exact in-place accumulation. Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Jiaxing Qi <jqi@nvidia.com> Signed-off-by: Cael Ling <caell@nvidia.com>
The grouped per-token GEMM uploads ~14 small per-group metadata arrays (problem shapes, A/B/C/D strides, SFA/SFB layouts, data/SF/D pointers, alpha_a/alpha_b) to device before each launch. Issuing one cudaMemcpyAsync per array adds a fixed ~15-20us of host-side overhead that dominates the launch-bound regime (small / many-expert MoE). Pack all arrays into a process-persistent pageable host mirror at their 256B-aligned scratch offsets and ship them in a SINGLE H2D copy. The buffer is intentionally pageable: cudaMemcpyAsync from pageable host memory stages the source into the driver before returning, so the mirror is safe to overwrite on the next call even when the host runs ahead of the stream. Gated by NVTE_NVFP4_GROUPED_BATCHED_H2D (default on); set it to 0 to fall back to the per-array copies for A/B measurement in the same build. Pure-GEMM bench (SM100): the native kernel drops a constant ~15-20us per launch, lifting native-vs-fallback speedup from ~1.5-3.5x to ~2.0-4.0x; biggest relative win on small/few-expert shapes. Numerics unchanged: test_nvfp4_cutlass_per_token_gemm.py grouped cases pass bit-exact with the env on and off. Signed-off-by: Cael Ling <caell@nvidia.com>
Description
This PR adds an NVFP4 per-token quantization recipe for model pre-training. The default NVFP4BlockScaling recipe computes a single per-tensor outer amax (s_global) per tensor. The per-token variant instead computes a per-row outer amax (length M) for rowwise data and a per-col outer amax (length K) for columnwise data, giving each token/row its own global scale.
Changes
Ongoing work
The per-token recipe currently targets accuracy evaluation, not optimized production deployment:
2026-06-15: Native CUTLASS grouped per-token NVFP4 GEMM (single ptr-array launch) — done (b0b72cf)2026-06-15: Support fuse_wgrad_accumulation in NVFP4 per-token GEMM — done (c3ce651)Type of change
Checklist: