Skip to content

Enable NVFP4 RHT amax for grouped SReLU MLP#3133

Open
sraman-rgb wants to merge 3 commits into
NVIDIA:mainfrom
sraman-rgb:te-nvfp4-srelu-rht-hadamard
Open

Enable NVFP4 RHT amax for grouped SReLU MLP#3133
sraman-rgb wants to merge 3 commits into
NVIDIA:mainfrom
sraman-rgb:te-nvfp4-srelu-rht-hadamard

Conversation

@sraman-rgb

Copy link
Copy Markdown
Contributor

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sraman-rgb sraman-rgb requested a review from timmoon10 as a code owner June 16, 2026 18:42
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Jun 16, 2026
Signed-off-by: Siddhartha Raman <sraman@nvidia.com>
@sraman-rgb sraman-rgb force-pushed the te-nvfp4-srelu-rht-hadamard branch from fa32e3b to 79def34 Compare June 16, 2026 18:45
@greptile-apps

greptile-apps Bot commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR extends NVFP4 RHT amax support to the grouped SReLU MLP path by adding grouped_gemm_act_hadamard_kernel to GroupedMLP_CuTeGEMMUnary and threading a new use_fc1_act_hadamard_srelu flag through the forward pass. The test helper test_grouped_mlp is generalised to accept an activation parameter and a new dedicated test test_grouped_mlp_nvfp4_rht_srelu is added.

  • grouped_mlp.py: Renames use_fc1_glu_hadamard to use_fc1_act_hadamard, adds a per-class grouped_gemm_act_hadamard_kernel override in GroupedMLP_CuTeGEMMUnary that reuses grouped_gemm_glu_hadamard_wrapper_sm100 with act_func=\"srelu\", and gates the new SReLU hadamard path on activation_is_srelu.
  • test_fusible_ops.py: Parametrises test_grouped_mlp over \"scaled_swiglu\" / \"scaled_srelu\", updates the reference activation computation, and replaces a hard-coded tolerance dict with quantization_tols(quantization) — but the replacement line is missing its indentation, causing a module-level IndentationError.

Confidence Score: 2/5

Not safe to merge: the test file will fail to import due to a missing indentation on the tols assignment, which breaks every test in the module.

The tols = quantization_tols(quantization) line in test_grouped_mlp has zero indentation instead of eight spaces. Python raises an IndentationError the moment the file is imported, making all tests in test_fusible_ops.py unreachable. The production code in grouped_mlp.py is clean and correct.

tests/pytorch/test_fusible_ops.py — line 3606 has a missing indentation that prevents the file from parsing.

Important Files Changed

Filename Overview
tests/pytorch/test_fusible_ops.py Adds activation parameter and SReLU reference path to test_grouped_mlp, plus a dedicated test_grouped_mlp_nvfp4_rht_srelu helper. A missing indentation on the tols assignment (line 3606) produces an IndentationError at import time, breaking all tests in this file.
transformer_engine/pytorch/ops/fused/grouped_mlp.py Renames use_fc1_glu_hadamard to use_fc1_act_hadamard and adds use_fc1_act_hadamard_srelu flag; adds grouped_gemm_act_hadamard_kernel to GroupedMLP_CuTeGEMMUnary that delegates to grouped_gemm_glu_hadamard_wrapper_sm100 with act_func="srelu". Logic and imports look correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[fuser_forward] --> B{use_nvfp4_rht_amax?}
    B -- No --> C[grouped_gemm_activation_kernel]
    B -- Yes --> D{swiglu or srelu?}
    D -- Neither --> C
    D -- Yes --> E[grouped_gemm_act_hadamard_kernel available?]
    E -- No --> C
    E -- Yes --> F{activation_is_srelu?}
    F -- Yes --> G[act_func = srelu]
    F -- No --> H[act_func = _cudnn_act_func]
    G --> I[grouped_gemm_glu_hadamard_wrapper_sm100]
    H --> I
    I --> J[_group_quantize_with_amax_for_grouped_mlp]
    C --> K[grouped_gemm_activation_kernel output]
    J --> L[FC2 GEMM]
    K --> L
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[fuser_forward] --> B{use_nvfp4_rht_amax?}
    B -- No --> C[grouped_gemm_activation_kernel]
    B -- Yes --> D{swiglu or srelu?}
    D -- Neither --> C
    D -- Yes --> E[grouped_gemm_act_hadamard_kernel available?]
    E -- No --> C
    E -- Yes --> F{activation_is_srelu?}
    F -- Yes --> G[act_func = srelu]
    F -- No --> H[act_func = _cudnn_act_func]
    G --> I[grouped_gemm_glu_hadamard_wrapper_sm100]
    H --> I
    I --> J[_group_quantize_with_amax_for_grouped_mlp]
    C --> K[grouped_gemm_activation_kernel output]
    J --> L[FC2 GEMM]
    K --> L
Loading

Reviews (3): Last reviewed commit: "Update tests/pytorch/test_fusible_ops.py" | Re-trigger Greptile

Comment thread tests/pytorch/test_fusible_ops.py

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM mostly except CUDNN guard update that I think is needed.

Comment thread tests/pytorch/test_fusible_ops.py Outdated
"""Fused grouped GEMM activation kernel that also emits NVFP4 RHT amaxes."""
try:
from cudnn import (
grouped_gemm_glu_hadamard_wrapper_sm100,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need new cudnn version for supporting srelu in this kernel? If so, we should update it.

@vthumbe1503

Copy link
Copy Markdown
Collaborator

/te-ci pytorch

Co-authored-by: vthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: Siddhartha Raman Sundara Raman <sraman@nvidia.com>
Comment on lines 3605 to +3606
# Loose tols for sanity checking
tols = {"rtol": 0.125, "atol": 0.25}
tols = quantization_tols(quantization)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 The tols assignment has zero indentation — it is at module scope rather than inside test_grouped_mlp. Python will raise an IndentationError when importing this file, which will break every test in the module. The line needs the same 8-space indent as the surrounding code.

Suggested change
# Loose tols for sanity checking
tols = {"rtol": 0.125, "atol": 0.25}
tols = quantization_tols(quantization)
# Loose tols for sanity checking
tols = quantization_tols(quantization)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants