Enable NVFP4 RHT amax for grouped SReLU MLP#3133
Conversation
Signed-off-by: Siddhartha Raman <sraman@nvidia.com>
fa32e3b to
79def34
Compare
Greptile SummaryThis PR extends NVFP4 RHT amax support to the grouped SReLU MLP path by adding
Confidence Score: 2/5Not safe to merge: the test file will fail to import due to a missing indentation on the The tests/pytorch/test_fusible_ops.py — line 3606 has a missing indentation that prevents the file from parsing. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fuser_forward] --> B{use_nvfp4_rht_amax?}
B -- No --> C[grouped_gemm_activation_kernel]
B -- Yes --> D{swiglu or srelu?}
D -- Neither --> C
D -- Yes --> E[grouped_gemm_act_hadamard_kernel available?]
E -- No --> C
E -- Yes --> F{activation_is_srelu?}
F -- Yes --> G[act_func = srelu]
F -- No --> H[act_func = _cudnn_act_func]
G --> I[grouped_gemm_glu_hadamard_wrapper_sm100]
H --> I
I --> J[_group_quantize_with_amax_for_grouped_mlp]
C --> K[grouped_gemm_activation_kernel output]
J --> L[FC2 GEMM]
K --> L
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A[fuser_forward] --> B{use_nvfp4_rht_amax?}
B -- No --> C[grouped_gemm_activation_kernel]
B -- Yes --> D{swiglu or srelu?}
D -- Neither --> C
D -- Yes --> E[grouped_gemm_act_hadamard_kernel available?]
E -- No --> C
E -- Yes --> F{activation_is_srelu?}
F -- Yes --> G[act_func = srelu]
F -- No --> H[act_func = _cudnn_act_func]
G --> I[grouped_gemm_glu_hadamard_wrapper_sm100]
H --> I
I --> J[_group_quantize_with_amax_for_grouped_mlp]
C --> K[grouped_gemm_activation_kernel output]
J --> L[FC2 GEMM]
K --> L
Reviews (3): Last reviewed commit: "Update tests/pytorch/test_fusible_ops.py" | Re-trigger Greptile |
vthumbe1503
left a comment
There was a problem hiding this comment.
LGTM mostly except CUDNN guard update that I think is needed.
| """Fused grouped GEMM activation kernel that also emits NVFP4 RHT amaxes.""" | ||
| try: | ||
| from cudnn import ( | ||
| grouped_gemm_glu_hadamard_wrapper_sm100, |
There was a problem hiding this comment.
Do we need new cudnn version for supporting srelu in this kernel? If so, we should update it.
|
/te-ci pytorch |
Co-authored-by: vthumbe1503 <vthumbe@nvidia.com> Signed-off-by: Siddhartha Raman Sundara Raman <sraman@nvidia.com>
| # Loose tols for sanity checking | ||
| tols = {"rtol": 0.125, "atol": 0.25} | ||
| tols = quantization_tols(quantization) |
There was a problem hiding this comment.
The
tols assignment has zero indentation — it is at module scope rather than inside test_grouped_mlp. Python will raise an IndentationError when importing this file, which will break every test in the module. The line needs the same 8-space indent as the surrounding code.
| # Loose tols for sanity checking | |
| tols = {"rtol": 0.125, "atol": 0.25} | |
| tols = quantization_tols(quantization) | |
| # Loose tols for sanity checking | |
| tols = quantization_tols(quantization) |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: