Skip to content

Avoid unpickling the extra state when not needed#3123

Open
ptrendx wants to merge 8 commits into
NVIDIA:mainfrom
ptrendx:pr_avoid_unpickle
Open

Avoid unpickling the extra state when not needed#3123
ptrendx wants to merge 8 commits into
NVIDIA:mainfrom
ptrendx:pr_avoid_unpickle

Conversation

@ptrendx

@ptrendx ptrendx commented Jun 12, 2026

Copy link
Copy Markdown
Member

Description

Avoids unpickling of the extra state if the recipe is stateless. Adds a guard prompting user to explicitly allow loading of the checkpoint when the unpickling is necessary.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Avoids unpickling of the stateless recipe extra state
  • Adds a guard and environment variable for the delayed scaling recipes

ptrendx added 2 commits June 12, 2026 05:24
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@greptile-apps

greptile-apps Bot commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces a pickle-safety layer for Transformer Engine's FP8 checkpoint extra state. A new _extra_state.py module inspects pickle payloads without executing them, classifying stateless recipes (Float8CurrentScaling, MXFP8BlockScaling, etc.) so their extra state is skipped entirely on both save and load, while legacy delayed-scaling payloads require an explicit NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1 opt-in.

  • _extra_state.py adds _classify_extra_state_pickle_impl which walks pickle opcodes via pickletools.genops to identify recipe class references and delayed-state keys without deserialising — stateless recipes return IGNORE, delayed-scaling or unknown payloads return UNSAFE_LOAD and raise unless the env var is set.
  • get_extra_state in both base.py and op.py now short-circuits for stateless recipes, avoiding unnecessary serialisation; set_extra_state routes all byte-tensor payloads through should_load_extra_state_pickle before calling pickle.loads.
  • Test files are updated to conditionally set the new env var around load_state_dict for FP8 quantization modes; test_recipe.py adds unit coverage for the classifier itself.

Confidence Score: 4/5

Safe to merge for production code paths; a test helper in test_numerics.py has an incomplete env-var setup that will cause load_state_dict to throw if FP8 delayed-scaling is ever enabled in that helper.

The core pickle-classification logic in _extra_state.py and the guards added to base.py and op.py are sound and well-tested. The test_numerics.py change saves and restores UNSAFE_PICKLE_EXTRA_STATE_ENV around load_state_dict but never sets it to '1', unlike every other test updated in this PR. Because _test_e2e_checkpointing_get_model creates a non-FP8 TransformerLayer today the extra state is always an empty tensor and the missing set is never exercised — but the pattern looks like an oversight that would immediately break if FP8 is introduced into that helper.

tests/pytorch/test_numerics.py — the save/restore wrapper around load_state_dict is missing the conditional os.environ set.

Important Files Changed

Filename Overview
transformer_engine/pytorch/_extra_state.py New module implementing pickle-free classification of extra-state payloads; logic is sound for first-party recipes but DYNAMIC/CustomRecipe handling has known asymmetries (see previous review threads).
transformer_engine/pytorch/module/base.py Adds early-return for stateless recipes in get_extra_state and routes set_extra_state through should_load_extra_state_pickle; save/load asymmetry for DYNAMIC/CustomRecipe remains (noted in previous review).
transformer_engine/pytorch/ops/op.py Applies same stateless-recipe skip and pickle guard as base.py; shares the same DYNAMIC save/load asymmetry.
tests/pytorch/test_numerics.py Save/restore env-var wrapper added around load_state_dict but the conditional set to '1' is missing, making the try/finally a no-op and inconsistent with the patterns used in test_checkpoint.py and test_fusible_ops.py.
tests/pytorch/test_checkpoint.py Correctly conditionally sets UNSAFE_PICKLE_EXTRA_STATE_ENV for fp8 quantization and restores it after load_state_dict.
tests/pytorch/test_fusible_ops.py Correctly conditionally sets UNSAFE_PICKLE_EXTRA_STATE_ENV for fp8/fp8_delayed_scaling quantization modes.
tests/pytorch/test_recipe.py New tests covering stateless ignore, DYNAMIC ignore, and UNSAFE_LOAD/opt-in cases; coverage is comprehensive for first-party recipes.
qa/L0_pytorch_unittest/test.sh Attention tests now run with NVTE_ALLOW_UNSAFE_PICKLE_EXTRA_STATE=1; looks correct.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["set_extra_state(state)"] --> B{state is None\nor numel==0?}
    B -- Yes --> Z[return early]
    B -- No --> C["state_bytes = state.tobytes()"]
    C --> D["should_load_extra_state_pickle(state_bytes)"]
    D --> E["_classify_extra_state_pickle(data)"]
    E --> F["_classify_extra_state_pickle_impl(data)\n(walk opcodes via pickletools.genops)"]
    F --> G{has_recipe_key?}
    G -- No --> H[UNSAFE_LOAD\nlegacy TE 1.x]
    G -- Yes --> I{STATEFUL_FP8_DELAYED_SCALING\nin policies?}
    I -- Yes --> H
    I -- No --> J{has_delayed\n_state_keys?}
    J -- Yes --> H
    J -- No --> K{policies empty?}
    K -- Yes --> H
    K -- No --> L[IGNORE\nstateless or DYNAMIC\nwithout delayed state]
    H --> M{NVTE_ALLOW_UNSAFE\n_PICKLE_EXTRA_STATE=1?}
    M -- Yes --> N["return True → pickle.loads"]
    M -- No --> O[raise RuntimeError\nwith advisory]
    L --> P[return False\nskip unpickling]
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A["set_extra_state(state)"] --> B{state is None\nor numel==0?}
    B -- Yes --> Z[return early]
    B -- No --> C["state_bytes = state.tobytes()"]
    C --> D["should_load_extra_state_pickle(state_bytes)"]
    D --> E["_classify_extra_state_pickle(data)"]
    E --> F["_classify_extra_state_pickle_impl(data)\n(walk opcodes via pickletools.genops)"]
    F --> G{has_recipe_key?}
    G -- No --> H[UNSAFE_LOAD\nlegacy TE 1.x]
    G -- Yes --> I{STATEFUL_FP8_DELAYED_SCALING\nin policies?}
    I -- Yes --> H
    I -- No --> J{has_delayed\n_state_keys?}
    J -- Yes --> H
    J -- No --> K{policies empty?}
    K -- Yes --> H
    K -- No --> L[IGNORE\nstateless or DYNAMIC\nwithout delayed state]
    H --> M{NVTE_ALLOW_UNSAFE\n_PICKLE_EXTRA_STATE=1?}
    M -- Yes --> N["return True → pickle.loads"]
    M -- No --> O[raise RuntimeError\nwith advisory]
    L --> P[return False\nskip unpickling]
Loading

Reviews (6): Last reviewed commit: "Fix tests" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/_extra_state.py Outdated

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this seems like a reasonable fix, although I have some design suggestions and nits. FP8 delayed scaling still has pickling, but at least we can avoid it for more modern recipes.

Comment thread transformer_engine/common/recipe/__init__.py Outdated
Comment thread transformer_engine/pytorch/_extra_state.py Outdated
"""

STATELESS = "stateless"
STATEFUL = "stateful"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may have stateful recipes in the future, but we've learned our lesson not to naively pickle. We should make clear that this particular enum value represents stateful recipes with unsafe pickling.

Suggested change
STATEFUL = "stateful"
STATEFUL_FP8_DELAYED_SCALING = "stateful_fp8_delayed_scaling"

Other possible names could be STATEFUL_PICKLE or STATEFUL_UNSAFE.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment thread transformer_engine/pytorch/_extra_state.py
@ksivaman

Copy link
Copy Markdown
Member

/te-ci pytorch

ksivaman added 2 commits June 23, 2026 09:13
Signed-off-by: ksivamani <ksivamani@nvidia.com>
@ksivaman

Copy link
Copy Markdown
Member

/te-ci pytorch

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment on lines +852 to +859
old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV)
try:
block.load_state_dict(loaded_state_dict)
finally:
if old_unsafe_extra_state is None:
os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None)
else:
os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing env-var set makes the save/restore a no-op

The try/finally block saves and unconditionally restores UNSAFE_PICKLE_EXTRA_STATE_ENV, but never actually sets it to "1" before calling load_state_dict. Every other test fixed in this PR (test_checkpoint.py line 136, test_fusible_ops.py line 3222) follows the pattern: save → conditionally set to "1" → try/finally restore. Here the "set" step is absent, so the entire save/restore is a no-op.

_test_e2e_checkpointing_get_model creates a plain TransformerLayer without FP8, so fp8_checkpoint is False and the extra state is an empty tensor today, which avoids the runtime error. If this helper is ever extended with FP8 delayed-scaling (a natural step), load_state_dict will raise a RuntimeError because the env var will never be set.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants