Skip to content

[PyTorch] Preserve fprop operands for dequantized backward override#3141

Open
negvet wants to merge 2 commits into
NVIDIA:mainfrom
negvet:fix_dequantized_override_save_original_input
Open

[PyTorch] Preserve fprop operands for dequantized backward override#3141
negvet wants to merge 2 commits into
NVIDIA:mainfrom
negvet:fix_dequantized_override_save_original_input

Conversation

@negvet

@negvet negvet commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator

Description

Follow-up to #2644, which introduced NVTE_BACKWARD_OVERRIDE=high_precision|dequantized.

high_precision is intended to use original unquantized tensor in backward, while dequantized is intended to use dequantized tensor from the forward-quantized one. However, save_original_input=True could override the dequantized behavior in Linear and GroupedLinear, causing backward to use the original input instead of the fprop-quantized operand.

This PR makes the override semantics explicit:

  • backward_override="high_precision" forces save_original_input=True
  • backward_override="dequantized" forces save_original_input=False

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested a review from ksivaman as a code owner June 23, 2026 13:19
@negvet

negvet commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

cc @zianglih

@negvet

negvet commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci L0 L1

@greptile-apps

greptile-apps Bot commented Jun 23, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes a semantic conflict between backward_override=\"dequantized\" and save_original_input=True in Linear and GroupedLinear modules. Before the fix, a user-specified save_original_input=True on the module would silently override the recipe's dequantized backward behavior, causing backward to use the original (unquantized) input instead of the fprop-quantized operand.

  • linear.py / grouped_linear.py: Adds elif backward_override == \"dequantized\": save_original_input = False alongside the pre-existing high_precision branch, making both override semantics explicit and authoritative.
  • Tests: Three new parametrized tests verify (a) dequantized correctly ignores save_original_input=True for both Linear and GroupedLinear, and (b) high_precision correctly forces save_original_input=True even when the module was constructed with False.

Confidence Score: 4/5

Safe to merge; the two-line production change is a targeted guard that correctly forces save_original_input=False when backward_override=dequantized, matching the already-existing pattern for high_precision.

The fix and its test coverage are clean and well-structured. The only gap is that GroupedLinear + high_precision overriding save_original_input=False is tested for Linear but has no equivalent test for GroupedLinear, leaving a small blind spot in case that path regresses.

The test file would benefit from a GroupedLinear counterpart to test_linear_backward_override_high_precision_forces_save_original_input; production files are straightforward and need no further review.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/linear.py Adds elif backward_override == "dequantized": save_original_input = False in _linear_forward_impl; mirrors the existing high_precision branch and closes the semantic gap.
transformer_engine/pytorch/module/grouped_linear.py Same two-line fix inside _GroupedLinear.forward; correctly placed in the general (non-grouped-tensor) path that actually reads save_original_input. The grouped-tensor fast-path already hardcodes ctx.save_original_input = False and is skipped when backward_override is not None.
tests/pytorch/test_backward_override.py Three new tests covering dequantized override for Linear and GroupedLinear (verifies quantized saved operand + numerical parity with save_original_input=False reference) and high_precision override for Linear (verifies saved operand is a plain torch.Tensor). Mirrors the fix symmetrically.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["forward() called\n(Linear / GroupedLinear)"] --> B{fp8 enabled?}
    B -- No --> C["backward_override = None"]
    B -- Yes --> D["backward_override =\nrecipe.backward_override"]
    C --> E{override value?}
    D --> E
    E -- high_precision --> F["save_original_input = True\n(force original tensor)"]
    E -- dequantized --> G["save_original_input = False\n(NEW: force quantized tensor)"]
    E -- None --> H["save_original_input =\nmodule constructor value"]
    F --> I["Save plain torch.Tensor\n(original input) for backward"]
    G --> J["Save QuantizedTensor\n(rowwise-only FP8) for backward"]
    H --> K{constructor value?}
    K -- True --> I
    K -- False --> J
    I --> L["Backward uses original\nunquantized activations"]
    J --> M["Backward dequantizes\nfprop-quantized operand"]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["forward() called\n(Linear / GroupedLinear)"] --> B{fp8 enabled?}
    B -- No --> C["backward_override = None"]
    B -- Yes --> D["backward_override =\nrecipe.backward_override"]
    C --> E{override value?}
    D --> E
    E -- high_precision --> F["save_original_input = True\n(force original tensor)"]
    E -- dequantized --> G["save_original_input = False\n(NEW: force quantized tensor)"]
    E -- None --> H["save_original_input =\nmodule constructor value"]
    F --> I["Save plain torch.Tensor\n(original input) for backward"]
    G --> J["Save QuantizedTensor\n(rowwise-only FP8) for backward"]
    H --> K{constructor value?}
    K -- True --> I
    K -- False --> J
    I --> L["Backward uses original\nunquantized activations"]
    J --> M["Backward dequantizes\nfprop-quantized operand"]
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

assert_close(y_test_detached, y_ref, rtol=0, atol=0, check_dtype=True)
assert_close(dx_test, dx_ref, rtol=0, atol=0, check_dtype=True)
for test_dw, ref_dw in zip(dw_test, dw_ref):
assert_close(test_dw, ref_dw, rtol=0, atol=0, check_dtype=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing GroupedLinear + high_precision test

There is a symmetric gap in the new test suite: test_linear_backward_override_high_precision_forces_save_original_input verifies that high_precision overrides save_original_input=False for te.Linear, but no equivalent test exists for te.GroupedLinear. The high_precision branch in grouped_linear.py has been in place since #2644 and the lack of coverage means a future regression there would go undetected by this test file.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant