Skip to content

NVFP4: cache GEMM-swizzled weight scale factors across micro-batches#3093

Open
cael-ling wants to merge 7 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-weight-swizzle-cache
Open

NVFP4: cache GEMM-swizzled weight scale factors across micro-batches#3093
cael-ling wants to merge 7 commits into
NVIDIA:mainfrom
cael-ling:feature/nvfp4-weight-swizzle-cache

Conversation

@cael-ling

Copy link
Copy Markdown
Contributor

Description

For block-scaled NVFP4, a cached weight is used in two GEMMs per step — fprop (row-wise scales) and dgrad (column-wise scales) — and each GEMM needs its scale factors in the GEMM-swizzled layout. Today that swizzle is recomputed lazily inside general_gemm on every micro-batch and thrown away, so with N micro-batches the weight scale swizzle runs 2*N times per step even though the weight is quantized only once, which hurts performance. (Activation quantizers already set optimize_for_gemm=True and were pre-swizzled; only the weight was missed.)

This PR sets weight_quantizer.optimize_for_gemm=True on the cached, non-FSDP path so the swizzle is done once at quantize time, persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused by every GEMM — 2*N2 swizzles per step.

  • Applied to Linear, LayerNormLinear, LayerNormMLP (fc1 + fc2) and GroupedLinear (per expert).

  • Gated to the cached path (is_first_microbatch is not None) with fsdp_group is None and not is_fsdp2: FSDP/FSDP2 all-gather weights using the un-swizzled scale layout, so pre-swizzling is unsupported there.

  • No-op for recipes whose scales do not require swizzling (e.g. per-tensor FP8).

  • Swizzling is a pure layout permutation, so numerics are unchanged.

  • New tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py: asserts the cached eager-swizzle path matches the lazy-swizzle baseline (fprop + dgrad) for Linear / LayerNormLinear / GroupedLinear, and that _with_gemm_swizzled_scales is set and persisted on the cached workspace.

  • pytest tests/pytorch/test_numerics.py -k "linear or layernorm or mlp" — no regressions.

  • pytest tests/pytorch/test_grouped_linear.py -k "not grouped_tensor and not fused_path" — no regressions.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…obatches

For block-scaled NVFP4 a cached weight participates in two GEMMs per step:
fprop (rowwise scales) and dgrad (columnwise scales). The GEMM-ready scale
swizzle was recomputed lazily inside every GEMM and discarded, so with N
microbatches the weight scale swizzle ran 2*N times per step even though the
weight is quantized only once.

Because weight RHT is disabled, the weight scales are not swizzled by the
cast-fusion path; with optimize_for_gemm off they also skip the post-quantize
fallback swizzle, so the only swizzle site left for the weight is the lazy one
inside general_gemm (swizzle_scales_for_gemm), which re-runs on every GEMM.
(Activation input/grad_output quantizers already set optimize_for_gemm=True, so
they were pre-swizzled via cast-fusion/fallback; only the weight was missed.)

Set weight_quantizer.optimize_for_gemm=True on the cached, non-FSDP path so the
swizzle is done once at quantize time (via the post-quantize fallback),
persisted on the cached workspace (_with_gemm_swizzled_scales=True), and reused
by every GEMM (swizzle_scales_for_gemm early-returns) -> 2 swizzles per step
instead of 2*N. Applied to Linear, LayerNormLinear, LayerNormMLP (fc1+fc2) and
GroupedLinear (per expert).

Gated to the cached path (is_first_microbatch is not None) with fsdp_group is
None and not is_fsdp2: FSDP/FSDP2 all-gather weights using the un-swizzled
scale layout, so pre-swizzling is unsupported there. No-op for recipes whose
scales do not require swizzling (e.g. per-tensor FP8). Swizzling is a pure
layout permutation, so numerics are unchanged.

Add tests/pytorch/nvfp4/test_nvfp4_weight_swizzle_cache.py verifying the cached
eager-swizzle path matches the lazy-swizzle baseline (fprop + dgrad) for
Linear/LayerNormLinear/GroupedLinear and that the swizzled flag is persisted.

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling cael-ling requested a review from ksivaman as a code owner June 5, 2026 14:29
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 5, 2026
@greptile-apps

greptile-apps Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR caches GEMM-swizzled weight scale factors for block-scaled recipes (NVFP4, MXFP8) so the per-weight swizzle is performed once at quantize time rather than lazily inside every GEMM call. The optimization reduces the swizzle cost from 2*N to 2 per training step (where N is the number of micro-batches) on the non-FSDP, cached-weight path.

  • Sets weight_quantizer.optimize_for_gemm = True inside _get_weight_quantizers() for Linear, LayerNormLinear, LayerNormMLP (fc1 + fc2), and GroupedLinear (per expert), gated to not self.primary_weights_in_fp8 so that FSDP2 with quantized-parameter all-gather and direct optimizer updates on FP8 weights are unaffected.
  • Adds a new parametrized test file (test_weight_swizzle_in_layers.py) covering workspace-caching flag propagation, primary-FP8-weight exclusion, and bit-exact numerical parity between cached and uncached paths across all four layer types and both swizzling recipes (MXFP8, NVFP4).

Confidence Score: 5/5

Safe to merge; the change is a pure layout optimization with no impact on numerics, and the gating via primary_weights_in_fp8 correctly preserves the unswizzled layout for FSDP2 all-gather and optimizer-update paths.

The optimization is additive and well-scoped: it sets a flag on an existing quantizer property, is a no-op for recipes that do not swizzle, and the new test suite confirms bit-exact parity between the cached pre-swizzled path and the lazy baseline across all four layer types.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/linear.py Adds optimize_for_gemm = True to the weight quantizer in _get_weight_quantizers() when primary weights are not in FP8; logic is correct and idempotent. Comment has minor typos.
transformer_engine/pytorch/module/layernorm_linear.py Mirrors the linear.py change for the single weight quantizer; correct gating via primary_weights_in_fp8.
transformer_engine/pytorch/module/layernorm_mlp.py Sets optimize_for_gemm = True on both fc1 and fc2 weight quantizers; comment accurately describes the FSDP2/quantized_model_init exception.
transformer_engine/pytorch/module/grouped_linear.py Iterates over all per-expert quantizers and sets optimize_for_gemm = True; consistent with existing internal flag gating.
tests/pytorch/test_weight_swizzle_in_layers.py New test file covering all four layer types and both swizzling recipes with three test scenarios.
qa/L0_pytorch_unittest/test.sh Adds the new test file to the L0 CI suite.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[_get_weight_quantizers called during set_meta_tensor] --> B{fp8 or fp8_calibration?}
    B -- No --> C[Return None quantizer no-op]
    B -- Yes --> D{primary_weights_in_fp8?}
    D -- Yes quantized_model_init / FSDP2 --> E[leave optimize_for_gemm = False]
    D -- No normal cached path --> F[Set optimize_for_gemm = True]
    F --> G[Scales pre-swizzled once at quantize time]
    G --> H{is_first_microbatch?}
    H -- None no caching --> I[Swizzle on every GEMM old lazy path]
    H -- True / False cached path --> J[Workspace cached Swizzled scales reused 2N to 2 swizzles per step]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[_get_weight_quantizers called during set_meta_tensor] --> B{fp8 or fp8_calibration?}
    B -- No --> C[Return None quantizer no-op]
    B -- Yes --> D{primary_weights_in_fp8?}
    D -- Yes quantized_model_init / FSDP2 --> E[leave optimize_for_gemm = False]
    D -- No normal cached path --> F[Set optimize_for_gemm = True]
    F --> G[Scales pre-swizzled once at quantize time]
    G --> H{is_first_microbatch?}
    H -- None no caching --> I[Swizzle on every GEMM old lazy path]
    H -- True / False cached path --> J[Workspace cached Swizzled scales reused 2N to 2 swizzles per step]
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into feature/nvfp4-w..." | Re-trigger Greptile

Comment on lines +67 to +72
with te.autocast(enabled=True, recipe=recipe):
out = module(x, is_first_microbatch=is_first)
out.sum().backward()
return out.detach().float(), x.grad.detach().float()


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing LayerNormMLP test coverage

layernorm_mlp.py is one of four files modified by this PR, yet the test suite parametrizes only over ["Linear", "LayerNormLinear"] for both test_weight_swizzle_cache_numerics and test_lazy_path_not_swizzled. The fc1/fc2 two-quantizer path in LayerNormMLP is structurally different from the single-quantizer modules: it independently gates fc1_weight_quantizer.optimize_for_gemm and fc2_weight_quantizer.optimize_for_gemm using separate cache_name_fc1/cache_name_fc2 variables. If either gating expression were wrong (e.g. swapping fc1/fc2 names), existing tests would not catch it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added LayerNormMLP coverage (fc1+fc2 two-quantizer path) to both parametrized tests.

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from FSDP2 condition being irrelevant, LGTM

Comment on lines +178 to +185
@pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"])
def test_lazy_path_not_swizzled(kind):
"""Without weight caching (is_first_microbatch=None) no workspace is created
and the optimization stays off — guards against accidentally always-on."""
torch.manual_seed(0)
device = "cuda"
recipe = NVFP4BlockScaling(disable_stochastic_rounding=True)
module = _make_module(kind, 1024, 1024, device)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
@pytest.mark.parametrize("kind", ["Linear", "LayerNormLinear"])
def test_lazy_path_not_swizzled(kind):
"""Without weight caching (is_first_microbatch=None) no workspace is created
and the optimization stays offguards against accidentally always-on."""
torch.manual_seed(0)
device = "cuda"
recipe = NVFP4BlockScaling(disable_stochastic_rounding=True)
module = _make_module(kind, 1024, 1024, device)
@pytest.mark.parametrize("layer_type", ["Linear", "LayerNormLinear"])
def test_lazy_path_not_swizzled(kind):
"""Without weight caching (is_first_microbatch=None) no workspace is created
and the optimization stays offguards against accidentally always-on."""
torch.manual_seed(0)
device = "cuda"
recipe = NVFP4BlockScaling(disable_stochastic_rounding=True)
module = _make_module(layer_type, 1024, 1024, device)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — test_lazy_path_not_swizzled now parametrizes over all four module kinds.

x = x.detach().clone().requires_grad_(True)
module.zero_grad(set_to_none=True) # per-micro-batch grads (no accumulation)
with te.autocast(enabled=True, recipe=recipe):
out = module(x, is_first_microbatch=is_first)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If absence of m_splits argument is the only reason for creating new test for grouped_linear below, then can we add a check on the module in terms of passing m_splits only if module is GroupedLinear, instead of duplicating the test?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — folded GroupedLinear into the parametrized test_weight_swizzle_cache_numerics, passing m_splits only for GroupedLinear (in _step); removed the duplicated grouped-only test.

Comment on lines +1738 to +1749
# Pre-swizzle (and cache) the weight scale factors when the quantized
# weights are cached across microbatches, so the per-GEMM scale swizzle
# (fprop rowwise + dgrad columnwise, redone every microbatch) collapses
# from 2*num_microbatches kernels to 2 per step per expert. Gated to the
# cached, non-FSDP path (FSDP/FSDP2 all-gather weights with un-swizzled
# scales; see NVFP4Tensor.fsdp_pre_all_gather), so pre-swizzling is
# unsupported there. No-op for non-swizzled recipes (e.g. per-tensor FP8).
if cache_weight and self.fsdp_group is None and not self.is_fsdp2:
for weight_quantizer in weight_quantizers:
if weight_quantizer is not None:
weight_quantizer.optimize_for_gemm = True

@vthumbe1503 vthumbe1503 Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think the comment is relevant In case of FSDP/FSDP2,
For FSDP, The scales are not sharded, and the whole scales are replicated across ranks today. So it doesnt matter if scales are swizzled or not. cc: @denera. Also NVFP4 pre allgather is function specific to FSDP2 not FSDP.
For FSDP2, we havent been caching weights as it causes memory bloating. And weight caching as a mechanism doesnt fit well with fsdp2. This was done for linear and layer_norm_linear but apparently not for grouped_linear in this PR #2805. But fixing that for grouped_linear might be byond scope of this PR. Even if weight caching is still kept as it is, current behavior is to save the entire weight instead of shard in the workspace and so swizzling being present shouldnt cause any issue.

Suggested change
# Pre-swizzle (and cache) the weight scale factors when the quantized
# weights are cached across microbatches, so the per-GEMM scale swizzle
# (fprop rowwise + dgrad columnwise, redone every microbatch) collapses
# from 2*num_microbatches kernels to 2 per step per expert. Gated to the
# cached, non-FSDP path (FSDP/FSDP2 all-gather weights with un-swizzled
# scales; see NVFP4Tensor.fsdp_pre_all_gather), so pre-swizzling is
# unsupported there. No-op for non-swizzled recipes (e.g. per-tensor FP8).
if cache_weight and self.fsdp_group is None and not self.is_fsdp2:
for weight_quantizer in weight_quantizers:
if weight_quantizer is not None:
weight_quantizer.optimize_for_gemm = True
# Pre-swizzle (and cache) the weight scale factors when the quantized
# weights are cached across microbatches, so the per-GEMM scale swizzle
# (fprop rowwise + dgrad columnwise, redone every microbatch) collapses
# from 2*num_microbatches kernels to 2 per step per expert.
# No-op for non-swizzled recipes (e.g. per-tensor FP8).
if cache_weight:
for weight_quantizer in weight_quantizers:
if weight_quantizer is not None:
weight_quantizer.optimize_for_gemm = True

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same applies in other files.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied to all four module files.

cael-ling and others added 2 commits June 17, 2026 00:36
Drop the FSDP/FSDP2 gating on optimize_for_gemm in Linear, LayerNormLinear,
LayerNormMLP and GroupedLinear. FSDP1 replicates (does not shard) the scale
factors, so the swizzle layout is irrelevant there, and weights are not cached
under FSDP2; the guard only added a misleading comment and dead conditions.
Pre-swizzle the weight scales whenever the quantized weight is cached.

Tests:
- Fold the GroupedLinear case into the parametrized
  test_weight_swizzle_cache_numerics by passing m_splits only for
  GroupedLinear, removing the duplicated grouped-only test.
- Add LayerNormMLP coverage (fc1 + fc2 two-quantizer path), generalizing
  the cached-workspace-count assertion per module type.
- Parametrize test_lazy_path_not_swizzled over all four module kinds.

Signed-off-by: Cael Ling <caell@nvidia.com>
@cael-ling

Copy link
Copy Markdown
Contributor Author

Pushed a commit addressing the review: removed the irrelevant FSDP gating across all four modules, merged the GroupedLinear test, and added LayerNormMLP coverage. Please take a look, thanks. @vthumbe1503

@cael-ling cael-ling requested a review from vthumbe1503 June 17, 2026 07:45
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the changes. I think that optimize_for_gemm(weight preswizzling) should be enabled for mostly all use-cases except for primary weights in fp8/fp4.

For primary weights in fp8/fp4 it wont work because of

  1. Dequantization needs in optimizer step update which wont work on swizzled weights
  2. FSDP2 allgather is currently supported only for unswizzled weights.

So lets enable it for most use-cases @cael-ling instead of restricting it to weight caching only.
cael-ling#1

…ache

Enable weight swizzling for most cases
@greptile-apps

greptile-apps Bot commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Want your agent to iterate on Greptile's feedback? Try greploops.

@vthumbe1503

Copy link
Copy Markdown
Collaborator

/te-ci L1 pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants