Avoid unpickling the extra state when not needed#3123
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Greptile SummaryThis PR introduces a pickle-safety layer for Transformer Engine's FP8 checkpoint extra state. A new
Confidence Score: 4/5Safe to merge for production code paths; a test helper in test_numerics.py has an incomplete env-var setup that will cause load_state_dict to throw if FP8 delayed-scaling is ever enabled in that helper. The core pickle-classification logic in _extra_state.py and the guards added to base.py and op.py are sound and well-tested. The test_numerics.py change saves and restores UNSAFE_PICKLE_EXTRA_STATE_ENV around load_state_dict but never sets it to '1', unlike every other test updated in this PR. Because _test_e2e_checkpointing_get_model creates a non-FP8 TransformerLayer today the extra state is always an empty tensor and the missing set is never exercised — but the pattern looks like an oversight that would immediately break if FP8 is introduced into that helper. tests/pytorch/test_numerics.py — the save/restore wrapper around load_state_dict is missing the conditional os.environ set. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["set_extra_state(state)"] --> B{state is None\nor numel==0?}
B -- Yes --> Z[return early]
B -- No --> C["state_bytes = state.tobytes()"]
C --> D["should_load_extra_state_pickle(state_bytes)"]
D --> E["_classify_extra_state_pickle(data)"]
E --> F["_classify_extra_state_pickle_impl(data)\n(walk opcodes via pickletools.genops)"]
F --> G{has_recipe_key?}
G -- No --> H[UNSAFE_LOAD\nlegacy TE 1.x]
G -- Yes --> I{STATEFUL_FP8_DELAYED_SCALING\nin policies?}
I -- Yes --> H
I -- No --> J{has_delayed\n_state_keys?}
J -- Yes --> H
J -- No --> K{policies empty?}
K -- Yes --> H
K -- No --> L[IGNORE\nstateless or DYNAMIC\nwithout delayed state]
H --> M{NVTE_ALLOW_UNSAFE\n_PICKLE_EXTRA_STATE=1?}
M -- Yes --> N["return True → pickle.loads"]
M -- No --> O[raise RuntimeError\nwith advisory]
L --> P[return False\nskip unpickling]
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
A["set_extra_state(state)"] --> B{state is None\nor numel==0?}
B -- Yes --> Z[return early]
B -- No --> C["state_bytes = state.tobytes()"]
C --> D["should_load_extra_state_pickle(state_bytes)"]
D --> E["_classify_extra_state_pickle(data)"]
E --> F["_classify_extra_state_pickle_impl(data)\n(walk opcodes via pickletools.genops)"]
F --> G{has_recipe_key?}
G -- No --> H[UNSAFE_LOAD\nlegacy TE 1.x]
G -- Yes --> I{STATEFUL_FP8_DELAYED_SCALING\nin policies?}
I -- Yes --> H
I -- No --> J{has_delayed\n_state_keys?}
J -- Yes --> H
J -- No --> K{policies empty?}
K -- Yes --> H
K -- No --> L[IGNORE\nstateless or DYNAMIC\nwithout delayed state]
H --> M{NVTE_ALLOW_UNSAFE\n_PICKLE_EXTRA_STATE=1?}
M -- Yes --> N["return True → pickle.loads"]
M -- No --> O[raise RuntimeError\nwith advisory]
L --> P[return False\nskip unpickling]
Reviews (6): Last reviewed commit: "Fix tests" | Re-trigger Greptile |
timmoon10
left a comment
There was a problem hiding this comment.
Overall this seems like a reasonable fix, although I have some design suggestions and nits. FP8 delayed scaling still has pickling, but at least we can avoid it for more modern recipes.
| """ | ||
|
|
||
| STATELESS = "stateless" | ||
| STATEFUL = "stateful" |
There was a problem hiding this comment.
We may have stateful recipes in the future, but we've learned our lesson not to naively pickle. We should make clear that this particular enum value represents stateful recipes with unsafe pickling.
| STATEFUL = "stateful" | |
| STATEFUL_FP8_DELAYED_SCALING = "stateful_fp8_delayed_scaling" |
Other possible names could be STATEFUL_PICKLE or STATEFUL_UNSAFE.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: ksivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
Signed-off-by: ksivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
| old_unsafe_extra_state = os.environ.get(UNSAFE_PICKLE_EXTRA_STATE_ENV) | ||
| try: | ||
| block.load_state_dict(loaded_state_dict) | ||
| finally: | ||
| if old_unsafe_extra_state is None: | ||
| os.environ.pop(UNSAFE_PICKLE_EXTRA_STATE_ENV, None) | ||
| else: | ||
| os.environ[UNSAFE_PICKLE_EXTRA_STATE_ENV] = old_unsafe_extra_state |
There was a problem hiding this comment.
Missing env-var set makes the save/restore a no-op
The try/finally block saves and unconditionally restores UNSAFE_PICKLE_EXTRA_STATE_ENV, but never actually sets it to "1" before calling load_state_dict. Every other test fixed in this PR (test_checkpoint.py line 136, test_fusible_ops.py line 3222) follows the pattern: save → conditionally set to "1" → try/finally restore. Here the "set" step is absent, so the entire save/restore is a no-op.
_test_e2e_checkpointing_get_model creates a plain TransformerLayer without FP8, so fp8_checkpoint is False and the extra state is an empty tensor today, which avoids the runtime error. If this helper is ever extended with FP8 delayed-scaling (a natural step), load_state_dict will raise a RuntimeError because the env var will never be set.
Description
Avoids unpickling of the extra state if the recipe is stateless. Adds a guard prompting user to explicitly allow loading of the checkpoint when the unpickling is necessary.
Type of change
Changes
Please list the changes introduced in this PR: