From 0721ed2544596d905fd90786ca23a88a5e15a703 Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Wed, 15 Apr 2026 11:39:19 +0800 Subject: [PATCH 01/16] [Debug] Add AutoswitchGEmm for Debug Precision Tool --- docs/debug/1_getting_started.rst | 48 +- docs/debug/2_config_file_structure.rst | 22 + docs/debug/3_api_features.rst | 1 + docs/debug/autoswitch_gemm_example.yaml | 72 +++ .../debug/features/autoswitch_gemm.py | 585 ++++++++++++++++++ .../debug/pytorch/debug_quantization.py | 87 ++- transformer_engine/pytorch/module/base.py | 2 +- 7 files changed, 800 insertions(+), 17 deletions(-) create mode 100644 docs/debug/autoswitch_gemm_example.yaml create mode 100644 transformer_engine/debug/features/autoswitch_gemm.py diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index cce2616998..ac36acf990 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -149,10 +149,11 @@ Inspecting the logs ------------------- -Let's look at the files with the logs. Two files will be created: +Let's look at the files with the logs. At least two files will be created: 1. debug logs. 2. statistics logs. +3. optional feature-specific logs (for example AutoswitchGemm metrics). Let's look inside them! @@ -214,6 +215,51 @@ The second log file (``nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank- INFO - transformer_layer.self_attention.layernorm_qkv_activation_std iteration=000004 value=0.9996 INFO - transformer_layer.self_attention.layernorm_qkv_activation_l1_norm iteration=000004 value=130776.7969 +AutoswitchGemm quick guide +-------------------------- + +``AutoswitchGemm`` monitors quantization quality and can dynamically switch selected GEMMs +to high precision when thresholds are exceeded. + +Minimal config example: + +.. code-block:: yaml + + autoswitch_fc_layers: + enabled: True + layers: + layer_types: [fc1, fc2] + transformer_engine: + AutoswitchGemm: + enabled: True + gemms: [fprop, dgrad, wgrad] + underflow_threshold_pct: 1.0 + mse_threshold: 1.0e-4 + # Needed only if the layer uses fp8 model parameters and + # you want fprop/dgrad to be able to switch to high precision. + allow_fp8_model_params_dequantized_weight: False + freq: 1 + +Behavior summary: + +1. For each ``(layer, gemm)``, AutoswitchGemm tracks the latest tensor metrics and applies + OR logic across monitored tensors: if any tensor breaches thresholds, that GEMM switches. +2. Metrics computed in iteration ``n`` are consumed in iteration ``n`` only. +3. If thresholds are not breached in the current iteration, the GEMM stays quantized. + +When AutoswitchGemm is enabled, an additional directory is created under ``log_dir``: + +``nvdlfw_inspect_autoswitchgemm_logs/nvdlfw_inspect_globalrank-.log`` + +It contains per-rank, per-iteration metrics such as: + +- ``___underflow_pct`` +- ``___mse`` +- ``__quantized_enabled`` +- ``__disable_until_iter`` +- ``__switch_blocked_fp8_model_params`` +- ``__fp8_model_params_dequantized_fallback`` + Logging using TensorBoard ------------------------- diff --git a/docs/debug/2_config_file_structure.rst b/docs/debug/2_config_file_structure.rst index 3ade970b57..28da6beab3 100644 --- a/docs/debug/2_config_file_structure.rst +++ b/docs/debug/2_config_file_structure.rst @@ -220,6 +220,28 @@ We can use both structs for tensors and GEMMs. The tensors_struct should be nest tensor_feature_param2: value gemm_feature_param1: value +AutoswitchGemm notes +-------------------- + +``AutoswitchGemm`` supports both global and per-GEMM configuration. + +- Use ``gemms: [...]`` for one shared policy. +- Use ``gemms_struct`` to set per-GEMM thresholds. + +If ``tensors``/``tensors_struct`` are omitted, monitored tensors are inferred from GEMMs: + +- ``fprop`` -> ``activation``, ``weight`` +- ``dgrad`` -> ``gradient``, ``weight`` +- ``wgrad`` -> ``activation``, ``gradient`` + +Other important keys: + +- ``underflow_threshold_pct``: switch trigger based on underflow percentage. +- ``mse_threshold``: switch trigger based on quantization MSE. +- metrics are consumed in the same iteration where they are computed. +- ``allow_fp8_model_params_dequantized_weight``: allows ``fprop``/``dgrad`` switching + for layers with FP8 model parameters by using dequantized temporary weights. + Enabling or Disabling Sections and Features ------------------------------------------- diff --git a/docs/debug/3_api_features.rst b/docs/debug/3_api_features.rst index a8a644d5b5..1972a3d1d8 100644 --- a/docs/debug/3_api_features.rst +++ b/docs/debug/3_api_features.rst @@ -10,6 +10,7 @@ Debug features .. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats .. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats .. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM +.. autoapiclass:: transformer_engine.debug.features.autoswitch_gemm.AutoswitchGemm .. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer .. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling .. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant diff --git a/docs/debug/autoswitch_gemm_example.yaml b/docs/debug/autoswitch_gemm_example.yaml new file mode 100644 index 0000000000..c24462a67e --- /dev/null +++ b/docs/debug/autoswitch_gemm_example.yaml @@ -0,0 +1,72 @@ +# Example config for transformer_engine.debug.features.autoswitch_gemm.AutoswitchGemm +# +# Usage: +# import nvdlfw_inspect.api as debug_api +# debug_api.initialize( +# config_file="docs/debug/autoswitch_gemm_example.yaml", +# feature_dirs=["transformer_engine/debug/features"], +# log_dir="./log", +# ) +# ... +# debug_api.step() # call once per training step + +autoswitch_attention_blocks: + enabled: True + layers: + # Match attention linear layers, e.g. *.qkv / *.proj + layer_name_regex_pattern: ".*(qkv|proj).*" + transformer_engine: + AutoswitchGemm: + enabled: True + + # Optional. If omitted, tensors are inferred from selected gemms: + # fprop -> [activation, weight], dgrad -> [gradient, weight], + # wgrad -> [activation, gradient]. + tensors: [activation, weight, gradient] + + # Per-GEMM switching policy. + gemms_struct: + - gemm: fprop + underflow_threshold_pct: 1.0 + mse_threshold: 1.0e-4 + - gemm: dgrad + underflow_threshold_pct: 1.5 + mse_threshold: 1.5e-4 + - gemm: wgrad + underflow_threshold_pct: 2.0 + mse_threshold: 2.0e-4 + + # For layers with fp8 model parameters: + # - False: keep fprop/dgrad quantized + # - True: allow high-precision switch via temporary dequantized weights + allow_fp8_model_params_dequantized_weight: False + + # Collect metrics every step after warmup. + freq: 1 + start_step: 10 + end_step: 5000 + + +autoswitch_mlp_blocks: + enabled: True + layers: + layer_types: [fc1, fc2] + transformer_engine: + AutoswitchGemm: + enabled: True + + # Simpler global policy (shared by selected GEMMs). + gemms: [fprop, wgrad] + tensors: [activation, weight, gradient] + + underflow_threshold_pct: 3.0 + mse_threshold: 3.0e-4 + + # Example sparse monitoring windows. + freq: 2 + start_end_list: + - [0, 300] + - [800, 3000] + +# Autoswitch per-rank metrics are written to: +# /nvdlfw_inspect_autoswitchgemm_logs/nvdlfw_inspect_globalrank-.log diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py new file mode 100644 index 0000000000..807947f627 --- /dev/null +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -0,0 +1,585 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""AutoswitchGemm Feature support for nvidia-dlframework-inspect.""" + +import copy +import logging +import os +from typing import Dict, Optional, Set, Tuple + +import torch +import torch.distributed as dist + +import nvdlfw_inspect.api as debug_api +from nvdlfw_inspect.logging import get_logger +from nvdlfw_inspect.registry import Registry, api_method + +from transformer_engine.debug.features.api import TEConfigAPIMapper +from transformer_engine.debug.features.utils import next_enabled_iter + + +class _AutoswitchGemmMetricLogger: + """Writes per-rank autoswitch metrics to a dedicated log file.""" + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if _AutoswitchGemmMetricLogger._initialized: + return + self.root_dir = None + self.log_file = None + self.logger = None + _AutoswitchGemmMetricLogger._initialized = True + + @staticmethod + def _get_rank() -> int: + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return 0 + + def _expected_paths(self, root_log_dir: str) -> Tuple[str, str]: + rank = self._get_rank() + root_dir = os.path.join(root_log_dir, "nvdlfw_inspect_autoswitchgemm_logs") + log_file = os.path.join(root_dir, f"nvdlfw_inspect_globalrank-{rank}.log") + return root_dir, log_file + + def initialize(self, root_log_dir: str) -> None: + """Initialize rank-local logger under autoswitch log directory.""" + root_dir, log_file = self._expected_paths(root_log_dir) + os.makedirs(root_dir, exist_ok=True) + + rank = self._get_rank() + logger_name = f"nvdlfw_inspect.autoswitchgemm.rank{rank}" + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + logger.propagate = False + + for handler in list(logger.handlers): + logger.removeHandler(handler) + handler.close() + + file_handler = logging.FileHandler(log_file, mode="a") + file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) + logger.addHandler(file_handler) + + self.root_dir = root_dir + self.log_file = log_file + self.logger = logger + + def ensure_initialized(self, root_log_dir: Optional[str]) -> bool: + """Ensure logger tracks current debug session's root log dir.""" + if not root_log_dir: + return False + expected_root_dir, expected_log_file = self._expected_paths(root_log_dir) + if ( + self.logger is None + or self.root_dir != expected_root_dir + or self.log_file != expected_log_file + or not os.path.isdir(expected_root_dir) + ): + self.initialize(root_log_dir) + return self.logger is not None + + def log_scalar( + self, + layer_name: str, + gemm: str, + metric_name: str, + iteration: int, + value: float, + ) -> None: + """Log metric in LogTensorStats-like `iteration/value` format.""" + if self.logger is None: + return + metric_key = f"{layer_name}_{gemm}_{metric_name}" + self.logger.info( + f"{metric_key} \t\t\t\t iteration={iteration:06d} \t\t\t\t value={value:.8f}" + ) + + +def _get_autoswitch_metric_logger() -> _AutoswitchGemmMetricLogger: + """Get singleton autoswitch metric logger.""" + return _AutoswitchGemmMetricLogger() + + +class _GemmSwitchState: + """Autoswitch state tracked independently for each (layer, gemm).""" + + def __init__(self): + self.disable_until_iter = -1 + self.last_applied_metric_snapshot = None + self.last_reason = "" + + +@Registry.register_feature(namespace="transformer_engine") +class AutoswitchGemm(TEConfigAPIMapper): + """ + Dynamically switches selected GEMMs between quantized and high-precision execution. + + The feature continuously monitors quantization quality for selected tensors and, + when quality degrades beyond configured thresholds, temporarily disables quantized + GEMM for the affected operation. + + The decision is made per `(layer_name, gemm)`: + + - `fp8_gemm_enabled(..., gemm="fprop")` controls FPROP GEMM + - `fp8_gemm_enabled(..., gemm="dgrad")` controls DGRAD GEMM + - `fp8_gemm_enabled(..., gemm="wgrad")` controls WGRAD GEMM + + The API name `fp8_gemm_enabled` is kept for backward compatibility with the + debug API; the switch applies to all quantized formats supported by TE. + When multiple tensors are monitored for a GEMM, their metrics are aggregated + with OR semantics: if any monitored tensor breaches thresholds, the GEMM + switches to high precision. + + Parameters + ---------- + + gemms / gemms_struct: List[str] + GEMMs to control: + + - fprop + - dgrad + - wgrad + + tensors / tensors_struct: Optional[List[str]] + Tensors to monitor: + + - activation + - weight + - gradient + + If omitted, tensors are inferred from selected GEMMs: + + - fprop -> activation, weight + - dgrad -> gradient, weight + - wgrad -> activation, gradient + + underflow_threshold_pct: float, default = 5.0 + Trigger switch to high precision if underflow percentage exceeds this value. + + mse_threshold: float, default = 1e-4 + Trigger switch to high precision if quantization MSE exceeds this value. + + The switch decision is same-iteration only: + metrics computed at iteration `n` are consumed in iteration `n`. + There is no cross-iteration hold window. + + allow_fp8_model_params_dequantized_weight: bool, default = False + If True, allows `fprop`/`dgrad` to switch to high precision even when + fp8 model parameters are enabled by using a temporary dequantized weight + tensor for GEMM execution. + If False, `fprop`/`dgrad` stay quantized for such layers. + + freq/start_step/end_step/start_end_list: Optional + Sampling controls for tensor inspection calls. + + Example + ------- + .. code-block:: yaml + + example_autoswitch_gemm: + enabled: True + layers: + layer_types: [qkv] + transformer_engine: + AutoswitchGemm: + enabled: True + gemms: [fprop, dgrad, wgrad] + underflow_threshold_pct: 3.0 + mse_threshold: 1e-4 + # decision is computed and consumed in the same iteration + """ + + _GEMM_TO_TENSORS = { + "fprop": {"activation", "weight"}, + "dgrad": {"gradient", "weight"}, + "wgrad": {"activation", "gradient"}, + } + + # Mirrors DebugQuantizer's internal mapping. + _TENSOR_TO_GEMMS = { + "weight": ("fprop", "dgrad"), + "activation": ("fprop", "wgrad"), + "gradient": ("dgrad", "wgrad"), + "output": ("fprop", None), + "wgrad": ("wgrad", None), + "dgrad": ("dgrad", None), + } + + _DEFAULT_UNDERFLOW_THRESHOLD_PCT = 5.0 + _DEFAULT_MSE_THRESHOLD = 1e-4 + _DEFAULT_ALLOW_FP8_MODEL_PARAMS_DEQUANTIZED_WEIGHT = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._gemm_state: Dict[Tuple[str, str], _GemmSwitchState] = {} + self._latest_metrics: Dict[Tuple[str, str], Dict[str, float | int | str]] = {} + self._layer_has_fp8_model_params: Dict[str, bool] = {} + + def parse_config_and_api(self, config, **kwargs): + """ + Parse config for GEMM-routing and tensor-inspection APIs. + + Unlike the default TEConfigAPIMapper behavior, this implementation supports + tensor inspection even when `tensors` is omitted by inferring monitored + tensors from selected GEMMs. + """ + processed_config = None + config_copy = copy.deepcopy(config) + + gemm = kwargs.get("gemm", None) + tensor_name = kwargs.get("tensor_name", None) + + if gemm is not None and tensor_name is None: + processed_config = self._process_transformer_engine_config(config_copy, **kwargs) + elif tensor_name is not None: + if "tensors" in config_copy or "tensors_struct" in config_copy: + processed_config = self._process_tensor_config(config_copy, tensor_name) + else: + monitored_tensors = self._infer_monitored_tensors(config_copy) + if tensor_name not in monitored_tensors: + return False, None + processed_config = config_copy + processed_config["tensor"] = tensor_name + + if not processed_config: + return False, None + + if "enabled" in processed_config: + processed_config.pop("enabled") + + return True, processed_config + + def _infer_monitored_tensors(self, config: Dict) -> Set[str]: + """Infer tensors to inspect from configured GEMMs.""" + configured_gemms = self._extract_configured_gemms(config) + if not configured_gemms: + configured_gemms = set(self._GEMM_TO_TENSORS.keys()) + + tensors = set() + for gemm in configured_gemms: + self._validate_gemm(gemm) + tensors.update(self._GEMM_TO_TENSORS[gemm]) + return tensors + + @staticmethod + def _extract_configured_gemms(config: Dict) -> Set[str]: + """Extract GEMM names from config keys `gemm`, `gemms`, and `gemms_struct`.""" + gemms = set() + if "gemm" in config: + gemms.add(config["gemm"]) + if "gemms" in config: + gemms.update(config["gemms"]) + if "gemms_struct" in config: + for cfg in config["gemms_struct"]: + if "gemm" in cfg: + gemms.add(cfg["gemm"]) + return gemms + + @staticmethod + def _config_float(config: Dict, key: str, default: Optional[float]) -> Optional[float]: + """Read optional float value from config.""" + value = config.get(key, default) + if value is None: + return None + return float(value) + + @staticmethod + def _config_bool(config: Dict, key: str, default: bool) -> bool: + """Read bool value from config.""" + value = config.get(key, default) + if isinstance(value, str): + return value.strip().lower() in ("1", "true", "yes", "on") + return bool(value) + + @staticmethod + def _get_root_log_dir() -> Optional[str]: + """Best-effort retrieval of nvdlfw_inspect root log directory.""" + try: + root_log_dir = getattr(get_logger(), "root_log_dir", None) + except Exception: # pylint: disable=broad-except + return None + return root_log_dir + + def _get_metrics_logger(self) -> Optional[_AutoswitchGemmMetricLogger]: + """Return initialized autoswitch metric logger if log dir is available.""" + metric_logger = _get_autoswitch_metric_logger() + if metric_logger.ensure_initialized(self._get_root_log_dir()): + return metric_logger + return None + + def _get_or_create_state(self, layer_name: str, gemm: str) -> _GemmSwitchState: + key = (layer_name, gemm) + if key not in self._gemm_state: + self._gemm_state[key] = _GemmSwitchState() + return self._gemm_state[key] + + def _update_metric( + self, + layer_name: str, + gemm: str, + iteration: int, + tensor_name: str, + underflow_pct: float, + mse: float, + ) -> None: + """Store the latest quality metric for a `(layer, gemm)` pair.""" + metric_logger = self._get_metrics_logger() + if metric_logger is not None: + metric_logger.log_scalar( + layer_name, gemm, f"{tensor_name}_underflow_pct", iteration, underflow_pct + ) + metric_logger.log_scalar(layer_name, gemm, f"{tensor_name}_mse", iteration, mse) + + key = (layer_name, gemm) + entry = self._latest_metrics.get(key) + + if entry is None or int(entry["iteration"]) < iteration: + self._latest_metrics[key] = { + "iteration": iteration, + "underflow_pct": underflow_pct, + "mse": mse, + "tensor_name": tensor_name, + } + return + + if int(entry["iteration"]) == iteration: + if underflow_pct >= float(entry["underflow_pct"]): + entry["underflow_pct"] = underflow_pct + entry["tensor_name"] = tensor_name + entry["mse"] = max(float(entry["mse"]), mse) + + @staticmethod + def _dequantize_like( + quantized_tensor, + dtype: torch.dtype, + shape: torch.Size, + ) -> Optional[torch.Tensor]: + """Best-effort dequantization helper used for quality metrics.""" + if quantized_tensor is None or not hasattr(quantized_tensor, "dequantize"): + return None + + try: + dequantized = quantized_tensor.dequantize(dtype=dtype) + except TypeError: + dequantized = quantized_tensor.dequantize() + if dequantized.dtype != dtype: + dequantized = dequantized.to(dtype) + + if dequantized.shape != shape: + expected_numel = 1 + for dim in shape: + expected_numel *= int(dim) + if dequantized.numel() != expected_numel: + return None + dequantized = dequantized.reshape(shape) + return dequantized + + @staticmethod + def _compute_metrics( + tensor: Optional[torch.Tensor], + quantized_tensor, + ) -> Optional[Tuple[float, float]]: + """Compute underflow percentage and MSE for one tensor.""" + if tensor is None or tensor.numel() == 0: + return None + + if not tensor.is_floating_point(): + return None + + dequantized = AutoswitchGemm._dequantize_like(quantized_tensor, tensor.dtype, tensor.shape) + if dequantized is None: + return None + + tensor_fp32 = tensor.float() + dequantized_fp32 = dequantized.float() + + underflow_count = torch.count_nonzero((tensor_fp32 != 0) & (dequantized_fp32 == 0)) + underflow_pct = (underflow_count.float() * 100.0 / tensor_fp32.numel()).item() + + mse = torch.mean((tensor_fp32 - dequantized_fp32) ** 2).item() + return underflow_pct, mse + + def _consume_new_metric_and_maybe_arm_switch( + self, + layer_name: str, + gemm: str, + iteration: int, + config: Dict, + state: _GemmSwitchState, + ) -> None: + """Consume current-iteration metrics and arm switch for this iteration only.""" + metric = self._latest_metrics.get((layer_name, gemm)) + if metric is None: + return + + metric_iter = int(metric["iteration"]) + if metric_iter != iteration: + # Autoswitch consumes metrics only in the same iteration they were produced. + return + + metric_snapshot = ( + metric_iter, + float(metric["underflow_pct"]), + float(metric["mse"]), + str(metric["tensor_name"]), + ) + if metric_snapshot == state.last_applied_metric_snapshot: + return + state.last_applied_metric_snapshot = metric_snapshot + + underflow_threshold = self._config_float( + config, "underflow_threshold_pct", self._DEFAULT_UNDERFLOW_THRESHOLD_PCT + ) + mse_threshold = self._config_float(config, "mse_threshold", self._DEFAULT_MSE_THRESHOLD) + + reasons = [] + metric_underflow = float(metric["underflow_pct"]) + metric_mse = float(metric["mse"]) + + if underflow_threshold is not None and metric_underflow > underflow_threshold: + reasons.append( + f"underflow={metric_underflow:.4f}% > threshold={underflow_threshold:.4f}%" + ) + if mse_threshold is not None and metric_mse > mse_threshold: + reasons.append(f"mse={metric_mse:.6e} > threshold={mse_threshold:.6e}") + + if not reasons: + return + + state.disable_until_iter = iteration + state.last_reason = "; ".join(reasons) + + debug_api.log_message( + f"Feature={self.__class__.__name__}: switch {gemm} to high precision in" + f" iter={iteration}. Triggered by {metric['tensor_name']} at iter={metric_iter}:" + f" {state.last_reason}", + layer_name, + extra_cachable_args=(gemm, "switch"), + ) + + @api_method + def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): + """Decide whether selected GEMM should run quantized (True) or high precision (False).""" + state = self._get_or_create_state(layer_name, gemm) + metric_logger = self._get_metrics_logger() + + fp8_model_params_layer = self._layer_has_fp8_model_params.get(layer_name, False) + allow_fp8_model_params_fallback = self._config_bool( + config, + "allow_fp8_model_params_dequantized_weight", + self._DEFAULT_ALLOW_FP8_MODEL_PARAMS_DEQUANTIZED_WEIGHT, + ) + + # With fp8 model parameters enabled, fprop/dgrad can switch to high precision + # only when dequantized fallback is explicitly enabled in config. + if gemm in {"fprop", "dgrad"} and fp8_model_params_layer and not allow_fp8_model_params_fallback: + state.disable_until_iter = -1 + if metric_logger is not None: + metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) + metric_logger.log_scalar( + layer_name, gemm, "switch_blocked_fp8_model_params", iteration, 1.0 + ) + debug_api.log_message( + f"Feature={self.__class__.__name__}: skip switch for {gemm} at" + f" iter={iteration} because fp8 model parameters are enabled.", + layer_name, + extra_cachable_args=(gemm, "skip_fp8_model_params"), + ) + return True, iteration + 1 + + if gemm in {"fprop", "dgrad"} and fp8_model_params_layer and allow_fp8_model_params_fallback: + if metric_logger is not None: + metric_logger.log_scalar( + layer_name, gemm, "fp8_model_params_dequantized_fallback", iteration, 1.0 + ) + debug_api.log_message( + f"Feature={self.__class__.__name__}: {gemm} allows fp8-model-params" + " dequantized-weight fallback.", + layer_name, + extra_cachable_args=(gemm, "fp8_model_params_dequantized_fallback"), + ) + + self._consume_new_metric_and_maybe_arm_switch(layer_name, gemm, iteration, config, state) + + if iteration <= state.disable_until_iter: + if metric_logger is not None: + metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 0.0) + metric_logger.log_scalar( + layer_name, gemm, "disable_until_iter", iteration, float(state.disable_until_iter) + ) + debug_api.log_message( + f"Feature={self.__class__.__name__}: {gemm} forced high precision at" + f" iter={iteration} (disable_until={state.disable_until_iter}).", + layer_name, + extra_cachable_args=(gemm, "high_precision"), + ) + return False, iteration + 1 + + if metric_logger is not None: + metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) + return True, iteration + 1 + + @api_method + def inspect_tensor_enabled( + self, + config: Dict, + layer_name: str, + tensor_name: str, + iteration: int, + ): # pylint: disable=unused-argument + """Enable metric collection according to the standard freq/start/end controls.""" + run_current, next_iter = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + return run_current, next_iter + + @api_method + def inspect_tensor( + self, + config: Dict, + layer_name: str, + tensor_name: str, + iteration: int, + tp_group: torch.distributed.ProcessGroup, # pylint: disable=unused-argument + tensor: Optional[torch.Tensor], + rowwise_quantized_tensor: Optional[torch.Tensor] = None, + columnwise_quantized_tensor: Optional[torch.Tensor] = None, + quantizer=None, # pylint: disable=unused-argument + tp_size: int = 1, # pylint: disable=unused-argument + ): + """Collect quantization quality metrics for autoswitch decisions.""" + if tensor_name == "weight" and tensor is None: + # Weight tensor unavailable in high precision indicates fp8 model params. + self._layer_has_fp8_model_params[layer_name] = True + + _ = config + gemms = self._TENSOR_TO_GEMMS.get(tensor_name, (None, None)) + + rowwise_gemm, columnwise_gemm = gemms + if rowwise_gemm is not None: + metrics = self._compute_metrics(tensor, rowwise_quantized_tensor) + if metrics is not None: + self._update_metric( + layer_name, rowwise_gemm, iteration, tensor_name, metrics[0], metrics[1] + ) + + if columnwise_gemm is not None: + metrics = self._compute_metrics(tensor, columnwise_quantized_tensor) + if metrics is not None: + self._update_metric( + layer_name, columnwise_gemm, iteration, tensor_name, metrics[0], metrics[1] + ) diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index ed5fdd4660..7d52f3a875 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -320,14 +320,21 @@ def quantize( self.parent_quantizer.set_usage(rowwise=True) rowwise_gemm_tensor, columnwise_gemm_tensor = None, None - if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + parent_has_quantized_usage = ( + self.parent_quantizer is not None + and (self.parent_quantizer.rowwise_usage or self.parent_quantizer.columnwise_usage) + ) + if ( + STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan] + and parent_has_quantized_usage + ): quantized_tensor = self.parent_quantizer(tensor) # if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized, # one tensor with columnwise=True and rowwise=True is computed # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. - if self.rowwise_tensor_plan == STANDARD_QUANTIZE: + if self.rowwise_tensor_plan == STANDARD_QUANTIZE and self.parent_quantizer.rowwise_usage: rowwise_gemm_tensor = quantized_tensor - if self.columnwise_tensor_plan == STANDARD_QUANTIZE: + if self.columnwise_tensor_plan == STANDARD_QUANTIZE and self.parent_quantizer.columnwise_usage: columnwise_gemm_tensor = quantized_tensor # 2. modify_tensor() is called, if it is used. @@ -562,25 +569,56 @@ def set_usage(self, rowwise: bool = None, columnwise: bool = None): if not self.output_tensor: self._update_parent_quantizer_usage() - def wrap_quantized_tensor(self, tensor: QuantizedTensor): + def wrap_quantized_tensor( + self, tensor: QuantizedTensor, dtype: Optional[torch.dtype] = None + ): """ Wraps the quantized tensor with the debug quantizer. It is used for weight tensors when fp8 model parameters are enabled. """ + if API_CALL_MODIFY in (self.rowwise_tensor_plan, self.columnwise_tensor_plan): + raise AssertionError( + "[NVTORCH INSPECT ERROR] Weight tensor with fp8 model parameters enabled cannot" + " be modified by modify_tensor()." + ) + + dequantized_weight = None - assert ( + def _get_dequantized_weight(): + nonlocal dequantized_weight + if dequantized_weight is None: + output_dtype = dtype if dtype is not None else tensor.dtype + try: + dequantized_weight = tensor.dequantize(dtype=output_dtype) + except TypeError: + dequantized_weight = tensor.dequantize() + if dequantized_weight.dtype != output_dtype: + dequantized_weight = dequantized_weight.to(output_dtype) + return dequantized_weight + + if ( self.rowwise_tensor_plan == STANDARD_QUANTIZE and self.columnwise_tensor_plan == STANDARD_QUANTIZE - ), ( - "[NVTORCH INSPECT ERROR] Weight tensor with fp8 model parameters enabled cannot be" - " modified by any feature." - ) + ): + rowwise_tensor = tensor + columnwise_tensor = tensor + inspect_source = None + else: + rowwise_tensor = ( + tensor if self.rowwise_tensor_plan == STANDARD_QUANTIZE else _get_dequantized_weight() + ) + columnwise_tensor = ( + tensor + if self.columnwise_tensor_plan == STANDARD_QUANTIZE + else _get_dequantized_weight() + ) + inspect_source = _get_dequantized_weight() - self._call_inspect_tensor_api(None, tensor, tensor) + self._call_inspect_tensor_api(inspect_source, rowwise_tensor, columnwise_tensor) return DebugQuantizedTensor( - rowwise_gemm_tensor=tensor, - columnwise_gemm_tensor=tensor, + rowwise_gemm_tensor=rowwise_tensor, + columnwise_gemm_tensor=columnwise_tensor, quantizer=self, layer_name=self.layer_name, tensor_name=self.tensor_name, @@ -676,7 +714,8 @@ def size(self, *args): def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None): """Update usage of the tensor.""" - if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor: + same_storage = self.rowwise_gemm_tensor is self.columnwise_gemm_tensor + if not same_storage: # If the same object is used both for rowwise and columnwise gemms, # there is no benefit in erasing the usage of one of them. # And there are scenarios when not deleting the usage of one of them is needed. @@ -687,9 +726,27 @@ def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None self.columnwise_gemm_tensor = None if isinstance(self.rowwise_gemm_tensor, QuantizedTensor): - self.rowwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage) + if same_storage: + rowwise_rowwise_usage = rowwise_usage + rowwise_columnwise_usage = columnwise_usage + else: + # Keep rowwise storage focused on rowwise path. + rowwise_rowwise_usage = rowwise_usage + rowwise_columnwise_usage = False if columnwise_usage is not None else None + self.rowwise_gemm_tensor.update_usage( + rowwise_rowwise_usage, rowwise_columnwise_usage + ) if isinstance(self.columnwise_gemm_tensor, QuantizedTensor): - self.columnwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage) + if same_storage: + columnwise_rowwise_usage = rowwise_usage + columnwise_columnwise_usage = columnwise_usage + else: + # Keep columnwise storage focused on columnwise path. + columnwise_rowwise_usage = False if rowwise_usage is not None else None + columnwise_columnwise_usage = columnwise_usage + self.columnwise_gemm_tensor.update_usage( + columnwise_rowwise_usage, columnwise_columnwise_usage + ) if rowwise_usage and self.rowwise_gemm_tensor is None: raise RuntimeError( diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 1b237ece29..3363578de6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1442,7 +1442,7 @@ def get_weight_workspace( ) if isinstance(quantizer, DebugQuantizer): - tensor = quantizer.wrap_quantized_tensor(tensor) + tensor = quantizer.wrap_quantized_tensor(tensor, dtype=workspace_dtype) return tensor From e0d16645a8623a1e370fb3c1ac138711321952a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 15 Apr 2026 03:41:10 +0000 Subject: [PATCH 02/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../debug/features/autoswitch_gemm.py | 18 ++++++++++--- .../debug/pytorch/debug_quantization.py | 27 ++++++++++--------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index 807947f627..f247807d1c 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -482,7 +482,11 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): # With fp8 model parameters enabled, fprop/dgrad can switch to high precision # only when dequantized fallback is explicitly enabled in config. - if gemm in {"fprop", "dgrad"} and fp8_model_params_layer and not allow_fp8_model_params_fallback: + if ( + gemm in {"fprop", "dgrad"} + and fp8_model_params_layer + and not allow_fp8_model_params_fallback + ): state.disable_until_iter = -1 if metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) @@ -497,7 +501,11 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): ) return True, iteration + 1 - if gemm in {"fprop", "dgrad"} and fp8_model_params_layer and allow_fp8_model_params_fallback: + if ( + gemm in {"fprop", "dgrad"} + and fp8_model_params_layer + and allow_fp8_model_params_fallback + ): if metric_logger is not None: metric_logger.log_scalar( layer_name, gemm, "fp8_model_params_dequantized_fallback", iteration, 1.0 @@ -515,7 +523,11 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): if metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 0.0) metric_logger.log_scalar( - layer_name, gemm, "disable_until_iter", iteration, float(state.disable_until_iter) + layer_name, + gemm, + "disable_until_iter", + iteration, + float(state.disable_until_iter), ) debug_api.log_message( f"Feature={self.__class__.__name__}: {gemm} forced high precision at" diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 7d52f3a875..3f499a02f9 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -320,9 +320,8 @@ def quantize( self.parent_quantizer.set_usage(rowwise=True) rowwise_gemm_tensor, columnwise_gemm_tensor = None, None - parent_has_quantized_usage = ( - self.parent_quantizer is not None - and (self.parent_quantizer.rowwise_usage or self.parent_quantizer.columnwise_usage) + parent_has_quantized_usage = self.parent_quantizer is not None and ( + self.parent_quantizer.rowwise_usage or self.parent_quantizer.columnwise_usage ) if ( STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan] @@ -332,9 +331,15 @@ def quantize( # if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized, # one tensor with columnwise=True and rowwise=True is computed # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. - if self.rowwise_tensor_plan == STANDARD_QUANTIZE and self.parent_quantizer.rowwise_usage: + if ( + self.rowwise_tensor_plan == STANDARD_QUANTIZE + and self.parent_quantizer.rowwise_usage + ): rowwise_gemm_tensor = quantized_tensor - if self.columnwise_tensor_plan == STANDARD_QUANTIZE and self.parent_quantizer.columnwise_usage: + if ( + self.columnwise_tensor_plan == STANDARD_QUANTIZE + and self.parent_quantizer.columnwise_usage + ): columnwise_gemm_tensor = quantized_tensor # 2. modify_tensor() is called, if it is used. @@ -569,9 +574,7 @@ def set_usage(self, rowwise: bool = None, columnwise: bool = None): if not self.output_tensor: self._update_parent_quantizer_usage() - def wrap_quantized_tensor( - self, tensor: QuantizedTensor, dtype: Optional[torch.dtype] = None - ): + def wrap_quantized_tensor(self, tensor: QuantizedTensor, dtype: Optional[torch.dtype] = None): """ Wraps the quantized tensor with the debug quantizer. It is used for weight tensors when fp8 model parameters are enabled. @@ -605,7 +608,9 @@ def _get_dequantized_weight(): inspect_source = None else: rowwise_tensor = ( - tensor if self.rowwise_tensor_plan == STANDARD_QUANTIZE else _get_dequantized_weight() + tensor + if self.rowwise_tensor_plan == STANDARD_QUANTIZE + else _get_dequantized_weight() ) columnwise_tensor = ( tensor @@ -733,9 +738,7 @@ def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None # Keep rowwise storage focused on rowwise path. rowwise_rowwise_usage = rowwise_usage rowwise_columnwise_usage = False if columnwise_usage is not None else None - self.rowwise_gemm_tensor.update_usage( - rowwise_rowwise_usage, rowwise_columnwise_usage - ) + self.rowwise_gemm_tensor.update_usage(rowwise_rowwise_usage, rowwise_columnwise_usage) if isinstance(self.columnwise_gemm_tensor, QuantizedTensor): if same_storage: columnwise_rowwise_usage = rowwise_usage From 2907537d4bd069369434b8d9b270691fdb7651f5 Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Wed, 15 Apr 2026 18:27:16 +0800 Subject: [PATCH 03/16] apply resolve_gemm_inputs_after_sampling before gemm --- .../debug/features/autoswitch_gemm.py | 36 ++-- .../debug/pytorch/gemm_runtime_hooks.py | 183 ++++++++++++++++++ transformer_engine/pytorch/module/linear.py | 35 ++++ 3 files changed, 238 insertions(+), 16 deletions(-) create mode 100644 transformer_engine/debug/pytorch/gemm_runtime_hooks.py diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index f247807d1c..b4a05662dd 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -169,9 +169,10 @@ class AutoswitchGemm(TEConfigAPIMapper): mse_threshold: float, default = 1e-4 Trigger switch to high precision if quantization MSE exceeds this value. - The switch decision is same-iteration only: - metrics computed at iteration `n` are consumed in iteration `n`. - There is no cross-iteration hold window. + The switch decision is same-iteration: + metrics computed at iteration `n` are consumed in iteration `n` + after all GEMM input tensors are prepared. + The switch is applied for one iteration. allow_fp8_model_params_dequantized_weight: bool, default = False If True, allows `fprop`/`dgrad` to switch to high precision even when @@ -417,14 +418,14 @@ def _consume_new_metric_and_maybe_arm_switch( config: Dict, state: _GemmSwitchState, ) -> None: - """Consume current-iteration metrics and arm switch for this iteration only.""" + """Consume current-iteration metrics and arm switch for one iteration.""" metric = self._latest_metrics.get((layer_name, gemm)) if metric is None: return metric_iter = int(metric["iteration"]) if metric_iter != iteration: - # Autoswitch consumes metrics only in the same iteration they were produced. + # Autoswitch consumes metrics only in the iteration they were produced. return metric_snapshot = ( @@ -461,14 +462,21 @@ def _consume_new_metric_and_maybe_arm_switch( debug_api.log_message( f"Feature={self.__class__.__name__}: switch {gemm} to high precision in" - f" iter={iteration}. Triggered by {metric['tensor_name']} at iter={metric_iter}:" + f" iter={iteration}. Triggered by {metric['tensor_name']} sampled at iter={metric_iter}:" f" {state.last_reason}", layer_name, extra_cachable_args=(gemm, "switch"), ) @api_method - def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): + def fp8_gemm_enabled( + self, + config, + layer_name: str, + gemm: str, + iteration: int, + final_decision: bool = False, + ): """Decide whether selected GEMM should run quantized (True) or high precision (False).""" state = self._get_or_create_state(layer_name, gemm) metric_logger = self._get_metrics_logger() @@ -488,7 +496,7 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): and not allow_fp8_model_params_fallback ): state.disable_until_iter = -1 - if metric_logger is not None: + if final_decision and metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) metric_logger.log_scalar( layer_name, gemm, "switch_blocked_fp8_model_params", iteration, 1.0 @@ -501,12 +509,8 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): ) return True, iteration + 1 - if ( - gemm in {"fprop", "dgrad"} - and fp8_model_params_layer - and allow_fp8_model_params_fallback - ): - if metric_logger is not None: + if gemm in {"fprop", "dgrad"} and fp8_model_params_layer and allow_fp8_model_params_fallback: + if final_decision and metric_logger is not None: metric_logger.log_scalar( layer_name, gemm, "fp8_model_params_dequantized_fallback", iteration, 1.0 ) @@ -520,7 +524,7 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): self._consume_new_metric_and_maybe_arm_switch(layer_name, gemm, iteration, config, state) if iteration <= state.disable_until_iter: - if metric_logger is not None: + if final_decision and metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 0.0) metric_logger.log_scalar( layer_name, @@ -537,7 +541,7 @@ def fp8_gemm_enabled(self, config, layer_name: str, gemm: str, iteration: int): ) return False, iteration + 1 - if metric_logger is not None: + if final_decision and metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) return True, iteration + 1 diff --git a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py new file mode 100644 index 0000000000..3d4ecf9866 --- /dev/null +++ b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py @@ -0,0 +1,183 @@ +"""Runtime GEMM hooks used by AutoswitchGemm.""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch + +from transformer_engine.debug.pytorch.debug_state import TEDebugState +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer +from transformer_engine.pytorch.utils import cast_if_needed + +_AUTOSWITCH_FEATURE_NAME = "AutoswitchGemm" +_AUTOSWITCH_ENABLED_CACHE = {} + + +def _is_fp8_debug_quantizer(quantizer: Optional[Quantizer]) -> bool: + """Return True for DebugQuantizer objects wrapping an FP8/NVFP4 quantizer.""" + return ( + quantizer is not None + and quantizer.__class__.__name__ == "DebugQuantizer" + and getattr(quantizer, "parent_quantizer", None) is not None + ) + + +def _feature_block_enabled(feature_config: Any) -> bool: + """Return whether an Autoswitch feature block is enabled.""" + if isinstance(feature_config, dict): + return bool(feature_config.get("enabled", True)) + if isinstance(feature_config, bool): + return feature_config + return feature_config is not None + + +def _contains_enabled_autoswitch(config: Any, visited: Optional[set] = None) -> bool: + """Recursively check whether config contains enabled AutoswitchGemm feature.""" + if visited is None: + visited = set() + obj_id = id(config) + if obj_id in visited: + return False + visited.add(obj_id) + + if isinstance(config, dict): + for key, value in config.items(): + if key == _AUTOSWITCH_FEATURE_NAME and _feature_block_enabled(value): + return True + for value in config.values(): + if _contains_enabled_autoswitch(value, visited): + return True + return False + + if isinstance(config, (list, tuple, set)): + for item in config: + if _contains_enabled_autoswitch(item, visited): + return True + return False + + return False + + +def _autoswitch_feature_enabled() -> bool: + """Best-effort detection for whether AutoswitchGemm is enabled in debug config.""" + try: + import nvdlfw_inspect.api as debug_api + except ImportError: + return False + + manager = getattr(debug_api, "DEBUG_MANAGER", None) + if manager is None: + return False + + manager_id = id(manager) + cached = _AUTOSWITCH_ENABLED_CACHE.get(manager_id) + if cached is not None: + return cached + + candidate_configs = [] + for attr in ( + "config", + "_config", + "debug_config", + "_debug_config", + "user_config", + "_user_config", + "raw_config", + "_raw_config", + ): + value = getattr(manager, attr, None) + if value is not None: + candidate_configs.append(value) + + for attr_name, value in getattr(manager, "__dict__", {}).items(): + if "config" in attr_name.lower() and value is not None: + candidate_configs.append(value) + + if not candidate_configs: + # Keep previous behavior if manager internals cannot be introspected. + _AUTOSWITCH_ENABLED_CACHE[manager_id] = True + return True + + enabled = any(_contains_enabled_autoswitch(config) for config in candidate_configs) + _AUTOSWITCH_ENABLED_CACHE[manager_id] = enabled + return enabled + + +def should_resolve_inputs_after_sampling( + lhs_quantizer: Optional[Quantizer], + rhs_quantizer: Optional[Quantizer], +) -> bool: + """Return True when runtime GEMM decision path should be applied.""" + if not (_is_fp8_debug_quantizer(lhs_quantizer) or _is_fp8_debug_quantizer(rhs_quantizer)): + return False + return _autoswitch_feature_enabled() + + +def _to_high_precision_gemm_input(tensor, dtype: torch.dtype): + """Convert GEMM input to high precision tensor if needed.""" + if hasattr(tensor, "get_tensor") and hasattr(tensor, "rowwise_gemm_tensor"): + rowwise_tensor = _to_high_precision_gemm_input(tensor.get_tensor(False), dtype) + columnwise_src = tensor.get_tensor(True) + if columnwise_src is tensor.get_tensor(False): + columnwise_tensor = rowwise_tensor + else: + columnwise_tensor = _to_high_precision_gemm_input(columnwise_src, dtype) + tensor.rowwise_gemm_tensor = rowwise_tensor + tensor.columnwise_gemm_tensor = columnwise_tensor + return tensor + + if dtype is None: + dtype = getattr(tensor, "dtype", None) + if isinstance(tensor, QuantizedTensorStorage): + if dtype is None: + return tensor.dequantize() + try: + return tensor.dequantize(dtype=dtype) + except TypeError: + return cast_if_needed(tensor.dequantize(), dtype) + if dtype is None: + return tensor + return cast_if_needed(tensor, dtype) + + +def resolve_gemm_inputs_after_sampling( + gemm_name: str, + lhs, + rhs, + lhs_quantizer: Optional[Quantizer], + rhs_quantizer: Optional[Quantizer], + target_dtype: torch.dtype, +): + """ + Make post-sampling GEMM precision decision and enforce OR logic across inputs. + + If any sampled input for this GEMM triggers high precision, both GEMM inputs are + converted to high precision tensors before kernel launch. + """ + layer_name = ( + getattr(lhs_quantizer, "layer_name", None) or getattr(rhs_quantizer, "layer_name", None) + ) + if layer_name is None: + return lhs, rhs + + try: + import nvdlfw_inspect.api as debug_api + except ImportError: + return lhs, rhs + + iteration = TEDebugState.get_iteration() + enabled_ret = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=layer_name, + gemm=gemm_name, + iteration=iteration, + final_decision=True, + ) + quantized_enabled = enabled_ret[0] if isinstance(enabled_ret, tuple) else enabled_ret + if quantized_enabled: + return lhs, rhs + + return ( + _to_high_precision_gemm_input(lhs, target_dtype), + _to_high_precision_gemm_input(rhs, target_dtype), + ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 8510f6cf8f..1dadb57f39 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -73,6 +73,10 @@ mark_not_offload, mark_activation_offload, ) +from ...debug.pytorch.gemm_runtime_hooks import ( + resolve_gemm_inputs_after_sampling, + should_resolve_inputs_after_sampling, +) from ...debug.pytorch.debug_state import TEDebugState __all__ = ["Linear"] @@ -335,6 +339,15 @@ def forward( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ + if debug and should_resolve_inputs_after_sampling(weight_quantizer, input_quantizer): + weightmat, inputmat_total = resolve_gemm_inputs_after_sampling( + "fprop", + weightmat, + inputmat_total, + weight_quantizer, + input_quantizer, + activation_dtype, + ) nvtx_range_push(f"{nvtx_label}.gemm") gemm_out, *_, reduce_scatter_out = general_gemm( weightmat, @@ -760,6 +773,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weight_for_dgrad = weight if isinstance(weight_for_dgrad, QuantizedTensorStorage): weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + if ctx.debug and should_resolve_inputs_after_sampling( + ctx.weight_quantizer, ctx.grad_output_quantizer + ): + weight_for_dgrad, grad_output = resolve_gemm_inputs_after_sampling( + "dgrad", + weight_for_dgrad, + grad_output, + ctx.weight_quantizer, + ctx.grad_output_quantizer, + ctx.activation_dtype, + ) gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, @@ -920,6 +944,17 @@ def wgrad_gemm( some advanced communication/compute overlapping. """ + if ctx.debug and should_resolve_inputs_after_sampling( + ctx.input_quantizer, ctx.grad_output_quantizer + ): + x, dy = resolve_gemm_inputs_after_sampling( + "wgrad", + x, + dy, + ctx.input_quantizer, + ctx.grad_output_quantizer, + ctx.activation_dtype, + ) nvtx_range_push(f"{nvtx_label}.wgrad_gemm") dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs) nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") From 13e14faea0d5a63b8f0edfcf4ff0ea47d5bbe747 Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Mon, 20 Apr 2026 17:11:23 +0800 Subject: [PATCH 04/16] Loosen the autogemm condition --- .../debug/features/autoswitch_gemm.py | 2 +- .../debug/pytorch/gemm_runtime_hooks.py | 91 +------------------ 2 files changed, 3 insertions(+), 90 deletions(-) diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index b4a05662dd..962fdb30a8 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -66,7 +66,7 @@ def initialize(self, root_log_dir: str) -> None: logger.removeHandler(handler) handler.close() - file_handler = logging.FileHandler(log_file, mode="a") + file_handler = logging.FileHandler(log_file, mode="w") file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) logger.addHandler(file_handler) diff --git a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py index 3d4ecf9866..0a0c1d1068 100644 --- a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py +++ b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Optional import torch @@ -10,10 +10,6 @@ from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer from transformer_engine.pytorch.utils import cast_if_needed -_AUTOSWITCH_FEATURE_NAME = "AutoswitchGemm" -_AUTOSWITCH_ENABLED_CACHE = {} - - def _is_fp8_debug_quantizer(quantizer: Optional[Quantizer]) -> bool: """Return True for DebugQuantizer objects wrapping an FP8/NVFP4 quantizer.""" return ( @@ -23,95 +19,12 @@ def _is_fp8_debug_quantizer(quantizer: Optional[Quantizer]) -> bool: ) -def _feature_block_enabled(feature_config: Any) -> bool: - """Return whether an Autoswitch feature block is enabled.""" - if isinstance(feature_config, dict): - return bool(feature_config.get("enabled", True)) - if isinstance(feature_config, bool): - return feature_config - return feature_config is not None - - -def _contains_enabled_autoswitch(config: Any, visited: Optional[set] = None) -> bool: - """Recursively check whether config contains enabled AutoswitchGemm feature.""" - if visited is None: - visited = set() - obj_id = id(config) - if obj_id in visited: - return False - visited.add(obj_id) - - if isinstance(config, dict): - for key, value in config.items(): - if key == _AUTOSWITCH_FEATURE_NAME and _feature_block_enabled(value): - return True - for value in config.values(): - if _contains_enabled_autoswitch(value, visited): - return True - return False - - if isinstance(config, (list, tuple, set)): - for item in config: - if _contains_enabled_autoswitch(item, visited): - return True - return False - - return False - - -def _autoswitch_feature_enabled() -> bool: - """Best-effort detection for whether AutoswitchGemm is enabled in debug config.""" - try: - import nvdlfw_inspect.api as debug_api - except ImportError: - return False - - manager = getattr(debug_api, "DEBUG_MANAGER", None) - if manager is None: - return False - - manager_id = id(manager) - cached = _AUTOSWITCH_ENABLED_CACHE.get(manager_id) - if cached is not None: - return cached - - candidate_configs = [] - for attr in ( - "config", - "_config", - "debug_config", - "_debug_config", - "user_config", - "_user_config", - "raw_config", - "_raw_config", - ): - value = getattr(manager, attr, None) - if value is not None: - candidate_configs.append(value) - - for attr_name, value in getattr(manager, "__dict__", {}).items(): - if "config" in attr_name.lower() and value is not None: - candidate_configs.append(value) - - if not candidate_configs: - # Keep previous behavior if manager internals cannot be introspected. - _AUTOSWITCH_ENABLED_CACHE[manager_id] = True - return True - - enabled = any(_contains_enabled_autoswitch(config) for config in candidate_configs) - _AUTOSWITCH_ENABLED_CACHE[manager_id] = enabled - return enabled - - def should_resolve_inputs_after_sampling( lhs_quantizer: Optional[Quantizer], rhs_quantizer: Optional[Quantizer], ) -> bool: """Return True when runtime GEMM decision path should be applied.""" - if not (_is_fp8_debug_quantizer(lhs_quantizer) or _is_fp8_debug_quantizer(rhs_quantizer)): - return False - return _autoswitch_feature_enabled() + return _is_fp8_debug_quantizer(lhs_quantizer) or _is_fp8_debug_quantizer(rhs_quantizer) def _to_high_precision_gemm_input(tensor, dtype: torch.dtype): From a1c4977ba9eeed235a5e00b1fa9152207d8335f4 Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Mon, 27 Apr 2026 19:20:54 +0800 Subject: [PATCH 05/16] autoswitch on current iterater --- .../debug/features/autoswitch_gemm.py | 19 +- .../debug/pytorch/gemm_runtime_hooks.py | 163 ++++++++++++++++-- transformer_engine/pytorch/module/linear.py | 8 +- 3 files changed, 173 insertions(+), 17 deletions(-) diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index 962fdb30a8..218a158823 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -66,7 +66,7 @@ def initialize(self, root_log_dir: str) -> None: logger.removeHandler(handler) handler.close() - file_handler = logging.FileHandler(log_file, mode="w") + file_handler = logging.FileHandler(log_file, mode="a") file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) logger.addHandler(file_handler) @@ -481,6 +481,11 @@ def fp8_gemm_enabled( state = self._get_or_create_state(layer_name, gemm) metric_logger = self._get_metrics_logger() + # Keep plan-time behavior quantized. Autoswitch decisions are applied only + # at final decision points right before GEMM launch. + if not final_decision: + return True, iteration + 1 + fp8_model_params_layer = self._layer_has_fp8_model_params.get(layer_name, False) allow_fp8_model_params_fallback = self._config_bool( config, @@ -545,6 +550,18 @@ def fp8_gemm_enabled( metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) return True, iteration + 1 + @api_method + def modify_tensor_enabled( + self, + config: Dict, # pylint: disable=unused-argument + layer_name: str, # pylint: disable=unused-argument + gemm: str, # pylint: disable=unused-argument + tensor_name: str, # pylint: disable=unused-argument + iteration: int, # pylint: disable=unused-argument + ): + """AutoswitchGemm does not participate in modify_tensor routing.""" + return False, None + @api_method def inspect_tensor_enabled( self, diff --git a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py index 0a0c1d1068..a4f0e03ba8 100644 --- a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py +++ b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py @@ -2,6 +2,9 @@ from __future__ import annotations +import copy +import os +from datetime import datetime from typing import Optional import torch @@ -29,16 +32,28 @@ def should_resolve_inputs_after_sampling( def _to_high_precision_gemm_input(tensor, dtype: torch.dtype): """Convert GEMM input to high precision tensor if needed.""" + if tensor is None: + return None + if hasattr(tensor, "get_tensor") and hasattr(tensor, "rowwise_gemm_tensor"): - rowwise_tensor = _to_high_precision_gemm_input(tensor.get_tensor(False), dtype) - columnwise_src = tensor.get_tensor(True) - if columnwise_src is tensor.get_tensor(False): + # Clone wrapper before replacing internals to avoid mutating cached/reused + # DebugQuantizedTensor objects across multiple GEMM calls in one step. + tensor_copy = copy.copy(tensor) + + # Convert both GEMM views explicitly once autoswitch requests high precision. + # This avoids mixed rowwise/columnwise dtypes when a later GEMM selects + # the opposite view. + rowwise_src = getattr(tensor, "rowwise_gemm_tensor", None) + columnwise_src = getattr(tensor, "columnwise_gemm_tensor", None) + rowwise_tensor = _to_high_precision_gemm_input(rowwise_src, dtype) + columnwise_tensor = _to_high_precision_gemm_input(columnwise_src, dtype) + if rowwise_tensor is None: + rowwise_tensor = columnwise_tensor + if columnwise_tensor is None: columnwise_tensor = rowwise_tensor - else: - columnwise_tensor = _to_high_precision_gemm_input(columnwise_src, dtype) - tensor.rowwise_gemm_tensor = rowwise_tensor - tensor.columnwise_gemm_tensor = columnwise_tensor - return tensor + tensor_copy.rowwise_gemm_tensor = rowwise_tensor + tensor_copy.columnwise_gemm_tensor = columnwise_tensor + return tensor_copy if dtype is None: dtype = getattr(tensor, "dtype", None) @@ -54,6 +69,118 @@ def _to_high_precision_gemm_input(tensor, dtype: torch.dtype): return cast_if_needed(tensor, dtype) +def _parent_quantizer(quantizer: Optional[Quantizer]) -> Optional[Quantizer]: + """Return the quantizer that performs real quantization.""" + if quantizer is None: + return None + parent = getattr(quantizer, "parent_quantizer", None) + return parent if parent is not None else quantizer + + +def _to_quantized_gemm_input(tensor, quantizer: Optional[Quantizer], dtype: torch.dtype): + """Convert GEMM input to a quantized DebugQuantizedTensor-compatible object.""" + if tensor is None: + return None + + if hasattr(tensor, "get_tensor") and hasattr(tensor, "rowwise_gemm_tensor"): + tensor_copy = copy.copy(tensor) + rowwise_src = getattr(tensor, "rowwise_gemm_tensor", None) + columnwise_src = getattr(tensor, "columnwise_gemm_tensor", None) + rowwise_tensor = _to_quantized_gemm_input(rowwise_src, quantizer, dtype) + if columnwise_src is rowwise_src: + columnwise_tensor = rowwise_tensor + else: + columnwise_tensor = _to_quantized_gemm_input(columnwise_src, quantizer, dtype) + if rowwise_tensor is None: + rowwise_tensor = columnwise_tensor + if columnwise_tensor is None: + columnwise_tensor = rowwise_tensor + tensor_copy.rowwise_gemm_tensor = rowwise_tensor + tensor_copy.columnwise_gemm_tensor = columnwise_tensor + return tensor_copy + + if isinstance(tensor, QuantizedTensorStorage): + return tensor + + quantizer = _parent_quantizer(quantizer) + if quantizer is None or not isinstance(tensor, torch.Tensor): + return tensor + + # Use an isolated quantizer copy so runtime coercion does not perturb module state. + quantizer = quantizer.copy() if hasattr(quantizer, "copy") else copy.copy(quantizer) + quantizer.set_usage(rowwise=True, columnwise=True) + return quantizer(cast_if_needed(tensor, dtype)) + + +def _selected_gemm_tensor(tensor, transpose: bool): + """Return the tensor view that general_gemm will pass to the backend.""" + if hasattr(tensor, "get_tensor") and hasattr(tensor, "rowwise_gemm_tensor"): + return tensor.get_tensor(transpose) + return tensor + + +def _is_quantized_gemm_tensor(tensor) -> bool: + """Return True if the selected GEMM operand is a quantized tensor.""" + return isinstance(tensor, QuantizedTensorStorage) + + +def _selected_transposes_for_gemm(gemm_name: str) -> tuple[bool, bool]: + """Return DebugQuantizedTensor view selection for known TE GEMM layouts.""" + # general_gemm selects A.get_tensor(not transa) and B.get_tensor(transb). + # Linear fprop uses TN, dgrad uses NN, and wgrad uses NT. + if gemm_name == "fprop": + return False, False + if gemm_name == "dgrad": + return True, False + if gemm_name == "wgrad": + return True, True + return False, False + + +def _selected_gemm_quantization_state(gemm_name: str, lhs, rhs) -> tuple[bool, bool]: + """Return whether the actual selected GEMM operands are quantized.""" + lhs_transpose, rhs_transpose = _selected_transposes_for_gemm(gemm_name) + lhs_tensor = _selected_gemm_tensor(lhs, lhs_transpose) + rhs_tensor = _selected_gemm_tensor(rhs, rhs_transpose) + return _is_quantized_gemm_tensor(lhs_tensor), _is_quantized_gemm_tensor(rhs_tensor) + + +def _log_final_gemm_decision( + layer_name: str, + gemm_name: str, + iteration: int, + quantized_enabled: bool, + lhs_quantized: bool, + rhs_quantized: bool, +) -> None: + """Write final AutoswitchGemm decision to the autoswitch rank-local log.""" + rank = os.getenv("RANK", "0") + if rank != "0": + return + try: + from nvdlfw_inspect.logging import get_logger + + root_log_dir = getattr(get_logger(), "root_log_dir", None) + except Exception: # pylint: disable=broad-except + root_log_dir = None + if not root_log_dir: + return + + log_dir = os.path.join(root_log_dir, "nvdlfw_inspect_autoswitchgemm_logs") + log_file = os.path.join(log_dir, f"nvdlfw_inspect_globalrank-{rank}.log") + os.makedirs(log_dir, exist_ok=True) + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] + message = ( + f"{timestamp} - INFO - {layer_name}_{gemm_name}_final_decision " + f"\t\t\t\t iteration={iteration:06d} " + f"\t\t\t\t quantized_enabled={int(bool(quantized_enabled))} " + f"lhs_quantized={int(lhs_quantized)} " + f"rhs_quantized={int(rhs_quantized)}" + ) + with open(log_file, mode="a", encoding="utf-8") as log: + log.write(message + "\n") + + def resolve_gemm_inputs_after_sampling( gemm_name: str, lhs, @@ -88,9 +215,19 @@ def resolve_gemm_inputs_after_sampling( ) quantized_enabled = enabled_ret[0] if isinstance(enabled_ret, tuple) else enabled_ret if quantized_enabled: - return lhs, rhs - - return ( - _to_high_precision_gemm_input(lhs, target_dtype), - _to_high_precision_gemm_input(rhs, target_dtype), + lhs_out = _to_quantized_gemm_input(lhs, lhs_quantizer, target_dtype) + rhs_out = _to_quantized_gemm_input(rhs, rhs_quantizer, target_dtype) + else: + lhs_out = _to_high_precision_gemm_input(lhs, target_dtype) + rhs_out = _to_high_precision_gemm_input(rhs, target_dtype) + + lhs_quantized, rhs_quantized = _selected_gemm_quantization_state(gemm_name, lhs_out, rhs_out) + _log_final_gemm_decision( + layer_name, + gemm_name, + iteration, + bool(quantized_enabled), + lhs_quantized, + rhs_quantized, ) + return lhs_out, rhs_out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1dadb57f39..bc282956f0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -339,8 +339,10 @@ def forward( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ + fprop_weightmat = weightmat + fprop_inputmat_total = inputmat_total if debug and should_resolve_inputs_after_sampling(weight_quantizer, input_quantizer): - weightmat, inputmat_total = resolve_gemm_inputs_after_sampling( + fprop_weightmat, fprop_inputmat_total = resolve_gemm_inputs_after_sampling( "fprop", weightmat, inputmat_total, @@ -350,8 +352,8 @@ def forward( ) nvtx_range_push(f"{nvtx_label}.gemm") gemm_out, *_, reduce_scatter_out = general_gemm( - weightmat, - inputmat_total, + fprop_weightmat, + fprop_inputmat_total, quantization_params=output_quantizer, out_dtype=activation_dtype, bias=bias, From c42dac0675fecff7adc5c25e4b3a1d8728d3f0ec Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Tue, 12 May 2026 14:31:13 +0800 Subject: [PATCH 06/16] support cuda graph --- .../debug/features/autoswitch_gemm.py | 86 ++++++++++++++++++- .../debug/pytorch/gemm_runtime_hooks.py | 30 ++++++- .../pytorch/module/grouped_linear.py | 71 +++++++++++++-- .../pytorch/module/layernorm_linear.py | 47 +++++++++- 4 files changed, 219 insertions(+), 15 deletions(-) diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index 218a158823..0c54ad8e8b 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -7,6 +7,7 @@ import copy import logging import os +import weakref from typing import Dict, Optional, Set, Tuple import torch @@ -20,6 +21,64 @@ from transformer_engine.debug.features.utils import next_enabled_iter +_AUTOSWITCH_FEATURE_INSTANCES = weakref.WeakSet() +_AUTOSWITCH_SAMPLING_CONFIGS = [] +_AUTOSWITCH_SAMPLING_CONFIG_KEYS = set() +_AUTOSWITCH_DISABLE_UNTIL_BY_GEMM = {} + + +def _register_sampling_config(config: Dict) -> None: + """Track AutoswitchGemm sampling schedules for runtime eager/graph routing.""" + schedule = { + "start_step": config.get("start_step", None), + "end_step": config.get("end_step", None), + "start_end_list": config.get("start_end_list", None), + "freq": config.get("freq", 1), + } + key = repr(schedule) + if key not in _AUTOSWITCH_SAMPLING_CONFIG_KEYS: + _AUTOSWITCH_SAMPLING_CONFIG_KEYS.add(key) + _AUTOSWITCH_SAMPLING_CONFIGS.append(schedule) + + +def _is_sampling_iteration(iteration: int) -> bool: + """Return True if any AutoswitchGemm config samples on this iteration.""" + for schedule in _AUTOSWITCH_SAMPLING_CONFIGS: + run_current, _ = next_enabled_iter( + schedule["start_step"], + schedule["end_step"], + schedule["start_end_list"], + schedule["freq"], + iteration, + ) + if run_current: + return True + return False + + +def autoswitch_gemm_should_force_eager(iteration: Optional[int] = None) -> bool: + """ + Return True when AutoswitchGemm needs eager execution for the whole iteration. + + This is used by Megatron CUDA graph routing. Sampling iterations must be eager + so tensor inspection can run; high-precision windows must also be eager because + the captured graph represents the stable quantized path. + """ + if iteration is None: + try: + from transformer_engine.debug.pytorch.debug_state import TEDebugState + + iteration = TEDebugState.get_iteration() + except Exception: # pylint: disable=broad-except + return False + + if _is_sampling_iteration(iteration): + return True + + max_disable_until = max(_AUTOSWITCH_DISABLE_UNTIL_BY_GEMM.values(), default=-1) + return iteration <= max_disable_until + + class _AutoswitchGemmMetricLogger: """Writes per-rank autoswitch metrics to a dedicated log file.""" @@ -172,7 +231,7 @@ class AutoswitchGemm(TEConfigAPIMapper): The switch decision is same-iteration: metrics computed at iteration `n` are consumed in iteration `n` after all GEMM input tensors are prepared. - The switch is applied for one iteration. + The switch is applied until the next sampling period. allow_fp8_model_params_dequantized_weight: bool, default = False If True, allows `fprop`/`dgrad` to switch to high precision even when @@ -225,6 +284,7 @@ def __init__(self, *args, **kwargs): self._gemm_state: Dict[Tuple[str, str], _GemmSwitchState] = {} self._latest_metrics: Dict[Tuple[str, str], Dict[str, float | int | str]] = {} self._layer_has_fp8_model_params: Dict[str, bool] = {} + _AUTOSWITCH_FEATURE_INSTANCES.add(self) def parse_config_and_api(self, config, **kwargs): """ @@ -258,6 +318,7 @@ def parse_config_and_api(self, config, **kwargs): if "enabled" in processed_config: processed_config.pop("enabled") + _register_sampling_config(processed_config) return True, processed_config def _infer_monitored_tensors(self, config: Dict) -> Set[str]: @@ -302,6 +363,15 @@ def _config_bool(config: Dict, key: str, default: bool) -> bool: return value.strip().lower() in ("1", "true", "yes", "on") return bool(value) + @staticmethod + def _config_positive_int(config: Dict, key: str, default: int) -> int: + """Read positive int value from config.""" + try: + value = int(config.get(key, default)) + except (TypeError, ValueError): + value = default + return max(1, value) + @staticmethod def _get_root_log_dir() -> Optional[str]: """Best-effort retrieval of nvdlfw_inspect root log directory.""" @@ -418,7 +488,7 @@ def _consume_new_metric_and_maybe_arm_switch( config: Dict, state: _GemmSwitchState, ) -> None: - """Consume current-iteration metrics and arm switch for one iteration.""" + """Consume current-iteration metrics and arm switch until the next sampling period.""" metric = self._latest_metrics.get((layer_name, gemm)) if metric is None: return @@ -455,14 +525,21 @@ def _consume_new_metric_and_maybe_arm_switch( reasons.append(f"mse={metric_mse:.6e} > threshold={mse_threshold:.6e}") if not reasons: + # A fresh sample without threshold breach clears any currently active switch. + state.disable_until_iter = min(state.disable_until_iter, iteration - 1) + _AUTOSWITCH_DISABLE_UNTIL_BY_GEMM[(layer_name, gemm)] = state.disable_until_iter + state.last_reason = "" return - state.disable_until_iter = iteration + hold_steps = self._config_positive_int(config, "freq", 1) + state.disable_until_iter = iteration + hold_steps - 1 + _AUTOSWITCH_DISABLE_UNTIL_BY_GEMM[(layer_name, gemm)] = state.disable_until_iter state.last_reason = "; ".join(reasons) debug_api.log_message( f"Feature={self.__class__.__name__}: switch {gemm} to high precision in" - f" iter={iteration}. Triggered by {metric['tensor_name']} sampled at iter={metric_iter}:" + f" iter={iteration} through iter={state.disable_until_iter}. Triggered by" + f" {metric['tensor_name']} sampled at iter={metric_iter}:" f" {state.last_reason}", layer_name, extra_cachable_args=(gemm, "switch"), @@ -501,6 +578,7 @@ def fp8_gemm_enabled( and not allow_fp8_model_params_fallback ): state.disable_until_iter = -1 + _AUTOSWITCH_DISABLE_UNTIL_BY_GEMM[(layer_name, gemm)] = state.disable_until_iter if final_decision and metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) metric_logger.log_scalar( diff --git a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py index a4f0e03ba8..6a7c8e0b27 100644 --- a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py +++ b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py @@ -77,6 +77,19 @@ def _parent_quantizer(quantizer: Optional[Quantizer]) -> Optional[Quantizer]: return parent if parent is not None else quantizer +def _can_quantize(tensor, quantizer: Optional[Quantizer]) -> bool: + """Return whether a tensor can be quantized by this quantizer.""" + if quantizer is None or not isinstance(tensor, torch.Tensor): + return False + is_quantizable = getattr(quantizer, "is_quantizable", None) + if callable(is_quantizable): + try: + return bool(is_quantizable(tensor)) + except Exception: # pylint: disable=broad-except + return False + return True + + def _to_quantized_gemm_input(tensor, quantizer: Optional[Quantizer], dtype: torch.dtype): """Convert GEMM input to a quantized DebugQuantizedTensor-compatible object.""" if tensor is None: @@ -105,11 +118,14 @@ def _to_quantized_gemm_input(tensor, quantizer: Optional[Quantizer], dtype: torc quantizer = _parent_quantizer(quantizer) if quantizer is None or not isinstance(tensor, torch.Tensor): return tensor + tensor = cast_if_needed(tensor, dtype) + if not _can_quantize(tensor, quantizer): + return tensor # Use an isolated quantizer copy so runtime coercion does not perturb module state. quantizer = quantizer.copy() if hasattr(quantizer, "copy") else copy.copy(quantizer) quantizer.set_usage(rowwise=True, columnwise=True) - return quantizer(cast_if_needed(tensor, dtype)) + return quantizer(tensor) def _selected_gemm_tensor(tensor, transpose: bool): @@ -152,6 +168,7 @@ def _log_final_gemm_decision( quantized_enabled: bool, lhs_quantized: bool, rhs_quantized: bool, + actual_precision: str, ) -> None: """Write final AutoswitchGemm decision to the autoswitch rank-local log.""" rank = os.getenv("RANK", "0") @@ -170,10 +187,13 @@ def _log_final_gemm_decision( log_file = os.path.join(log_dir, f"nvdlfw_inspect_globalrank-{rank}.log") os.makedirs(log_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] + requested_precision = "fp8" if quantized_enabled else "bf16" message = ( f"{timestamp} - INFO - {layer_name}_{gemm_name}_final_decision " f"\t\t\t\t iteration={iteration:06d} " f"\t\t\t\t quantized_enabled={int(bool(quantized_enabled))} " + f"requested_precision={requested_precision} " + f"precision={actual_precision} " f"lhs_quantized={int(lhs_quantized)} " f"rhs_quantized={int(rhs_quantized)}" ) @@ -222,6 +242,13 @@ def resolve_gemm_inputs_after_sampling( rhs_out = _to_high_precision_gemm_input(rhs, target_dtype) lhs_quantized, rhs_quantized = _selected_gemm_quantization_state(gemm_name, lhs_out, rhs_out) + if quantized_enabled and not (lhs_quantized and rhs_quantized): + lhs_out = _to_high_precision_gemm_input(lhs_out, target_dtype) + rhs_out = _to_high_precision_gemm_input(rhs_out, target_dtype) + lhs_quantized, rhs_quantized = _selected_gemm_quantization_state( + gemm_name, lhs_out, rhs_out + ) + actual_precision = "fp8" if lhs_quantized and rhs_quantized else "bf16" _log_final_gemm_decision( layer_name, gemm_name, @@ -229,5 +256,6 @@ def resolve_gemm_inputs_after_sampling( bool(quantized_enabled), lhs_quantized, rhs_quantized, + actual_precision, ) return lhs_out, rhs_out diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 2cce6c3ef8..31d7e6f26f 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -53,6 +53,14 @@ ) from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.gemm_runtime_hooks import ( + resolve_gemm_inputs_after_sampling, + should_resolve_inputs_after_sampling, +) +from ...debug.pytorch.gemm_runtime_hooks import ( + resolve_gemm_inputs_after_sampling, + should_resolve_inputs_after_sampling, +) __all__ = ["GroupedLinear"] @@ -215,9 +223,22 @@ def forward( use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator # Perform GEMM + fprop_weights = list(weights_fp8) + fprop_inputmats = list(inputmats) + if debug: + for i in range(num_gemms): + if should_resolve_inputs_after_sampling(weight_quantizers[i], input_quantizers[i]): + fprop_weights[i], fprop_inputmats[i] = resolve_gemm_inputs_after_sampling( + "fprop", + weights_fp8[i], + inputmats[i], + weight_quantizers[i], + input_quantizers[i], + activation_dtype, + ) general_grouped_gemm( - weights_fp8, - inputmats, + fprop_weights, + fprop_inputmats, [out], output_quantizers, activation_dtype, @@ -457,9 +478,26 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], for weight in weights_for_dgrad: if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) + dgrad_weights = list(weights_for_dgrad) + dgrad_grad_outputs = list(grad_output) + if ctx.debug: + for i in range(N): + if should_resolve_inputs_after_sampling( + ctx.weight_quantizers[i], ctx.grad_output_quantizers[i] + ): + dgrad_weights[i], dgrad_grad_outputs[i] = ( + resolve_gemm_inputs_after_sampling( + "dgrad", + weights_for_dgrad[i], + grad_output[i], + ctx.weight_quantizers[i], + ctx.grad_output_quantizers[i], + ctx.activation_dtype, + ) + ) general_grouped_gemm( - weights_for_dgrad, - grad_output, + dgrad_weights, + dgrad_grad_outputs, [dgrad], ctx.grad_input_quantizers, ctx.activation_dtype, @@ -537,6 +575,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: inputmats_dequant.append(cast_if_needed(inputmat, ctx.activation_dtype)) inputmats = inputmats_dequant + wgrad_inputmats = list(inputmats) + wgrad_grad_outputs = list(grad_output) + if ctx.debug: + for i in range(N): + if should_resolve_inputs_after_sampling( + ctx.input_quantizers[i], ctx.grad_output_quantizers[i] + ): + wgrad_inputmats[i], wgrad_grad_outputs[i] = ( + resolve_gemm_inputs_after_sampling( + "wgrad", + inputmats[i], + grad_output[i], + ctx.input_quantizers[i], + ctx.grad_output_quantizers[i], + ctx.activation_dtype, + ) + ) grouped_gemm_wgrad = functools.partial( general_grouped_gemm, quantization_params=ctx.grad_weight_quantizers, @@ -555,9 +610,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) # WGRAD if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): - ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad) + ctx.wgrad_store.put( + [wgrad_inputmats, wgrad_grad_outputs, wgrad_list], grouped_gemm_wgrad + ) else: - _, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list) + _, grad_biases_, _ = grouped_gemm_wgrad( + wgrad_inputmats, wgrad_grad_outputs, wgrad_list + ) for i in range(ctx.num_gemms): if grad_biases[i] is None: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index dc021ca6b7..ce61ca695a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -63,6 +63,10 @@ restore_from_func_ctx, ) from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.gemm_runtime_hooks import ( + resolve_gemm_inputs_after_sampling, + should_resolve_inputs_after_sampling, +) from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..cpu_offload import ( is_cpu_offload_enabled, @@ -365,10 +369,21 @@ def forward( # Forward GEMM # Note: y = x * w^T # ------------------------------------------------------ + fprop_weightmat = weightmat + fprop_ln_out_total = ln_out_total + if debug and should_resolve_inputs_after_sampling(weight_quantizer, input_quantizer): + fprop_weightmat, fprop_ln_out_total = resolve_gemm_inputs_after_sampling( + "fprop", + weightmat, + ln_out_total, + weight_quantizer, + input_quantizer, + activation_dtype, + ) nvtx_range_push(f"{nvtx_label}.gemm") gemm_out, *_, reduce_scatter_out = general_gemm( - weightmat, - ln_out_total, + fprop_weightmat, + fprop_ln_out_total, quantization_params=output_quantizer, out_dtype=activation_dtype, bias=bias, @@ -772,9 +787,22 @@ def backward( weight_for_dgrad = origin_weight if isinstance(weight_for_dgrad, QuantizedTensorStorage): weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + weight_for_dgrad_gemm = weight_for_dgrad + grad_output_for_dgrad = grad_output + if ctx.debug and should_resolve_inputs_after_sampling( + ctx.weight_quantizer, ctx.grad_output_quantizer + ): + weight_for_dgrad_gemm, grad_output_for_dgrad = resolve_gemm_inputs_after_sampling( + "dgrad", + weight_for_dgrad, + grad_output, + ctx.weight_quantizer, + ctx.grad_output_quantizer, + ctx.activation_dtype, + ) gemm_out, *_, reduce_scatter_out = general_gemm( - weight_for_dgrad, - grad_output, + weight_for_dgrad_gemm, + grad_output_for_dgrad, layout="NN", grad=True, quantization_params=ctx.grad_input_quantizer, @@ -933,6 +961,17 @@ def wgrad_gemm( some advanced communication/compute overlapping. """ + if ctx.debug and should_resolve_inputs_after_sampling( + ctx.input_quantizer, ctx.grad_output_quantizer + ): + x, dy = resolve_gemm_inputs_after_sampling( + "wgrad", + x, + dy, + ctx.input_quantizer, + ctx.grad_output_quantizer, + ctx.activation_dtype, + ) nvtx_range_push(f"{nvtx_label}.wgrad_gemm") dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs) nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") From d80f096e41fe6dba0cac7dd1ed17fefc0d217bf6 Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Wed, 13 May 2026 15:12:19 +0800 Subject: [PATCH 07/16] print actual precision --- .../debug/pytorch/gemm_runtime_hooks.py | 56 ++++++++++++++++++- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py index 6a7c8e0b27..ce885e916b 100644 --- a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py +++ b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py @@ -140,6 +140,44 @@ def _is_quantized_gemm_tensor(tensor) -> bool: return isinstance(tensor, QuantizedTensorStorage) +def _precision_name_from_class_name(class_name: str) -> str: + """Map quantizer/tensor class names to user-facing precision labels.""" + lowered = class_name.lower() + if "mxfp8" in lowered: + return "mxfp8" + if "nvfp4" in lowered: + return "nvfp4" + if "float8blockwise" in lowered or "blockwise" in lowered: + return "fp8_blockwise" + if "float8" in lowered or "fp8" in lowered: + return "fp8" + return "quantized" + + +def _precision_name_from_quantizer(quantizer: Optional[Quantizer]) -> str: + """Return requested quantized precision based on the underlying quantizer.""" + quantizer = _parent_quantizer(quantizer) + if quantizer is None: + return "quantized" + return _precision_name_from_class_name(quantizer.__class__.__name__) + + +def _precision_name_from_tensor(tensor) -> str: + """Return actual precision based on the selected GEMM operand.""" + if isinstance(tensor, QuantizedTensorStorage): + return _precision_name_from_class_name(tensor.__class__.__name__) + dtype = getattr(tensor, "dtype", None) + if dtype is None: + return "unknown" + if dtype == torch.bfloat16: + return "bf16" + if dtype == torch.float16: + return "fp16" + if dtype == torch.float32: + return "fp32" + return str(dtype).replace("torch.", "") + + def _selected_transposes_for_gemm(gemm_name: str) -> tuple[bool, bool]: """Return DebugQuantizedTensor view selection for known TE GEMM layouts.""" # general_gemm selects A.get_tensor(not transa) and B.get_tensor(transb). @@ -161,6 +199,16 @@ def _selected_gemm_quantization_state(gemm_name: str, lhs, rhs) -> tuple[bool, b return _is_quantized_gemm_tensor(lhs_tensor), _is_quantized_gemm_tensor(rhs_tensor) +def _selected_gemm_precision(gemm_name: str, lhs, rhs) -> str: + """Return the actual precision label for selected GEMM operands.""" + lhs_transpose, rhs_transpose = _selected_transposes_for_gemm(gemm_name) + lhs_precision = _precision_name_from_tensor(_selected_gemm_tensor(lhs, lhs_transpose)) + rhs_precision = _precision_name_from_tensor(_selected_gemm_tensor(rhs, rhs_transpose)) + if lhs_precision == rhs_precision: + return lhs_precision + return f"{lhs_precision}+{rhs_precision}" + + def _log_final_gemm_decision( layer_name: str, gemm_name: str, @@ -168,6 +216,7 @@ def _log_final_gemm_decision( quantized_enabled: bool, lhs_quantized: bool, rhs_quantized: bool, + requested_precision: str, actual_precision: str, ) -> None: """Write final AutoswitchGemm decision to the autoswitch rank-local log.""" @@ -187,7 +236,6 @@ def _log_final_gemm_decision( log_file = os.path.join(log_dir, f"nvdlfw_inspect_globalrank-{rank}.log") os.makedirs(log_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] - requested_precision = "fp8" if quantized_enabled else "bf16" message = ( f"{timestamp} - INFO - {layer_name}_{gemm_name}_final_decision " f"\t\t\t\t iteration={iteration:06d} " @@ -234,6 +282,9 @@ def resolve_gemm_inputs_after_sampling( final_decision=True, ) quantized_enabled = enabled_ret[0] if isinstance(enabled_ret, tuple) else enabled_ret + requested_precision = ( + _precision_name_from_quantizer(lhs_quantizer) if quantized_enabled else "bf16" + ) if quantized_enabled: lhs_out = _to_quantized_gemm_input(lhs, lhs_quantizer, target_dtype) rhs_out = _to_quantized_gemm_input(rhs, rhs_quantizer, target_dtype) @@ -248,7 +299,7 @@ def resolve_gemm_inputs_after_sampling( lhs_quantized, rhs_quantized = _selected_gemm_quantization_state( gemm_name, lhs_out, rhs_out ) - actual_precision = "fp8" if lhs_quantized and rhs_quantized else "bf16" + actual_precision = _selected_gemm_precision(gemm_name, lhs_out, rhs_out) _log_final_gemm_decision( layer_name, gemm_name, @@ -256,6 +307,7 @@ def resolve_gemm_inputs_after_sampling( bool(quantized_enabled), lhs_quantized, rhs_quantized, + requested_precision, actual_precision, ) return lhs_out, rhs_out From ebfc70b010354c360ef4ceefcb97e6eb763f9f14 Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Wed, 13 May 2026 17:54:46 +0800 Subject: [PATCH 08/16] fix cuda graph --- .../debug/features/autoswitch_gemm.py | 67 +++++++++++++++++-- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index 0c54ad8e8b..f526655ae9 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -25,6 +25,8 @@ _AUTOSWITCH_SAMPLING_CONFIGS = [] _AUTOSWITCH_SAMPLING_CONFIG_KEYS = set() _AUTOSWITCH_DISABLE_UNTIL_BY_GEMM = {} +_AUTOSWITCH_CONFIG_FILE_LOADED = False +_AUTOSWITCH_DISABLE_UNTIL_ENV = "NVTE_AUTOSWITCH_GEMM_DISABLE_UNTIL" def _register_sampling_config(config: Dict) -> None: @@ -41,8 +43,47 @@ def _register_sampling_config(config: Dict) -> None: _AUTOSWITCH_SAMPLING_CONFIGS.append(schedule) +def _register_sampling_configs_from_file() -> None: + """Best-effort loading of AutoswitchGemm sampling schedules from NVDFW config.""" + global _AUTOSWITCH_CONFIG_FILE_LOADED + if _AUTOSWITCH_CONFIG_FILE_LOADED: + return + _AUTOSWITCH_CONFIG_FILE_LOADED = True + + config_file = os.getenv("NVDFW_CONFIG_FILE") + if not config_file or not os.path.exists(config_file): + return + + try: + import yaml + + with open(config_file, encoding="utf-8") as config_stream: + config = yaml.safe_load(config_stream) + except Exception: # pylint: disable=broad-except + return + + def _walk(node): + if isinstance(node, dict): + transformer_engine_config = node.get("transformer_engine") + if isinstance(transformer_engine_config, dict): + autoswitch_config = transformer_engine_config.get("AutoswitchGemm") + if isinstance(autoswitch_config, dict) and autoswitch_config.get( + "enabled", True + ): + _register_sampling_config(autoswitch_config) + for value in node.values(): + _walk(value) + elif isinstance(node, list): + for value in node: + _walk(value) + + _walk(config) + + def _is_sampling_iteration(iteration: int) -> bool: """Return True if any AutoswitchGemm config samples on this iteration.""" + if not _AUTOSWITCH_SAMPLING_CONFIGS: + _register_sampling_configs_from_file() for schedule in _AUTOSWITCH_SAMPLING_CONFIGS: run_current, _ = next_enabled_iter( schedule["start_step"], @@ -56,6 +97,23 @@ def _is_sampling_iteration(iteration: int) -> bool: return False +def _update_global_disable_until(layer_name: str, gemm: str, disable_until_iter: int) -> None: + """Update process-wide AutoswitchGemm high-precision window state.""" + _AUTOSWITCH_DISABLE_UNTIL_BY_GEMM[(layer_name, gemm)] = disable_until_iter + max_disable_until = max(_AUTOSWITCH_DISABLE_UNTIL_BY_GEMM.values(), default=-1) + os.environ[_AUTOSWITCH_DISABLE_UNTIL_ENV] = str(max_disable_until) + + +def _get_global_disable_until() -> int: + """Return process-wide high-precision window end iteration.""" + max_disable_until = max(_AUTOSWITCH_DISABLE_UNTIL_BY_GEMM.values(), default=-1) + try: + env_disable_until = int(os.getenv(_AUTOSWITCH_DISABLE_UNTIL_ENV, "-1")) + except ValueError: + env_disable_until = -1 + return max(max_disable_until, env_disable_until) + + def autoswitch_gemm_should_force_eager(iteration: Optional[int] = None) -> bool: """ Return True when AutoswitchGemm needs eager execution for the whole iteration. @@ -75,8 +133,7 @@ def autoswitch_gemm_should_force_eager(iteration: Optional[int] = None) -> bool: if _is_sampling_iteration(iteration): return True - max_disable_until = max(_AUTOSWITCH_DISABLE_UNTIL_BY_GEMM.values(), default=-1) - return iteration <= max_disable_until + return iteration <= _get_global_disable_until() class _AutoswitchGemmMetricLogger: @@ -527,13 +584,13 @@ def _consume_new_metric_and_maybe_arm_switch( if not reasons: # A fresh sample without threshold breach clears any currently active switch. state.disable_until_iter = min(state.disable_until_iter, iteration - 1) - _AUTOSWITCH_DISABLE_UNTIL_BY_GEMM[(layer_name, gemm)] = state.disable_until_iter + _update_global_disable_until(layer_name, gemm, state.disable_until_iter) state.last_reason = "" return hold_steps = self._config_positive_int(config, "freq", 1) state.disable_until_iter = iteration + hold_steps - 1 - _AUTOSWITCH_DISABLE_UNTIL_BY_GEMM[(layer_name, gemm)] = state.disable_until_iter + _update_global_disable_until(layer_name, gemm, state.disable_until_iter) state.last_reason = "; ".join(reasons) debug_api.log_message( @@ -578,7 +635,7 @@ def fp8_gemm_enabled( and not allow_fp8_model_params_fallback ): state.disable_until_iter = -1 - _AUTOSWITCH_DISABLE_UNTIL_BY_GEMM[(layer_name, gemm)] = state.disable_until_iter + _update_global_disable_until(layer_name, gemm, state.disable_until_iter) if final_decision and metric_logger is not None: metric_logger.log_scalar(layer_name, gemm, "quantized_enabled", iteration, 1.0) metric_logger.log_scalar( From 4c94bcf9a953cfbcb1a19270b8ea8946951b423e Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Thu, 14 May 2026 10:45:08 +0800 Subject: [PATCH 09/16] update docs --- docs/debug/1_getting_started.rst | 54 ++++++++++++++---- docs/debug/2_config_file_structure.rst | 40 +++++++++++++- docs/debug/autoswitch_gemm_example.yaml | 73 ++++++++++--------------- 3 files changed, 108 insertions(+), 59 deletions(-) diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index ac36acf990..158a3b9afb 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -219,33 +219,53 @@ AutoswitchGemm quick guide -------------------------- ``AutoswitchGemm`` monitors quantization quality and can dynamically switch selected GEMMs -to high precision when thresholds are exceeded. +to high precision when thresholds are exceeded. It supports the normal FP8 paths as well +as block-scaled formats such as FP8 blockwise and MXFP8, as long as the selected TE module +routes the GEMM through the AutoswitchGemm runtime hooks. -Minimal config example: +Example config matching attention and MLP linears: .. code-block:: yaml - autoswitch_fc_layers: + log_tensor_stats_all: enabled: True layers: - layer_types: [fc1, fc2] + layer_types: [linear_qkv, linear_proj, linear_fc1, linear_fc2] transformer_engine: + LogTensorStats: + enabled: True + stats: [max, min, mean, std, dynamic_range, cur_amax] + tensors: [activation, gradient, weight] + freq: 10 + start_step: 10 AutoswitchGemm: enabled: True gemms: [fprop, dgrad, wgrad] - underflow_threshold_pct: 1.0 - mse_threshold: 1.0e-4 - # Needed only if the layer uses fp8 model parameters and - # you want fprop/dgrad to be able to switch to high precision. - allow_fp8_model_params_dequantized_weight: False - freq: 1 + tensors: [activation, weight, gradient] + underflow_threshold_pct: 5 + mse_threshold: 0.1 + allow_fp8_model_params_dequantized_weight: True + freq: 10 + start_step: 10 Behavior summary: 1. For each ``(layer, gemm)``, AutoswitchGemm tracks the latest tensor metrics and applies OR logic across monitored tensors: if any tensor breaches thresholds, that GEMM switches. -2. Metrics computed in iteration ``n`` are consumed in iteration ``n`` only. -3. If thresholds are not breached in the current iteration, the GEMM stays quantized. +2. Sampling is controlled by ``start_step``, ``end_step`` / ``start_end_list``, and + ``freq``. For example, ``start_step: 10`` and ``freq: 10`` samples at steps + 10, 20, 30, ... +3. A threshold breach at sampling step ``n`` keeps the affected ``(layer, gemm)`` in + high precision through ``n + freq - 1``. The next sampling step refreshes the + decision; if thresholds are not breached, the GEMM returns to quantized execution. +4. If model parameters are stored in a quantized format, set + ``allow_fp8_model_params_dequantized_weight: True`` to allow ``fprop`` and + ``dgrad`` to switch by using temporary dequantized weights. +5. When CUDA Graphs are used, sampling and high-precision windows must run in eager + mode. Quantized windows can continue using CUDA Graphs if the training framework + supports this routing. Megatron-LM support for this workflow depends on the + ``autogemm`` branch: + https://github.com/shangxiaokang/Megatron-LM/tree/autogemm When AutoswitchGemm is enabled, an additional directory is created under ``log_dir``: @@ -259,6 +279,16 @@ It contains per-rank, per-iteration metrics such as: - ``__disable_until_iter`` - ``__switch_blocked_fp8_model_params`` - ``__fp8_model_params_dequantized_fallback`` +- ``__final_decision`` with fields such as + ``requested_precision``, ``precision``, ``lhs_quantized``, and ``rhs_quantized``. + +A typical Megatron-LM launch exports the debug config and log directory: + +.. code-block:: bash + + export ENABLE_NVDFW_INSPECT=1 + export NVDFW_CONFIG_FILE=/path/to/nvdlfw_inspect_30b.yaml + export NVDFW_LOG_DIR=/path/to/output/nvdlfw_logs Logging using TensorBoard ------------------------- diff --git a/docs/debug/2_config_file_structure.rst b/docs/debug/2_config_file_structure.rst index 28da6beab3..394f66bc5c 100644 --- a/docs/debug/2_config_file_structure.rst +++ b/docs/debug/2_config_file_structure.rst @@ -238,9 +238,45 @@ Other important keys: - ``underflow_threshold_pct``: switch trigger based on underflow percentage. - ``mse_threshold``: switch trigger based on quantization MSE. -- metrics are consumed in the same iteration where they are computed. +- ``freq``: sampling interval. A sampled threshold breach at iteration ``n`` keeps + that ``(layer, gemm)`` in high precision through ``n + freq - 1``. +- ``start_step`` / ``end_step`` / ``start_end_list``: sampling windows. If ``end_step`` + is omitted, sampling continues according to ``freq`` after ``start_step``. - ``allow_fp8_model_params_dequantized_weight``: allows ``fprop``/``dgrad`` switching - for layers with FP8 model parameters by using dequantized temporary weights. + for layers with quantized model parameters by using temporary dequantized weights. +- ``AutoswitchGemm`` should use the same ``freq`` / sampling window as companion + tensor-inspection features such as ``LogTensorStats`` when they share the same + layers and tensors. + +Example for attention and MLP linear layers: + +.. code-block:: yaml + + log_tensor_stats_all: + enabled: True + layers: + layer_types: [linear_qkv, linear_proj, linear_fc1, linear_fc2] + transformer_engine: + LogTensorStats: + enabled: True + stats: [max, min, mean, std, dynamic_range, cur_amax] + tensors: [activation, gradient, weight] + freq: 10 + start_step: 10 + AutoswitchGemm: + enabled: True + gemms: [fprop, dgrad, wgrad] + tensors: [activation, weight, gradient] + underflow_threshold_pct: 5 + mse_threshold: 0.1 + allow_fp8_model_params_dequantized_weight: True + freq: 10 + start_step: 10 + +For CUDA Graph training, sampling and high-precision windows must be executed in eager +mode. Quantized windows may continue to use CUDA Graphs if the training framework routes +them separately. The Megatron-LM integration used by this example depends on: +https://github.com/shangxiaokang/Megatron-LM/tree/autogemm Enabling or Disabling Sections and Features ------------------------------------------- diff --git a/docs/debug/autoswitch_gemm_example.yaml b/docs/debug/autoswitch_gemm_example.yaml index c24462a67e..9b2fc11e3e 100644 --- a/docs/debug/autoswitch_gemm_example.yaml +++ b/docs/debug/autoswitch_gemm_example.yaml @@ -10,63 +10,46 @@ # ... # debug_api.step() # call once per training step -autoswitch_attention_blocks: +log_tensor_stats_all: enabled: True layers: - # Match attention linear layers, e.g. *.qkv / *.proj - layer_name_regex_pattern: ".*(qkv|proj).*" + # Names may be inferred by Megatron/TE. This matches attention linears and + # common MLP/MoE linears used by Qwen3-style models. + layer_types: [linear_qkv, linear_proj, linear_fc1, linear_fc2] transformer_engine: - AutoswitchGemm: + LogTensorStats: enabled: True - - # Optional. If omitted, tensors are inferred from selected gemms: - # fprop -> [activation, weight], dgrad -> [gradient, weight], - # wgrad -> [activation, gradient]. - tensors: [activation, weight, gradient] - - # Per-GEMM switching policy. - gemms_struct: - - gemm: fprop - underflow_threshold_pct: 1.0 - mse_threshold: 1.0e-4 - - gemm: dgrad - underflow_threshold_pct: 1.5 - mse_threshold: 1.5e-4 - - gemm: wgrad - underflow_threshold_pct: 2.0 - mse_threshold: 2.0e-4 - - # For layers with fp8 model parameters: - # - False: keep fprop/dgrad quantized - # - True: allow high-precision switch via temporary dequantized weights - allow_fp8_model_params_dequantized_weight: False - - # Collect metrics every step after warmup. - freq: 1 + stats: [max, min, mean, std, dynamic_range, cur_amax] + tensors: [activation, gradient, weight] + # Match AutoswitchGemm's schedule when both features share the same + # inspect_tensor_enabled API calls. + freq: 10 start_step: 10 - end_step: 5000 - -autoswitch_mlp_blocks: - enabled: True - layers: - layer_types: [fc1, fc2] - transformer_engine: AutoswitchGemm: enabled: True - # Simpler global policy (shared by selected GEMMs). - gemms: [fprop, wgrad] + # Enable all GEMM paths. If tensors are omitted, AutoswitchGemm infers: + # fprop -> [activation, weight] + # dgrad -> [gradient, weight] + # wgrad -> [activation, gradient] + gemms: [fprop, dgrad, wgrad] tensors: [activation, weight, gradient] - underflow_threshold_pct: 3.0 - mse_threshold: 3.0e-4 + # Switch to high precision when any monitored tensor for the GEMM + # exceeds either threshold. + underflow_threshold_pct: 5 + mse_threshold: 0.1 + + # If model parameters are stored in a quantized format, fprop/dgrad can + # switch to high precision by using temporary dequantized weights. + allow_fp8_model_params_dequantized_weight: True - # Example sparse monitoring windows. - freq: 2 - start_end_list: - - [0, 300] - - [800, 3000] + # Start sampling at step 10, then sample every 10 steps. A threshold + # breach at step N keeps that (layer, GEMM) in high precision through + # step N + freq - 1. The next sampling step refreshes the decision. + freq: 10 + start_step: 10 # Autoswitch per-rank metrics are written to: # /nvdlfw_inspect_autoswitchgemm_logs/nvdlfw_inspect_globalrank-.log From 1dd13a92ce33b027baaf23ec553a58b3ecea910c Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Thu, 14 May 2026 13:51:25 +0800 Subject: [PATCH 10/16] fix nvpf4 dequantize --- transformer_engine/debug/pytorch/gemm_runtime_hooks.py | 10 ++++++++-- transformer_engine/pytorch/module/grouped_linear.py | 4 +++- transformer_engine/pytorch/module/layernorm_linear.py | 4 +++- transformer_engine/pytorch/module/linear.py | 4 +++- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py index ce885e916b..3e792ef6bd 100644 --- a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py +++ b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py @@ -47,6 +47,8 @@ def _to_high_precision_gemm_input(tensor, dtype: torch.dtype): columnwise_src = getattr(tensor, "columnwise_gemm_tensor", None) rowwise_tensor = _to_high_precision_gemm_input(rowwise_src, dtype) columnwise_tensor = _to_high_precision_gemm_input(columnwise_src, dtype) + if rowwise_tensor is None and columnwise_tensor is None: + return tensor if rowwise_tensor is None: rowwise_tensor = columnwise_tensor if columnwise_tensor is None: @@ -58,12 +60,16 @@ def _to_high_precision_gemm_input(tensor, dtype: torch.dtype): if dtype is None: dtype = getattr(tensor, "dtype", None) if isinstance(tensor, QuantizedTensorStorage): - if dtype is None: - return tensor.dequantize() try: + if dtype is None: + return tensor.dequantize() return tensor.dequantize(dtype=dtype) except TypeError: return cast_if_needed(tensor.dequantize(), dtype) + except NotImplementedError as err: + if "column-wise NVFP4" in str(err): + return None + raise if dtype is None: return tensor return cast_if_needed(tensor, dtype) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 31d7e6f26f..181d575d4c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -272,7 +272,9 @@ def forward( else: for inputmat in inputmats: if isinstance(inputmat, QuantizedTensorStorage): - if backward_override is not None: + if debug: + inputmat.update_usage(rowwise_usage=True, columnwise_usage=True) + elif backward_override is not None: # In dequantized mode we should dequantize directly from # fprop quantized layouts without retargeting usage. inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ce61ca695a..9cf2a6224f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -449,7 +449,9 @@ def forward( # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data # can be allgathered. - if ( + if debug: + ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) + elif ( isinstance(ln_out, (MXFP8TensorStorage, Float8BlockwiseQTensorStorage)) or not ctx.ln_out_needs_gather ): diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index bc282956f0..6fce8a6a78 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -418,7 +418,9 @@ def forward( and own_quantized_input and isinstance(inputmat, QuantizedTensorStorage) ): - if backward_override is not None: + if debug: + inputmat.update_usage(rowwise_usage=True, columnwise_usage=True) + elif backward_override is not None: # In dequantized mode we should dequantize directly from the # fprop quantized tensor layout without retargeting usage. inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) From a34c86d1a91c0d1741d0b7fb29ae2b7da2465ebc Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Thu, 14 May 2026 15:23:02 +0800 Subject: [PATCH 11/16] fix cuda graph --- .../debug/pytorch/debug_quantization.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 3f499a02f9..e7210a28e2 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -40,6 +40,18 @@ HIGH_PRECISION = "High Precision" +def _autoswitch_gemm_runtime_decision_active() -> bool: + """Return True when AutoswitchGemm needs per-GEMM final decisions this iteration.""" + try: + from transformer_engine.debug.features.autoswitch_gemm import ( + autoswitch_gemm_should_force_eager, + ) + + return bool(autoswitch_gemm_should_force_eager(TEDebugState.get_iteration())) + except Exception: # pylint: disable=broad-except + return False + + class DebugQuantizer(Quantizer): """ DebugQuantizer is a Quantizer object used for debugging with nvidia-dlframework-inspect. @@ -425,6 +437,10 @@ def any_feature_enabled(self) -> bool: """Returns bool if there is at least one API call enabled.""" if self.output_tensor: return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY + if self.parent_quantizer is not None and _autoswitch_gemm_runtime_decision_active(): + # AutoswitchGemm may need final precision decisions during the hold window even when + # inspect_tensor is disabled for this non-sampling iteration. + return True # pylint: disable=too-many-boolean-expressions if ( self.inspect_tensor_enabled From 588c96ce2ab1d02542ad84758e7fadae08ef5e48 Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Wed, 20 May 2026 14:19:19 +0800 Subject: [PATCH 12/16] compatible with finial_decision --- .../debug/pytorch/gemm_runtime_hooks.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py index 3e792ef6bd..c3d2e4b393 100644 --- a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py +++ b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py @@ -281,12 +281,21 @@ def resolve_gemm_inputs_after_sampling( return lhs, rhs iteration = TEDebugState.get_iteration() - enabled_ret = debug_api.transformer_engine.fp8_gemm_enabled( - layer_name=layer_name, - gemm=gemm_name, - iteration=iteration, - final_decision=True, - ) + try: + enabled_ret = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=layer_name, + gemm=gemm_name, + iteration=iteration, + final_decision=True, + ) + except TypeError as err: + if "final_decision" not in str(err): + raise + enabled_ret = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=layer_name, + gemm=gemm_name, + iteration=iteration, + ) quantized_enabled = enabled_ret[0] if isinstance(enabled_ret, tuple) else enabled_ret requested_precision = ( _precision_name_from_quantizer(lhs_quantizer) if quantized_enabled else "bf16" From 42f1a56c04b28bf12bb9c3b1aa7d69cb668c192e Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Wed, 20 May 2026 15:03:43 +0800 Subject: [PATCH 13/16] set allow_fp8_model_params_dequantized_weight true --- transformer_engine/debug/features/autoswitch_gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index f526655ae9..00c91a1247 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -290,7 +290,7 @@ class AutoswitchGemm(TEConfigAPIMapper): after all GEMM input tensors are prepared. The switch is applied until the next sampling period. - allow_fp8_model_params_dequantized_weight: bool, default = False + allow_fp8_model_params_dequantized_weight: bool, default = True If True, allows `fprop`/`dgrad` to switch to high precision even when fp8 model parameters are enabled by using a temporary dequantized weight tensor for GEMM execution. @@ -334,7 +334,7 @@ class AutoswitchGemm(TEConfigAPIMapper): _DEFAULT_UNDERFLOW_THRESHOLD_PCT = 5.0 _DEFAULT_MSE_THRESHOLD = 1e-4 - _DEFAULT_ALLOW_FP8_MODEL_PARAMS_DEQUANTIZED_WEIGHT = False + _DEFAULT_ALLOW_FP8_MODEL_PARAMS_DEQUANTIZED_WEIGHT = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) From 7a702bc9a206d7738facb8796b6a832416a01f86 Mon Sep 17 00:00:00 2001 From: Xiaokang Shang Date: Wed, 10 Jun 2026 15:24:09 +0800 Subject: [PATCH 14/16] control logging with NVTE_AUTOSWITCH_GEMM_LOGGING --- .../debug/features/autoswitch_gemm.py | 100 ++++++++++++------ .../debug/pytorch/gemm_runtime_hooks.py | 23 ++++ 2 files changed, 91 insertions(+), 32 deletions(-) diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index 00c91a1247..387d6c06c0 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -27,10 +27,19 @@ _AUTOSWITCH_DISABLE_UNTIL_BY_GEMM = {} _AUTOSWITCH_CONFIG_FILE_LOADED = False _AUTOSWITCH_DISABLE_UNTIL_ENV = "NVTE_AUTOSWITCH_GEMM_DISABLE_UNTIL" +_AUTOSWITCH_LOGGING_ENV = "NVTE_AUTOSWITCH_GEMM_LOGGING" +_AUTOSWITCH_LOGGING_ENABLED = False + + +def _env_flag_enabled(name: str, default: bool = False) -> bool: + """Interpret common boolean environment flag values.""" + default_value = "1" if default else "0" + return os.getenv(name, default_value).strip().lower() in {"1", "true", "yes", "on"} def _register_sampling_config(config: Dict) -> None: """Track AutoswitchGemm sampling schedules for runtime eager/graph routing.""" + global _AUTOSWITCH_LOGGING_ENABLED schedule = { "start_step": config.get("start_step", None), "end_step": config.get("end_step", None), @@ -41,6 +50,13 @@ def _register_sampling_config(config: Dict) -> None: if key not in _AUTOSWITCH_SAMPLING_CONFIG_KEYS: _AUTOSWITCH_SAMPLING_CONFIG_KEYS.add(key) _AUTOSWITCH_SAMPLING_CONFIGS.append(schedule) + if any(bool(config.get(key, False)) for key in ("log_metrics", "log_decisions", "verbose")): + _AUTOSWITCH_LOGGING_ENABLED = True + + +def autoswitch_gemm_logging_enabled() -> bool: + """Return True when verbose AutoswitchGemm logging is globally enabled.""" + return _AUTOSWITCH_LOGGING_ENABLED or _env_flag_enabled(_AUTOSWITCH_LOGGING_ENV, False) def _register_sampling_configs_from_file() -> None: @@ -438,8 +454,18 @@ def _get_root_log_dir() -> Optional[str]: return None return root_log_dir - def _get_metrics_logger(self) -> Optional[_AutoswitchGemmMetricLogger]: + def _logging_enabled(self, config: Optional[Dict] = None) -> bool: + """Return True when verbose AutoswitchGemm logging is enabled.""" + if config is not None: + for key in ("log_metrics", "log_decisions", "verbose"): + if key in config: + return self._config_bool(config, key, False) + return autoswitch_gemm_logging_enabled() + + def _get_metrics_logger(self, config: Optional[Dict] = None) -> Optional[_AutoswitchGemmMetricLogger]: """Return initialized autoswitch metric logger if log dir is available.""" + if not self._logging_enabled(config): + return None metric_logger = _get_autoswitch_metric_logger() if metric_logger.ensure_initialized(self._get_root_log_dir()): return metric_logger @@ -459,9 +485,10 @@ def _update_metric( tensor_name: str, underflow_pct: float, mse: float, + config: Optional[Dict] = None, ) -> None: """Store the latest quality metric for a `(layer, gemm)` pair.""" - metric_logger = self._get_metrics_logger() + metric_logger = self._get_metrics_logger(config) if metric_logger is not None: metric_logger.log_scalar( layer_name, gemm, f"{tensor_name}_underflow_pct", iteration, underflow_pct @@ -593,14 +620,15 @@ def _consume_new_metric_and_maybe_arm_switch( _update_global_disable_until(layer_name, gemm, state.disable_until_iter) state.last_reason = "; ".join(reasons) - debug_api.log_message( - f"Feature={self.__class__.__name__}: switch {gemm} to high precision in" - f" iter={iteration} through iter={state.disable_until_iter}. Triggered by" - f" {metric['tensor_name']} sampled at iter={metric_iter}:" - f" {state.last_reason}", - layer_name, - extra_cachable_args=(gemm, "switch"), - ) + if self._logging_enabled(config): + debug_api.log_message( + f"Feature={self.__class__.__name__}: switch {gemm} to high precision in" + f" iter={iteration} through iter={state.disable_until_iter}. Triggered by" + f" {metric['tensor_name']} sampled at iter={metric_iter}:" + f" {state.last_reason}", + layer_name, + extra_cachable_args=(gemm, "switch"), + ) @api_method def fp8_gemm_enabled( @@ -613,7 +641,7 @@ def fp8_gemm_enabled( ): """Decide whether selected GEMM should run quantized (True) or high precision (False).""" state = self._get_or_create_state(layer_name, gemm) - metric_logger = self._get_metrics_logger() + metric_logger = self._get_metrics_logger(config) # Keep plan-time behavior quantized. Autoswitch decisions are applied only # at final decision points right before GEMM launch. @@ -641,12 +669,13 @@ def fp8_gemm_enabled( metric_logger.log_scalar( layer_name, gemm, "switch_blocked_fp8_model_params", iteration, 1.0 ) - debug_api.log_message( - f"Feature={self.__class__.__name__}: skip switch for {gemm} at" - f" iter={iteration} because fp8 model parameters are enabled.", - layer_name, - extra_cachable_args=(gemm, "skip_fp8_model_params"), - ) + if self._logging_enabled(config): + debug_api.log_message( + f"Feature={self.__class__.__name__}: skip switch for {gemm} at" + f" iter={iteration} because fp8 model parameters are enabled.", + layer_name, + extra_cachable_args=(gemm, "skip_fp8_model_params"), + ) return True, iteration + 1 if gemm in {"fprop", "dgrad"} and fp8_model_params_layer and allow_fp8_model_params_fallback: @@ -654,12 +683,13 @@ def fp8_gemm_enabled( metric_logger.log_scalar( layer_name, gemm, "fp8_model_params_dequantized_fallback", iteration, 1.0 ) - debug_api.log_message( - f"Feature={self.__class__.__name__}: {gemm} allows fp8-model-params" - " dequantized-weight fallback.", - layer_name, - extra_cachable_args=(gemm, "fp8_model_params_dequantized_fallback"), - ) + if self._logging_enabled(config): + debug_api.log_message( + f"Feature={self.__class__.__name__}: {gemm} allows fp8-model-params" + " dequantized-weight fallback.", + layer_name, + extra_cachable_args=(gemm, "fp8_model_params_dequantized_fallback"), + ) self._consume_new_metric_and_maybe_arm_switch(layer_name, gemm, iteration, config, state) @@ -673,12 +703,13 @@ def fp8_gemm_enabled( iteration, float(state.disable_until_iter), ) - debug_api.log_message( - f"Feature={self.__class__.__name__}: {gemm} forced high precision at" - f" iter={iteration} (disable_until={state.disable_until_iter}).", - layer_name, - extra_cachable_args=(gemm, "high_precision"), - ) + if self._logging_enabled(config): + debug_api.log_message( + f"Feature={self.__class__.__name__}: {gemm} forced high precision at" + f" iter={iteration} (disable_until={state.disable_until_iter}).", + layer_name, + extra_cachable_args=(gemm, "high_precision"), + ) return False, iteration + 1 if final_decision and metric_logger is not None: @@ -734,7 +765,6 @@ def inspect_tensor( # Weight tensor unavailable in high precision indicates fp8 model params. self._layer_has_fp8_model_params[layer_name] = True - _ = config gemms = self._TENSOR_TO_GEMMS.get(tensor_name, (None, None)) rowwise_gemm, columnwise_gemm = gemms @@ -742,12 +772,18 @@ def inspect_tensor( metrics = self._compute_metrics(tensor, rowwise_quantized_tensor) if metrics is not None: self._update_metric( - layer_name, rowwise_gemm, iteration, tensor_name, metrics[0], metrics[1] + layer_name, rowwise_gemm, iteration, tensor_name, metrics[0], metrics[1], config ) if columnwise_gemm is not None: metrics = self._compute_metrics(tensor, columnwise_quantized_tensor) if metrics is not None: self._update_metric( - layer_name, columnwise_gemm, iteration, tensor_name, metrics[0], metrics[1] + layer_name, + columnwise_gemm, + iteration, + tensor_name, + metrics[0], + metrics[1], + config, ) diff --git a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py index c3d2e4b393..6436ece062 100644 --- a/transformer_engine/debug/pytorch/gemm_runtime_hooks.py +++ b/transformer_engine/debug/pytorch/gemm_runtime_hooks.py @@ -13,6 +13,27 @@ from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer from transformer_engine.pytorch.utils import cast_if_needed +_AUTOSWITCH_LOGGING_ENV = "NVTE_AUTOSWITCH_GEMM_LOGGING" + + +def _env_flag_enabled(name: str, default: bool = False) -> bool: + """Interpret common boolean environment flag values.""" + default_value = "1" if default else "0" + return os.getenv(name, default_value).strip().lower() in {"1", "true", "yes", "on"} + + +def _autoswitch_logging_enabled() -> bool: + """Return True when verbose AutoswitchGemm runtime logging is enabled.""" + try: + from transformer_engine.debug.features.autoswitch_gemm import ( + autoswitch_gemm_logging_enabled, + ) + + return bool(autoswitch_gemm_logging_enabled()) + except Exception: # pylint: disable=broad-except + return _env_flag_enabled(_AUTOSWITCH_LOGGING_ENV, False) + + def _is_fp8_debug_quantizer(quantizer: Optional[Quantizer]) -> bool: """Return True for DebugQuantizer objects wrapping an FP8/NVFP4 quantizer.""" return ( @@ -226,6 +247,8 @@ def _log_final_gemm_decision( actual_precision: str, ) -> None: """Write final AutoswitchGemm decision to the autoswitch rank-local log.""" + if not _autoswitch_logging_enabled(): + return rank = os.getenv("RANK", "0") if rank != "0": return From 8386078803a7aebee82ef4d4003ed17cd670fe2a Mon Sep 17 00:00:00 2001 From: "xshang@nvidia.com" Date: Wed, 24 Jun 2026 16:09:27 +0800 Subject: [PATCH 15/16] direct_high_precision_in_hold_window --- docs/debug/1_getting_started.rst | 7 +++- docs/debug/autoswitch_gemm_example.yaml | 5 +++ .../debug/features/autoswitch_gemm.py | 42 +++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index 158a3b9afb..aae7e53ab2 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -245,6 +245,7 @@ Example config matching attention and MLP linears: underflow_threshold_pct: 5 mse_threshold: 0.1 allow_fp8_model_params_dequantized_weight: True + direct_high_precision_in_hold_window: True freq: 10 start_step: 10 @@ -261,7 +262,11 @@ Behavior summary: 4. If model parameters are stored in a quantized format, set ``allow_fp8_model_params_dequantized_weight: True`` to allow ``fprop`` and ``dgrad`` to switch by using temporary dequantized weights. -5. When CUDA Graphs are used, sampling and high-precision windows must run in eager +5. Set ``direct_high_precision_in_hold_window: True`` to directly select + high-precision tensor plans on non-sampling hold-window iterations. This + bypasses runtime quantize->dequantize conversion when high-precision source + tensors are available. +6. When CUDA Graphs are used, sampling and high-precision windows must run in eager mode. Quantized windows can continue using CUDA Graphs if the training framework supports this routing. Megatron-LM support for this workflow depends on the ``autogemm`` branch: diff --git a/docs/debug/autoswitch_gemm_example.yaml b/docs/debug/autoswitch_gemm_example.yaml index 9b2fc11e3e..a749b494a7 100644 --- a/docs/debug/autoswitch_gemm_example.yaml +++ b/docs/debug/autoswitch_gemm_example.yaml @@ -45,6 +45,11 @@ log_tensor_stats_all: # switch to high precision by using temporary dequantized weights. allow_fp8_model_params_dequantized_weight: True + # Optional: in hold-window non-sampling steps, route directly to + # high-precision plans when source tensors are available in bf16/fp16. + # This avoids quantize->dequantize conversion in runtime hooks. + direct_high_precision_in_hold_window: True + # Start sampling at step 10, then sample every 10 steps. A threshold # breach at step N keeps that (layer, GEMM) in high precision through # step N + freq - 1. The next sampling step refreshes the decision. diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index 387d6c06c0..df978d44a5 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -312,6 +312,12 @@ class AutoswitchGemm(TEConfigAPIMapper): tensor for GEMM execution. If False, `fprop`/`dgrad` stay quantized for such layers. + direct_high_precision_in_hold_window: bool, default = False + If True, non-sampling iterations within an active hold window return + high-precision plans directly from `fp8_gemm_enabled(..., final_decision=False)`. + This bypasses the quantize->dequantize runtime conversion path for tensors + that are available in high precision. + freq/start_step/end_step/start_end_list: Optional Sampling controls for tensor inspection calls. @@ -351,6 +357,7 @@ class AutoswitchGemm(TEConfigAPIMapper): _DEFAULT_UNDERFLOW_THRESHOLD_PCT = 5.0 _DEFAULT_MSE_THRESHOLD = 1e-4 _DEFAULT_ALLOW_FP8_MODEL_PARAMS_DEQUANTIZED_WEIGHT = True + _DEFAULT_DIRECT_HIGH_PRECISION_IN_HOLD_WINDOW = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -454,6 +461,18 @@ def _get_root_log_dir() -> Optional[str]: return None return root_log_dir + @staticmethod + def _is_sampling_iteration_for_config(config: Dict, iteration: int) -> bool: + """Return True when this config schedules tensor sampling in `iteration`.""" + run_current, _ = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + return run_current + def _logging_enabled(self, config: Optional[Dict] = None) -> bool: """Return True when verbose AutoswitchGemm logging is enabled.""" if config is not None: @@ -646,6 +665,29 @@ def fp8_gemm_enabled( # Keep plan-time behavior quantized. Autoswitch decisions are applied only # at final decision points right before GEMM launch. if not final_decision: + direct_high_precision_in_hold_window = self._config_bool( + config, + "direct_high_precision_in_hold_window", + self._DEFAULT_DIRECT_HIGH_PRECISION_IN_HOLD_WINDOW, + ) + if ( + direct_high_precision_in_hold_window + and not self._is_sampling_iteration_for_config(config, iteration) + and iteration <= state.disable_until_iter + ): + fp8_model_params_layer = self._layer_has_fp8_model_params.get(layer_name, False) + allow_fp8_model_params_fallback = self._config_bool( + config, + "allow_fp8_model_params_dequantized_weight", + self._DEFAULT_ALLOW_FP8_MODEL_PARAMS_DEQUANTIZED_WEIGHT, + ) + if ( + gemm in {"fprop", "dgrad"} + and fp8_model_params_layer + and not allow_fp8_model_params_fallback + ): + return True, iteration + 1 + return False, iteration + 1 return True, iteration + 1 fp8_model_params_layer = self._layer_has_fp8_model_params.get(layer_name, False) From 91d25f0f65d7b61e85f37ed40cf68407283c7a77 Mon Sep 17 00:00:00 2001 From: "xshang@nvidia.com" Date: Fri, 26 Jun 2026 13:49:37 +0800 Subject: [PATCH 16/16] hold_window_scope: layer/global --- docs/debug/1_getting_started.rst | 7 ++- docs/debug/autoswitch_gemm_example.yaml | 5 ++ .../debug/features/autoswitch_gemm.py | 52 ++++++++++++++++++- 3 files changed, 61 insertions(+), 3 deletions(-) diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index aae7e53ab2..8a5aa7f8ad 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -246,6 +246,7 @@ Example config matching attention and MLP linears: mse_threshold: 0.1 allow_fp8_model_params_dequantized_weight: True direct_high_precision_in_hold_window: True + hold_window_scope: layer freq: 10 start_step: 10 @@ -266,7 +267,11 @@ Behavior summary: high-precision tensor plans on non-sampling hold-window iterations. This bypasses runtime quantize->dequantize conversion when high-precision source tensors are available. -6. When CUDA Graphs are used, sampling and high-precision windows must run in eager +6. Set ``hold_window_scope`` to control eager routing scope during hold windows + in CUDA graph mode: + ``global`` (default) forces eager globally when any layer is in hold window, + while ``layer`` only forces eager for layers that have active hold windows. +7. When CUDA Graphs are used, sampling and high-precision windows must run in eager mode. Quantized windows can continue using CUDA Graphs if the training framework supports this routing. Megatron-LM support for this workflow depends on the ``autogemm`` branch: diff --git a/docs/debug/autoswitch_gemm_example.yaml b/docs/debug/autoswitch_gemm_example.yaml index a749b494a7..021c78200e 100644 --- a/docs/debug/autoswitch_gemm_example.yaml +++ b/docs/debug/autoswitch_gemm_example.yaml @@ -50,6 +50,11 @@ log_tensor_stats_all: # This avoids quantize->dequantize conversion in runtime hooks. direct_high_precision_in_hold_window: True + # Optional: hold-window eager routing scope in CUDA graph mode. + # global: any triggered layer forces eager globally (default behavior) + # layer: only triggered layers force eager; others can stay on graphs + hold_window_scope: layer + # Start sampling at step 10, then sample every 10 steps. A threshold # breach at step N keeps that (layer, GEMM) in high precision through # step N + freq - 1. The next sampling step refreshes the decision. diff --git a/transformer_engine/debug/features/autoswitch_gemm.py b/transformer_engine/debug/features/autoswitch_gemm.py index df978d44a5..696d117e75 100644 --- a/transformer_engine/debug/features/autoswitch_gemm.py +++ b/transformer_engine/debug/features/autoswitch_gemm.py @@ -28,7 +28,9 @@ _AUTOSWITCH_CONFIG_FILE_LOADED = False _AUTOSWITCH_DISABLE_UNTIL_ENV = "NVTE_AUTOSWITCH_GEMM_DISABLE_UNTIL" _AUTOSWITCH_LOGGING_ENV = "NVTE_AUTOSWITCH_GEMM_LOGGING" +_AUTOSWITCH_HOLD_WINDOW_SCOPE_ENV = "NVTE_AUTOSWITCH_GEMM_HOLD_WINDOW_SCOPE" _AUTOSWITCH_LOGGING_ENABLED = False +_AUTOSWITCH_HOLD_WINDOW_SCOPE = "global" def _env_flag_enabled(name: str, default: bool = False) -> bool: @@ -40,6 +42,7 @@ def _env_flag_enabled(name: str, default: bool = False) -> bool: def _register_sampling_config(config: Dict) -> None: """Track AutoswitchGemm sampling schedules for runtime eager/graph routing.""" global _AUTOSWITCH_LOGGING_ENABLED + global _AUTOSWITCH_HOLD_WINDOW_SCOPE schedule = { "start_step": config.get("start_step", None), "end_step": config.get("end_step", None), @@ -53,6 +56,20 @@ def _register_sampling_config(config: Dict) -> None: if any(bool(config.get(key, False)) for key in ("log_metrics", "log_decisions", "verbose")): _AUTOSWITCH_LOGGING_ENABLED = True + scope = str(config.get("hold_window_scope", "global")).strip().lower() + if scope not in {"global", "layer"}: + scope = "global" + if scope == "layer": + _AUTOSWITCH_HOLD_WINDOW_SCOPE = "layer" + + +def _get_hold_window_scope() -> str: + """Return hold-window eager-routing scope: `global` or `layer`.""" + env_scope = os.getenv(_AUTOSWITCH_HOLD_WINDOW_SCOPE_ENV, "").strip().lower() + if env_scope in {"global", "layer"}: + return env_scope + return _AUTOSWITCH_HOLD_WINDOW_SCOPE + def autoswitch_gemm_logging_enabled() -> bool: """Return True when verbose AutoswitchGemm logging is globally enabled.""" @@ -130,9 +147,31 @@ def _get_global_disable_until() -> int: return max(max_disable_until, env_disable_until) -def autoswitch_gemm_should_force_eager(iteration: Optional[int] = None) -> bool: +def _get_layer_disable_until(layer_name: str, layer_number: Optional[int] = None) -> int: + """Return layer-local high-precision window end iteration.""" + if not layer_name and layer_number is None: + return -1 + + max_disable_until = -1 + layer_marker = f".layers.{layer_number}." if layer_number is not None else None + for (gemm_layer_name, _), disable_until in _AUTOSWITCH_DISABLE_UNTIL_BY_GEMM.items(): + if layer_name and ( + gemm_layer_name == layer_name or gemm_layer_name.startswith(f"{layer_name}.") + ): + max_disable_until = max(max_disable_until, disable_until) + continue + if layer_marker and layer_marker in gemm_layer_name: + max_disable_until = max(max_disable_until, disable_until) + return max_disable_until + + +def autoswitch_gemm_should_force_eager( + iteration: Optional[int] = None, + layer_name: Optional[str] = None, + layer_number: Optional[int] = None, +) -> bool: """ - Return True when AutoswitchGemm needs eager execution for the whole iteration. + Return True when AutoswitchGemm needs eager execution for current routing target. This is used by Megatron CUDA graph routing. Sampling iterations must be eager so tensor inspection can run; high-precision windows must also be eager because @@ -149,6 +188,9 @@ def autoswitch_gemm_should_force_eager(iteration: Optional[int] = None) -> bool: if _is_sampling_iteration(iteration): return True + if _get_hold_window_scope() == "layer": + return iteration <= _get_layer_disable_until(layer_name or "", layer_number) + return iteration <= _get_global_disable_until() @@ -318,6 +360,12 @@ class AutoswitchGemm(TEConfigAPIMapper): This bypasses the quantize->dequantize runtime conversion path for tensors that are available in high precision. + hold_window_scope: str, default = "global" + Controls eager routing scope for hold-window iterations in CUDA graph mode. + ``global`` keeps current behavior (any triggered layer forces eager globally). + ``layer`` enables per-layer eager routing (only layers that triggered hold + windows force eager; other layers can stay on graph replay). + freq/start_step/end_step/start_end_list: Optional Sampling controls for tensor inspection calls.