diff --git a/src/spatialdata_plot/pl/_color.py b/src/spatialdata_plot/pl/_color.py new file mode 100644 index 00000000..c3e6a21d --- /dev/null +++ b/src/spatialdata_plot/pl/_color.py @@ -0,0 +1,1212 @@ +"""Color resolution, palettes, and colormap helpers (extracted from utils.py, see #696).""" + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from copy import copy +from typing import Any, Literal + +import matplotlib +import numpy as np +import pandas as pd +import spatialdata as sd +from anndata import AnnData +from cycler import Cycler, cycler +from geopandas import GeoDataFrame +from matplotlib import colors, rcParams +from matplotlib.cm import ScalarMappable +from matplotlib.colors import ( + ColorConverter, + Colormap, + LinearSegmentedColormap, + ListedColormap, + Normalize, + to_rgba, +) +from numpy.random import default_rng +from pandas.api.types import CategoricalDtype, is_bool_dtype, is_numeric_dtype, is_string_dtype +from pandas.core.arrays.categorical import Categorical +from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation +from scanpy.plotting.palettes import default_20, default_28, default_102 +from skimage.color import label2rgb +from skimage.morphology import erosion, footprint_rectangle +from skimage.util import map_array +from spatialdata import ( + get_values, +) +from spatialdata._core.query.relational_query import _locate_value +from spatialdata._types import ArrayLike +from spatialdata.models import ( + SpatialElement, +) + +from spatialdata_plot._logging import logger +from spatialdata_plot.pl.render_params import ( + CmapParams, + Color, + ColorLike, + OutlineParams, +) +from spatialdata_plot.pl.utils import ( + _MPL_SINGLE_LETTER_COLORS, + _build_alignment_dtype_hint, + _ensure_one_to_one_mapping, + _format_element_name, + to_hex, +) + + +def _is_color_like(color: Any) -> bool: + """Check if a value is a valid color. + + We reject several matplotlib shorthand notations that are likely to collide + with column or gene names. For discussion, see: + + - https://github.com/scverse/spatialdata-plot/issues/211 + - https://github.com/scverse/spatialdata-plot/issues/327 + + Rejected shorthands: + + - Greyscale strings: ``"0"``, ``"0.5"``, ``"1"`` (floats in [0, 1]) + - Short hex: ``"#RGB"`` / ``"#RGBA"`` (only ``#RRGGBB`` / ``#RRGGBBAA`` accepted) + - Single-letter colors: ``"b"``, ``"g"``, ``"r"``, ``"c"``, ``"m"``, ``"y"``, ``"k"``, ``"w"`` + - CN cycle notation: ``"C0"``, ``"C1"``, … + - ``tab:`` prefixed colors: ``"tab:blue"``, ``"tab:orange"``, … + - ``xkcd:`` prefixed colors: ``"xkcd:sky blue"``, … + """ + if isinstance(color, str): + # greyscale strings + try: + num_value = float(color) + if 0 <= num_value <= 1: + return False + except ValueError: + pass + + # short hex + if color.startswith("#") and len(color) not in [7, 9]: + return False + + # single-letter color shortcuts + if color in _MPL_SINGLE_LETTER_COLORS: + return False + + # CN cycle notation (C0, C1, …) + if len(color) >= 2 and color[0] == "C" and color[1:].isdigit(): + return False + + # tab: and xkcd: prefixed colors + if color.startswith(("tab:", "xkcd:")): + return False + + return bool(colors.is_color_like(color)) + + +def _make_continuous_mappable(vmin: float, vmax: float, cmap: Any) -> ScalarMappable: + """Build a ``ScalarMappable`` for a continuous colorbar, with a ±0.5 fallback when ``vmin == vmax``.""" + if vmin == vmax: + vmin, vmax = vmin - 0.5, vmax + 0.5 + return ScalarMappable(norm=Normalize(vmin=vmin, vmax=vmax), cmap=cmap) + + +def _apply_mask_to_outline_vectors( + outline_color_vector: Any, + outline_color_source_vector: pd.Series | None, + mask: Any, +) -> tuple[Any, pd.Series | None]: + """Apply a boolean ``keep`` mask to outline color vector(s). + + Used to keep outline data aligned with the fill data after a ``groups`` + or rasterize-based filter is applied to the rendered element. + """ + arr = np.asarray(mask) + if outline_color_source_vector is not None: + outline_color_source_vector = outline_color_source_vector[arr] + return outline_color_vector[arr], outline_color_source_vector + + +def _align_outline_vector_to_length( + outline_color_vector: Any, + outline_color_source_vector: pd.Series | None, + n: int, +) -> tuple[Any, pd.Series | None]: + """Pad or truncate the outline color vector(s) to length ``n``. + + Used when the outline column annotates a different row count than the rendered + element (cross-table case, or rasterize-induced label drop). Missing entries + are padded with NaN so downstream code maps them to ``na_color``. + """ + if outline_color_vector is None or len(outline_color_vector) == n: + return outline_color_vector, outline_color_source_vector + if len(outline_color_vector) > n: + if outline_color_source_vector is not None: + outline_color_source_vector = outline_color_source_vector[:n] + return outline_color_vector[:n], outline_color_source_vector + pad = n - len(outline_color_vector) + if outline_color_source_vector is not None: + # Categorical: downstream picks one hex per category from rows that *have* a + # category. NaN-padded rows contribute no category, so the per-row hex pad is + # immaterial; pad with NaN to skip the allocation. + padded_vec = np.concatenate([np.asarray(outline_color_vector), np.full(pad, np.nan, dtype=object)]) + outline_color_source_vector = pd.Categorical( + list(outline_color_source_vector) + [None] * pad, + categories=outline_color_source_vector.categories, + ) + else: + # Continuous: numeric vector, pad with NaN so cmap maps padded rows to na_color. + padded_vec = np.concatenate([np.asarray(outline_color_vector, dtype=float), np.full(pad, np.nan)]) + return padded_vec, outline_color_source_vector + + +def _color_vector_to_rgba( + color_vector: Any | None, + color_source_vector: pd.Series | None, + cmap_params: CmapParams, + n_rows: int, +) -> np.ndarray: + """Convert a fill/outline `color_vector` (categorical hex strings or continuous numerics) to (N, 4) RGBA. + + Mirrors the per-row mapping done inside :func:`_get_collection_shape` so that + callers can pre-materialize an outline-color array. NaN/non-finite entries are + painted with ``cmap_params.na_color``. + """ + na_rgba = colors.to_rgba(cmap_params.na_color.get_hex_with_alpha()) + if color_vector is None: + rgba = np.empty((n_rows, 4), dtype=float) + rgba[:] = na_rgba + return rgba + + if color_source_vector is not None: + # Categorical: color_vector contains hex strings aligned to color_source_vector + return np.asarray(ColorConverter().to_rgba_array(list(color_vector))) + + arr = np.asarray(color_vector) + if arr.ndim == 2 and arr.shape[1] in (3, 4) and np.issubdtype(arr.dtype, np.number): + return np.asarray(ColorConverter().to_rgba_array(arr)) + + rgba = np.empty((len(arr), 4), dtype=float) + rgba[:] = na_rgba + if np.issubdtype(arr.dtype, np.number): + finite_mask = np.isfinite(arr) + if finite_mask.any(): + norm = cmap_params.norm + if norm.vmin is None or norm.vmax is None: + vmin = float(np.nanmin(arr[finite_mask])) + vmax = float(np.nanmax(arr[finite_mask])) + if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: + vmin, vmax = 0.0, 1.0 + used_norm = Normalize(vmin=vmin, vmax=vmax, clip=False) + else: + used_norm = norm + rgba[finite_mask] = cmap_params.cmap(used_norm(arr[finite_mask])) + return rgba + + # Object dtype: mix of numerics and color-like specs (apply cmap to the numeric subset only) + series = pd.Series(arr, copy=False) + num = pd.to_numeric(series, errors="coerce").to_numpy() + is_num = np.isfinite(num) + if is_num.any(): + norm = cmap_params.norm + if norm.vmin is None or norm.vmax is None: + vmin = float(np.nanmin(num[is_num])) + vmax = float(np.nanmax(num[is_num])) + if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: + vmin, vmax = 0.0, 1.0 + used_norm = Normalize(vmin=vmin, vmax=vmax, clip=False) + else: + used_norm = norm + rgba[is_num] = cmap_params.cmap(used_norm(num[is_num])) + color_mask = (~is_num) & series.notna().to_numpy() + if color_mask.any(): + rgba[color_mask] = ColorConverter().to_rgba_array(series[color_mask].tolist()) + return rgba + + +def _prepare_cmap_norm( + cmap: Colormap | str | None = None, + norm: Normalize | None = None, + na_color: Color = Color(), +) -> CmapParams: + # TODO: check refactoring norm out here as it gets overwritten later + cmap_is_default = cmap is None + if cmap is None: + cmap = rcParams["image.cmap"] + if isinstance(cmap, str): + cmap = matplotlib.colormaps[cmap] + + cmap = copy(cmap) + + assert isinstance(cmap, Colormap), f"Invalid type of `cmap`: {type(cmap)}, expected `Colormap`." + + norm = Normalize(vmin=None, vmax=None, clip=False) if norm is None else copy(norm) + + cmap.set_bad(na_color.get_hex_with_alpha()) + + return CmapParams( + cmap=cmap, + norm=norm, + na_color=na_color, + cmap_is_default=cmap_is_default, + ) + + +def _set_outline( + outline_alpha: float | int | tuple[float | int, float | int] | None, + outline_width: int | float | tuple[float | int, float | int] | None, + outline_color: Color | tuple[Color, Color | None] | None, + **kwargs: Any, +) -> tuple[tuple[float, float], OutlineParams]: + """Create OutlineParams object for shapes, including possibility of double outline. + + Rules for outline rendering: + 1) outline_alpha always takes precedence if given by the user. + In absence of outline_alpha: + 2) If outline_color is specified and implying an alpha (e.g. RGBA array or #RRGGBBAA): that alpha is used + 3) If outline_color (w/o implying an alpha) and/or outline_width is specified: alpha of outlines set to 1.0 + """ + # A) User doesn't want to see outlines + if ( + outline_alpha == 0.0 + or (isinstance(outline_alpha, tuple) and np.all(np.array(outline_alpha) == 0.0)) + or not (outline_alpha or outline_width or outline_color) + ): + return (0.0, 0.0), OutlineParams(None, 1.5, None, 0.5) + + # B) User wants to see at least 1 outline + if isinstance(outline_width, tuple): + if len(outline_width) != 2: + raise ValueError( + f"Tuple of length {len(outline_width)} was passed for outline_width. When specifying multiple outlines," + " please pass a tuple of exactly length 2." + ) + if not outline_color: + outline_color = (Color("#000000"), Color("#ffffff")) + elif not isinstance(outline_color, tuple): + raise ValueError( + "No tuple was passed for outline_color, while two outlines were specified by using the outline_width " + "argument. Please specify the outline colors in a tuple of length two." + ) + + if isinstance(outline_color, tuple): + if len(outline_color) != 2: + raise ValueError( + f"Tuple of length {len(outline_color)} was passed for outline_color. When specifying multiple outlines," + " please pass a tuple of exactly length 2." + ) + if not outline_width: + outline_width = (1.5, 0.5) + elif not isinstance(outline_width, tuple): + raise ValueError( + "No tuple was passed for outline_width, while two outlines were specified by using the outline_color " + "argument. Please specify the outline widths in a tuple of length two." + ) + + if isinstance(outline_width, float | int): + outline_width = (outline_width, 0.0) + elif not outline_width: + outline_width = (1.5, 0.0) + if isinstance(outline_color, Color): + outline_color = (outline_color, None) + elif not outline_color: + outline_color = (Color("#000000ff"), None) + + assert isinstance(outline_color, tuple), "outline_color is not a tuple" # shut up mypy + assert isinstance(outline_width, tuple), "outline_width is not a tuple" + + for ow in outline_width: + if not isinstance(ow, int | float): + raise TypeError(f"Invalid type of `outline_width`: {type(ow)}, expected `int` or `float`.") + + if outline_alpha: + if isinstance(outline_alpha, int | float): + # for a single outline: second width value is 0.0 + outline_alpha = (outline_alpha, 0.0) if outline_width[1] == 0.0 else (outline_alpha, outline_alpha) + else: + # if alpha wasn't explicitly specified by the user + outer_ol_alpha = outline_color[0].get_alpha_as_float() if isinstance(outline_color[0], Color) else 1.0 + inner_ol_alpha = outline_color[1].get_alpha_as_float() if isinstance(outline_color[1], Color) else 1.0 + outline_alpha = (outer_ol_alpha, inner_ol_alpha) + + # handle possible linewidths of 0.0 => outline won't be rendered in the first place + if outline_width[0] == 0.0: + outline_alpha = (0.0, outline_alpha[1]) + if outline_width[1] == 0.0: + outline_alpha = (outline_alpha[0], 0.0) + + if outline_alpha[0] > 0.0 or outline_alpha[1] > 0.0: + kwargs.pop("edgecolor", None) # remove edge from kwargs if present + kwargs.pop("alpha", None) # remove alpha from kwargs if present + + return outline_alpha, OutlineParams( + outline_color[0], + outline_width[0], + outline_color[1], + outline_width[1], + ) + + +def _get_colors_for_categorical_obs( + categories: Sequence[str | int], + palette: ListedColormap | str | list[str] | None = None, + alpha: float = 1.0, + cmap_params: CmapParams | None = None, +) -> list[str]: + """ + Return a list of colors for a categorical observation. + + Parameters + ---------- + adata + AnnData object + value_to_plot + Name of a valid categorical observation + categories + categories of the categorical observation. + + Returns + ------- + None + """ + len_cat = len(categories) + + # check if default matplotlib palette has enough colors + if palette is None: + if cmap_params is not None and not cmap_params.cmap_is_default: + palette = cmap_params.cmap + elif len(rcParams["axes.prop_cycle"].by_key()["color"]) >= len_cat: + cc = rcParams["axes.prop_cycle"]() + palette = [next(cc)["color"] for _ in range(len_cat)] + elif len_cat <= 20: + palette = default_20 + elif len_cat <= 28: + palette = default_28 + elif len_cat <= len(default_102): # 103 colors + palette = default_102 + else: + palette = ["grey" for _ in range(len_cat)] + logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") + else: + # raise error when user didn't provide the right number of colors in palette + if isinstance(palette, list) and len(palette) != len(categories): + raise ValueError( + f"The number of provided values in the palette ({len(palette)}) doesn't agree with the number of " + f"categories that should be colored ({categories})." + ) + + # otherwise, single channels turn out grey + color_idx = np.linspace(0, 1, len_cat) if len_cat > 1 else [0.7] + + if isinstance(palette, str): + palette = [to_hex(palette)] + elif isinstance(palette, list): + palette = [to_hex(x) for x in palette] + elif isinstance(palette, ListedColormap): + palette = [to_hex(x) for x in palette(color_idx, alpha=alpha)] + elif isinstance(palette, LinearSegmentedColormap): + palette = [to_hex(palette(x, alpha=alpha)) for x in color_idx] # type: ignore[attr-defined] + else: + raise TypeError(f"Palette is {type(palette)} but should be string or list.") + + return palette[:len_cat] # type: ignore[return-value] + + +def _infer_color_data_kind( + series: pd.Series, + value_to_plot: str, + element_name: list[str] | str | None, + table_name: str | None, + warn_on_object_to_categorical: bool = False, +) -> tuple[Literal["numeric", "categorical"], pd.Series | pd.Categorical]: + element_label = _format_element_name(element_name) + + if isinstance(series.dtype, pd.CategoricalDtype): + return "categorical", pd.Categorical(series) + + if is_bool_dtype(series.dtype): + return "numeric", series.astype(float) + + if is_numeric_dtype(series.dtype): + return "numeric", pd.to_numeric(series, errors="coerce") + + if is_string_dtype(series.dtype) or series.dtype == object: + non_na = series[~pd.isna(series)] + if len(non_na) == 0: + return "numeric", pd.to_numeric(series, errors="coerce") + + numeric_like = pd.to_numeric(non_na, errors="coerce") + has_numeric = numeric_like.notna().any() + has_non_numeric = numeric_like.isna().any() + + if has_numeric and has_non_numeric: + invalid_examples = non_na[numeric_like.isna()].astype(str).unique()[:3] + location = f" in table '{table_name}'" if table_name is not None else "" + raise TypeError( + f"Column '{value_to_plot}' for element '{element_label}'{location} contains both numeric and " + f"non-numeric values (e.g. {', '.join(invalid_examples)}). " + "Please ensure that the column stores consistent data." + ) + + if has_numeric: + return "numeric", pd.to_numeric(series, errors="coerce") + + if warn_on_object_to_categorical: + logger.warning( + f"Converting copy of '{value_to_plot}' column to categorical dtype for categorical plotting. " + "Consider converting before plotting." + ) + + return "categorical", pd.Categorical(series) + + return "numeric", pd.to_numeric(series, errors="coerce") + + +def _extract_color_column( + table: AnnData, + value_key: str, + *, + origin: str, + element: GeoDataFrame, + element_name: str, + table_layer: str | None = None, +) -> pd.Series: + """Read one color column from ``table`` aligned to ``element`` order, without copying the table. + + Equivalent to ``get_values(value_key, sdata=..., element_name=..., table_name=...)[value_key]`` but + skips the table->element join, whose ``table[indices, :].copy()`` does an expensive out-of-order + sparse CSR row-gather. Restricts to rows annotating ``element_name`` (via ``region_key``), then + reindexes to the element's instance order (``NaN`` for instances with no table row), preserving the + categorical dtype of ``obs`` columns so the downstream legend path is unchanged. + """ + attrs = table.uns["spatialdata_attrs"] + region_key, instance_key = attrs["region_key"], attrs["instance_key"] + mask = table.obs[region_key].to_numpy() == element_name + inst = table.obs[instance_key].to_numpy()[mask] + if origin == "var": + source = table.layers[table_layer] if table_layer is not None else table.X + col = source[:, table.var_names.get_loc(value_key)] + col = np.asarray(col.todense()).ravel() if hasattr(col, "todense") else np.asarray(col).ravel() + values = pd.Series(col[mask], index=inst) + else: # obs column; .values keeps a Categorical categorical so the legend path still sees one + values = pd.Series(table.obs[value_key].values[mask], index=inst) + return values.reindex(element.index) + + +def _set_color_source_vec( + sdata: sd.SpatialData, + element: SpatialElement | None, + value_to_plot: str | None, + na_color: Color, + element_name: list[str] | str | None = None, + groups: list[str] | str | None = None, + palette: dict[str, str] | list[str] | str | None = None, + cmap_params: CmapParams | None = None, + alpha: float = 1.0, + table_name: str | None = None, + table_layer: str | None = None, + render_type: Literal["points", "labels"] | None = None, + coordinate_system: str | None = None, + preloaded_color_data: pd.Series | None = None, +) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]: + if value_to_plot is None and element is not None: + color = np.full(len(element), na_color.get_hex_with_alpha()) + return color, color, False + + # Figure out where to get the color from + origins = _locate_value( + value_key=value_to_plot, + sdata=sdata, + element_name=element_name, + table_name=table_name, + ) + + # When both the element's own dataframe and the chosen table contain a + # column with this name, an explicit `table_name=` resolves the ambiguity — + # keep only the table origin and skip the multi-origin error below. + explicit_table_shadows_df = table_name is not None and any(o.origin == "df" for o in origins) + if explicit_table_shadows_df: + origins = [o for o in origins if o.origin != "df"] + + if len(origins) > 1: + raise ValueError( + f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. " + "Please keep it in exactly one place (preferably on the points parquet for speed) to avoid ambiguity." + ) + + if len(origins) == 1 and value_to_plot is not None: + if table_name is not None: + _ensure_one_to_one_mapping( + sdata=sdata, + element=element, + element_name=element_name, + table_name=table_name, + ) + if preloaded_color_data is not None: + color_source_vector = preloaded_color_data + elif ( + isinstance(element, GeoDataFrame) + and isinstance(element_name, str) + and table_name is not None + and table_name in sdata.tables + and origins[0].origin in ("obs", "var") + ): + # Fast path: read the single aligned column directly instead of joining/copying the + # whole annotating table (the join's out-of-order sparse row-gather dominates large renders). + color_source_vector = _extract_color_column( + sdata[table_name], + value_to_plot, + origin=origins[0].origin, + element=element, + element_name=element_name, + table_layer=table_layer, + ) + elif explicit_table_shadows_df: + # Pass the table as `element` so upstream `get_values` skips the + # element-column lookup and avoids the multi-origin error. + color_source_vector = get_values( + value_key=value_to_plot, + element=sdata[table_name], + element_name=element_name, + table_layer=table_layer, + )[value_to_plot] + else: + color_source_vector = get_values( + value_key=value_to_plot, + sdata=sdata, + element_name=element_name, + table_name=table_name, + table_layer=table_layer, + )[value_to_plot] + + color_series = ( + color_source_vector if isinstance(color_source_vector, pd.Series) else pd.Series(color_source_vector) + ) + + if color_series.isna().all(): + element_label = _format_element_name(element_name) + dtype_hint = _build_alignment_dtype_hint(sdata, element, color_series, table_name) + hint_suffix = f" {dtype_hint.strip()}" if dtype_hint else "" + logger.warning( + f"Column '{value_to_plot}' for element '{element_label}' contains only NaN values; " + f"rendering with na_color.{hint_suffix}" + ) + na_color_arr = np.full(len(color_series), na_color.get_hex_with_alpha()) + return na_color_arr, na_color_arr, False + + kind, processed = _infer_color_data_kind( + series=color_series, + value_to_plot=value_to_plot, + element_name=element_name, + table_name=table_name, + warn_on_object_to_categorical=table_name is not None, + ) + + if kind == "numeric": + numeric_vector = processed + if ( + not isinstance(element, GeoDataFrame) + and isinstance(palette, list) + and palette[0] is not None + or isinstance(element, GeoDataFrame) + and isinstance(palette, list) + ): + logger.warning( + "Ignoring categorical palette which is given for a continuous variable. " + "Consider using `cmap` to pass a ColorMap." + ) + return None, numeric_vector, False + + assert isinstance(processed, pd.Categorical) + if not processed.ordered: + # ensure deterministic category order when the source is unordered (e.g., from a Python set) + processed = processed.reorder_categories(sorted(processed.categories)) + color_source_vector = processed # convert, e.g., `pd.Series` + + # When the value lives on the element's own DataFrame (origin="df"), + # there is no reason to look up a table for .uns colors. + value_from_element = origins[0].origin == "df" + + # Use the provided table_name parameter, fall back to only one present + table_to_use: str | None + if value_from_element: + table_to_use = None + elif table_name is not None and table_name in sdata.tables: + table_to_use = table_name + elif table_name is not None and table_name not in sdata.tables: + logger.warning(f"Table '{table_name}' not found in `sdata.tables`. Falling back to default behavior.") + table_to_use = None + else: + table_keys = list(sdata.tables.keys()) + if len(table_keys) == 1: + table_to_use = table_keys[0] + elif len(table_keys) > 1: + table_to_use = table_keys[0] + logger.warning(f"No table name provided, using '{table_to_use}' as fallback for color mapping.") + else: + table_to_use = None + + adata_for_mapping = sdata[table_to_use] if table_to_use is not None else None + + # Check if custom colors exist in the resolved table's .uns slot + if ( + value_to_plot is not None + and table_to_use is not None + and _has_colors_in_uns(sdata, table_to_use, value_to_plot) + ): + # Extract colors directly from the table's .uns slot + # Convert Color to ColorLike (str) for the function + na_color_like: ColorLike = na_color.get_hex() if isinstance(na_color, Color) else na_color + color_mapping = _extract_colors_from_table_uns( + sdata=sdata, + table_name=table_to_use, + col_to_colorby=value_to_plot, + color_source_vector=color_source_vector, + na_color=na_color_like, + ) + if color_mapping is not None: + if isinstance(palette, str): + palette = [palette] + color_mapping = _modify_categorical_color_mapping( + mapping=color_mapping, + groups=groups, + palette=palette, + ) + else: + logger.warning(f"Failed to extract colors for '{value_to_plot}', falling back to default mapping.") + # Fall back to the existing method if extraction fails + color_mapping = _get_categorical_color_mapping( + adata=sdata[table_to_use], + cluster_key=value_to_plot, + color_source_vector=color_source_vector, + cmap_params=cmap_params, + alpha=alpha, + groups=groups, + palette=palette, + na_color=na_color, + render_type=render_type, + ) + else: + color_mapping = None + + if color_mapping is None: + # Use the existing color mapping method + color_mapping = _get_categorical_color_mapping( + adata=adata_for_mapping, + cluster_key=value_to_plot, + color_source_vector=color_source_vector, + cmap_params=cmap_params, + alpha=alpha, + groups=groups, + palette=palette, + na_color=na_color, + render_type=render_type, + ) + + color_source_vector = color_source_vector.set_categories(color_mapping.keys()) + if color_mapping is None: + raise ValueError("Unable to create color palette.") + + # do not rename categories, as colors need not be unique + # pd.Categorical.map() demotes to object dtype when mapped values aren't unique + # (e.g. two categories share a color). Wrapping back in pd.Categorical ensures + # downstream consumers always receive a Categorical for categorical data. + color_vector = pd.Categorical(color_source_vector.map(color_mapping, na_action="ignore")) + # nan handling: only add the NA category if needed, and store it as a hex string + na_color_hex = na_color.get_hex_with_alpha() if isinstance(na_color, Color) else str(na_color) + if color_vector.isna().any(): + if na_color_hex not in color_vector.categories: + color_vector = color_vector.add_categories(na_color_hex) + color_vector[pd.isna(color_vector)] = na_color_hex + + return color_source_vector, color_vector, True + + if table_name is None: + raise KeyError( + f"Unable to locate color key '{value_to_plot}' for element '{element_name}'. " + "Please ensure the key exists in a table annotating this element." + ) + raise KeyError( + f"Unable to locate color key '{value_to_plot}' in table '{table_name}' for element '{element_name}'." + ) + + +def _map_color_seg( + seg: ArrayLike, + cell_id: ArrayLike, + color_vector: ArrayLike | pd.Series[CategoricalDtype], + color_source_vector: pd.Series[CategoricalDtype], + cmap_params: CmapParams, + na_color: Color, + seg_erosionpx: int | None = None, + seg_boundaries: bool = False, + outline_color: Color | None = None, + outline_color_vector: ArrayLike | pd.Series[CategoricalDtype] | None = None, + outline_color_source_vector: pd.Series[CategoricalDtype] | None = None, +) -> ArrayLike: + cell_id = np.array(cell_id) + + if isinstance(color_vector.dtype, pd.CategoricalDtype): + # Case A: users wants to plot a categorical column + val_im: ArrayLike = map_array(seg, cell_id, color_vector.codes + 1) + cols = colors.to_rgba_array(color_vector.categories) + elif pd.api.types.is_numeric_dtype(color_vector.dtype): + # Case B: user wants to plot a continous column + if isinstance(color_vector, pd.Series): + color_vector = color_vector.to_numpy() + # normalize only the not nan values, else the whole array would contain only nan values + normed_color_vector = color_vector.copy().astype(float) + normed_color_vector[~np.isnan(normed_color_vector)] = cmap_params.norm( + normed_color_vector[~np.isnan(normed_color_vector)] + ) + cols = cmap_params.cmap(normed_color_vector) + val_im = map_array(seg, cell_id, cell_id) + else: + # Case C: User didn't specify any colors + if color_source_vector is not None and ( + set(color_vector) == set(color_source_vector) + and len(set(color_vector)) == 1 + and set(color_vector) == {na_color.get_hex_with_alpha()} + and not na_color.color_modified_by_user() + ): + val_im = map_array(seg, cell_id, cell_id) + RNG = default_rng(42) + cols = RNG.random((len(color_vector), 3)) + else: + # Case D: User didn't specify a column to color by, but modified the na_color + val_im = map_array(seg, cell_id, cell_id) + first_value = color_vector.iloc[0] if isinstance(color_vector, pd.Series) else color_vector[0] + if _is_color_like(first_value): + # we have color-like values (e.g., hex or named colors) + assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like." + cols = colors.to_rgba_array(color_vector) + else: + cols = cmap_params.cmap(cmap_params.norm(color_vector)) + + if seg_erosionpx is not None: + val_im[val_im == erosion(val_im, footprint_rectangle((seg_erosionpx, seg_erosionpx)))] = 0 + + if seg_boundaries and outline_color_vector is not None: + # Column-driven outline: build per-label colors from the outline vector and overlay + # on the eroded ring. Two cases (mirroring _set_color_source_vec's return contract): + # - categorical: outline_color_source_vector is the source Categorical; outline_color_vector + # holds hex strings aligned to cells. + # - continuous: outline_color_source_vector is None; outline_color_vector is numeric. + if outline_color_source_vector is not None: + cat = pd.Categorical(outline_color_source_vector) + cat_codes = cat.codes + outline_val_im: ArrayLike = map_array(seg, cell_id, cat_codes + 1) + color_arr = np.asarray(outline_color_vector, dtype=object) + # Pick the first per-cell hex for each category in one vectorized pass + # (avoids `K × O(N)` Python loops on large label sets). + cat_colors: list[Any] = [na_color.get_hex_with_alpha()] * len(cat.categories) + unique_codes, first_indices = np.unique(cat_codes, return_index=True) + for code, idx in zip(unique_codes, first_indices, strict=True): + if code >= 0: + cat_colors[code] = color_arr[idx] + outline_cols = colors.to_rgba_array(cat_colors) + else: + # Continuous: numeric values normalized via cmap + ov = ( + outline_color_vector.to_numpy() + if isinstance(outline_color_vector, pd.Series) + else np.asarray(outline_color_vector) + ) + normed = ov.copy().astype(float) + finite = ~np.isnan(normed) + if finite.any(): + normed[finite] = cmap_params.norm(normed[finite]) + outline_cols = cmap_params.cmap(normed) + outline_val_im = map_array(seg, cell_id, cell_id) + if seg_erosionpx is not None: + outline_val_im[ + outline_val_im == erosion(outline_val_im, footprint_rectangle((seg_erosionpx, seg_erosionpx))) + ] = 0 + outline_seg_im = label2rgb( + label=outline_val_im, + colors=outline_cols, + bg_label=0, + bg_color=(1, 1, 1), + image_alpha=0, + ) + outline_mask = val_im > 0 + alpha_channel = outline_mask.astype(float) + return np.dstack((outline_seg_im, alpha_channel)) + + if seg_boundaries and outline_color is not None: + # Uniform outline color requested: skip label2rgb, build RGBA directly + outline_rgba = colors.to_rgba(outline_color.get_hex_with_alpha()) + outline_mask = val_im > 0 + rgba = np.zeros((*val_im.shape, 4), dtype=float) + rgba[outline_mask, :3] = outline_rgba[:3] + rgba[outline_mask, 3] = outline_rgba[3] + return rgba + + seg_im: ArrayLike = label2rgb( + label=val_im, + colors=cols, + bg_label=0, + bg_color=(1, 1, 1), # transparency doesn't really work + image_alpha=0, + ) + + if seg_boundaries: + # Data-driven outline: use seg_im colors on the eroded ring, transparent elsewhere + outline_mask = val_im > 0 + alpha_channel = outline_mask.astype(float) + return np.dstack((seg_im, alpha_channel)) + + if len(val_im.shape) != len(seg_im.shape): + val_im = np.expand_dims((val_im > 0).astype(int), axis=-1) + return np.dstack((seg_im, val_im)) + + +def _generate_base_categorial_color_mapping( + adata: AnnData | None, + cluster_key: str, + color_source_vector: ArrayLike | pd.Series[CategoricalDtype], + na_color: Color, + cmap_params: CmapParams | None = None, +) -> Mapping[str, str]: + if adata is not None and cluster_key in adata.uns and f"{cluster_key}_colors" in adata.uns: + all_colors = adata.uns[f"{cluster_key}_colors"] + + # When plotting per-coordinate-system, the color_source_vector may carry + # categories from other coordinate systems that aren't present in the + # current subset. Drop them so that categories and colors stay aligned. + color_source_vector = color_source_vector.remove_unused_categories() + + # The stored colors in .uns correspond 1-to-1 to the *full* set of + # categories in adata.obs[cluster_key]. Subset to the categories that + # are still present after removing unused ones. + if cluster_key in adata.obs and hasattr(adata.obs[cluster_key], "cat"): + all_cats = adata.obs[cluster_key].cat.categories.tolist() + keep_idx = [i for i, c in enumerate(all_cats) if c in color_source_vector.categories] + colors = [to_hex(to_rgba(all_colors[i])[:3]) for i in keep_idx] + else: + colors = [to_hex(to_rgba(c)[:3]) for c in all_colors] + + categories = color_source_vector.categories.tolist() + ["NaN"] + + if len(categories) > len(colors): + return dict(zip(categories, colors + [na_color.get_hex_with_alpha()], strict=True)) + + return dict(zip(categories, colors, strict=True)) + + return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params) + + +def _has_colors_in_uns( + sdata: sd.SpatialData, + table_name: str | None, + col_to_colorby: str, +) -> bool: + """ + Check if _colors exists in the specified table's .uns slot. + + Parameters + ---------- + sdata + SpatialData object containing tables + table_name + Name of the table to check. If None, uses the first available table. + col_to_colorby + Name of the categorical column (e.g., "celltype") + + Returns + ------- + True if _colors exists in the table's .uns, False otherwise + """ + color_key = f"{col_to_colorby}_colors" + + # Determine which table to use + if table_name is not None: + if table_name not in sdata.tables: + return False + table_to_use = table_name + else: + if len(sdata.tables.keys()) == 0: + return False + # When no table is specified, check all tables for the color key + return any(color_key in adata.uns for adata in sdata.tables.values()) + + adata = sdata.tables[table_to_use] + return color_key in adata.uns + + +def _extract_colors_from_table_uns( + sdata: sd.SpatialData, + table_name: str | None, + col_to_colorby: str, + color_source_vector: ArrayLike | pd.Series[CategoricalDtype], + na_color: ColorLike, +) -> Mapping[str, str] | None: + """ + Extract categorical colors from the _colors pattern in adata.uns. + + This function looks for colors stored in the format _colors in the + specified table's .uns slot and creates a mapping from categories to colors. + + Parameters + ---------- + sdata + SpatialData object containing tables + table_name + Name of the table to look in. If None, uses the first available table. + col_to_colorby + Name of the categorical column (e.g., "celltype") + color_source_vector + Categorical vector containing the categories to map + na_color + Color to use for NaN/missing values + + Returns + ------- + Mapping from category names to hex colors, or None if colors not found + """ + color_key = f"{col_to_colorby}_colors" + + # Determine which table to use + if table_name is not None: + if table_name not in sdata.tables: + logger.warning(f"Table '{table_name}' not found in sdata. Available tables: {list(sdata.tables.keys())}") + return None + table_to_use = table_name + else: + if len(sdata.tables) == 0: + logger.warning("No tables found in sdata.") + return None + # No explicit table provided: search all tables for the color key + candidate_tables: list[str] = [ + name + for name, ad in sdata.tables.items() + if color_key in ad.uns # type: ignore[union-attr] + ] + if not candidate_tables: + logger.debug(f"Color key '{color_key}' not found in any table uns.") + return None + table_to_use = candidate_tables[0] + if len(candidate_tables) > 1: + logger.warning( + f"Color key '{color_key}' found in multiple tables {candidate_tables}; using table '{table_to_use}'." + ) + logger.info(f"No table name provided, using '{table_to_use}' for color extraction.") + + adata = sdata.tables[table_to_use] + + # Check if the color pattern exists + if color_key not in adata.uns: + logger.debug(f"Color key '{color_key}' not found in table '{table_to_use}' uns.") + return None + + # Extract colors and categories + stored_colors = adata.uns[color_key] + # Drop categories not present in the current subset (e.g. when plotting + # per-coordinate-system) so that positional color lookups stay aligned. + color_source_vector = color_source_vector.remove_unused_categories() + categories = color_source_vector.categories.tolist() + + # Validate na_color format and convert to hex string + if isinstance(na_color, Color): + na_color_hex = na_color.get_hex() + else: + na_color_str = str(na_color) + if "#" not in na_color_str: + logger.warning("Expected `na_color` to be a hex color, converting...") + na_color_hex = to_hex(to_rgba(na_color)[:3]) + else: + na_color_hex = na_color_str + + # Strip alpha channel from na_color if present + if len(na_color_hex) == 9: # #rrggbbaa format + na_color_hex = na_color_hex[:7] # Keep only #rrggbb + + def _to_hex_no_alpha(color_value: Any) -> str | None: + try: + rgba = to_rgba(color_value)[:3] + hex_color: str = to_hex(rgba) + if len(hex_color) == 9: + hex_color = hex_color[:7] + return hex_color + except (TypeError, ValueError) as e: + logger.warning(f"Error converting color '{color_value}' to hex format: {e}") + return None + + color_mapping: dict[str, str] = {} + + if isinstance(stored_colors, Mapping): + for category in categories: + raw_color = stored_colors.get(category) + if raw_color is None: + logger.warning(f"No color specified for '{category}' in '{color_key}', using na_color.") + color_mapping[category] = na_color_hex + continue + hex_color = _to_hex_no_alpha(raw_color) + color_mapping[category] = hex_color if hex_color is not None else na_color_hex + logger.info(f"Successfully extracted {len(color_mapping)} colors from '{color_key}' in table '{table_to_use}'.") + else: + try: + hex_colors = [_to_hex_no_alpha(color) for color in stored_colors] + except TypeError: + logger.warning(f"Unsupported color storage for '{color_key}'. Expected sequence or mapping.") + return None + + # Map by the category's position in the *full* table, not in the + # (possibly subset) color_source_vector, so colors stay consistent + # across coordinate systems. + all_cats = ( + adata.obs[col_to_colorby].cat.categories.tolist() + if col_to_colorby in adata.obs and hasattr(adata.obs[col_to_colorby], "cat") + else categories + ) + # Map category -> index once (O(K)) instead of a per-category list scan + # (was O(K^2) via list.index). all_cats comes from pandas .categories, + # which is unique, so a plain dict comprehension is sufficient. + cat_to_idx: dict[Any, int] = {c: i for i, c in enumerate(all_cats)} + for category in categories: + idx = cat_to_idx.get(category) + if idx is not None and idx < len(hex_colors) and hex_colors[idx] is not None: + hex_color = hex_colors[idx] + assert hex_color is not None # type narrowing for mypy + color_mapping[category] = hex_color + else: + logger.warning(f"Not enough colors provided for category '{category}', using na_color.") + color_mapping[category] = na_color_hex + logger.info(f"Successfully extracted {len(hex_colors)} colors from '{color_key}' in table '{table_to_use}'.") + + color_mapping["NaN"] = na_color_hex + return color_mapping + + +def _modify_categorical_color_mapping( + mapping: Mapping[str, str], + groups: list[str] | str | None = None, + palette: dict[str, str] | list[str] | str | None = None, +) -> Mapping[str, str]: + if groups is None or isinstance(groups, list) and groups[0] is None: + return mapping + + if palette is None or isinstance(palette, list) and palette[0] is None: + # subset base mapping to only those specified in groups + modified_mapping = {key: mapping[key] for key in mapping if key in groups or key == "NaN"} + elif len(palette) == len(groups) and isinstance(groups, list) and isinstance(palette, list): + modified_mapping = dict(zip(groups, palette, strict=True)) + else: + raise ValueError(f"Expected palette to be of length `{len(groups)}`, found `{len(palette)}`.") + + return modified_mapping + + +def _get_default_categorial_color_mapping( + color_source_vector: ArrayLike | pd.Series[CategoricalDtype], + cmap_params: CmapParams | None = None, +) -> Mapping[str, str]: + len_cat = len(color_source_vector.categories.unique()) + # Try to use provided colormap first + if cmap_params is not None and cmap_params.cmap is not None and not cmap_params.cmap_is_default: + # Generate evenly spaced indices for the colormap + color_idx = np.linspace(0, 1, len_cat) + if isinstance(cmap_params.cmap, ListedColormap): + palette = [to_hex(x) for x in cmap_params.cmap(color_idx)] + elif isinstance(cmap_params.cmap, LinearSegmentedColormap): + palette = [to_hex(cmap_params.cmap(x)) for x in color_idx] + else: + # Fall back to default palettes if cmap is not of expected type + palette = None + else: + palette = None + + # Fall back to default palettes if needed + if palette is None: + if len_cat <= 20: + palette = default_20 + elif len_cat <= 28: + palette = default_28 + elif len_cat <= len(default_102): # 103 colors + palette = default_102 + else: + palette = ["grey"] * len_cat + logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") + + return dict(zip(color_source_vector.categories, palette[:len_cat], strict=True)) + + +def _get_categorical_color_mapping( + adata: AnnData | None, + na_color: Color, + cluster_key: str | None = None, + color_source_vector: ArrayLike | pd.Series[CategoricalDtype] | None = None, + cmap_params: CmapParams | None = None, + alpha: float = 1, + groups: list[str] | str | None = None, + palette: dict[str, str] | list[str] | str | None = None, + render_type: Literal["points", "labels"] | None = None, +) -> Mapping[str, str]: + if not isinstance(color_source_vector, Categorical): + raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}") + + # Dict palette (e.g. from make_palette_from_data): use directly as category→color mapping + if isinstance(palette, dict): + na_color_hex = na_color.get_hex_with_alpha() if isinstance(na_color, Color) else str(na_color) + if isinstance(groups, str): + groups = [groups] + if groups is not None: + mapping = {cat: palette.get(cat, na_color_hex) for cat in groups if cat in color_source_vector.categories} + else: + mapping = {cat: palette.get(cat, na_color_hex) for cat in color_source_vector.categories} + mapping["NaN"] = na_color_hex + return mapping + + if isinstance(groups, str): + groups = [groups] + + if not palette and render_type == "points" and cmap_params is not None and not cmap_params.cmap_is_default: + palette = cmap_params.cmap + + color_idx = color_idx = np.linspace(0, 1, len(color_source_vector.categories)) + if isinstance(palette, ListedColormap): + palette = [to_hex(x) for x in palette(color_idx, alpha=alpha)] + elif isinstance(palette, LinearSegmentedColormap): + palette = [to_hex(palette(x, alpha=alpha)) for x in color_idx] # type: ignore[attr-defined] + return dict(zip(color_source_vector.categories, palette, strict=True)) + + if isinstance(palette, str): + palette = [palette] + + if cluster_key is None: + # user didn't specify a column to use for coloring + base_mapping = _get_default_categorial_color_mapping( + color_source_vector=color_source_vector, cmap_params=cmap_params + ) + else: + base_mapping = _generate_base_categorial_color_mapping( + adata=adata, + cluster_key=cluster_key, + color_source_vector=color_source_vector, + na_color=na_color, + cmap_params=cmap_params, + ) + + return _modify_categorical_color_mapping(mapping=base_mapping, groups=groups, palette=palette) + + +def _maybe_set_colors( + source: AnnData, + target: AnnData, + key: str, + palette: str | ListedColormap | Cycler | Sequence[Any] | None = None, +) -> None: + color_key = f"{key}_colors" + try: + if palette is not None: + raise KeyError("Unable to copy the palette when there was other explicitly specified.") + target.uns[color_key] = source.uns[color_key] + except KeyError: + if isinstance(palette, str): + palette = ListedColormap([palette]) + if isinstance(palette, ListedColormap): # `scanpy` requires it + palette = cycler(color=palette.colors) + palette = None + add_colors_for_categorical_sample_annotation(target, key=key, force_update_colors=True, palette=palette) + + +def _get_linear_colormap(colors: list[str], background: str) -> list[LinearSegmentedColormap]: + return [LinearSegmentedColormap.from_list(c, [background, c], N=256) for c in colors] diff --git a/src/spatialdata_plot/pl/_datashader.py b/src/spatialdata_plot/pl/_datashader.py index dd6e78e4..c8d30e9f 100644 --- a/src/spatialdata_plot/pl/_datashader.py +++ b/src/spatialdata_plot/pl/_datashader.py @@ -8,27 +8,43 @@ from copy import copy from typing import Any, Literal +import dask import dask.dataframe as dd import datashader as ds import matplotlib import matplotlib.colors +import matplotlib.image +import matplotlib.transforms as mtransforms import numpy as np +import numpy.ma as ma +import numpy.typing as npt import pandas as pd +from datashader.core import Canvas +from geopandas import GeoDataFrame +from matplotlib.axes import Axes from matplotlib.cm import ScalarMappable -from matplotlib.colors import Normalize +from matplotlib.colors import ( + Colormap, + LinearSegmentedColormap, + ListedColormap, + Normalize, +) +from matplotlib.transforms import CompositeGenericTransform +from numpy.ma.core import MaskedArray +from spatialdata.models import Image2DModel, SpatialElement +from spatialdata.transformations.operations import get_transformation +from spatialdata.transformations.transformations import Scale, Translation +from spatialdata.transformations.transformations import Sequence as TransformSequence +from xarray import DataArray from spatialdata_plot._logging import logger +from spatialdata_plot.pl._color import ( + _make_continuous_mappable, +) from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams, _DsReduction from spatialdata_plot.pl.utils import ( - _DS_REDUCTION_FUNCS, - _ax_show_and_transform, - _convert_alpha_to_datashader_range, - _create_image_from_datashader_result, - _datashader_aggregate_with_function, - _datashader_map_aggregate_to_color, - _datshader_get_how_kw_for_spread, - _hex_no_alpha, - _make_continuous_mappable, + _fast_extent, + to_hex, ) # --------------------------------------------------------------------------- @@ -244,7 +260,7 @@ def _ds_shade_continuous( if spread_px is not None: # markers overlay (don't accumulate): spread with "max" so overlapping dots keep the true # value range instead of summing and inflating the colorbar (see as_points). - spread_how = "max" if uniform_alpha else _datshader_get_how_kw_for_spread(ds_reduction) + spread_how = "max" if uniform_alpha else _datashader_get_how_kw_for_spread(ds_reduction) agg = ds.tf.spread(agg, px=spread_px, how=spread_how) reduction_bounds = (agg.min(), agg.max()) @@ -601,3 +617,419 @@ def _build_ds_colorbar( vmin = reduction_bounds[0].values if norm.vmin is None else norm.vmin vmax = reduction_bounds[1].values if norm.vmax is None else norm.vmax return _make_continuous_mappable(vmin, vmax, cmap) + + +# --------------------------------------------------------------------------- +# Datashader reduction constants +# --------------------------------------------------------------------------- + + +_DS_REDUCTION_FUNCS: dict[str, Any] = { + "sum": ds.sum, + "mean": ds.mean, + "any": ds.any, + "count": ds.count, + "std": ds.std, + "var": ds.var, + "max": ds.max, + "min": ds.min, +} + + +# --------------------------------------------------------------------------- +# Color / alpha helpers +# --------------------------------------------------------------------------- + + +def _hex_no_alpha(hex: str) -> str: + """ + Return a hex color string without an alpha component. + + Parameters + ---------- + hex : str + The input hex color string. Must be in one of the following formats: + - "#RRGGBB": a hex color without an alpha channel. + - "#RRGGBBAA": a hex color with an alpha channel that will be removed. + + Returns + ------- + str + The hex color string in "#RRGGBB" format. + """ + if not isinstance(hex, str): + raise TypeError("Input must be a string") + if not hex.startswith("#"): + raise ValueError("Invalid hex color: must start with '#'") + + hex_digits = hex[1:] + length = len(hex_digits) + + if length == 6: + if not all(c in "0123456789abcdefABCDEF" for c in hex_digits): + raise ValueError("Invalid hex color: contains non-hex characters") + return hex # Already in #RRGGBB format. + + if length == 8: + if not all(c in "0123456789abcdefABCDEF" for c in hex_digits): + raise ValueError("Invalid hex color: contains non-hex characters") + # Return only the first 6 characters, stripping the alpha. + return "#" + hex_digits[:6] + + raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'") + + +def _convert_alpha_to_datashader_range(alpha: float) -> float: + """Convert alpha from the range [0, 1] to the range [0, 255] used in datashader.""" + # prevent a value of 255, bc that led to fully colored test plots instead of just colored points/shapes + return min([254, alpha * 255]) + + +# --------------------------------------------------------------------------- +# Canvas geometry helpers +# --------------------------------------------------------------------------- + + +def _ax_show_and_transform( + array: MaskedArray[tuple[int, ...], Any] | npt.NDArray[Any], + trans_data: CompositeGenericTransform, + ax: Axes, + alpha: float | None = None, + cmap: ListedColormap | LinearSegmentedColormap | None = None, + zorder: int = 0, + norm: Normalize | None = None, + interpolation: str | None = None, +) -> matplotlib.image.AxesImage: + # ``extent`` uses mpl's pixel-grid convention; world placement happens via + # ``set_transform(trans_data)`` afterwards. + image_extent = (-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5) + # ``alpha`` is applied only when no cmap is set, so RGBA arrays already + # carrying per-pixel alpha (e.g. datashader output) are not double-attenuated. + imshow_kwargs: dict[str, Any] = {"zorder": zorder, "extent": image_extent, "norm": norm} + if not cmap and alpha is not None: + imshow_kwargs["alpha"] = alpha + else: + imshow_kwargs["cmap"] = cmap + if interpolation is not None: + imshow_kwargs["interpolation"] = interpolation + im = ax.imshow(array, **imshow_kwargs) + im.set_transform(trans_data) + return im + + +def _compute_datashader_canvas_params( + x_ext: list[Any], + y_ext: list[Any], + fig_params: FigParams, +) -> tuple[Any, Any, list[Any], list[Any], Any]: + """Compute datashader canvas dimensions from spatial extents. + + Shared logic used by both the dask-based and pandas-based entry points. + """ + # Compute canvas size in pixels, capped at the figure's display resolution. + # Using np.max ensures the canvas never exceeds display pixels on either axis, + # preventing pixel-based operations (spread, line_width) from being downscaled + # to sub-pixel size when the data aspect ratio differs from the figure's. + plot_width = x_ext[1] - x_ext[0] + plot_height = y_ext[1] - y_ext[0] + plot_width_px = int(round(fig_params.fig.get_size_inches()[0] * fig_params.fig.dpi)) + plot_height_px = int(round(fig_params.fig.get_size_inches()[1] * fig_params.fig.dpi)) + factor: float + factor = np.max([plot_width / plot_width_px, plot_height / plot_height_px]) + plot_width = int(np.round(plot_width / factor)) + plot_height = int(np.round(plot_height / factor)) + + return plot_width, plot_height, x_ext, y_ext, factor + + +def _get_extent_and_range_for_datashader_canvas( + spatial_element: SpatialElement, + coordinate_system: str, + fig_params: FigParams, +) -> tuple[Any, Any, list[Any], list[Any], Any]: + extent = _fast_extent(spatial_element, coordinate_system) + x_ext = [float(extent["x"][0]), float(extent["x"][1])] + y_ext = [float(extent["y"][0]), float(extent["y"][1])] + return _compute_datashader_canvas_params(x_ext, y_ext, fig_params) + + +def _datashader_canvas_from_dataframe( + df: pd.DataFrame, + fig_params: FigParams, +) -> tuple[Any, Any, list[Any], list[Any], Any]: + """Compute datashader canvas params directly from a pandas DataFrame. + + Avoids the overhead of ``get_extent()`` (which requires a dask-backed + SpatialElement) by reading min/max from the already-materialised data. + """ + if len(df) == 0: + # Empty input (e.g., a bounding_box_query with no overlap) — caller + # should short-circuit; return zero-sized canvas params as a sentinel. + return 0, 0, [0.0, 0.0], [0.0, 0.0], 1.0 + x_ext = [float(df["x"].min()), float(df["x"].max())] + y_ext = [float(df["y"].min()), float(df["y"].max())] + return _compute_datashader_canvas_params(x_ext, y_ext, fig_params) + + +def _prepare_transformation( + element: DataArray | GeoDataFrame | dask.dataframe.core.DataFrame, + coordinate_system: str, + ax: Axes | None = None, +) -> tuple[ + matplotlib.transforms.Affine2D, + matplotlib.transforms.CompositeGenericTransform | None, +]: + trans = get_transformation(element, get_all=True)[coordinate_system] + affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) + trans = mtransforms.Affine2D(matrix=affine_trans) + trans_data = trans + ax.transData if ax is not None else None + + return trans, trans_data + + +def _create_image_from_datashader_result( + ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]], + factor: float, + ax: Axes, + x_min: float = 0.0, + y_min: float = 0.0, +) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]: + # create SpatialImage from datashader output to get it back to original size + rgba_image_data = ds_result.copy() if isinstance(ds_result, np.ndarray) else ds_result.to_numpy().base + rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1)) + transformation: Scale | TransformSequence = Scale([1, factor, factor], ("c", "y", "x")) + if x_min != 0.0 or y_min != 0.0: + # Canvas pixel (0, 0) corresponds to world (x_min, y_min). Without this + # translation the rgba would render at the world origin instead of at + # the element's actual position. + transformation = TransformSequence([transformation, Translation([x_min, y_min], ("x", "y"))]) + rgba_image = Image2DModel.parse( + rgba_image_data, + dims=("c", "y", "x"), + transformations={"global": transformation}, + ) + + _, trans_data = _prepare_transformation(rgba_image, "global", ax) + + rgba_image = np.transpose(rgba_image.data.compute(), (1, 2, 0)) # type: ignore[attr-defined] + rgba_image = ma.masked_array(rgba_image) # type conversion for mypy + + return rgba_image, trans_data + + +# --------------------------------------------------------------------------- +# Aggregation / shading helpers +# --------------------------------------------------------------------------- + + +def _datashader_aggregate_with_function( + reduction: _DsReduction | None, + cvs: Canvas, + spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame, + col_for_color: str | None, + element_type: Literal["points", "shapes"], +) -> DataArray: + """ + When shapes or points are colored by a continuous value during rendering with datashader. + + This function performs the aggregation using the user-specified reduction method. + + Parameters + ---------- + reduction: String specifying the datashader reduction method to be used. + If None, "sum" is used as default. + cvs: Canvas object previously created with ds.Canvas() + spatial_element: geo or dask dataframe with the shapes or points to render + col_for_color: name of the column containing the values by which to color + element_type: tells us if this function is called from _render_shapes() or _render_points() + """ + if reduction is None: + reduction = "sum" + + try: + reduction_function = _DS_REDUCTION_FUNCS[reduction](column=col_for_color) + except KeyError as e: + raise ValueError( + f"Reduction '{reduction}' is not supported. Please use one of: {', '.join(_DS_REDUCTION_FUNCS.keys())}." + ) from e + + element_function_map = { + "points": cvs.points, + "shapes": cvs.polygons, + } + + try: + element_function = element_function_map[element_type] + except KeyError as e: + raise ValueError(f"Element type '{element_type}' is not supported. Use 'points' or 'shapes'.") from e + + if element_type == "points": + points_aggregate = element_function(spatial_element, "x", "y", agg=reduction_function) + if reduction == "any": + # replace False/True by nan/1 + points_aggregate = points_aggregate.astype(int) + points_aggregate = points_aggregate.where(points_aggregate > 0) + return points_aggregate + + # is shapes + return element_function(spatial_element, geometry="geometry", agg=reduction_function) + + +def _datashader_get_how_kw_for_spread( + reduction: _DsReduction | None, +) -> str: + # Get the best input for the how argument of ds.tf.spread(), needed for numerical values + reduction = reduction or "sum" + + reduction_to_how_map = { + "sum": "add", + "mean": "source", + "any": "source", + "count": "add", + "std": "source", + "var": "source", + "max": "max", + "min": "min", + } + + if reduction not in reduction_to_how_map: + raise ValueError( + f"Reduction {reduction} is not supported, please use one of the following: sum, mean, any, count" + ", std, var, max, min." + ) + + return reduction_to_how_map[reduction] + + +def _apply_cmap_alpha_to_datashader_result( + result: Any, + agg: DataArray, + cmap: str | list[str] | Colormap, + span: list[float] | tuple[float, float] | None, +) -> Any: + """Apply the colormap's alpha channel to a datashader RGBA result. + + Datashader ignores the per-entry alpha channel of matplotlib colormaps, + so pixels that the cmap marks as transparent (alpha=0) are rendered + opaque. This function post-processes the shaded RGBA output to restore + the cmap's intended transparency. See :issue:`376`. + """ + if not isinstance(cmap, Colormap): + return result + + # Quick check: does this cmap have any transparent entries? + test_vals = np.linspace(0, 1, min(cmap.N, 256)) + cmap_alphas = cmap(test_vals)[:, 3] + if np.all(cmap_alphas >= 1.0): + return result + + # Get or ensure we have an (H, W, 4) uint8 array + if hasattr(result, "values"): + # datashader Image — uint32 packed, convert via to_numpy() + rgba = result.to_numpy().base + if rgba is None: + return result + else: + rgba = result + + if rgba.ndim != 3 or rgba.shape[2] != 4: + return result + + # Normalise aggregate values to [0, 1] using the same span datashader used + agg_vals = agg.values.astype(np.float64) + valid = np.isfinite(agg_vals) + if not valid.any(): + return result + + if span is not None: + lo, hi = float(span[0]), float(span[1]) + else: + lo = float(np.nanmin(agg_vals)) + hi = float(np.nanmax(agg_vals)) + + if hi <= lo or not np.isfinite(lo) or not np.isfinite(hi): + return result + + normed = np.clip((agg_vals - lo) / (hi - lo), 0.0, 1.0) + + # Look up cmap alpha for each pixel + desired_alpha = cmap(normed)[:, :, 3] + + # Zero out pixels where the cmap wants transparency + transparent = valid & (desired_alpha < 1.0) + if transparent.any(): + # Scale the existing alpha by the cmap's alpha + rgba[transparent, 3] = (rgba[transparent, 3].astype(np.float32) * desired_alpha[transparent]).astype(np.uint8) + + return result + + +def _datashader_map_aggregate_to_color( + agg: DataArray, + cmap: str | list[str] | ListedColormap, + color_key: list[str] | dict[str, str] | None = None, + min_alpha: float = 40, + span: None | list[float] = None, + clip: bool = True, + how: str = "linear", +) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]: + """ds.tf.shade() part, ensuring correct clipping behavior. + + If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results. + This ensures the correct clipping behavior, because else datashader would always automatically clip. + + ``how`` controls the count-to-color mapping passed to :func:`datashader.transfer_functions.shade` + (``"linear"`` by default; ``"log"``/``"cbrt"``/``"eq_hist"`` compress dynamic range). The split-shade + branch used for ``norm.clip=False`` always uses ``"linear"`` since per-segment shading would otherwise + interact poorly with rank-based mappings. + """ + if not clip and isinstance(cmap, Colormap) and span is not None: + # in case we use datashader together with a Normalize object where clip=False + # why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372 + agg_in = agg.where((agg >= span[0]) & (agg <= span[1])) + img_in = ds.tf.shade( + agg_in, + cmap=cmap, + span=(span[0], span[1]), + how="linear", + color_key=color_key, + min_alpha=min_alpha, + ) + + agg_under = agg.where(agg < span[0]) + img_under = ds.tf.shade( + agg_under, + cmap=[to_hex(cmap.get_under())[:7]], + min_alpha=min_alpha, + color_key=color_key, + ) + + agg_over = agg.where(agg > span[1]) + img_over = ds.tf.shade( + agg_over, + cmap=[to_hex(cmap.get_over())[:7]], + min_alpha=min_alpha, + color_key=color_key, + ) + + # stack the 3 arrays manually: go from under, through in to over and always overlay the values where alpha=0 + stack = img_under.to_numpy().base + if stack is None: + stack = img_in.to_numpy().base + else: + stack[stack[:, :, 3] == 0] = img_in.to_numpy().base[stack[:, :, 3] == 0] + img_over = img_over.to_numpy().base + if img_over is not None: + stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0] + + return _apply_cmap_alpha_to_datashader_result(stack, agg, cmap, span) + + result = ds.tf.shade( + agg, + cmap=cmap, + color_key=color_key, + min_alpha=min_alpha, + span=span, + how=how, + ) + return _apply_cmap_alpha_to_datashader_result(result, agg, cmap, span) diff --git a/src/spatialdata_plot/pl/_geometry.py b/src/spatialdata_plot/pl/_geometry.py new file mode 100644 index 00000000..70b2dcf3 --- /dev/null +++ b/src/spatialdata_plot/pl/_geometry.py @@ -0,0 +1,519 @@ +"""Shape geometry and matplotlib patch construction (extracted from utils.py, see #696).""" + +from __future__ import annotations + +import math +from typing import Any + +import matplotlib.patches as mpatches +import matplotlib.path as mpath +import numpy as np +import pandas as pd +import shapely +from geopandas import GeoDataFrame +from matplotlib import colors +from matplotlib.collections import PatchCollection +from matplotlib.colors import ColorConverter, Normalize +from scipy.spatial import ConvexHull +from shapely.errors import GEOSException + +from spatialdata_plot._logging import logger +from spatialdata_plot.pl.render_params import ShapesRenderParams +from spatialdata_plot.pl.utils import _extract_scalar_value + + +def _get_centroid_of_pathpatch(pathpatch: mpatches.PathPatch) -> tuple[float, float]: + # Extract the vertices from the PathPatch + path = pathpatch.get_path() + vertices = path.vertices + x = vertices[:, 0] + y = vertices[:, 1] + + area = 0.5 * np.sum(x[:-1] * y[1:] - x[1:] * y[:-1]) + + # Calculate the centroid coordinates + centroid_x = np.sum((x[:-1] + x[1:]) * (x[:-1] * y[1:] - x[1:] * y[:-1])) / (6 * area) + centroid_y = np.sum((y[:-1] + y[1:]) * (x[:-1] * y[1:] - x[1:] * y[:-1])) / (6 * area) + + return centroid_x, centroid_y + + +def _scale_pathpatch_around_centroid(pathpatch: mpatches.PathPatch, scale_factor: float) -> None: + scale_value = _extract_scalar_value(scale_factor, default=1.0) + centroid = _get_centroid_of_pathpatch(pathpatch) + vertices = pathpatch.get_path().vertices + scaled_vertices = np.array([centroid + (vertex - centroid) * scale_value for vertex in vertices]) + pathpatch.get_path().vertices = scaled_vertices + + +def _normalize_geom(geom: Any) -> Any: + """Canonicalize ring orientation so matplotlib's fill rules render holes correctly. + + ``shapely.normalize`` (shapely>=2) is preferred; falls back to ``geom.normalize()``. + None/empty geometries and geometries that fail to normalize are returned unchanged. + """ + if geom is None or getattr(geom, "is_empty", False): + return geom + normalize_func = getattr(shapely, "normalize", None) + if callable(normalize_func): + try: + return normalize_func(geom) + except (GEOSException, TypeError, ValueError): + return geom + if hasattr(geom, "normalize"): + try: + return geom.normalize() + except (GEOSException, TypeError, ValueError): + return geom + return geom + + +def _make_patch_from_multipolygon(mp: shapely.MultiPolygon) -> list[mpatches.PathPatch]: + """ + Create PathPatches from a MultiPolygon, preserving holes robustly. + + This follows the same strategy as GeoPandas' internal Polygon plotting: + each (multi)polygon part becomes a compound Path composed of the exterior + ring and all interior rings. Orientation is handled by prior geometry + normalization rather than manual ring reversal. + """ + patches: list[mpatches.PathPatch] = [] + + for poly in mp.geoms: + if poly.is_empty: + continue + + # Ensure 2D vertices in case geometries carry Z + exterior = np.asarray(poly.exterior.coords)[..., :2] + interiors = [np.asarray(ring.coords)[..., :2] for ring in poly.interiors] + + if len(interiors) == 0: + # Simple polygon without holes + patches.append(mpatches.Polygon(exterior, closed=True)) + continue + + # Build a compound path: exterior + all interior rings + compound_path = mpath.Path.make_compound_path( + mpath.Path(exterior, closed=True), + *[mpath.Path(ring, closed=True) for ring in interiors], + ) + patches.append(mpatches.PathPatch(compound_path)) + + return patches + + +def _build_shape_patches( + shapes: GeoDataFrame, + scale: float, +) -> tuple[list[mpatches.Patch], list[int], int]: + """Build matplotlib patches from shape geometries, once. + + Patch geometry is independent of colour/alpha, so it can be built a single time and + shared across the fill and outline ``PatchCollection``s in :func:`_render_shapes` + instead of being rebuilt per layer (the dominant cost for shape elements). + + Returns + ------- + patches + The matplotlib patches (a MultiPolygon expands to several patches). + patch_row_idx + For each patch, the index into the empty-filtered, re-indexed shapes — used to + look up the per-shape colour. + n_shapes + Number of shapes after empty filtering (used for the single-colour broadcast rule). + """ + df: GeoDataFrame | pd.DataFrame = shapes if isinstance(shapes, GeoDataFrame) else pd.DataFrame(shapes) + if "geometry" not in df.columns: + return [], [], 0 + + # Normalize ring orientation, then drop empty geometries (both vectorized; fall + # back to per-geometry normalization only if the bulk call rejects an input). + geom_array = df["geometry"].to_numpy() + try: + geom_array = shapely.normalize(geom_array) + except (GEOSException, TypeError, ValueError): + geom_array = np.array([_normalize_geom(g) for g in geom_array], dtype=object) + keep = ~shapely.is_empty(geom_array) + geoms = geom_array[keep] + radii = df["radius"].to_numpy()[keep] if "radius" in df.columns else None + + # Resolve the scale scalar once instead of per shape. + scale_value = _extract_scalar_value(scale, default=1.0) + + patches: list[mpatches.Patch] = [] + patch_row_idx: list[int] = [] + for i, geom in enumerate(geoms): + geom_type = geom.geom_type + if geom_type == "Polygon": + coords = np.asarray(geom.exterior.coords) + centroid = np.mean(coords, axis=0) + scaled = centroid + (coords - centroid) * scale_value + patches.append(mpatches.Polygon(scaled, closed=True)) + patch_row_idx.append(i) + elif geom_type == "MultiPolygon": + for m in _make_patch_from_multipolygon(geom): + _scale_pathpatch_around_centroid(m, scale_value) + patches.append(m) + patch_row_idx.append(i) + elif geom_type == "Point": + radius_value = _extract_scalar_value(radii[i], default=0.0) if radii is not None else 0.0 + patches.append(mpatches.Circle((geom.x, geom.y), radius=radius_value * scale_value)) + patch_row_idx.append(i) + + return patches, patch_row_idx, len(geoms) + + +def _get_collection_shape( + shapes: list[GeoDataFrame], + c: Any, + s: float, + norm: Any, + render_params: ShapesRenderParams, + fill_alpha: None | float = None, + outline_alpha: None | float = None, + outline_color: None | str | list[float] | np.ndarray = "white", + linewidth: float = 0.0, + prebuilt_patches: tuple[list[mpatches.Patch], list[int], int] | None = None, + **kwargs: Any, +) -> PatchCollection: + """ + Build a PatchCollection for shapes with correct handling of. + + - continuous numeric vectors with NaNs, + - per-row RGBA arrays, + - a single color or a list of color specs. + + Only NaNs are painted with na_color; finite values are mapped via norm+cmap. + + .. note:: + When ``outline_color`` is passed as an ``(N, 4)`` RGBA array of dtype ``float``, + its alpha channel is mutated in place to apply ``outline_alpha``. Pass a copy + if you need to retain the original buffer. + """ + cmap = kwargs["cmap"] + + # Resolve na color once + na_rgba = colors.to_rgba(render_params.cmap_params.na_color.get_hex_with_alpha()) + + # Try to interpret c as numpy array + c_arr = np.asarray(c) + fill_c: np.ndarray + + def _as_rgba_array(x: Any) -> np.ndarray: + return np.asarray(ColorConverter().to_rgba_array(x)) + + # Case A: per-row numeric colors given as Nx3 or Nx4 float array + if ( + c_arr.ndim == 2 + and c_arr.shape[0] == len(shapes) + and c_arr.shape[1] in (3, 4) + and np.issubdtype(c_arr.dtype, np.number) + ): + fill_c = _as_rgba_array(c_arr) + + # Case B: continuous numeric vector len == n_shapes (possibly with NaNs) + elif c_arr.ndim == 1 and len(c_arr) == len(shapes) and np.issubdtype(c_arr.dtype, np.number): + finite_mask = np.isfinite(c_arr) + + # Select or build a normalization that ignores NaNs for scaling + if isinstance(norm, Normalize): + used_norm: Normalize = norm + else: + if finite_mask.any(): + vmin = float(np.nanmin(c_arr[finite_mask])) + vmax = float(np.nanmax(c_arr[finite_mask])) + if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: + vmin, vmax = 0.0, 1.0 + else: + vmin, vmax = 0.0, 1.0 + used_norm = colors.Normalize(vmin=vmin, vmax=vmax, clip=False) + + # Map finite values through cmap(norm(.)); NaNs get na_color + fill_c = np.empty((len(c_arr), 4), dtype=float) + fill_c[:] = na_rgba + if finite_mask.any(): + fill_c[finite_mask] = cmap(used_norm(c_arr[finite_mask])) + + elif c_arr.ndim == 1 and len(c_arr) == len(shapes) and c_arr.dtype == object: + # Split into numeric vs color-like + c_series = pd.Series(c_arr, copy=False) + num = pd.to_numeric(c_series, errors="coerce").to_numpy() + is_num = np.isfinite(num) + + # init with na color + fill_c = np.empty((len(c_series), 4), dtype=float) + fill_c[:] = na_rgba + + # numeric entries via cmap(norm) + if is_num.any(): + if isinstance(norm, Normalize): + used_norm = norm + else: + vmin = float(np.nanmin(num[is_num])) if is_num.any() else 0.0 + vmax = float(np.nanmax(num[is_num])) if is_num.any() else 1.0 + if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: + vmin, vmax = 0.0, 1.0 + used_norm = colors.Normalize(vmin=vmin, vmax=vmax, clip=False) + fill_c[is_num] = cmap(used_norm(num[is_num])) + + # non-numeric, non-NaN entries as explicit colors + non_numeric_color_mask = (~is_num) & c_series.notna().to_numpy() + if non_numeric_color_mask.any(): + fill_c[non_numeric_color_mask] = ColorConverter().to_rgba_array(c_series[non_numeric_color_mask].tolist()) + + # Case C: single color or list of color-like specs (strings or tuples) + else: + fill_c = _as_rgba_array(c) + + # Apply optional fill alpha without destroying existing transparency + if fill_alpha is not None: + nonzero_alpha = fill_c[..., -1] > 0 + fill_c[nonzero_alpha, -1] = fill_alpha + + # Outline handling + if outline_alpha and outline_alpha > 0.0: + outline_arr = np.asarray(outline_color) if not isinstance(outline_color, str) else None + if outline_arr is not None and outline_arr.ndim == 2 and outline_arr.shape == (len(shapes), 4): + # Per-shape RGBA array. Mutate in place when already float so we don't allocate twice + # on the hot path; otherwise upcast to a fresh float buffer. + outline_c_array = outline_arr if outline_arr.dtype == float else outline_arr.astype(float) + else: + outline_c_array = _as_rgba_array(outline_color) + outline_c_array[..., -1] = outline_alpha + outline_c = outline_c_array.tolist() + else: + outline_c = [None] * fill_c.shape[0] + + # Build (or reuse) the matplotlib patches. Geometry is colour-independent, so the + # caller can build it once via `_build_shape_patches` and share it across the fill + # and outline collections instead of rebuilding it on every call. + patches, patch_row_idx, n_shapes = ( + prebuilt_patches if prebuilt_patches is not None else _build_shape_patches(shapes, s) + ) + + if not patches: + return PatchCollection([]) + + # Expand the per-shape fill colours to per-patch (a MultiPolygon owns several + # patches). Preserve the single-colour broadcast used for multi-shape elements. + broadcast_single = n_shapes > 1 and len(fill_c) == 1 + patch_fill = np.repeat(fill_c, len(patches), axis=0) if broadcast_single else fill_c[patch_row_idx] + + return PatchCollection( + patches, + snap=False, + lw=linewidth, + facecolor=patch_fill, + edgecolor=None if all(o is None for o in outline_c) else outline_c, + **kwargs, + ) + + +def _validate_polygons(shapes: GeoDataFrame) -> GeoDataFrame: + """ + Convert Polygons with holes to MultiPolygons to keep interior rings during rendering. + + Parameters + ---------- + shapes + GeoDataFrame containing a `geometry` column. + + Returns + ------- + GeoDataFrame + ``shapes`` with holed Polygons converted to MultiPolygons. + """ + if "geometry" not in shapes: + return shapes + + converted_count = 0 + for idx, geom in shapes["geometry"].items(): + if isinstance(geom, shapely.Polygon) and len(geom.interiors) > 0: + shapes.at[idx, "geometry"] = shapely.MultiPolygon([geom]) + converted_count += 1 + + if converted_count > 0: + logger.info( + "Converted %d Polygon(s) with holes to MultiPolygon(s) for correct rendering.", + converted_count, + ) + + return shapes + + +def _convert_shapes( + shapes: GeoDataFrame, + target_shape: str, + max_extent: float, + warn_above_extent_fraction: float = 0.5, +) -> GeoDataFrame: + """Convert shapes in a GeoDataFrame to the target_shape, using positional indexing.""" + if warn_above_extent_fraction < 0.0 or warn_above_extent_fraction > 1.0: + warn_above_extent_fraction = 0.5 + warn_shape_size = False + + # work on a copy with a clean positional index + shapes = shapes.reset_index(drop=True).copy() + + def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + verts = [ + ( + center.x + radius * math.cos(math.radians(a)), + center.y + radius * math.sin(math.radians(a)), + ) + for a in range(30, 390, 60) + ] + return shapely.Polygon(verts), None + + def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + verts = [ + ( + center.x + radius * math.cos(math.radians(a)), + center.y + radius * math.sin(math.radians(a)), + ) + for a in range(45, 360, 90) + ] + return shapely.Polygon(verts), None + + def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]: + return center, radius + + def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]: + coords = np.array(polygon.exterior.coords) + hull_pts = coords[ConvexHull(coords).vertices] + center = np.mean(hull_pts, axis=0) + radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1))) + nonlocal warn_shape_size + if 2 * radius > max_extent * warn_above_extent_fraction: + warn_shape_size = True + return shapely.Point(center), radius + + def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: + c, r = _polygon_to_circle(polygon) + return _circle_to_hexagon(c, r) + + def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: + c, r = _polygon_to_circle(polygon) + return _circle_to_square(c, r) + + def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]: + pts = [] + for poly in multipolygon.geoms: + pts.extend(poly.exterior.coords) + pts_array = np.array(pts) + hull_pts = pts_array[ConvexHull(pts_array).vertices] + center = np.mean(hull_pts, axis=0) + radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1))) + nonlocal warn_shape_size + if 2 * radius > max_extent * warn_above_extent_fraction: + warn_shape_size = True + return shapely.Point(center), radius + + def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: + c, r = _multipolygon_to_circle(multipolygon) + return _circle_to_hexagon(c, r) + + def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: + c, r = _multipolygon_to_circle(multipolygon) + return _circle_to_square(c, r) + + # choose conversion methods + conversion_methods: dict[str, Any] + if target_shape == "circle": + conversion_methods = { + "Point": _circle_to_circle, + "Polygon": _polygon_to_circle, + "MultiPolygon": _multipolygon_to_circle, + } + elif target_shape == "hex": + conversion_methods = { + "Point": _circle_to_hexagon, + "Polygon": _polygon_to_hexagon, + "MultiPolygon": _multipolygon_to_hexagon, + } + elif target_shape == "visium_hex": + # estimate hex radius from point spacing when possible + point_centers = [] + non_point_count = 0 + for geom in shapes.geometry: + if geom.geom_type == "Point": + point_centers.append((geom.x, geom.y)) + else: + non_point_count += 1 + if non_point_count > 0: + logger.warning("visium_hex supports Points best. Non-Point geometries will use regular hex conversion.") + if len(point_centers) >= 2: + centers = np.array(point_centers, dtype=float) + # pairwise min distance + dmin = np.inf + for i in range(len(centers)): + diffs = centers[i + 1 :] - centers[i] + if diffs.size: + d = np.min(np.linalg.norm(diffs, axis=1)) + dmin = min(dmin, d) + if not np.isfinite(dmin) or dmin <= 0: + # fallback + conversion_methods = { + "Point": _circle_to_hexagon, + "Polygon": _polygon_to_hexagon, + "MultiPolygon": _multipolygon_to_hexagon, + } + else: + hex_radius = dmin / math.sqrt(3.0) + + def _circle_to_visium_hex(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + return _circle_to_hexagon(center, hex_radius) + + def _polygon_to_visium_hex(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: + return _polygon_to_hexagon(polygon) + + def _multipolygon_to_visium_hex(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: + return _multipolygon_to_hexagon(multipolygon) + + conversion_methods = { + "Point": _circle_to_visium_hex, + "Polygon": _polygon_to_visium_hex, + "MultiPolygon": _multipolygon_to_visium_hex, + } + else: + conversion_methods = { + "Point": _circle_to_hexagon, + "Polygon": _polygon_to_hexagon, + "MultiPolygon": _multipolygon_to_hexagon, + } + else: + conversion_methods = { + "Point": _circle_to_square, + "Polygon": _polygon_to_square, + "MultiPolygon": _multipolygon_to_square, + } + + # ensure radius column exists if needed + if "radius" not in shapes.columns: + shapes["radius"] = np.nan + + # convert all geometries using positional indexing + for i in range(len(shapes)): + geom = shapes.geometry.iloc[i] + gtype = geom.geom_type + if gtype == "Point": + r = shapes["radius"].iloc[i] + r = float(r) if np.isfinite(r) else 0.0 + converted, radius = conversion_methods["Point"](geom, r) # type: ignore[arg-type] + elif gtype == "Polygon": + converted, radius = conversion_methods["Polygon"](geom) # type: ignore[arg-type] + elif gtype == "MultiPolygon": + converted, radius = conversion_methods["MultiPolygon"](geom) # type: ignore[arg-type] + else: + raise ValueError(f"Converting shape {gtype} to {target_shape} is not supported.") + shapes.at[i, "geometry"] = converted + if radius is not None: + shapes.at[i, "radius"] = radius + + if warn_shape_size: + logger.info( + f"At least one converted shape spans >= {warn_above_extent_fraction * 100:.0f}% of the " + "original total bound. Results may be suboptimal." + ) + + return shapes diff --git a/src/spatialdata_plot/pl/_validate.py b/src/spatialdata_plot/pl/_validate.py new file mode 100644 index 00000000..ad28a613 --- /dev/null +++ b/src/spatialdata_plot/pl/_validate.py @@ -0,0 +1,1440 @@ +"""Parameter validation and type checking for render_* / show (extracted from utils.py, see #696).""" + +from __future__ import annotations + +import warnings +from collections import Counter +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import Any, Literal + +import numpy as np +import pandas as pd +import spatialdata as sd +from anndata import AnnData +from matplotlib.axes import Axes +from matplotlib.colors import Colormap, Normalize +from matplotlib.figure import Figure +from spatialdata import ( + SpatialData, + get_element_annotators, +) +from spatialdata.models import get_table_keys +from xarray import DataArray, DataTree + +from spatialdata_plot._logging import logger +from spatialdata_plot.pl._color import _get_colors_for_categorical_obs, _is_color_like, _prepare_cmap_norm +from spatialdata_plot.pl.render_params import ( + CmapParams, + Color, + ColorLike, + _FontSize, + _FontWeight, +) + +_GROUPS_IGNORED_WARNING = "Parameter 'groups' is ignored when 'color' is a literal color, not a column name." + + +def _check_obs_var_shadow( + sdata: SpatialData | None, + element_name: str | None, + value_to_plot: str | None, + table_name: str | None, +) -> None: + """Raise if ``value_to_plot`` exists in both ``table.obs.columns`` and ``table.var_names``. + + Upstream ``_get_table_origins`` uses an ``elif`` chain, so a key that lives in + both locations is silently resolved to ``obs`` — masking the user's likely + intent of plotting gene expression. Catch this here before any value fetch. + Any ``None`` parameter short-circuits the check. + """ + if ( + value_to_plot is None + or table_name is None + or element_name is None + or sdata is None + or table_name not in sdata.tables + ): + return + if table_name not in get_element_annotators(sdata, element_name): + return + table = sdata.tables[table_name] + if value_to_plot in table.obs.columns and value_to_plot in table.var_names: + raise ValueError( + f"`color={value_to_plot!r}` is ambiguous: it exists in both " + f"`table[{table_name!r}].obs.columns` and `table[{table_name!r}].var_names`. " + "Rename one of them (or drop the obs column) so the intended source is unambiguous." + ) + + +def _gate_palette_and_groups( + element_params: dict[str, Any], + param_dict: dict[str, Any], +) -> None: + """Set palette/groups on element_params only when col_for_color is present, else warn.""" + has_col = element_params.get("col_for_color") is not None + element_params["palette"] = param_dict["palette"] if has_col else None + if not has_col and param_dict["groups"] is not None: + logger.warning(_GROUPS_IGNORED_WARNING) + element_params["groups"] = param_dict["groups"] if has_col else None + + +def _validate_show_parameters( + coordinate_systems: list[str] | str | None, + legend_fontsize: int | float | _FontSize | None, + legend_fontweight: int | _FontWeight, + legend_loc: str | None, + legend_fontoutline: int | None, + na_in_legend: bool, + colorbar: bool, + colorbar_params: dict[str, object] | None, + wspace: float | None, + hspace: float, + ncols: int, + frameon: bool | None, + figsize: tuple[float, float] | None, + dpi: int | None, + fig: Figure | None, + title: list[str] | str | None, + pad_extent: int | float, + ax: list[Axes] | Axes | None, + return_ax: bool, + save: str | Path | None, + show: bool | None, + scalebar_dx: float | None, + scalebar_units: str, + scalebar_params: dict[str, Any] | None, + legend_params: dict[str, Any] | None, +) -> None: + if coordinate_systems is not None and not isinstance(coordinate_systems, list | str): + raise TypeError("Parameter 'coordinate_systems' must be a string or a list of strings.") + + font_weights = ["light", "normal", "medium", "semibold", "bold", "heavy", "black"] + if legend_fontweight is not None and ( + not isinstance(legend_fontweight, int | str) + or (isinstance(legend_fontweight, str) and legend_fontweight not in font_weights) + ): + readable_font_weights = ", ".join(font_weights[:-1]) + ", or " + font_weights[-1] + raise TypeError( + "Parameter 'legend_fontweight' must be an integer or one of", + f"the following strings: {readable_font_weights}.", + ) + + font_sizes = [ + "xx-small", + "x-small", + "small", + "medium", + "large", + "x-large", + "xx-large", + ] + + if legend_fontsize is not None and ( + not isinstance(legend_fontsize, int | float | str) + or (isinstance(legend_fontsize, str) and legend_fontsize not in font_sizes) + ): + readable_font_sizes = ", ".join(font_sizes[:-1]) + ", or " + font_sizes[-1] + raise TypeError( + "Parameter 'legend_fontsize' must be an integer, a float, or ", + f"one of the following strings: {readable_font_sizes}.", + ) + + if legend_loc is not None and not isinstance(legend_loc, str): + raise TypeError("Parameter 'legend_loc' must be a string.") + + if legend_fontoutline is not None and not isinstance(legend_fontoutline, int): + raise TypeError("Parameter 'legend_fontoutline' must be an integer.") + + if not isinstance(na_in_legend, bool): + raise TypeError("Parameter 'na_in_legend' must be a boolean.") + + if not isinstance(colorbar, bool): + raise TypeError("Parameter 'colorbar' must be a boolean.") + + if colorbar_params is not None and not isinstance(colorbar_params, dict): + raise TypeError("Parameter 'colorbar_params' must be a dictionary or None.") + + if wspace is not None and not isinstance(wspace, float): + raise TypeError("Parameter 'wspace' must be a float.") + + if not isinstance(hspace, float): + raise TypeError("Parameter 'hspace' must be a float.") + + if not isinstance(ncols, int): + raise TypeError("Parameter 'ncols' must be an integer.") + + if frameon is not None and not isinstance(frameon, bool): + raise TypeError("Parameter 'frameon' must be a boolean.") + + if figsize is not None and ( + not isinstance(figsize, tuple | list | np.ndarray) + or len(figsize) != 2 + or not all(isinstance(x, int | float) and not isinstance(x, bool) for x in figsize) + ): + raise TypeError("Parameter 'figsize' must be a tuple, list, or numpy array of two numbers.") + + if dpi is not None and not isinstance(dpi, int): + raise TypeError("Parameter 'dpi' must be an integer.") + + if fig is not None and not isinstance(fig, Figure): + raise TypeError("Parameter 'fig' must be a matplotlib.figure.Figure.") + + if title is not None and not isinstance(title, list | str): + raise TypeError("Parameter 'title' must be a string or a list of strings.") + + if not isinstance(pad_extent, int | float): + raise TypeError("Parameter 'pad_extent' must be numeric.") + + if ax is not None and not isinstance(ax, Axes | list): + raise TypeError("Parameter 'ax' must be a matplotlib.axes.Axes or a list of Axes.") + + if not isinstance(return_ax, bool): + raise TypeError("Parameter 'return_ax' must be a boolean.") + + if save is not None and not isinstance(save, str | Path): + raise TypeError("Parameter 'save' must be a string or a pathlib.Path.") + + if show is not None and not isinstance(show, bool): + raise TypeError("Parameter 'show' must be a boolean or None.") + + if scalebar_dx is not None: + if not isinstance(scalebar_dx, int | float) or isinstance(scalebar_dx, bool): + raise TypeError("Parameter 'scalebar_dx' must be a number or None.") + if scalebar_dx <= 0: + raise ValueError("Parameter 'scalebar_dx' must be > 0.") + if not isinstance(scalebar_units, str): + raise TypeError("Parameter 'scalebar_units' must be a string.") + + if scalebar_params is not None and not isinstance(scalebar_params, dict): + raise TypeError("Parameter 'scalebar_params' must be a dictionary or None.") + + if legend_params is not None: + if not isinstance(legend_params, dict): + raise TypeError("Parameter 'legend_params' must be a dictionary or None.") + # `loc` is matplotlib.Legend's native key; `location` aligns with colorbar_params / scalebar_params. + allowed_legend_keys = {"loc", "location", "fontsize", "fontweight", "fontoutline", "na_in_legend"} + unknown = set(legend_params) - allowed_legend_keys + if unknown: + raise ValueError( + f"Unknown legend_params key(s): {sorted(unknown)}. Allowed keys: {sorted(allowed_legend_keys)}." + ) + + +def _check_color_column_collision( + sdata: SpatialData, + elements: list[str], + color: str, + element_type: str, +) -> None: + """Raise if ``color`` is a color-like string that also names a column in the element or its tables.""" + matches: list[str] = [] + for el in elements: + if element_type in {"shapes", "points"}: + try: + el_cols = sdata[el].columns + except (KeyError, AttributeError): + el_cols = () + if color in el_cols: + matches.append(f"element '{el}'") + continue + try: + tables = get_element_annotators(sdata, el) + except (KeyError, ValueError): + tables = set() + for t in tables: + adata = sdata[t] + if color in adata.obs.columns or color in adata.var_names: + matches.append(f"table '{t}' (annotating '{el}')") + break + if matches: + locations = ", ".join(matches) + raise ValueError( + f"`color={color!r}` is ambiguous: it is a valid matplotlib color name AND a column " + f"name in {locations}. Disambiguate by either passing an unambiguous color form " + f"(hex string like '#ffa500' or an RGB(A) tuple), or by renaming the column." + ) + + +def _resolve_gene_symbols( + adata: AnnData, + col_for_color: str, + gene_symbols: str, +) -> str: + """Resolve a gene symbol to its var_name using an alternate var column. + + Mimics scanpy's ``gene_symbols`` behaviour: look up *col_for_color* in + ``adata.var[gene_symbols]`` and return the corresponding ``var_name`` + (i.e. the var index value). + """ + if gene_symbols not in adata.var.columns: + raise KeyError(f"Column '{gene_symbols}' not found in `adata.var`. Cannot use it as `gene_symbols` lookup.") + mask = adata.var[gene_symbols] == col_for_color + if not mask.any(): + raise KeyError(f"'{col_for_color}' not found in `adata.var['{gene_symbols}']`.") + n_matches = mask.sum() + if n_matches > 1: + logger.warning( + f"Gene symbol '{col_for_color}' maps to {n_matches} var_names in column '{gene_symbols}'. " + f"Using the first match: '{adata.var.index[mask][0]}'." + ) + return str(adata.var.index[mask][0]) + + +def _resolve_obsp_key(table: AnnData, connectivity_key: str) -> str | None: + """Resolve connectivity_key to an actual obsp key. Accepts full key or prefix.""" + if connectivity_key in table.obsp: + return connectivity_key + suffixed = f"{connectivity_key}_connectivities" + if suffixed in table.obsp: + return suffixed + return None + + +def _require_obsp_key(table: AnnData, key: str, *, param_name: str) -> str: + """Resolve key (with prefix fallback) or raise KeyError.""" + resolved = _resolve_obsp_key(table, key) + if resolved is None: + raise KeyError( + f"`{param_name}='{key}'` not found in `table.obsp`. " + f"Tried '{key}' and '{key}_connectivities'. " + f"Available obsp keys: {list(table.obsp.keys())}." + ) + return resolved + + +def _validate_col_for_column_table( + sdata: SpatialData, + element_name: str, + col_for_color: str | None, + table_name: str | None, + labels: bool = False, + gene_symbols: str | None = None, +) -> tuple[str | None, str | None]: + if col_for_color is None: + return None, None + + if not labels and col_for_color in sdata[element_name].columns and table_name is None: + return col_for_color, None + if table_name is not None: + tables = get_element_annotators(sdata, element_name) + if table_name not in tables: + logger.warning(f"Table '{table_name}' does not annotate element '{element_name}'.") + raise KeyError(f"Table '{table_name}' does not annotate element '{element_name}'.") + if col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names: + if gene_symbols is not None: + col_for_color = _resolve_gene_symbols(sdata[table_name], col_for_color, gene_symbols) + else: + raise KeyError( + f"Column '{col_for_color}' not found in obs/var of table '{table_name}' " + f"for element '{element_name}'." + ) + else: + tables = get_element_annotators(sdata, element_name) + if len(tables) == 0: + raise KeyError( + f"Element '{element_name}' has no annotating tables. " + f"Cannot use column '{col_for_color}' for coloring. " + "Please ensure the element is annotated by at least one table." + ) + # Now check which tables contain the column + resolved_var_name: str | None = None + if gene_symbols is not None and not any(gene_symbols in sdata[t].var.columns for t in tables): + available = sorted({c for t in tables for c in sdata[t].var.columns}) + raise KeyError( + f"Column '{gene_symbols}' specified in `gene_symbols=` was not found in " + f"`adata.var` of any table annotating element '{element_name}'. " + f"Available var columns: {available}" + ) + for annotates in tables.copy(): + if col_for_color not in sdata[annotates].obs.columns and col_for_color not in sdata[annotates].var_names: + if gene_symbols is not None: + try: + resolved_var_name = _resolve_gene_symbols(sdata[annotates], col_for_color, gene_symbols) + except KeyError: + tables.remove(annotates) + else: + tables.remove(annotates) + if len(tables) == 0: + raise KeyError( + f"Unable to locate color key '{col_for_color}' for element '{element_name}'. " + "Please ensure the key exists in a table annotating this element." + ) + table_name = next(iter(tables)) + if len(tables) > 1: + logger.warning(f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.") + if resolved_var_name is not None: + col_for_color = resolved_var_name + return col_for_color, table_name + + +def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[str, Any]: + colorbar = param_dict.get("colorbar", "auto") + if colorbar not in {True, False, None, "auto"}: + raise TypeError("Parameter 'colorbar' must be one of True, False or 'auto'.") + + colorbar_params = param_dict.get("colorbar_params") + if colorbar_params is not None and not isinstance(colorbar_params, dict): + raise TypeError("Parameter 'colorbar_params' must be a dictionary or None.") + + element = param_dict.get("element") + if element is not None and not isinstance(element, str): + raise ValueError( + "Parameter 'element' must be a string. If you want to display more elements, pass `element` " + "as `None` or chain pl.render(...).pl.render(...).pl.show()" + ) + if element_type == "images": + param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].images.keys()) + elif element_type == "labels": + param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].labels.keys()) + elif element_type == "points": + param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].points.keys()) + elif element_type == "shapes": + param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].shapes.keys()) + + channel = param_dict.get("channel") + if channel is not None and not isinstance(channel, list | str | int): + raise TypeError("Parameter 'channel' must be a string, an integer, or a list of strings or integers.") + if isinstance(channel, list): + if not all(isinstance(c, str | int) for c in channel): + raise TypeError("Each item in 'channel' list must be a string or an integer.") + if not all(isinstance(c, type(channel[0])) for c in channel): + raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.") + + elif "channel" in param_dict: + param_dict["channel"] = [channel] if channel is not None else None + + contour_px = param_dict.get("contour_px") + if contour_px and not isinstance(contour_px, int): + raise TypeError("Parameter 'contour_px' must be an integer.") + + color = param_dict.get("color") + if color and element_type in { + "shapes", + "points", + "labels", + "graph", + }: + if not isinstance(color, str | tuple | list): + raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.") + if _is_color_like(color): + if isinstance(color, str): + _check_color_column_collision(param_dict["sdata"], param_dict["element"], color, element_type) + param_dict["col_for_color"] = None + param_dict["color"] = Color(color) + if param_dict["color"].alpha_is_user_defined(): + if element_type == "points" and param_dict.get("alpha") is None: + param_dict["alpha"] = param_dict["color"].get_alpha_as_float() + elif element_type in {"shapes", "labels"} and param_dict.get("fill_alpha") is None: + param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float() + else: + logger.info( + f"Alpha implied by color '{color}' is ignored since the parameter 'alpha' or 'fill_alpha' " + "is set and its value takes precedence." + ) + elif isinstance(color, str): + param_dict["col_for_color"] = color + param_dict["color"] = None + else: + raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.") + elif "color" in param_dict and element_type != "images": + param_dict["col_for_color"] = None + + outline_width = param_dict.get("outline_width") + if outline_width: + # outline_width only exists for shapes at the moment + if isinstance(outline_width, tuple): + for ow in outline_width: + if isinstance(ow, float | int): + if ow < 0: + raise ValueError("Parameter 'outline_width' cannot contain negative values.") + else: + raise TypeError("Parameter 'outline_width' must contain only numerics when it is a tuple.") + elif not isinstance(outline_width, float | int): + raise TypeError("Parameter 'outline_width' must be numeric or a tuple of two numerics.") + if isinstance(outline_width, float | int) and outline_width < 0: + raise ValueError("Parameter 'outline_width' cannot be negative.") + + outline_alpha = param_dict.get("outline_alpha") + if outline_alpha: + if isinstance(outline_alpha, tuple): + if element_type != "shapes": + raise ValueError("Parameter 'outline_alpha' must be a single numeric.") + if len(outline_alpha) == 1: + if not isinstance(outline_alpha[0], float | int) or not 0 <= outline_alpha[0] <= 1: + raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.") + param_dict["outline_alpha"] = outline_alpha[0] + elif len(outline_alpha) < 1: + raise ValueError("Empty tuple is not supported as input for outline_alpha!") + else: + if len(outline_alpha) > 2: + logger.warning( + f"Tuple of length {len(outline_alpha)} was passed for outline_alpha, only first two positions " + "are used since more than 2 outlines are not supported!" + ) + if ( + not isinstance(outline_alpha[0], float | int) + or not isinstance(outline_alpha[1], float | int) + or not 0 <= outline_alpha[0] <= 1 + or not 0 <= outline_alpha[1] <= 1 + ): + raise TypeError("Parameter 'outline_alpha' must contain numeric values between 0 and 1.") + param_dict["outline_alpha"] = (outline_alpha[0], outline_alpha[1]) + elif not isinstance(outline_alpha, float | int) or not 0 <= outline_alpha <= 1: + raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.") + + outline_color = param_dict.get("outline_color") + if "outline_color" in param_dict and element_type in {"shapes", "labels"}: + param_dict["col_for_outline_color"] = None + if outline_color: + if not isinstance(outline_color, str | tuple | list): + raise TypeError("Parameter 'outline_color' must be a string or a tuple/list of floats or colors.") + if isinstance(outline_color, tuple | list): + if len(outline_color) < 1: + raise ValueError("Empty tuple is not supported as input for outline_color!") + if len(outline_color) == 1: + param_dict["outline_color"] = Color(outline_color[0]) + elif len(outline_color) == 2: + # assuming the case of 2 outlines + param_dict["outline_color"] = (Color(outline_color[0]), Color(outline_color[1])) + elif len(outline_color) in [3, 4]: + # assuming RGB(A) array + param_dict["outline_color"] = Color(outline_color) + else: + raise ValueError( + f"Tuple/List of length {len(outline_color)} was passed for outline_color. Valid options would be: " + "tuple of 2 colors (for 2 outlines) or an RGB(A) array, aka a list/tuple of 3-4 floats." + ) + elif isinstance(outline_color, str) and element_type in {"shapes", "labels"}: + if _is_color_like(outline_color): + _check_color_column_collision(param_dict["sdata"], param_dict["element"], outline_color, element_type) + param_dict["outline_color"] = Color(outline_color) + else: + if isinstance(param_dict.get("outline_width"), tuple): + raise ValueError( + "Coloring outlines by a column is not supported with two outlines. " + "Pass a scalar `outline_width` or a literal color for `outline_color`." + ) + param_dict["col_for_outline_color"] = outline_color + param_dict["outline_color"] = None + else: + param_dict["outline_color"] = Color(outline_color) + + if contour_px is not None and contour_px < 2: + raise ValueError( + "Parameter 'contour_px' must be >= 2; values below 2 produce no visible outline " + "(a 1x1 erosion is the identity transformation)." + ) + + alpha = param_dict.get("alpha") + if alpha is not None: + if not isinstance(alpha, float | int): + raise TypeError("Parameter 'alpha' must be numeric.") + if not 0 <= alpha <= 1: + raise ValueError("Parameter 'alpha' must be between 0 and 1.") + elif element_type == "points": + # set default alpha for points if not given by user explicitly or implicitly (as part of color) + param_dict["alpha"] = 1.0 + + fill_alpha = param_dict.get("fill_alpha") + if fill_alpha is not None: + if not isinstance(fill_alpha, float | int): + raise TypeError("Parameter 'fill_alpha' must be numeric.") + if fill_alpha < 0: + raise ValueError("Parameter 'fill_alpha' cannot be negative.") + elif element_type == "shapes": + # set default fill_alpha for shapes if not given by user explicitly or implicitly (as part of color) + param_dict["fill_alpha"] = 1.0 + elif element_type == "labels": + # set default fill_alpha for labels if not given by user explicitly or implicitly (as part of color) + param_dict["fill_alpha"] = 0.4 + + cmap = param_dict.get("cmap") + palette = param_dict.get("palette") + if cmap is not None and palette is not None: + raise ValueError("Both `palette` and `cmap` are specified. Please specify only one of them.") + param_dict["cmap"] = cmap + + groups = param_dict.get("groups") + if groups is not None: + if not isinstance(groups, list | str): + raise TypeError("Parameter 'groups' must be a string or a list of strings.") + if isinstance(groups, str): + param_dict["groups"] = [groups] + elif not all(isinstance(g, str) for g in groups): + raise TypeError("Each item in 'groups' must be a string.") + + palette = param_dict["palette"] + + # dict palettes (e.g. from make_palette_from_data) bypass groups validation + if isinstance(palette, dict): + from matplotlib.colors import is_color_like + + invalid = [f"'{k}': '{v}'" for k, v in palette.items() if not is_color_like(v)] + if invalid: + raise ValueError(f"Dict palette contains invalid color values: {', '.join(invalid)}.") + elif isinstance(palette, list): + if not all(isinstance(p, str) for p in palette): + raise ValueError("If specified, parameter 'palette' must contain only strings.") + elif isinstance(palette, str | type(None)) and "palette" in param_dict and element_type != "graph": + param_dict["palette"] = [palette] if palette is not None else None + + palette_group = param_dict.get("palette") + if element_type in ["shapes", "points", "labels"] and palette_group is not None and not isinstance(palette, dict): + groups = param_dict.get("groups") + if groups is not None and len(groups) != len(palette_group): + raise ValueError( + f"The length of 'palette' and 'groups' must be the same, length is {len(palette_group)} and" + f"{len(groups)} respectively." + ) + + if isinstance(cmap, list): + if not all(isinstance(c, Colormap | str) for c in cmap): + raise TypeError("Each item in 'cmap' list must be a string or a Colormap.") + elif isinstance(cmap, Colormap | str | type(None)): + if "cmap" in param_dict and element_type != "graph": + param_dict["cmap"] = [cmap] if cmap is not None else None + else: + raise TypeError("Parameter 'cmap' must be a string, a Colormap, or a list of these types.") + + # validation happens within Color constructor (images don't use na_color) + if "na_color" in param_dict: + param_dict["na_color"] = Color(param_dict.get("na_color")) + + norm = param_dict.get("norm") + if norm is not None: + if element_type == "images": + if isinstance(norm, list): + if not norm: + raise ValueError("Parameter 'norm' list must not be empty.") + if not all(isinstance(n, Normalize) for n in norm): + raise TypeError("Every item in 'norm' list must be a Normalize instance.") + elif not isinstance(norm, Normalize): + raise TypeError("Parameter 'norm' must be a Normalize or a list of Normalize instances.") + elif element_type == "labels" and not isinstance(norm, Normalize): + raise TypeError("Parameter 'norm' must be of type Normalize.") + if element_type in {"shapes", "points"} and not isinstance(norm, bool | Normalize): + raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.") + if element_type == "graph" and not isinstance(norm, Normalize): + raise TypeError("Parameter 'norm' must be a Normalize instance.") + + scale = param_dict.get("scale") + if scale is not None: + if element_type in {"images", "labels"} and not isinstance(scale, str): + raise TypeError("Parameter 'scale' must be a string if specified.") + if element_type == "shapes": + if not isinstance(scale, float | int): + raise TypeError("Parameter 'scale' must be numeric.") + if scale < 0: + raise ValueError("Parameter 'scale' must be a positive number.") + + size = param_dict.get("size") + if size: + if not isinstance(size, float | int): + raise TypeError("Parameter 'size' must be numeric.") + if size < 0: + raise ValueError("Parameter 'size' must be a positive number.") + + shape = param_dict.get("shape") + if element_type == "shapes" and shape is not None: + valid_shapes = {"circle", "hex", "visium_hex", "square"} + if not isinstance(shape, str): + raise TypeError(f"Parameter 'shape' must be a String from {valid_shapes} if not None.") + if shape not in valid_shapes: + raise ValueError(f"'{shape}' is not supported for 'shape', please choose from {valid_shapes}.") + + table_name = param_dict.get("table_name") + table_layer = param_dict.get("table_layer") + if table_name and not isinstance(param_dict["table_name"], str): + raise TypeError("Parameter 'table_name' must be a string.") + + if table_layer and not isinstance(param_dict["table_layer"], str): + raise TypeError("Parameter 'table_layer' must be a string.") + + def _ensure_table_and_layer_exist_in_sdata( + sdata: SpatialData, table_name: str | None, table_layer: str | None + ) -> bool: + """Ensure that table_name and table_layer are valid; throw error if not.""" + if table_name: + if table_layer: + if table_layer in sdata.tables[table_name].layers: + return True + raise ValueError(f"Layer '{table_layer}' not found in table '{table_name}'.") + return True # using sdata.tables[table_name].X + + if table_layer: + # user specified a layer but we have no tables => invalid + if len(sdata.tables) == 0: + raise ValueError("Trying to use 'table_layer' but no tables are present in the SpatialData object.") + if len(sdata.tables) == 1: + single_table_name = list(sdata.tables.keys())[0] + if table_layer in sdata.tables[single_table_name].layers: + return True + raise ValueError(f"Layer '{table_layer}' not found in table '{single_table_name}'.") + # more than one tables, try to find which one has the given layer + found_table = False + for tname in sdata.tables: + if table_layer in sdata.tables[tname].layers: + if found_table: + raise ValueError( + "Trying to guess 'table_name' based on 'table_layer', but found multiple matches." + ) + found_table = True + + if found_table: + return True + + raise ValueError(f"Layer '{table_layer}' not found in any table.") + + return True # not using any table + + _ensure_table_and_layer_exist_in_sdata(param_dict.get("sdata"), table_name, table_layer) + + method = param_dict.get("method") + if method not in ["matplotlib", "datashader", None]: + raise ValueError("If specified, parameter 'method' must be either 'matplotlib' or 'datashader'.") + + valid_ds_reduction_methods = [ + "sum", + "mean", + "any", + "count", + # "m2", -> not intended to be used alone (see https://datashader.org/api.html#datashader.reductions.m2) + # "mode", -> not supported for points (see https://datashader.org/api.html#datashader.reductions.mode) + "std", + "var", + "max", + "min", + ] + ds_reduction = param_dict.get("ds_reduction") + if ds_reduction and (ds_reduction not in valid_ds_reduction_methods): + raise ValueError(f"Parameter 'ds_reduction' must be one of the following: {valid_ds_reduction_methods}.") + + if element_type == "graph": + for key in ("connectivity_key",): + val = param_dict.get(key) + if val is not None and not isinstance(val, str): + raise TypeError(f"Parameter '{key}' must be a string.") + + for key in ("obsp_key", "weight_key", "group_key"): + val = param_dict.get(key) + if val is not None and not isinstance(val, str): + raise TypeError(f"Parameter '{key}' must be a string or None.") + + for key in ("edge_width", "edge_alpha"): + val = param_dict.get(key) + if val == "weight": + continue + if not isinstance(val, float | int): + raise TypeError(f"Parameter '{key}' must be numeric or the literal string 'weight'.") + if val < 0: + raise ValueError(f"Parameter '{key}' cannot be negative.") + + linestyle = param_dict.get("linestyle") + if linestyle is not None and not isinstance(linestyle, str | list | tuple): + raise TypeError("Parameter 'linestyle' must be a string or a sequence of strings.") + + for key in ("include_self_loops", "rasterize"): + val = param_dict.get(key) + if val is not None and not isinstance(val, bool): + raise TypeError(f"Parameter '{key}' must be a boolean.") + + return param_dict + + +def _resolve_color_panels(color: Any) -> tuple[Any, list[str] | None]: + """Split a ``color`` argument into a scalar color and an optional multi-panel key list. + + Returns ``(scalar_color, panel_keys)``. When ``panel_keys`` is ``None`` the call is a + normal single-color render and ``scalar_color`` is the (unchanged) color to use. When + ``panel_keys`` is a list, the render must be expanded into one panel per key. + + A list of all-strings is treated as multi-panel keys; a length-1 list normalizes to a + scalar color; an all-numeric list stays a single RGB(A) color. Empty, duplicate, or + mixed str/number lists raise ``ValueError``. + """ + if not isinstance(color, list): + return color, None + if all(isinstance(c, str) for c in color): + if len(color) == 0: + raise ValueError("`color` was given an empty list; provide at least one column/key name.") + duplicate_keys = sorted(k for k, n in Counter(color).items() if n > 1) + if duplicate_keys: + raise ValueError(f"`color` contains duplicate keys {duplicate_keys}; each multi-panel key must be unique.") + if len(color) == 1: + return color[0], None + return None, list(color) + if any(isinstance(c, str) for c in color): + raise ValueError( + "`color` list must be either all column/key names (str) for a multi-panel plot, " + "or 3-4 floats for a single RGB(A) color, not a mix of both." + ) + return color, None + + +def _expand_color_panels( + sdata: SpatialData, + color: Any, + render_fn_name: str, + validate: Callable[[Any], dict[str, Any]], +) -> list[tuple[str | None, dict[str, Any]]]: + """Resolve ``color`` into validated per-panel render params for the multi-panel ``color=[...]`` feature. + + ``validate`` is a callback that runs the render function's own parameter validation for a single + color value and returns its per-element ``params_dict``. Returns a list of ``(panel_key, params_dict)`` + pairs: a single ``(None, params_dict)`` for the scalar case, or one entry per key for a key list. + + Enforces that only one ``render_*`` call per figure may pass a color list, and aggregates per-key + validation errors into a single message. Used by ``render_shapes`` and ``render_labels``. + """ + color, panel_keys = _resolve_color_panels(color) + if panel_keys is not None and any( + getattr(params, "panel_key", None) is not None for params in getattr(sdata, "plotting_tree", {}).values() + ): + raise ValueError( + "Only one `render_*` call may use a list of color keys per figure. Other chained render " + "calls must use a single (scalar) color; they are drawn into every panel as a shared layer." + ) + + color_specs = [(None, color)] if panel_keys is None else [(key, key) for key in panel_keys] + panel_param_dicts: list[tuple[str | None, dict[str, Any]]] = [] + key_errors: dict[str, str] = {} + for panel_key, color_value in color_specs: + try: + params_dict = validate(color_value) + except (KeyError, ValueError) as e: + if panel_keys is None: + raise + key_errors[panel_key] = str(e) # type: ignore[index] + continue + panel_param_dicts.append((panel_key, params_dict)) + if key_errors: + details = "\n".join(f" - {key!r}: {msg}" for key, msg in key_errors.items()) + raise ValueError(f"Invalid color key(s) for multi-panel `{render_fn_name}`:\n{details}") + return panel_param_dicts + + +def _validate_as_points_size(size: float) -> None: + """Validate the centroid marker `size` used by ``render_shapes``/``render_labels`` with ``as_points=True``.""" + if isinstance(size, bool) or not isinstance(size, (int, float)): + raise TypeError("Parameter 'size' must be numeric.") + if size <= 0: + raise ValueError("Parameter 'size' must be a positive number.") + + +def _validate_label_render_params( + sdata: sd.SpatialData, + element: str | None, + cmap: list[Colormap | str] | Colormap | str | None, + color: ColorLike | None, + fill_alpha: float | int | None, + contour_px: int | None, + groups: list[str] | str | None, + palette: dict[str, str] | list[str] | str | None, + na_color: ColorLike | None, + norm: Normalize | None, + outline_alpha: float | int, + outline_color: ColorLike | None, + scale: str | None, + table_name: str | None, + table_layer: str | None, + colorbar: bool | str | None, + colorbar_params: dict[str, object] | None, + gene_symbols: str | None = None, +) -> dict[str, dict[str, Any]]: + param_dict: dict[str, Any] = { + "sdata": sdata, + "element": element, + "fill_alpha": fill_alpha, + "contour_px": contour_px, + "groups": groups, + "palette": palette, + "color": color, + "na_color": na_color, + "outline_alpha": outline_alpha, + "outline_color": outline_color, + "cmap": cmap, + "norm": norm, + "scale": scale, + "table_name": table_name, + "table_layer": table_layer, + "colorbar": colorbar, + "colorbar_params": colorbar_params, + } + param_dict = _type_check_params(param_dict, "labels") + + element_params: dict[str, dict[str, Any]] = {} + for el in param_dict["element"]: + # ensure that the element exists in the SpatialData object + _ = param_dict["sdata"][el] + + element_params[el] = {} + element_params[el]["na_color"] = param_dict["na_color"] + element_params[el]["cmap"] = param_dict["cmap"] + element_params[el]["norm"] = param_dict["norm"] + element_params[el]["fill_alpha"] = param_dict["fill_alpha"] + element_params[el]["scale"] = param_dict["scale"] + element_params[el]["outline_alpha"] = param_dict["outline_alpha"] + element_params[el]["outline_color"] = param_dict["outline_color"] + element_params[el]["contour_px"] = param_dict["contour_px"] + element_params[el]["table_layer"] = param_dict["table_layer"] + + element_params[el]["table_name"] = None + element_params[el]["color"] = param_dict["color"] # literal Color or None + element_params[el]["col_for_color"] = None + if (col_for_color := param_dict["col_for_color"]) is not None: + col_for_color, table_name = _validate_col_for_column_table( + sdata, el, col_for_color, param_dict["table_name"], labels=True, gene_symbols=gene_symbols + ) + element_params[el]["table_name"] = table_name + element_params[el]["col_for_color"] = col_for_color + + element_params[el]["col_for_outline_color"] = None + element_params[el]["outline_table_name"] = None + if (col_for_outline_color := param_dict.get("col_for_outline_color")) is not None: + col_for_outline_color, outline_table_name = _validate_col_for_column_table( + sdata, + el, + col_for_outline_color, + param_dict["table_name"], + labels=True, + gene_symbols=gene_symbols, + ) + element_params[el]["col_for_outline_color"] = col_for_outline_color + element_params[el]["outline_table_name"] = outline_table_name + + _gate_palette_and_groups(element_params[el], param_dict) + element_params[el]["colorbar"] = param_dict["colorbar"] + element_params[el]["colorbar_params"] = param_dict["colorbar_params"] + + return element_params + + +def _validate_points_render_params( + sdata: sd.SpatialData, + element: str | None, + alpha: float | int | None, + color: ColorLike | None, + groups: list[str] | str | None, + palette: dict[str, str] | list[str] | str | None, + na_color: ColorLike | None, + cmap: list[Colormap | str] | Colormap | str | None, + norm: Normalize | None, + size: float | int, + table_name: str | None, + table_layer: str | None, + ds_reduction: str | None, + colorbar: bool | str | None, + colorbar_params: dict[str, object] | None, + gene_symbols: str | None = None, + density: bool = False, + density_how: Literal["linear", "log", "cbrt", "eq_hist"] = "linear", + transfunc: Callable[[float], float] | None = None, + method: str | None = None, +) -> dict[str, dict[str, Any]]: + if not isinstance(density, bool): + raise TypeError("Parameter 'density' must be a bool.") + allowed_how = ("linear", "log", "cbrt", "eq_hist") + if density_how not in allowed_how: + raise ValueError(f"Parameter 'density_how' must be one of {allowed_how}; got {density_how!r}.") + + param_dict: dict[str, Any] = { + "sdata": sdata, + "element": element, + "alpha": alpha, + "color": color, + "groups": groups, + "palette": palette, + "na_color": na_color, + "cmap": cmap, + "norm": norm, + "size": size, + "table_name": table_name, + "table_layer": table_layer, + "ds_reduction": ds_reduction, + "colorbar": colorbar, + "colorbar_params": colorbar_params, + } + param_dict = _type_check_params(param_dict, "points") + + if density: + if method == "matplotlib": + raise ValueError( + "density=True requires the datashader backend; got method='matplotlib'. " + "Either drop method= or set method='datashader'." + ) + # Literal color (resolved into param_dict["color"] as a Color instance, with + # col_for_color set to None) is ambiguous with density: it could mean a + # single-hue cmap or a one-entry palette. Force the user to choose. + if param_dict["color"] is not None and param_dict["col_for_color"] is None: + raise ValueError( + "density=True with a literal color is ambiguous. Pass cmap= to recolor the " + "density, or palette= to assign a categorical color, but not color=." + ) + # Warn-and-ignore: these parameters do not interact meaningfully with a + # count-based density and are silently dropped to keep the API consistent. + if size != 1.0: + warnings.warn( + "size is ignored when density=True; spreading would distort the count signal.", + UserWarning, + stacklevel=3, + ) + if transfunc is not None: + warnings.warn( + "transfunc is ignored when density=True (no continuous color vector to transform).", + UserWarning, + stacklevel=3, + ) + if isinstance(norm, Normalize) and (norm.vmin is not None or norm.vmax is not None): + warnings.warn( + "norm.vmin/vmax are ignored when density=True; use density_how= to control intensity mapping.", + UserWarning, + stacklevel=3, + ) + if ds_reduction is not None: + warnings.warn( + "datashader_reduction is ignored when density=True; counts are forced.", + UserWarning, + stacklevel=3, + ) + + element_params: dict[str, dict[str, Any]] = {} + for el in param_dict["element"]: + # ensure that the element exists in the SpatialData object + _ = param_dict["sdata"][el] + + element_params[el] = {} + element_params[el]["na_color"] = param_dict["na_color"] + element_params[el]["cmap"] = param_dict["cmap"] + element_params[el]["norm"] = param_dict["norm"] + element_params[el]["color"] = param_dict["color"] + element_params[el]["size"] = param_dict["size"] + element_params[el]["alpha"] = param_dict["alpha"] + element_params[el]["table_layer"] = param_dict["table_layer"] + + element_params[el]["table_name"] = None + element_params[el]["col_for_color"] = None + col_for_color = param_dict["col_for_color"] + if col_for_color is not None: + col_for_color, table_name = _validate_col_for_column_table( + sdata, el, col_for_color, param_dict["table_name"], gene_symbols=gene_symbols + ) + element_params[el]["table_name"] = table_name + element_params[el]["col_for_color"] = col_for_color + + _gate_palette_and_groups(element_params[el], param_dict) + element_params[el]["ds_reduction"] = param_dict["ds_reduction"] + element_params[el]["colorbar"] = param_dict["colorbar"] + element_params[el]["colorbar_params"] = param_dict["colorbar_params"] + + return element_params + + +def _validate_shape_render_params( + sdata: sd.SpatialData, + element: str | None, + fill_alpha: float | int | None, + groups: list[str] | str | None, + palette: dict[str, str] | list[str] | str | None, + color: ColorLike | None, + na_color: ColorLike | None, + outline_width: float | int | tuple[float | int, float | int] | None, + outline_color: ColorLike | tuple[ColorLike] | None, + outline_alpha: float | int | tuple[float | int, float | int] | None, + cmap: list[Colormap | str] | Colormap | str | None, + norm: Normalize | None, + scale: float | int, + table_name: str | None, + table_layer: str | None, + shape: Literal["circle", "hex", "visium_hex", "square"] | None, + method: str | None, + ds_reduction: str | None, + colorbar: bool | str | None, + colorbar_params: dict[str, object] | None, + gene_symbols: str | None = None, +) -> dict[str, dict[str, Any]]: + param_dict: dict[str, Any] = { + "sdata": sdata, + "element": element, + "fill_alpha": fill_alpha, + "groups": groups, + "palette": palette, + "color": color, + "na_color": na_color, + "outline_width": outline_width, + "outline_color": outline_color, + "outline_alpha": outline_alpha, + "cmap": cmap, + "norm": norm, + "scale": scale, + "table_name": table_name, + "table_layer": table_layer, + "shape": shape, + "method": method, + "ds_reduction": ds_reduction, + "colorbar": colorbar, + "colorbar_params": colorbar_params, + } + param_dict = _type_check_params(param_dict, "shapes") + + element_params: dict[str, dict[str, Any]] = {} + for el in param_dict["element"]: + # ensure that the element exists in the SpatialData object + _ = param_dict["sdata"][el] + + element_params[el] = {} + element_params[el]["fill_alpha"] = param_dict["fill_alpha"] + element_params[el]["na_color"] = param_dict["na_color"] + element_params[el]["outline_width"] = param_dict["outline_width"] + element_params[el]["outline_color"] = param_dict["outline_color"] + element_params[el]["outline_alpha"] = param_dict["outline_alpha"] + element_params[el]["cmap"] = param_dict["cmap"] + element_params[el]["norm"] = param_dict["norm"] + element_params[el]["scale"] = param_dict["scale"] + element_params[el]["table_layer"] = param_dict["table_layer"] + element_params[el]["shape"] = param_dict["shape"] + + element_params[el]["color"] = param_dict["color"] + + element_params[el]["table_name"] = None + element_params[el]["col_for_color"] = None + col_for_color = param_dict["col_for_color"] + if col_for_color is not None: + col_for_color, table_name = _validate_col_for_column_table( + sdata, el, col_for_color, param_dict["table_name"], gene_symbols=gene_symbols + ) + element_params[el]["table_name"] = table_name + element_params[el]["col_for_color"] = col_for_color + + element_params[el]["col_for_outline_color"] = None + element_params[el]["outline_table_name"] = None + col_for_outline_color = param_dict.get("col_for_outline_color") + if col_for_outline_color is not None: + col_for_outline_color, outline_table_name = _validate_col_for_column_table( + sdata, el, col_for_outline_color, param_dict["table_name"], gene_symbols=gene_symbols + ) + element_params[el]["col_for_outline_color"] = col_for_outline_color + element_params[el]["outline_table_name"] = outline_table_name + + _gate_palette_and_groups(element_params[el], param_dict) + element_params[el]["method"] = param_dict["method"] + element_params[el]["ds_reduction"] = param_dict["ds_reduction"] + element_params[el]["colorbar"] = param_dict["colorbar"] + element_params[el]["colorbar_params"] = param_dict["colorbar_params"] + + return element_params + + +def _validate_graph_render_params( + sdata: SpatialData, + element: str | None, + connectivity_key: str, + table_name: str | None, + color: ColorLike | None, + edge_width: float | Literal["weight"], + edge_alpha: float | Literal["weight"], + groups: list[str] | str | None, + group_key: str | None, + obsp_key: str | None = None, + weight_key: str | None = None, + palette: dict[str, str] | list[str] | str | None = None, + na_color: ColorLike | None = "default", + cmap: Colormap | str | None = None, + norm: Normalize | None = None, + linestyle: str | Sequence[str] = "solid", + include_self_loops: bool = False, + rasterize: bool = True, +) -> dict[str, Any]: + """Validate and resolve parameters for render_graph.""" + param_dict: dict[str, Any] = { + "sdata": sdata, + "element": element, + "color": color, + "groups": groups, + "palette": palette, + "na_color": na_color, + "cmap": cmap, + "norm": norm if norm is not None else Normalize(clip=False), + "table_name": table_name, + "connectivity_key": connectivity_key, + "obsp_key": obsp_key, + "weight_key": weight_key, + "group_key": group_key, + "edge_width": edge_width, + "edge_alpha": edge_alpha, + "linestyle": linestyle, + "include_self_loops": include_self_loops, + "rasterize": rasterize, + } + param_dict = _type_check_params(param_dict, "graph") + + if param_dict["table_name"] is None: + candidates = [tname for tname in sdata.tables if _resolve_obsp_key(sdata[tname], connectivity_key) is not None] + if len(candidates) == 0: + raise ValueError( + f"No table found with connectivity key '{connectivity_key}' in obsp. " + f"Available tables: {list(sdata.tables.keys())}." + ) + if len(candidates) > 1: + raise ValueError( + f"Multiple tables contain connectivity key '{connectivity_key}': {candidates}. " + "Please specify `table_name` explicitly." + ) + param_dict["table_name"] = candidates[0] + + if param_dict["table_name"] not in sdata.tables: + raise KeyError(f"Table '{param_dict['table_name']}' not found. Available: {list(sdata.tables.keys())}.") + + table = sdata[param_dict["table_name"]] + connectivity_obsp_key = _require_obsp_key(table, connectivity_key, param_name="connectivity_key") + + _, region_key, _ = get_table_keys(table) + if region_key is None: + raise ValueError( + f"Table '{param_dict['table_name']}' has no `region_key`; cannot associate its observations " + "with a spatial element. Re-parse the table with `TableModel.parse(..., region_key=...)`." + ) + + if param_dict["element"] is None: + regions = table.obs[region_key].unique().tolist() + spatial_regions = [r for r in regions if r in sdata.shapes or r in sdata.points or r in sdata.labels] + if len(spatial_regions) == 0: + raise ValueError( + f"Table '{param_dict['table_name']}' does not annotate any spatial element. Region values: {regions}." + ) + if len(spatial_regions) > 1: + raise ValueError( + f"Table '{param_dict['table_name']}' annotates multiple spatial elements: {spatial_regions}. " + "Please specify `element` explicitly." + ) + param_dict["element"] = spatial_regions[0] + elif not ( + param_dict["element"] in sdata.shapes + or param_dict["element"] in sdata.points + or param_dict["element"] in sdata.labels + ): + raise KeyError( + f"Element '{param_dict['element']}' not found in shapes, points, or labels. " + f"Available: shapes={list(sdata.shapes.keys())}, " + f"points={list(sdata.points.keys())}, labels={list(sdata.labels.keys())}." + ) + + # _type_check_params normalised string groups → list; renormalise the working set here. + if param_dict["groups"] is not None and param_dict["group_key"] is None: + raise ValueError("`groups` requires `group_key` to be specified.") + if param_dict["group_key"] is not None and param_dict["group_key"] not in table.obs.columns: + raise KeyError( + f"`group_key='{param_dict['group_key']}'` not found in table obs columns. " + f"Available: {list(table.obs.columns)}." + ) + if param_dict["groups"] is not None and param_dict["group_key"] is not None: + groups_set: set[Any] = set(param_dict["groups"]) + available_groups = set(table.obs[param_dict["group_key"]].dropna().unique()) + missing_groups = groups_set - available_groups + if missing_groups: + try: + missing_str = str(sorted(missing_groups)) + except TypeError: + missing_str = str(list(missing_groups)) + if missing_groups == groups_set: + logger.warning( + f"None of the requested groups {missing_str} were found in column " + f"'{param_dict['group_key']}'. Resulting plot will contain no edges." + ) + else: + logger.warning( + f"Groups {missing_str} not found in column '{param_dict['group_key']}' and will be ignored." + ) + + # After _type_check_params: col_for_color is the non-color string user passed via `color=`; + # color is either a Color (user gave a real color) or None (user gave a column name or nothing). + col_for_color = param_dict.get("col_for_color") + if col_for_color is not None and col_for_color not in table.obs.columns: + raise ValueError( + f"`color='{col_for_color}'` is not a matplotlib color and was not found in " + f"`table.obs` columns. Available obs columns: {list(table.obs.columns)}." + ) + + color_is_obs_col = col_for_color is not None + if obsp_key is not None and color_is_obs_col: + raise ValueError( + "Cannot set both `color` (as an obs column) and `obsp_key` for edge coloring. " + "Pick one source: scalar color, obs-column color, or obsp-matrix color." + ) + if obsp_key is not None and param_dict["color"] is not None: + raise ValueError( + "Cannot set both `color` and `obsp_key` for edge coloring. " + "Use `obsp_key` for matrix-driven coloring with `cmap`/`norm`, " + "or `color` for a scalar / obs-column-driven coloring." + ) + + color_obsp_key: str | None = None + obs_col: str | None = None + color_source: str = "scalar" + cmap_params: CmapParams | None = None + palette_map: dict[str, str] | None = None + + if obsp_key is not None: + color_obsp_key = _require_obsp_key(table, obsp_key, param_name="obsp_key") + color_source = "obsp" + cmap_params = _prepare_cmap_norm(cmap=cmap, norm=param_dict["norm"]) + elif color_is_obs_col: + obs_col = col_for_color + obs_values = table.obs[obs_col] + if isinstance(obs_values.dtype, pd.CategoricalDtype) or obs_values.dtype == object: + color_source = "obs_categorical" + categories = ( + obs_values.cat.categories.tolist() + if isinstance(obs_values.dtype, pd.CategoricalDtype) + else sorted(obs_values.dropna().unique().tolist()) + ) + if isinstance(palette, dict): + missing = [c for c in categories if c not in palette] + if missing: + raise KeyError( + f"Palette dict is missing entries for categories: {missing}. " + f"Available categories: {categories}." + ) + palette_map = {c: palette[c] for c in categories} + else: + cat_colors = _get_colors_for_categorical_obs(categories=categories, palette=palette) + palette_map = dict(zip(categories, cat_colors, strict=True)) + else: + color_source = "obs_continuous" + cmap_params = _prepare_cmap_norm(cmap=cmap, norm=param_dict["norm"]) + + # When edge_width/edge_alpha="weight" but weight_key isn't given, fall back to the + # connectivity matrix so binary graphs still produce a per-edge array. + resolved_weight_key: str | None = None + if edge_width == "weight" or edge_alpha == "weight": + resolved_weight_key = _require_obsp_key( + table, weight_key if weight_key is not None else connectivity_key, param_name="weight_key" + ) + + edge_color = param_dict["color"] if param_dict["color"] is not None else Color("grey") + parsed_na_color = param_dict["na_color"] + + return { + "element": param_dict["element"], + "connectivity_key": connectivity_key, + "connectivity_obsp_key": connectivity_obsp_key, + "obsp_key": color_obsp_key, + "obs_col": obs_col, + "cmap_params": cmap_params, + "palette_map": palette_map, + "na_color": parsed_na_color, + "color_source": color_source, + "table_name": param_dict["table_name"], + "weight_key": resolved_weight_key, + "color": edge_color, + "edge_width": edge_width, + "edge_alpha": edge_alpha, + "groups": param_dict["groups"], + "group_key": param_dict["group_key"], + } + + +def _validate_image_render_params( + sdata: sd.SpatialData, + element: str | None, + channel: list[str] | list[int] | str | int | None, + alpha: float | int | None, + palette: list[str] | str | None, + cmap: list[Colormap | str] | Colormap | str | None, + norm: list[Normalize] | Normalize | None, + scale: str | None, + colorbar: bool | str | None, + colorbar_params: dict[str, object] | None, +) -> dict[str, dict[str, Any]]: + param_dict: dict[str, Any] = { + "sdata": sdata, + "element": element, + "channel": channel, + "alpha": alpha, + "palette": palette, + "cmap": cmap, + "norm": norm, + "scale": scale, + "colorbar": colorbar, + "colorbar_params": colorbar_params, + } + param_dict = _type_check_params(param_dict, "images") + + element_params: dict[str, dict[str, Any]] = {} + for el in param_dict["element"]: + element_params[el] = {} + spatial_element = param_dict["sdata"][el] + + # robustly get channel names from image or multiscale image + spatial_element_ch = ( + spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values + ) + channel = param_dict["channel"] + if channel is not None: + # Normalize channel to always be a list of str or a list of int + if isinstance(channel, str): + channel = [channel] + + if isinstance(channel, int): + channel = [channel] + + # If channel is a list, ensure all elements are the same type + if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)): + raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.") + + invalid = [c for c in channel if c not in spatial_element_ch] + if invalid: + raise ValueError( + f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}" + ) + element_params[el]["channel"] = channel + else: + element_params[el]["channel"] = None + + element_params[el]["alpha"] = param_dict["alpha"] + + palette = param_dict["palette"] + assert isinstance(palette, list | type(None)) # if present, was converted to list, just to make sure + + if isinstance(palette, list): + # case A: single palette for all channels + if len(palette) == 1: + palette_length = len(channel) if channel is not None else len(spatial_element_ch) + palette = palette * palette_length + # case B: one palette per channel (either given or derived from channel length) + channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel + if channels_to_use is not None and len(palette) != len(channels_to_use): + raise ValueError( + f"Palette length ({len(palette)}) does not match channel length " + f"({', '.join(str(c) for c in channels_to_use)})." + ) + element_params[el]["palette"] = palette + + expected_len = len(channel) if channel is not None else len(spatial_element_ch) + + cmap = param_dict["cmap"] + if cmap is not None: + if len(cmap) == 1: + cmap = cmap * expected_len + if len(cmap) != expected_len: + raise ValueError( + f"Length of 'cmap' list ({len(cmap)}) must match the number of channels ({expected_len})." + ) + element_params[el]["cmap"] = cmap + + norm = param_dict["norm"] + if isinstance(norm, list) and len(norm) > 1 and len(norm) != expected_len: + raise ValueError(f"Length of 'norm' list ({len(norm)}) must match the number of channels ({expected_len}).") + element_params[el]["norm"] = norm + scale = param_dict["scale"] + if scale and isinstance(param_dict["sdata"][el], DataTree): + valid_scales = list(param_dict["sdata"][el].keys()) + if scale not in valid_scales and scale != "full": + raise ValueError( + f"Scale '{scale}' does not exist in image '{el}'. Valid scales: {valid_scales + ['full']}." + ) + element_params[el]["scale"] = scale + element_params[el]["colorbar"] = param_dict["colorbar"] + element_params[el]["colorbar_params"] = param_dict["colorbar_params"] + + return element_params diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index fa9bfd39..3001a317 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -29,6 +29,21 @@ from spatialdata_plot._accessor import register_spatial_data_accessor from spatialdata_plot._logging import _log_context, logger +from spatialdata_plot.pl._color import ( + _maybe_set_colors, + _prepare_cmap_norm, + _set_outline, +) +from spatialdata_plot.pl._validate import ( + _expand_color_panels, + _validate_as_points_size, + _validate_graph_render_params, + _validate_image_render_params, + _validate_label_render_params, + _validate_points_render_params, + _validate_shape_render_params, + _validate_show_parameters, +) from spatialdata_plot.pl.render import ( _draw_channel_legend, _render_graph, @@ -61,24 +76,13 @@ from spatialdata_plot.pl.utils import ( _RENDER_CMD_TO_CS_FLAG, _draw_scalebar, - _expand_color_panels, _get_cs_contents, _get_elements_to_be_rendered, _get_extent_fast, _get_valid_cs, _get_wanted_render_elements, - _maybe_set_colors, _mpl_ax_contains_elements, - _prepare_cmap_norm, _prepare_params_plot, - _set_outline, - _validate_as_points_size, - _validate_graph_render_params, - _validate_image_render_params, - _validate_label_render_params, - _validate_points_render_params, - _validate_shape_render_params, - _validate_show_parameters, _verify_plotting_tree, save_fig, ) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 71b41e67..54fc2f9e 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -32,12 +32,38 @@ from xarray import DataTree from spatialdata_plot._logging import _log_context, logger +from spatialdata_plot.pl._color import ( + _align_outline_vector_to_length, + _apply_mask_to_outline_vectors, + _color_vector_to_rgba, + _get_colors_for_categorical_obs, + _get_linear_colormap, + _make_continuous_mappable, + _map_color_seg, + _maybe_set_colors, + _prepare_cmap_norm, + _set_color_source_vec, +) from spatialdata_plot.pl._datashader import ( + _ax_show_and_transform, _build_ds_colorbar, + _datashader_canvas_from_dataframe, + _get_extent_and_range_for_datashader_canvas, + _hex_no_alpha, + _prepare_transformation, _render_ds_image, _render_ds_outlines, _shade_datashader_aggregate, ) +from spatialdata_plot.pl._geometry import ( + _build_shape_patches, + _convert_shapes, + _get_collection_shape, + _validate_polygons, +) +from spatialdata_plot.pl._validate import ( + _check_obs_var_shadow, +) from spatialdata_plot.pl.render_params import ( ChannelLegendEntry, CmapParams, @@ -53,35 +79,15 @@ _DsReduction, ) from spatialdata_plot.pl.utils import ( - _align_outline_vector_to_length, - _apply_mask_to_outline_vectors, - _ax_show_and_transform, - _build_shape_patches, - _check_obs_var_shadow, - _color_vector_to_rgba, - _convert_shapes, - _datashader_canvas_from_dataframe, _decorate_axs, _fast_extent, - _get_collection_shape, - _get_colors_for_categorical_obs, - _get_extent_and_range_for_datashader_canvas, - _get_linear_colormap, - _hex_no_alpha, _join_table_for_element, - _make_continuous_mappable, - _map_color_seg, - _maybe_set_colors, _mpl_ax_contains_elements, _multiscale_to_spatial_image, _pixel_to_coord, - _prepare_cmap_norm, - _prepare_transformation, _rasterize_if_necessary, _rasterize_if_necessary_datashader, - _set_color_source_vec, _stream_label_centroid_stats, - _validate_polygons, ) _Normalize = Normalize | abc.Sequence[Normalize] diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index a02b1484..323d6c3e 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -1,75 +1,46 @@ from __future__ import annotations -import math import os import warnings -from collections import Counter, OrderedDict -from collections.abc import Callable, Mapping, Sequence -from copy import copy +from collections import OrderedDict +from collections.abc import Mapping, Sequence from functools import partial from pathlib import Path from typing import Any, Literal import dask import datashader as ds -import matplotlib -import matplotlib.patches as mpatches -import matplotlib.path as mpath import matplotlib.pyplot as plt -import matplotlib.ticker -import matplotlib.transforms as mtransforms import numpy as np -import numpy.ma as ma -import numpy.typing as npt import pandas as pd -import shapely import spatialdata as sd from anndata import AnnData -from cycler import Cycler, cycler from dask.array.core import slices_from_chunks -from datashader.core import Canvas from geopandas import GeoDataFrame from matplotlib import colors, patheffects, rcParams from matplotlib.axes import Axes -from matplotlib.cm import ScalarMappable from matplotlib.collections import PatchCollection from matplotlib.colors import ( - ColorConverter, Colormap, - LinearSegmentedColormap, ListedColormap, - Normalize, - to_rgba, ) from matplotlib.figure import Figure from matplotlib.gridspec import GridSpec -from matplotlib.transforms import CompositeGenericTransform from matplotlib_scalebar.scalebar import ScaleBar -from numpy.ma.core import MaskedArray -from numpy.random import default_rng -from pandas.api.types import CategoricalDtype, is_bool_dtype, is_numeric_dtype, is_string_dtype +from pandas.api.types import CategoricalDtype, is_numeric_dtype from pandas.core.arrays.categorical import Categorical from scanpy import settings from scanpy.plotting._tools.scatterplots import _add_categorical_legend -from scanpy.plotting._utils import add_colors_for_categorical_sample_annotation -from scanpy.plotting.palettes import default_20, default_28, default_102 -from scipy.spatial import ConvexHull -from shapely.errors import GEOSException -from skimage.color import label2rgb -from skimage.morphology import erosion, footprint_rectangle -from skimage.util import map_array from spatialdata import ( SpatialData, get_element_annotators, get_extent, - get_values, join_spatialelement_table, rasterize, ) from spatialdata import ( deepcopy as sd_deepcopy, ) -from spatialdata._core.query.relational_query import _locate_value from spatialdata._types import ArrayLike from spatialdata.models import ( Image2DModel, @@ -81,33 +52,25 @@ get_table_keys, ) from spatialdata.transformations.operations import get_transformation -from spatialdata.transformations.transformations import Scale, Translation -from spatialdata.transformations.transformations import Sequence as TransformSequence from xarray import DataArray, DataTree from spatialdata_plot._logging import logger from spatialdata_plot.pl.render_params import ( - CmapParams, Color, ColorbarSpec, - ColorLike, FigParams, GraphRenderParams, ImageRenderParams, LabelsRenderParams, - OutlineParams, PointsRenderParams, ScalebarParams, ShapesRenderParams, - _DsReduction, _FontSize, _FontWeight, ) to_hex = partial(colors.to_hex, keep_alpha=True) -_GROUPS_IGNORED_WARNING = "Parameter 'groups' is ignored when 'color' is a literal color, not a column name." - _RENDER_CMD_TO_CS_FLAG: dict[str, str] = { "render_images": "has_images", "render_shapes": "has_shapes", @@ -116,50 +79,6 @@ } -def _check_obs_var_shadow( - sdata: SpatialData | None, - element_name: str | None, - value_to_plot: str | None, - table_name: str | None, -) -> None: - """Raise if ``value_to_plot`` exists in both ``table.obs.columns`` and ``table.var_names``. - - Upstream ``_get_table_origins`` uses an ``elif`` chain, so a key that lives in - both locations is silently resolved to ``obs`` — masking the user's likely - intent of plotting gene expression. Catch this here before any value fetch. - Any ``None`` parameter short-circuits the check. - """ - if ( - value_to_plot is None - or table_name is None - or element_name is None - or sdata is None - or table_name not in sdata.tables - ): - return - if table_name not in get_element_annotators(sdata, element_name): - return - table = sdata.tables[table_name] - if value_to_plot in table.obs.columns and value_to_plot in table.var_names: - raise ValueError( - f"`color={value_to_plot!r}` is ambiguous: it exists in both " - f"`table[{table_name!r}].obs.columns` and `table[{table_name!r}].var_names`. " - "Rename one of them (or drop the obs column) so the intended source is unambiguous." - ) - - -def _gate_palette_and_groups( - element_params: dict[str, Any], - param_dict: dict[str, Any], -) -> None: - """Set palette/groups on element_params only when col_for_color is present, else warn.""" - has_col = element_params.get("col_for_color") is not None - element_params["palette"] = param_dict["palette"] if has_col else None - if not has_col and param_dict["groups"] is not None: - logger.warning(_GROUPS_IGNORED_WARNING) - element_params["groups"] = param_dict["groups"] if has_col else None - - def _extract_scalar_value(value: Any, default: float = 0.0) -> float: """ Extract a scalar float value from various data types. @@ -257,52 +176,6 @@ def _get_coordinate_system_mapping(sdata: SpatialData) -> dict[str, list[str]]: _MPL_SINGLE_LETTER_COLORS = frozenset("bgrcmykw") -def _is_color_like(color: Any) -> bool: - """Check if a value is a valid color. - - We reject several matplotlib shorthand notations that are likely to collide - with column or gene names. For discussion, see: - - - https://github.com/scverse/spatialdata-plot/issues/211 - - https://github.com/scverse/spatialdata-plot/issues/327 - - Rejected shorthands: - - - Greyscale strings: ``"0"``, ``"0.5"``, ``"1"`` (floats in [0, 1]) - - Short hex: ``"#RGB"`` / ``"#RGBA"`` (only ``#RRGGBB`` / ``#RRGGBBAA`` accepted) - - Single-letter colors: ``"b"``, ``"g"``, ``"r"``, ``"c"``, ``"m"``, ``"y"``, ``"k"``, ``"w"`` - - CN cycle notation: ``"C0"``, ``"C1"``, … - - ``tab:`` prefixed colors: ``"tab:blue"``, ``"tab:orange"``, … - - ``xkcd:`` prefixed colors: ``"xkcd:sky blue"``, … - """ - if isinstance(color, str): - # greyscale strings - try: - num_value = float(color) - if 0 <= num_value <= 1: - return False - except ValueError: - pass - - # short hex - if color.startswith("#") and len(color) not in [7, 9]: - return False - - # single-letter color shortcuts - if color in _MPL_SINGLE_LETTER_COLORS: - return False - - # CN cycle notation (C0, C1, …) - if len(color) >= 2 and color[0] == "C" and color[1:].isdigit(): - return False - - # tab: and xkcd: prefixed colors - if color.startswith(("tab:", "xkcd:")): - return False - - return bool(colors.is_color_like(color)) - - def _prepare_params_plot( # this param is inferred when `pl.show`` is called num_panels: int, @@ -418,30 +291,6 @@ def _get_cs_contents(sdata: sd.SpatialData) -> pd.DataFrame: return cs_contents -def _get_centroid_of_pathpatch(pathpatch: mpatches.PathPatch) -> tuple[float, float]: - # Extract the vertices from the PathPatch - path = pathpatch.get_path() - vertices = path.vertices - x = vertices[:, 0] - y = vertices[:, 1] - - area = 0.5 * np.sum(x[:-1] * y[1:] - x[1:] * y[:-1]) - - # Calculate the centroid coordinates - centroid_x = np.sum((x[:-1] + x[1:]) * (x[:-1] * y[1:] - x[1:] * y[:-1])) / (6 * area) - centroid_y = np.sum((y[:-1] + y[1:]) * (x[:-1] * y[1:] - x[1:] * y[:-1])) / (6 * area) - - return centroid_x, centroid_y - - -def _scale_pathpatch_around_centroid(pathpatch: mpatches.PathPatch, scale_factor: float) -> None: - scale_value = _extract_scalar_value(scale_factor, default=1.0) - centroid = _get_centroid_of_pathpatch(pathpatch) - vertices = pathpatch.get_path().vertices - scaled_vertices = np.array([centroid + (vertex - centroid) * scale_value for vertex in vertices]) - pathpatch.get_path().vertices = scaled_vertices - - def _join_table_for_element( sdata: sd.SpatialData, element: str, @@ -490,355 +339,6 @@ def _join_table_for_element( return element_dict[element], joined_table -def _make_continuous_mappable(vmin: float, vmax: float, cmap: Any) -> ScalarMappable: - """Build a ``ScalarMappable`` for a continuous colorbar, with a ±0.5 fallback when ``vmin == vmax``.""" - if vmin == vmax: - vmin, vmax = vmin - 0.5, vmax + 0.5 - return ScalarMappable(norm=Normalize(vmin=vmin, vmax=vmax), cmap=cmap) - - -def _apply_mask_to_outline_vectors( - outline_color_vector: Any, - outline_color_source_vector: pd.Series | None, - mask: Any, -) -> tuple[Any, pd.Series | None]: - """Apply a boolean ``keep`` mask to outline color vector(s). - - Used to keep outline data aligned with the fill data after a ``groups`` - or rasterize-based filter is applied to the rendered element. - """ - arr = np.asarray(mask) - if outline_color_source_vector is not None: - outline_color_source_vector = outline_color_source_vector[arr] - return outline_color_vector[arr], outline_color_source_vector - - -def _align_outline_vector_to_length( - outline_color_vector: Any, - outline_color_source_vector: pd.Series | None, - n: int, -) -> tuple[Any, pd.Series | None]: - """Pad or truncate the outline color vector(s) to length ``n``. - - Used when the outline column annotates a different row count than the rendered - element (cross-table case, or rasterize-induced label drop). Missing entries - are padded with NaN so downstream code maps them to ``na_color``. - """ - if outline_color_vector is None or len(outline_color_vector) == n: - return outline_color_vector, outline_color_source_vector - if len(outline_color_vector) > n: - if outline_color_source_vector is not None: - outline_color_source_vector = outline_color_source_vector[:n] - return outline_color_vector[:n], outline_color_source_vector - pad = n - len(outline_color_vector) - if outline_color_source_vector is not None: - # Categorical: downstream picks one hex per category from rows that *have* a - # category. NaN-padded rows contribute no category, so the per-row hex pad is - # immaterial; pad with NaN to skip the allocation. - padded_vec = np.concatenate([np.asarray(outline_color_vector), np.full(pad, np.nan, dtype=object)]) - outline_color_source_vector = pd.Categorical( - list(outline_color_source_vector) + [None] * pad, - categories=outline_color_source_vector.categories, - ) - else: - # Continuous: numeric vector, pad with NaN so cmap maps padded rows to na_color. - padded_vec = np.concatenate([np.asarray(outline_color_vector, dtype=float), np.full(pad, np.nan)]) - return padded_vec, outline_color_source_vector - - -def _color_vector_to_rgba( - color_vector: Any | None, - color_source_vector: pd.Series | None, - cmap_params: CmapParams, - n_rows: int, -) -> np.ndarray: - """Convert a fill/outline `color_vector` (categorical hex strings or continuous numerics) to (N, 4) RGBA. - - Mirrors the per-row mapping done inside :func:`_get_collection_shape` so that - callers can pre-materialize an outline-color array. NaN/non-finite entries are - painted with ``cmap_params.na_color``. - """ - na_rgba = colors.to_rgba(cmap_params.na_color.get_hex_with_alpha()) - if color_vector is None: - rgba = np.empty((n_rows, 4), dtype=float) - rgba[:] = na_rgba - return rgba - - if color_source_vector is not None: - # Categorical: color_vector contains hex strings aligned to color_source_vector - return np.asarray(ColorConverter().to_rgba_array(list(color_vector))) - - arr = np.asarray(color_vector) - if arr.ndim == 2 and arr.shape[1] in (3, 4) and np.issubdtype(arr.dtype, np.number): - return np.asarray(ColorConverter().to_rgba_array(arr)) - - rgba = np.empty((len(arr), 4), dtype=float) - rgba[:] = na_rgba - if np.issubdtype(arr.dtype, np.number): - finite_mask = np.isfinite(arr) - if finite_mask.any(): - norm = cmap_params.norm - if norm.vmin is None or norm.vmax is None: - vmin = float(np.nanmin(arr[finite_mask])) - vmax = float(np.nanmax(arr[finite_mask])) - if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: - vmin, vmax = 0.0, 1.0 - used_norm = Normalize(vmin=vmin, vmax=vmax, clip=False) - else: - used_norm = norm - rgba[finite_mask] = cmap_params.cmap(used_norm(arr[finite_mask])) - return rgba - - # Object dtype: mix of numerics and color-like specs (apply cmap to the numeric subset only) - series = pd.Series(arr, copy=False) - num = pd.to_numeric(series, errors="coerce").to_numpy() - is_num = np.isfinite(num) - if is_num.any(): - norm = cmap_params.norm - if norm.vmin is None or norm.vmax is None: - vmin = float(np.nanmin(num[is_num])) - vmax = float(np.nanmax(num[is_num])) - if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: - vmin, vmax = 0.0, 1.0 - used_norm = Normalize(vmin=vmin, vmax=vmax, clip=False) - else: - used_norm = norm - rgba[is_num] = cmap_params.cmap(used_norm(num[is_num])) - color_mask = (~is_num) & series.notna().to_numpy() - if color_mask.any(): - rgba[color_mask] = ColorConverter().to_rgba_array(series[color_mask].tolist()) - return rgba - - -def _normalize_geom(geom: Any) -> Any: - """Canonicalize ring orientation so matplotlib's fill rules render holes correctly. - - ``shapely.normalize`` (shapely>=2) is preferred; falls back to ``geom.normalize()``. - None/empty geometries and geometries that fail to normalize are returned unchanged. - """ - if geom is None or getattr(geom, "is_empty", False): - return geom - normalize_func = getattr(shapely, "normalize", None) - if callable(normalize_func): - try: - return normalize_func(geom) - except (GEOSException, TypeError, ValueError): - return geom - if hasattr(geom, "normalize"): - try: - return geom.normalize() - except (GEOSException, TypeError, ValueError): - return geom - return geom - - -def _build_shape_patches( - shapes: GeoDataFrame, - scale: float, -) -> tuple[list[mpatches.Patch], list[int], int]: - """Build matplotlib patches from shape geometries, once. - - Patch geometry is independent of colour/alpha, so it can be built a single time and - shared across the fill and outline ``PatchCollection``s in :func:`_render_shapes` - instead of being rebuilt per layer (the dominant cost for shape elements). - - Returns - ------- - patches - The matplotlib patches (a MultiPolygon expands to several patches). - patch_row_idx - For each patch, the index into the empty-filtered, re-indexed shapes — used to - look up the per-shape colour. - n_shapes - Number of shapes after empty filtering (used for the single-colour broadcast rule). - """ - df: GeoDataFrame | pd.DataFrame = shapes if isinstance(shapes, GeoDataFrame) else pd.DataFrame(shapes) - if "geometry" not in df.columns: - return [], [], 0 - - # Normalize ring orientation, then drop empty geometries (both vectorized; fall - # back to per-geometry normalization only if the bulk call rejects an input). - geom_array = df["geometry"].to_numpy() - try: - geom_array = shapely.normalize(geom_array) - except (GEOSException, TypeError, ValueError): - geom_array = np.array([_normalize_geom(g) for g in geom_array], dtype=object) - keep = ~shapely.is_empty(geom_array) - geoms = geom_array[keep] - radii = df["radius"].to_numpy()[keep] if "radius" in df.columns else None - - # Resolve the scale scalar once instead of per shape. - scale_value = _extract_scalar_value(scale, default=1.0) - - patches: list[mpatches.Patch] = [] - patch_row_idx: list[int] = [] - for i, geom in enumerate(geoms): - geom_type = geom.geom_type - if geom_type == "Polygon": - coords = np.asarray(geom.exterior.coords) - centroid = np.mean(coords, axis=0) - scaled = centroid + (coords - centroid) * scale_value - patches.append(mpatches.Polygon(scaled, closed=True)) - patch_row_idx.append(i) - elif geom_type == "MultiPolygon": - for m in _make_patch_from_multipolygon(geom): - _scale_pathpatch_around_centroid(m, scale_value) - patches.append(m) - patch_row_idx.append(i) - elif geom_type == "Point": - radius_value = _extract_scalar_value(radii[i], default=0.0) if radii is not None else 0.0 - patches.append(mpatches.Circle((geom.x, geom.y), radius=radius_value * scale_value)) - patch_row_idx.append(i) - - return patches, patch_row_idx, len(geoms) - - -def _get_collection_shape( - shapes: list[GeoDataFrame], - c: Any, - s: float, - norm: Any, - render_params: ShapesRenderParams, - fill_alpha: None | float = None, - outline_alpha: None | float = None, - outline_color: None | str | list[float] | np.ndarray = "white", - linewidth: float = 0.0, - prebuilt_patches: tuple[list[mpatches.Patch], list[int], int] | None = None, - **kwargs: Any, -) -> PatchCollection: - """ - Build a PatchCollection for shapes with correct handling of. - - - continuous numeric vectors with NaNs, - - per-row RGBA arrays, - - a single color or a list of color specs. - - Only NaNs are painted with na_color; finite values are mapped via norm+cmap. - - .. note:: - When ``outline_color`` is passed as an ``(N, 4)`` RGBA array of dtype ``float``, - its alpha channel is mutated in place to apply ``outline_alpha``. Pass a copy - if you need to retain the original buffer. - """ - cmap = kwargs["cmap"] - - # Resolve na color once - na_rgba = colors.to_rgba(render_params.cmap_params.na_color.get_hex_with_alpha()) - - # Try to interpret c as numpy array - c_arr = np.asarray(c) - fill_c: np.ndarray - - def _as_rgba_array(x: Any) -> np.ndarray: - return np.asarray(ColorConverter().to_rgba_array(x)) - - # Case A: per-row numeric colors given as Nx3 or Nx4 float array - if ( - c_arr.ndim == 2 - and c_arr.shape[0] == len(shapes) - and c_arr.shape[1] in (3, 4) - and np.issubdtype(c_arr.dtype, np.number) - ): - fill_c = _as_rgba_array(c_arr) - - # Case B: continuous numeric vector len == n_shapes (possibly with NaNs) - elif c_arr.ndim == 1 and len(c_arr) == len(shapes) and np.issubdtype(c_arr.dtype, np.number): - finite_mask = np.isfinite(c_arr) - - # Select or build a normalization that ignores NaNs for scaling - if isinstance(norm, Normalize): - used_norm: Normalize = norm - else: - if finite_mask.any(): - vmin = float(np.nanmin(c_arr[finite_mask])) - vmax = float(np.nanmax(c_arr[finite_mask])) - if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: - vmin, vmax = 0.0, 1.0 - else: - vmin, vmax = 0.0, 1.0 - used_norm = colors.Normalize(vmin=vmin, vmax=vmax, clip=False) - - # Map finite values through cmap(norm(.)); NaNs get na_color - fill_c = np.empty((len(c_arr), 4), dtype=float) - fill_c[:] = na_rgba - if finite_mask.any(): - fill_c[finite_mask] = cmap(used_norm(c_arr[finite_mask])) - - elif c_arr.ndim == 1 and len(c_arr) == len(shapes) and c_arr.dtype == object: - # Split into numeric vs color-like - c_series = pd.Series(c_arr, copy=False) - num = pd.to_numeric(c_series, errors="coerce").to_numpy() - is_num = np.isfinite(num) - - # init with na color - fill_c = np.empty((len(c_series), 4), dtype=float) - fill_c[:] = na_rgba - - # numeric entries via cmap(norm) - if is_num.any(): - if isinstance(norm, Normalize): - used_norm = norm - else: - vmin = float(np.nanmin(num[is_num])) if is_num.any() else 0.0 - vmax = float(np.nanmax(num[is_num])) if is_num.any() else 1.0 - if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: - vmin, vmax = 0.0, 1.0 - used_norm = colors.Normalize(vmin=vmin, vmax=vmax, clip=False) - fill_c[is_num] = cmap(used_norm(num[is_num])) - - # non-numeric, non-NaN entries as explicit colors - non_numeric_color_mask = (~is_num) & c_series.notna().to_numpy() - if non_numeric_color_mask.any(): - fill_c[non_numeric_color_mask] = ColorConverter().to_rgba_array(c_series[non_numeric_color_mask].tolist()) - - # Case C: single color or list of color-like specs (strings or tuples) - else: - fill_c = _as_rgba_array(c) - - # Apply optional fill alpha without destroying existing transparency - if fill_alpha is not None: - nonzero_alpha = fill_c[..., -1] > 0 - fill_c[nonzero_alpha, -1] = fill_alpha - - # Outline handling - if outline_alpha and outline_alpha > 0.0: - outline_arr = np.asarray(outline_color) if not isinstance(outline_color, str) else None - if outline_arr is not None and outline_arr.ndim == 2 and outline_arr.shape == (len(shapes), 4): - # Per-shape RGBA array. Mutate in place when already float so we don't allocate twice - # on the hot path; otherwise upcast to a fresh float buffer. - outline_c_array = outline_arr if outline_arr.dtype == float else outline_arr.astype(float) - else: - outline_c_array = _as_rgba_array(outline_color) - outline_c_array[..., -1] = outline_alpha - outline_c = outline_c_array.tolist() - else: - outline_c = [None] * fill_c.shape[0] - - # Build (or reuse) the matplotlib patches. Geometry is colour-independent, so the - # caller can build it once via `_build_shape_patches` and share it across the fill - # and outline collections instead of rebuilding it on every call. - patches, patch_row_idx, n_shapes = ( - prebuilt_patches if prebuilt_patches is not None else _build_shape_patches(shapes, s) - ) - - if not patches: - return PatchCollection([]) - - # Expand the per-shape fill colours to per-patch (a MultiPolygon owns several - # patches). Preserve the single-colour broadcast used for multi-shape elements. - broadcast_single = n_shapes > 1 and len(fill_c) == 1 - patch_fill = np.repeat(fill_c, len(patches), axis=0) if broadcast_single else fill_c[patch_row_idx] - - return PatchCollection( - patches, - snap=False, - lw=linewidth, - facecolor=patch_fill, - edgecolor=None if all(o is None for o in outline_c) else outline_c, - **kwargs, - ) - - def _panel_grid( num_panels: int, hspace: float, @@ -885,192 +385,98 @@ def _get_scalebar( return _scalebar_dx, _scalebar_units -def _prepare_cmap_norm( - cmap: Colormap | str | None = None, - norm: Normalize | None = None, - na_color: Color = Color(), -) -> CmapParams: - # TODO: check refactoring norm out here as it gets overwritten later - cmap_is_default = cmap is None - if cmap is None: - cmap = rcParams["image.cmap"] - if isinstance(cmap, str): - cmap = matplotlib.colormaps[cmap] - - cmap = copy(cmap) - - assert isinstance(cmap, Colormap), f"Invalid type of `cmap`: {type(cmap)}, expected `Colormap`." - - norm = Normalize(vmin=None, vmax=None, clip=False) if norm is None else copy(norm) - - cmap.set_bad(na_color.get_hex_with_alpha()) - - return CmapParams( - cmap=cmap, - norm=norm, - na_color=na_color, - cmap_is_default=cmap_is_default, - ) - - -def _set_outline( - outline_alpha: float | int | tuple[float | int, float | int] | None, - outline_width: int | float | tuple[float | int, float | int] | None, - outline_color: Color | tuple[Color, Color | None] | None, - **kwargs: Any, -) -> tuple[tuple[float, float], OutlineParams]: - """Create OutlineParams object for shapes, including possibility of double outline. - - Rules for outline rendering: - 1) outline_alpha always takes precedence if given by the user. - In absence of outline_alpha: - 2) If outline_color is specified and implying an alpha (e.g. RGBA array or #RRGGBBAA): that alpha is used - 3) If outline_color (w/o implying an alpha) and/or outline_width is specified: alpha of outlines set to 1.0 - """ - # A) User doesn't want to see outlines - if ( - outline_alpha == 0.0 - or (isinstance(outline_alpha, tuple) and np.all(np.array(outline_alpha) == 0.0)) - or not (outline_alpha or outline_width or outline_color) - ): - return (0.0, 0.0), OutlineParams(None, 1.5, None, 0.5) - - # B) User wants to see at least 1 outline - if isinstance(outline_width, tuple): - if len(outline_width) != 2: - raise ValueError( - f"Tuple of length {len(outline_width)} was passed for outline_width. When specifying multiple outlines," - " please pass a tuple of exactly length 2." - ) - if not outline_color: - outline_color = (Color("#000000"), Color("#ffffff")) - elif not isinstance(outline_color, tuple): - raise ValueError( - "No tuple was passed for outline_color, while two outlines were specified by using the outline_width " - "argument. Please specify the outline colors in a tuple of length two." - ) - - if isinstance(outline_color, tuple): - if len(outline_color) != 2: - raise ValueError( - f"Tuple of length {len(outline_color)} was passed for outline_color. When specifying multiple outlines," - " please pass a tuple of exactly length 2." - ) - if not outline_width: - outline_width = (1.5, 0.5) - elif not isinstance(outline_width, tuple): - raise ValueError( - "No tuple was passed for outline_width, while two outlines were specified by using the outline_color " - "argument. Please specify the outline widths in a tuple of length two." - ) - - if isinstance(outline_width, float | int): - outline_width = (outline_width, 0.0) - elif not outline_width: - outline_width = (1.5, 0.0) - if isinstance(outline_color, Color): - outline_color = (outline_color, None) - elif not outline_color: - outline_color = (Color("#000000ff"), None) - - assert isinstance(outline_color, tuple), "outline_color is not a tuple" # shut up mypy - assert isinstance(outline_width, tuple), "outline_width is not a tuple" - - for ow in outline_width: - if not isinstance(ow, int | float): - raise TypeError(f"Invalid type of `outline_width`: {type(ow)}, expected `int` or `float`.") - - if outline_alpha: - if isinstance(outline_alpha, int | float): - # for a single outline: second width value is 0.0 - outline_alpha = (outline_alpha, 0.0) if outline_width[1] == 0.0 else (outline_alpha, outline_alpha) - else: - # if alpha wasn't explicitly specified by the user - outer_ol_alpha = outline_color[0].get_alpha_as_float() if isinstance(outline_color[0], Color) else 1.0 - inner_ol_alpha = outline_color[1].get_alpha_as_float() if isinstance(outline_color[1], Color) else 1.0 - outline_alpha = (outer_ol_alpha, inner_ol_alpha) - - # handle possible linewidths of 0.0 => outline won't be rendered in the first place - if outline_width[0] == 0.0: - outline_alpha = (0.0, outline_alpha[1]) - if outline_width[1] == 0.0: - outline_alpha = (outline_alpha[0], 0.0) - - if outline_alpha[0] > 0.0 or outline_alpha[1] > 0.0: - kwargs.pop("edgecolor", None) # remove edge from kwargs if present - kwargs.pop("alpha", None) # remove alpha from kwargs if present - - return outline_alpha, OutlineParams( - outline_color[0], - outline_width[0], - outline_color[1], - outline_width[1], - ) +def _build_alignment_dtype_hint( + sdata: sd.SpatialData | None, + element: object, + color_series: pd.Series, + table_name: str | None, +) -> str: + """Build a diagnostic hint string for dtype mismatches between element and table indices.""" + el_dtype = getattr(getattr(element, "index", None), "dtype", None) + if el_dtype is None or table_name is None or sdata is None or table_name not in sdata.tables: + return "" + try: + _, _, instance_key = get_table_keys(sdata.tables[table_name]) + except (KeyError, ValueError): + return "" + tbl_dtype = sdata.tables[table_name].obs[instance_key].dtype + if el_dtype != tbl_dtype: + return f" (hint: element index dtype is {el_dtype}, '{instance_key}' dtype is {tbl_dtype})" + return "" -def _get_colors_for_categorical_obs( - categories: Sequence[str | int], +def _decorate_axs( + ax: Axes, + cax: PatchCollection, + fig_params: FigParams, + value_to_plot: str | None, + color_source_vector: pd.Series[CategoricalDtype] | Categorical, + color_vector: pd.Series[CategoricalDtype] | Categorical, + adata: AnnData | None = None, palette: ListedColormap | str | list[str] | None = None, alpha: float = 1.0, - cmap_params: CmapParams | None = None, -) -> list[str]: - """ - Return a list of colors for a categorical observation. - - Parameters - ---------- - adata - AnnData object - value_to_plot - Name of a valid categorical observation - categories - categories of the categorical observation. + na_color: Color = Color("default"), + legend_fontsize: int | float | _FontSize | None = None, + legend_fontweight: int | _FontWeight = "bold", + legend_loc: str | None = "right margin", + legend_fontoutline: int | None = None, + na_in_legend: bool = True, + colorbar: bool = True, + colorbar_params: dict[str, object] | None = None, + colorbar_requests: list[ColorbarSpec] | None = None, + colorbar_label: str | None = None, + legend_title: str | None = None, +) -> Axes: + if value_to_plot is not None: + # if only dots were plotted without an associated value + # there is not need to plot a legend or a colorbar - Returns - ------- - None - """ - len_cat = len(categories) - - # check if default matplotlib palette has enough colors - if palette is None: - if cmap_params is not None and not cmap_params.cmap_is_default: - palette = cmap_params.cmap - elif len(rcParams["axes.prop_cycle"].by_key()["color"]) >= len_cat: - cc = rcParams["axes.prop_cycle"]() - palette = [next(cc)["color"] for _ in range(len_cat)] - elif len_cat <= 20: - palette = default_20 - elif len_cat <= 28: - palette = default_28 - elif len_cat <= len(default_102): # 103 colors - palette = default_102 + if legend_fontoutline is not None: + path_effect = [patheffects.withStroke(linewidth=legend_fontoutline, foreground="w")] else: - palette = ["grey" for _ in range(len_cat)] - logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") - else: - # raise error when user didn't provide the right number of colors in palette - if isinstance(palette, list) and len(palette) != len(categories): - raise ValueError( - f"The number of provided values in the palette ({len(palette)}) doesn't agree with the number of " - f"categories that should be colored ({categories})." - ) + path_effect = [] - # otherwise, single channels turn out grey - color_idx = np.linspace(0, 1, len_cat) if len_cat > 1 else [0.7] - - if isinstance(palette, str): - palette = [to_hex(palette)] - elif isinstance(palette, list): - palette = [to_hex(x) for x in palette] - elif isinstance(palette, ListedColormap): - palette = [to_hex(x) for x in palette(color_idx, alpha=alpha)] - elif isinstance(palette, LinearSegmentedColormap): - palette = [to_hex(palette(x, alpha=alpha)) for x in color_idx] # type: ignore[attr-defined] - else: - raise TypeError(f"Palette is {type(palette)} but should be string or list.") + # Adding legends + if color_source_vector is not None and isinstance(color_source_vector.dtype, pd.CategoricalDtype): + # order of clusters should agree to palette order + clusters = color_source_vector.remove_unused_categories().unique() + clusters = clusters[~clusters.isnull()] + # derive mapping from color_source_vector and color_vector + group_to_color_matching = pd.DataFrame( + { + "cats": color_source_vector.remove_unused_categories(), + "color": color_vector, + } + ) + color_mapping = group_to_color_matching.drop_duplicates("cats").set_index("cats")["color"].to_dict() + _add_categorical_legend( + ax, + pd.Categorical(values=color_source_vector, categories=clusters), + palette=color_mapping, + legend_loc=legend_loc, + legend_fontweight=legend_fontweight, + legend_fontsize=legend_fontsize, + legend_fontoutline=path_effect, + na_color=[na_color.get_hex()], + na_in_legend=na_in_legend, + multi_panel=fig_params.axs is not None, + ) + # scanpy's helper doesn't accept a title; set it post-hoc so the user can + # disambiguate fill vs outline when both legends are drawn. + if legend_title is not None and (legend := ax.get_legend()) is not None: + legend.set_title(legend_title) + elif colorbar and colorbar_requests is not None and cax is not None: + colorbar_requests.append( + ColorbarSpec( + ax=ax, + mappable=cax, + params=colorbar_params, + label=colorbar_label, + alpha=alpha, + ) + ) - return palette[:len_cat] # type: ignore[return-value] + return ax def _format_element_names(element_name: list[str] | str | None) -> str: @@ -1169,946 +575,54 @@ def _validate_table_instance_uniqueness( ) -def _infer_color_data_kind( - series: pd.Series, - value_to_plot: str, - element_name: list[str] | str | None, - table_name: str | None, - warn_on_object_to_categorical: bool = False, -) -> tuple[Literal["numeric", "categorical"], pd.Series | pd.Categorical]: - element_label = _format_element_name(element_name) - - if isinstance(series.dtype, pd.CategoricalDtype): - return "categorical", pd.Categorical(series) - - if is_bool_dtype(series.dtype): - return "numeric", series.astype(float) - - if is_numeric_dtype(series.dtype): - return "numeric", pd.to_numeric(series, errors="coerce") - - if is_string_dtype(series.dtype) or series.dtype == object: - non_na = series[~pd.isna(series)] - if len(non_na) == 0: - return "numeric", pd.to_numeric(series, errors="coerce") - - numeric_like = pd.to_numeric(non_na, errors="coerce") - has_numeric = numeric_like.notna().any() - has_non_numeric = numeric_like.isna().any() - - if has_numeric and has_non_numeric: - invalid_examples = non_na[numeric_like.isna()].astype(str).unique()[:3] - location = f" in table '{table_name}'" if table_name is not None else "" - raise TypeError( - f"Column '{value_to_plot}' for element '{element_label}'{location} contains both numeric and " - f"non-numeric values (e.g. {', '.join(invalid_examples)}). " - "Please ensure that the column stores consistent data." - ) +def _get_list( + var: Any, + _type: type[Any] | tuple[type[Any], ...], + ref_len: int | None = None, + name: str | None = None, +) -> list[Any]: + """ + Get a list from a variable. - if has_numeric: - return "numeric", pd.to_numeric(series, errors="coerce") + Parameters + ---------- + var + Variable to convert to a list. + _type + Type of the elements in the list. + ref_len + Reference length of the list. + name + Name of the variable. - if warn_on_object_to_categorical: - logger.warning( - f"Converting copy of '{value_to_plot}' column to categorical dtype for categorical plotting. " - "Consider converting before plotting." + Returns + ------- + List + """ + if isinstance(var, _type): + return [var] if ref_len is None else ([var] * ref_len) + if isinstance(var, list): + if ref_len is not None and ref_len != len(var): + raise ValueError( + f"Variable: `{name}` has length: {len(var)}, which is not equal to reference length: {ref_len}." ) + for v in var: + if not isinstance(v, _type): + raise ValueError(f"Variable: `{name}` has invalid type: {type(v)}, expected: {_type}.") + return var - return "categorical", pd.Categorical(series) + raise ValueError(f"Can't make a list from variable: `{var}`") - return "numeric", pd.to_numeric(series, errors="coerce") - -def _build_alignment_dtype_hint( - sdata: sd.SpatialData | None, - element: object, - color_series: pd.Series, - table_name: str | None, -) -> str: - """Build a diagnostic hint string for dtype mismatches between element and table indices.""" - el_dtype = getattr(getattr(element, "index", None), "dtype", None) - if el_dtype is None or table_name is None or sdata is None or table_name not in sdata.tables: - return "" - try: - _, _, instance_key = get_table_keys(sdata.tables[table_name]) - except (KeyError, ValueError): - return "" - tbl_dtype = sdata.tables[table_name].obs[instance_key].dtype - if el_dtype != tbl_dtype: - return f" (hint: element index dtype is {el_dtype}, '{instance_key}' dtype is {tbl_dtype})" - return "" - - -def _extract_color_column( - table: AnnData, - value_key: str, - *, - origin: str, - element: GeoDataFrame, - element_name: str, - table_layer: str | None = None, -) -> pd.Series: - """Read one color column from ``table`` aligned to ``element`` order, without copying the table. - - Equivalent to ``get_values(value_key, sdata=..., element_name=..., table_name=...)[value_key]`` but - skips the table->element join, whose ``table[indices, :].copy()`` does an expensive out-of-order - sparse CSR row-gather. Restricts to rows annotating ``element_name`` (via ``region_key``), then - reindexes to the element's instance order (``NaN`` for instances with no table row), preserving the - categorical dtype of ``obs`` columns so the downstream legend path is unchanged. - """ - attrs = table.uns["spatialdata_attrs"] - region_key, instance_key = attrs["region_key"], attrs["instance_key"] - mask = table.obs[region_key].to_numpy() == element_name - inst = table.obs[instance_key].to_numpy()[mask] - if origin == "var": - source = table.layers[table_layer] if table_layer is not None else table.X - col = source[:, table.var_names.get_loc(value_key)] - col = np.asarray(col.todense()).ravel() if hasattr(col, "todense") else np.asarray(col).ravel() - values = pd.Series(col[mask], index=inst) - else: # obs column; .values keeps a Categorical categorical so the legend path still sees one - values = pd.Series(table.obs[value_key].values[mask], index=inst) - return values.reindex(element.index) - - -def _set_color_source_vec( - sdata: sd.SpatialData, - element: SpatialElement | None, - value_to_plot: str | None, - na_color: Color, - element_name: list[str] | str | None = None, - groups: list[str] | str | None = None, - palette: dict[str, str] | list[str] | str | None = None, - cmap_params: CmapParams | None = None, - alpha: float = 1.0, - table_name: str | None = None, - table_layer: str | None = None, - render_type: Literal["points", "labels"] | None = None, - coordinate_system: str | None = None, - preloaded_color_data: pd.Series | None = None, -) -> tuple[ArrayLike | pd.Series | None, ArrayLike, bool]: - if value_to_plot is None and element is not None: - color = np.full(len(element), na_color.get_hex_with_alpha()) - return color, color, False - - # Figure out where to get the color from - origins = _locate_value( - value_key=value_to_plot, - sdata=sdata, - element_name=element_name, - table_name=table_name, - ) - - # When both the element's own dataframe and the chosen table contain a - # column with this name, an explicit `table_name=` resolves the ambiguity — - # keep only the table origin and skip the multi-origin error below. - explicit_table_shadows_df = table_name is not None and any(o.origin == "df" for o in origins) - if explicit_table_shadows_df: - origins = [o for o in origins if o.origin != "df"] - - if len(origins) > 1: - raise ValueError( - f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. " - "Please keep it in exactly one place (preferably on the points parquet for speed) to avoid ambiguity." - ) - - if len(origins) == 1 and value_to_plot is not None: - if table_name is not None: - _ensure_one_to_one_mapping( - sdata=sdata, - element=element, - element_name=element_name, - table_name=table_name, - ) - if preloaded_color_data is not None: - color_source_vector = preloaded_color_data - elif ( - isinstance(element, GeoDataFrame) - and isinstance(element_name, str) - and table_name is not None - and table_name in sdata.tables - and origins[0].origin in ("obs", "var") - ): - # Fast path: read the single aligned column directly instead of joining/copying the - # whole annotating table (the join's out-of-order sparse row-gather dominates large renders). - color_source_vector = _extract_color_column( - sdata[table_name], - value_to_plot, - origin=origins[0].origin, - element=element, - element_name=element_name, - table_layer=table_layer, - ) - elif explicit_table_shadows_df: - # Pass the table as `element` so upstream `get_values` skips the - # element-column lookup and avoids the multi-origin error. - color_source_vector = get_values( - value_key=value_to_plot, - element=sdata[table_name], - element_name=element_name, - table_layer=table_layer, - )[value_to_plot] - else: - color_source_vector = get_values( - value_key=value_to_plot, - sdata=sdata, - element_name=element_name, - table_name=table_name, - table_layer=table_layer, - )[value_to_plot] - - color_series = ( - color_source_vector if isinstance(color_source_vector, pd.Series) else pd.Series(color_source_vector) - ) - - if color_series.isna().all(): - element_label = _format_element_name(element_name) - dtype_hint = _build_alignment_dtype_hint(sdata, element, color_series, table_name) - hint_suffix = f" {dtype_hint.strip()}" if dtype_hint else "" - logger.warning( - f"Column '{value_to_plot}' for element '{element_label}' contains only NaN values; " - f"rendering with na_color.{hint_suffix}" - ) - na_color_arr = np.full(len(color_series), na_color.get_hex_with_alpha()) - return na_color_arr, na_color_arr, False - - kind, processed = _infer_color_data_kind( - series=color_series, - value_to_plot=value_to_plot, - element_name=element_name, - table_name=table_name, - warn_on_object_to_categorical=table_name is not None, - ) - - if kind == "numeric": - numeric_vector = processed - if ( - not isinstance(element, GeoDataFrame) - and isinstance(palette, list) - and palette[0] is not None - or isinstance(element, GeoDataFrame) - and isinstance(palette, list) - ): - logger.warning( - "Ignoring categorical palette which is given for a continuous variable. " - "Consider using `cmap` to pass a ColorMap." - ) - return None, numeric_vector, False - - assert isinstance(processed, pd.Categorical) - if not processed.ordered: - # ensure deterministic category order when the source is unordered (e.g., from a Python set) - processed = processed.reorder_categories(sorted(processed.categories)) - color_source_vector = processed # convert, e.g., `pd.Series` - - # When the value lives on the element's own DataFrame (origin="df"), - # there is no reason to look up a table for .uns colors. - value_from_element = origins[0].origin == "df" - - # Use the provided table_name parameter, fall back to only one present - table_to_use: str | None - if value_from_element: - table_to_use = None - elif table_name is not None and table_name in sdata.tables: - table_to_use = table_name - elif table_name is not None and table_name not in sdata.tables: - logger.warning(f"Table '{table_name}' not found in `sdata.tables`. Falling back to default behavior.") - table_to_use = None - else: - table_keys = list(sdata.tables.keys()) - if len(table_keys) == 1: - table_to_use = table_keys[0] - elif len(table_keys) > 1: - table_to_use = table_keys[0] - logger.warning(f"No table name provided, using '{table_to_use}' as fallback for color mapping.") - else: - table_to_use = None - - adata_for_mapping = sdata[table_to_use] if table_to_use is not None else None - - # Check if custom colors exist in the resolved table's .uns slot - if ( - value_to_plot is not None - and table_to_use is not None - and _has_colors_in_uns(sdata, table_to_use, value_to_plot) - ): - # Extract colors directly from the table's .uns slot - # Convert Color to ColorLike (str) for the function - na_color_like: ColorLike = na_color.get_hex() if isinstance(na_color, Color) else na_color - color_mapping = _extract_colors_from_table_uns( - sdata=sdata, - table_name=table_to_use, - col_to_colorby=value_to_plot, - color_source_vector=color_source_vector, - na_color=na_color_like, - ) - if color_mapping is not None: - if isinstance(palette, str): - palette = [palette] - color_mapping = _modify_categorical_color_mapping( - mapping=color_mapping, - groups=groups, - palette=palette, - ) - else: - logger.warning(f"Failed to extract colors for '{value_to_plot}', falling back to default mapping.") - # Fall back to the existing method if extraction fails - color_mapping = _get_categorical_color_mapping( - adata=sdata[table_to_use], - cluster_key=value_to_plot, - color_source_vector=color_source_vector, - cmap_params=cmap_params, - alpha=alpha, - groups=groups, - palette=palette, - na_color=na_color, - render_type=render_type, - ) - else: - color_mapping = None - - if color_mapping is None: - # Use the existing color mapping method - color_mapping = _get_categorical_color_mapping( - adata=adata_for_mapping, - cluster_key=value_to_plot, - color_source_vector=color_source_vector, - cmap_params=cmap_params, - alpha=alpha, - groups=groups, - palette=palette, - na_color=na_color, - render_type=render_type, - ) - - color_source_vector = color_source_vector.set_categories(color_mapping.keys()) - if color_mapping is None: - raise ValueError("Unable to create color palette.") - - # do not rename categories, as colors need not be unique - # pd.Categorical.map() demotes to object dtype when mapped values aren't unique - # (e.g. two categories share a color). Wrapping back in pd.Categorical ensures - # downstream consumers always receive a Categorical for categorical data. - color_vector = pd.Categorical(color_source_vector.map(color_mapping, na_action="ignore")) - # nan handling: only add the NA category if needed, and store it as a hex string - na_color_hex = na_color.get_hex_with_alpha() if isinstance(na_color, Color) else str(na_color) - if color_vector.isna().any(): - if na_color_hex not in color_vector.categories: - color_vector = color_vector.add_categories(na_color_hex) - color_vector[pd.isna(color_vector)] = na_color_hex - - return color_source_vector, color_vector, True - - if table_name is None: - raise KeyError( - f"Unable to locate color key '{value_to_plot}' for element '{element_name}'. " - "Please ensure the key exists in a table annotating this element." - ) - raise KeyError( - f"Unable to locate color key '{value_to_plot}' in table '{table_name}' for element '{element_name}'." - ) - - -def _map_color_seg( - seg: ArrayLike, - cell_id: ArrayLike, - color_vector: ArrayLike | pd.Series[CategoricalDtype], - color_source_vector: pd.Series[CategoricalDtype], - cmap_params: CmapParams, - na_color: Color, - seg_erosionpx: int | None = None, - seg_boundaries: bool = False, - outline_color: Color | None = None, - outline_color_vector: ArrayLike | pd.Series[CategoricalDtype] | None = None, - outline_color_source_vector: pd.Series[CategoricalDtype] | None = None, -) -> ArrayLike: - cell_id = np.array(cell_id) - - if isinstance(color_vector.dtype, pd.CategoricalDtype): - # Case A: users wants to plot a categorical column - val_im: ArrayLike = map_array(seg, cell_id, color_vector.codes + 1) - cols = colors.to_rgba_array(color_vector.categories) - elif pd.api.types.is_numeric_dtype(color_vector.dtype): - # Case B: user wants to plot a continous column - if isinstance(color_vector, pd.Series): - color_vector = color_vector.to_numpy() - # normalize only the not nan values, else the whole array would contain only nan values - normed_color_vector = color_vector.copy().astype(float) - normed_color_vector[~np.isnan(normed_color_vector)] = cmap_params.norm( - normed_color_vector[~np.isnan(normed_color_vector)] - ) - cols = cmap_params.cmap(normed_color_vector) - val_im = map_array(seg, cell_id, cell_id) - else: - # Case C: User didn't specify any colors - if color_source_vector is not None and ( - set(color_vector) == set(color_source_vector) - and len(set(color_vector)) == 1 - and set(color_vector) == {na_color.get_hex_with_alpha()} - and not na_color.color_modified_by_user() - ): - val_im = map_array(seg, cell_id, cell_id) - RNG = default_rng(42) - cols = RNG.random((len(color_vector), 3)) - else: - # Case D: User didn't specify a column to color by, but modified the na_color - val_im = map_array(seg, cell_id, cell_id) - first_value = color_vector.iloc[0] if isinstance(color_vector, pd.Series) else color_vector[0] - if _is_color_like(first_value): - # we have color-like values (e.g., hex or named colors) - assert all(_is_color_like(c) for c in color_vector), "Not all values are color-like." - cols = colors.to_rgba_array(color_vector) - else: - cols = cmap_params.cmap(cmap_params.norm(color_vector)) - - if seg_erosionpx is not None: - val_im[val_im == erosion(val_im, footprint_rectangle((seg_erosionpx, seg_erosionpx)))] = 0 - - if seg_boundaries and outline_color_vector is not None: - # Column-driven outline: build per-label colors from the outline vector and overlay - # on the eroded ring. Two cases (mirroring _set_color_source_vec's return contract): - # - categorical: outline_color_source_vector is the source Categorical; outline_color_vector - # holds hex strings aligned to cells. - # - continuous: outline_color_source_vector is None; outline_color_vector is numeric. - if outline_color_source_vector is not None: - cat = pd.Categorical(outline_color_source_vector) - cat_codes = cat.codes - outline_val_im: ArrayLike = map_array(seg, cell_id, cat_codes + 1) - color_arr = np.asarray(outline_color_vector, dtype=object) - # Pick the first per-cell hex for each category in one vectorized pass - # (avoids `K × O(N)` Python loops on large label sets). - cat_colors: list[Any] = [na_color.get_hex_with_alpha()] * len(cat.categories) - unique_codes, first_indices = np.unique(cat_codes, return_index=True) - for code, idx in zip(unique_codes, first_indices, strict=True): - if code >= 0: - cat_colors[code] = color_arr[idx] - outline_cols = colors.to_rgba_array(cat_colors) - else: - # Continuous: numeric values normalized via cmap - ov = ( - outline_color_vector.to_numpy() - if isinstance(outline_color_vector, pd.Series) - else np.asarray(outline_color_vector) - ) - normed = ov.copy().astype(float) - finite = ~np.isnan(normed) - if finite.any(): - normed[finite] = cmap_params.norm(normed[finite]) - outline_cols = cmap_params.cmap(normed) - outline_val_im = map_array(seg, cell_id, cell_id) - if seg_erosionpx is not None: - outline_val_im[ - outline_val_im == erosion(outline_val_im, footprint_rectangle((seg_erosionpx, seg_erosionpx))) - ] = 0 - outline_seg_im = label2rgb( - label=outline_val_im, - colors=outline_cols, - bg_label=0, - bg_color=(1, 1, 1), - image_alpha=0, - ) - outline_mask = val_im > 0 - alpha_channel = outline_mask.astype(float) - return np.dstack((outline_seg_im, alpha_channel)) - - if seg_boundaries and outline_color is not None: - # Uniform outline color requested: skip label2rgb, build RGBA directly - outline_rgba = colors.to_rgba(outline_color.get_hex_with_alpha()) - outline_mask = val_im > 0 - rgba = np.zeros((*val_im.shape, 4), dtype=float) - rgba[outline_mask, :3] = outline_rgba[:3] - rgba[outline_mask, 3] = outline_rgba[3] - return rgba - - seg_im: ArrayLike = label2rgb( - label=val_im, - colors=cols, - bg_label=0, - bg_color=(1, 1, 1), # transparency doesn't really work - image_alpha=0, - ) - - if seg_boundaries: - # Data-driven outline: use seg_im colors on the eroded ring, transparent elsewhere - outline_mask = val_im > 0 - alpha_channel = outline_mask.astype(float) - return np.dstack((seg_im, alpha_channel)) - - if len(val_im.shape) != len(seg_im.shape): - val_im = np.expand_dims((val_im > 0).astype(int), axis=-1) - return np.dstack((seg_im, val_im)) - - -def _generate_base_categorial_color_mapping( - adata: AnnData | None, - cluster_key: str, - color_source_vector: ArrayLike | pd.Series[CategoricalDtype], - na_color: Color, - cmap_params: CmapParams | None = None, -) -> Mapping[str, str]: - if adata is not None and cluster_key in adata.uns and f"{cluster_key}_colors" in adata.uns: - all_colors = adata.uns[f"{cluster_key}_colors"] - - # When plotting per-coordinate-system, the color_source_vector may carry - # categories from other coordinate systems that aren't present in the - # current subset. Drop them so that categories and colors stay aligned. - color_source_vector = color_source_vector.remove_unused_categories() - - # The stored colors in .uns correspond 1-to-1 to the *full* set of - # categories in adata.obs[cluster_key]. Subset to the categories that - # are still present after removing unused ones. - if cluster_key in adata.obs and hasattr(adata.obs[cluster_key], "cat"): - all_cats = adata.obs[cluster_key].cat.categories.tolist() - keep_idx = [i for i, c in enumerate(all_cats) if c in color_source_vector.categories] - colors = [to_hex(to_rgba(all_colors[i])[:3]) for i in keep_idx] - else: - colors = [to_hex(to_rgba(c)[:3]) for c in all_colors] - - categories = color_source_vector.categories.tolist() + ["NaN"] - - if len(categories) > len(colors): - return dict(zip(categories, colors + [na_color.get_hex_with_alpha()], strict=True)) - - return dict(zip(categories, colors, strict=True)) - - return _get_default_categorial_color_mapping(color_source_vector=color_source_vector, cmap_params=cmap_params) - - -def _has_colors_in_uns( - sdata: sd.SpatialData, - table_name: str | None, - col_to_colorby: str, -) -> bool: - """ - Check if _colors exists in the specified table's .uns slot. - - Parameters - ---------- - sdata - SpatialData object containing tables - table_name - Name of the table to check. If None, uses the first available table. - col_to_colorby - Name of the categorical column (e.g., "celltype") - - Returns - ------- - True if _colors exists in the table's .uns, False otherwise - """ - color_key = f"{col_to_colorby}_colors" - - # Determine which table to use - if table_name is not None: - if table_name not in sdata.tables: - return False - table_to_use = table_name - else: - if len(sdata.tables.keys()) == 0: - return False - # When no table is specified, check all tables for the color key - return any(color_key in adata.uns for adata in sdata.tables.values()) - - adata = sdata.tables[table_to_use] - return color_key in adata.uns - - -def _extract_colors_from_table_uns( - sdata: sd.SpatialData, - table_name: str | None, - col_to_colorby: str, - color_source_vector: ArrayLike | pd.Series[CategoricalDtype], - na_color: ColorLike, -) -> Mapping[str, str] | None: - """ - Extract categorical colors from the _colors pattern in adata.uns. - - This function looks for colors stored in the format _colors in the - specified table's .uns slot and creates a mapping from categories to colors. - - Parameters - ---------- - sdata - SpatialData object containing tables - table_name - Name of the table to look in. If None, uses the first available table. - col_to_colorby - Name of the categorical column (e.g., "celltype") - color_source_vector - Categorical vector containing the categories to map - na_color - Color to use for NaN/missing values - - Returns - ------- - Mapping from category names to hex colors, or None if colors not found - """ - color_key = f"{col_to_colorby}_colors" - - # Determine which table to use - if table_name is not None: - if table_name not in sdata.tables: - logger.warning(f"Table '{table_name}' not found in sdata. Available tables: {list(sdata.tables.keys())}") - return None - table_to_use = table_name - else: - if len(sdata.tables) == 0: - logger.warning("No tables found in sdata.") - return None - # No explicit table provided: search all tables for the color key - candidate_tables: list[str] = [ - name - for name, ad in sdata.tables.items() - if color_key in ad.uns # type: ignore[union-attr] - ] - if not candidate_tables: - logger.debug(f"Color key '{color_key}' not found in any table uns.") - return None - table_to_use = candidate_tables[0] - if len(candidate_tables) > 1: - logger.warning( - f"Color key '{color_key}' found in multiple tables {candidate_tables}; using table '{table_to_use}'." - ) - logger.info(f"No table name provided, using '{table_to_use}' for color extraction.") - - adata = sdata.tables[table_to_use] - - # Check if the color pattern exists - if color_key not in adata.uns: - logger.debug(f"Color key '{color_key}' not found in table '{table_to_use}' uns.") - return None - - # Extract colors and categories - stored_colors = adata.uns[color_key] - # Drop categories not present in the current subset (e.g. when plotting - # per-coordinate-system) so that positional color lookups stay aligned. - color_source_vector = color_source_vector.remove_unused_categories() - categories = color_source_vector.categories.tolist() - - # Validate na_color format and convert to hex string - if isinstance(na_color, Color): - na_color_hex = na_color.get_hex() - else: - na_color_str = str(na_color) - if "#" not in na_color_str: - logger.warning("Expected `na_color` to be a hex color, converting...") - na_color_hex = to_hex(to_rgba(na_color)[:3]) - else: - na_color_hex = na_color_str - - # Strip alpha channel from na_color if present - if len(na_color_hex) == 9: # #rrggbbaa format - na_color_hex = na_color_hex[:7] # Keep only #rrggbb - - def _to_hex_no_alpha(color_value: Any) -> str | None: - try: - rgba = to_rgba(color_value)[:3] - hex_color: str = to_hex(rgba) - if len(hex_color) == 9: - hex_color = hex_color[:7] - return hex_color - except (TypeError, ValueError) as e: - logger.warning(f"Error converting color '{color_value}' to hex format: {e}") - return None - - color_mapping: dict[str, str] = {} - - if isinstance(stored_colors, Mapping): - for category in categories: - raw_color = stored_colors.get(category) - if raw_color is None: - logger.warning(f"No color specified for '{category}' in '{color_key}', using na_color.") - color_mapping[category] = na_color_hex - continue - hex_color = _to_hex_no_alpha(raw_color) - color_mapping[category] = hex_color if hex_color is not None else na_color_hex - logger.info(f"Successfully extracted {len(color_mapping)} colors from '{color_key}' in table '{table_to_use}'.") - else: - try: - hex_colors = [_to_hex_no_alpha(color) for color in stored_colors] - except TypeError: - logger.warning(f"Unsupported color storage for '{color_key}'. Expected sequence or mapping.") - return None - - # Map by the category's position in the *full* table, not in the - # (possibly subset) color_source_vector, so colors stay consistent - # across coordinate systems. - all_cats = ( - adata.obs[col_to_colorby].cat.categories.tolist() - if col_to_colorby in adata.obs and hasattr(adata.obs[col_to_colorby], "cat") - else categories - ) - # Map category -> index once (O(K)) instead of a per-category list scan - # (was O(K^2) via list.index). all_cats comes from pandas .categories, - # which is unique, so a plain dict comprehension is sufficient. - cat_to_idx: dict[Any, int] = {c: i for i, c in enumerate(all_cats)} - for category in categories: - idx = cat_to_idx.get(category) - if idx is not None and idx < len(hex_colors) and hex_colors[idx] is not None: - hex_color = hex_colors[idx] - assert hex_color is not None # type narrowing for mypy - color_mapping[category] = hex_color - else: - logger.warning(f"Not enough colors provided for category '{category}', using na_color.") - color_mapping[category] = na_color_hex - logger.info(f"Successfully extracted {len(hex_colors)} colors from '{color_key}' in table '{table_to_use}'.") - - color_mapping["NaN"] = na_color_hex - return color_mapping - - -def _modify_categorical_color_mapping( - mapping: Mapping[str, str], - groups: list[str] | str | None = None, - palette: dict[str, str] | list[str] | str | None = None, -) -> Mapping[str, str]: - if groups is None or isinstance(groups, list) and groups[0] is None: - return mapping - - if palette is None or isinstance(palette, list) and palette[0] is None: - # subset base mapping to only those specified in groups - modified_mapping = {key: mapping[key] for key in mapping if key in groups or key == "NaN"} - elif len(palette) == len(groups) and isinstance(groups, list) and isinstance(palette, list): - modified_mapping = dict(zip(groups, palette, strict=True)) - else: - raise ValueError(f"Expected palette to be of length `{len(groups)}`, found `{len(palette)}`.") - - return modified_mapping - - -def _get_default_categorial_color_mapping( - color_source_vector: ArrayLike | pd.Series[CategoricalDtype], - cmap_params: CmapParams | None = None, -) -> Mapping[str, str]: - len_cat = len(color_source_vector.categories.unique()) - # Try to use provided colormap first - if cmap_params is not None and cmap_params.cmap is not None and not cmap_params.cmap_is_default: - # Generate evenly spaced indices for the colormap - color_idx = np.linspace(0, 1, len_cat) - if isinstance(cmap_params.cmap, ListedColormap): - palette = [to_hex(x) for x in cmap_params.cmap(color_idx)] - elif isinstance(cmap_params.cmap, LinearSegmentedColormap): - palette = [to_hex(cmap_params.cmap(x)) for x in color_idx] - else: - # Fall back to default palettes if cmap is not of expected type - palette = None - else: - palette = None - - # Fall back to default palettes if needed - if palette is None: - if len_cat <= 20: - palette = default_20 - elif len_cat <= 28: - palette = default_28 - elif len_cat <= len(default_102): # 103 colors - palette = default_102 - else: - palette = ["grey"] * len_cat - logger.info("input has more than 103 categories. Uniform 'grey' color will be used for all categories.") - - return dict(zip(color_source_vector.categories, palette[:len_cat], strict=True)) - - -def _get_categorical_color_mapping( - adata: AnnData | None, - na_color: Color, - cluster_key: str | None = None, - color_source_vector: ArrayLike | pd.Series[CategoricalDtype] | None = None, - cmap_params: CmapParams | None = None, - alpha: float = 1, - groups: list[str] | str | None = None, - palette: dict[str, str] | list[str] | str | None = None, - render_type: Literal["points", "labels"] | None = None, -) -> Mapping[str, str]: - if not isinstance(color_source_vector, Categorical): - raise TypeError(f"Expected `categories` to be a `Categorical`, but got {type(color_source_vector).__name__}") - - # Dict palette (e.g. from make_palette_from_data): use directly as category→color mapping - if isinstance(palette, dict): - na_color_hex = na_color.get_hex_with_alpha() if isinstance(na_color, Color) else str(na_color) - if isinstance(groups, str): - groups = [groups] - if groups is not None: - mapping = {cat: palette.get(cat, na_color_hex) for cat in groups if cat in color_source_vector.categories} - else: - mapping = {cat: palette.get(cat, na_color_hex) for cat in color_source_vector.categories} - mapping["NaN"] = na_color_hex - return mapping - - if isinstance(groups, str): - groups = [groups] - - if not palette and render_type == "points" and cmap_params is not None and not cmap_params.cmap_is_default: - palette = cmap_params.cmap - - color_idx = color_idx = np.linspace(0, 1, len(color_source_vector.categories)) - if isinstance(palette, ListedColormap): - palette = [to_hex(x) for x in palette(color_idx, alpha=alpha)] - elif isinstance(palette, LinearSegmentedColormap): - palette = [to_hex(palette(x, alpha=alpha)) for x in color_idx] # type: ignore[attr-defined] - return dict(zip(color_source_vector.categories, palette, strict=True)) - - if isinstance(palette, str): - palette = [palette] - - if cluster_key is None: - # user didn't specify a column to use for coloring - base_mapping = _get_default_categorial_color_mapping( - color_source_vector=color_source_vector, cmap_params=cmap_params - ) - else: - base_mapping = _generate_base_categorial_color_mapping( - adata=adata, - cluster_key=cluster_key, - color_source_vector=color_source_vector, - na_color=na_color, - cmap_params=cmap_params, - ) - - return _modify_categorical_color_mapping(mapping=base_mapping, groups=groups, palette=palette) - - -def _maybe_set_colors( - source: AnnData, - target: AnnData, - key: str, - palette: str | ListedColormap | Cycler | Sequence[Any] | None = None, -) -> None: - color_key = f"{key}_colors" - try: - if palette is not None: - raise KeyError("Unable to copy the palette when there was other explicitly specified.") - target.uns[color_key] = source.uns[color_key] - except KeyError: - if isinstance(palette, str): - palette = ListedColormap([palette]) - if isinstance(palette, ListedColormap): # `scanpy` requires it - palette = cycler(color=palette.colors) - palette = None - add_colors_for_categorical_sample_annotation(target, key=key, force_update_colors=True, palette=palette) - - -def _decorate_axs( - ax: Axes, - cax: PatchCollection, - fig_params: FigParams, - value_to_plot: str | None, - color_source_vector: pd.Series[CategoricalDtype] | Categorical, - color_vector: pd.Series[CategoricalDtype] | Categorical, - adata: AnnData | None = None, - palette: ListedColormap | str | list[str] | None = None, - alpha: float = 1.0, - na_color: Color = Color("default"), - legend_fontsize: int | float | _FontSize | None = None, - legend_fontweight: int | _FontWeight = "bold", - legend_loc: str | None = "right margin", - legend_fontoutline: int | None = None, - na_in_legend: bool = True, - colorbar: bool = True, - colorbar_params: dict[str, object] | None = None, - colorbar_requests: list[ColorbarSpec] | None = None, - colorbar_label: str | None = None, - legend_title: str | None = None, -) -> Axes: - if value_to_plot is not None: - # if only dots were plotted without an associated value - # there is not need to plot a legend or a colorbar - - if legend_fontoutline is not None: - path_effect = [patheffects.withStroke(linewidth=legend_fontoutline, foreground="w")] - else: - path_effect = [] - - # Adding legends - if color_source_vector is not None and isinstance(color_source_vector.dtype, pd.CategoricalDtype): - # order of clusters should agree to palette order - clusters = color_source_vector.remove_unused_categories().unique() - clusters = clusters[~clusters.isnull()] - # derive mapping from color_source_vector and color_vector - group_to_color_matching = pd.DataFrame( - { - "cats": color_source_vector.remove_unused_categories(), - "color": color_vector, - } - ) - color_mapping = group_to_color_matching.drop_duplicates("cats").set_index("cats")["color"].to_dict() - _add_categorical_legend( - ax, - pd.Categorical(values=color_source_vector, categories=clusters), - palette=color_mapping, - legend_loc=legend_loc, - legend_fontweight=legend_fontweight, - legend_fontsize=legend_fontsize, - legend_fontoutline=path_effect, - na_color=[na_color.get_hex()], - na_in_legend=na_in_legend, - multi_panel=fig_params.axs is not None, - ) - # scanpy's helper doesn't accept a title; set it post-hoc so the user can - # disambiguate fill vs outline when both legends are drawn. - if legend_title is not None and (legend := ax.get_legend()) is not None: - legend.set_title(legend_title) - elif colorbar and colorbar_requests is not None and cax is not None: - colorbar_requests.append( - ColorbarSpec( - ax=ax, - mappable=cax, - params=colorbar_params, - label=colorbar_label, - alpha=alpha, - ) - ) - - return ax - - -def _get_list( - var: Any, - _type: type[Any] | tuple[type[Any], ...], - ref_len: int | None = None, - name: str | None = None, -) -> list[Any]: - """ - Get a list from a variable. - - Parameters - ---------- - var - Variable to convert to a list. - _type - Type of the elements in the list. - ref_len - Reference length of the list. - name - Name of the variable. - - Returns - ------- - List - """ - if isinstance(var, _type): - return [var] if ref_len is None else ([var] * ref_len) - if isinstance(var, list): - if ref_len is not None and ref_len != len(var): - raise ValueError( - f"Variable: `{name}` has length: {len(var)}, which is not equal to reference length: {ref_len}." - ) - for v in var: - if not isinstance(v, _type): - raise ValueError(f"Variable: `{name}` has invalid type: {type(v)}, expected: {_type}.") - return var - - raise ValueError(f"Can't make a list from variable: `{var}`") - - -def save_fig( - fig: Figure, - path: str | Path, - make_dir: bool = True, - ext: str = "png", - **kwargs: Any, -) -> None: - """ - Save a figure. +def save_fig( + fig: Figure, + path: str | Path, + make_dir: bool = True, + ext: str = "png", + **kwargs: Any, +) -> None: + """ + Save a figure. Parameters ---------- @@ -2150,76 +664,6 @@ def save_fig( fig.savefig(path, **kwargs) -def _get_linear_colormap(colors: list[str], background: str) -> list[LinearSegmentedColormap]: - return [LinearSegmentedColormap.from_list(c, [background, c], N=256) for c in colors] - - -def _validate_polygons(shapes: GeoDataFrame) -> GeoDataFrame: - """ - Convert Polygons with holes to MultiPolygons to keep interior rings during rendering. - - Parameters - ---------- - shapes - GeoDataFrame containing a `geometry` column. - - Returns - ------- - GeoDataFrame - ``shapes`` with holed Polygons converted to MultiPolygons. - """ - if "geometry" not in shapes: - return shapes - - converted_count = 0 - for idx, geom in shapes["geometry"].items(): - if isinstance(geom, shapely.Polygon) and len(geom.interiors) > 0: - shapes.at[idx, "geometry"] = shapely.MultiPolygon([geom]) - converted_count += 1 - - if converted_count > 0: - logger.info( - "Converted %d Polygon(s) with holes to MultiPolygon(s) for correct rendering.", - converted_count, - ) - - return shapes - - -def _make_patch_from_multipolygon(mp: shapely.MultiPolygon) -> list[mpatches.PathPatch]: - """ - Create PathPatches from a MultiPolygon, preserving holes robustly. - - This follows the same strategy as GeoPandas' internal Polygon plotting: - each (multi)polygon part becomes a compound Path composed of the exterior - ring and all interior rings. Orientation is handled by prior geometry - normalization rather than manual ring reversal. - """ - patches: list[mpatches.PathPatch] = [] - - for poly in mp.geoms: - if poly.is_empty: - continue - - # Ensure 2D vertices in case geometries carry Z - exterior = np.asarray(poly.exterior.coords)[..., :2] - interiors = [np.asarray(ring.coords)[..., :2] for ring in poly.interiors] - - if len(interiors) == 0: - # Simple polygon without holes - patches.append(mpatches.Polygon(exterior, closed=True)) - continue - - # Build a compound path: exterior + all interior rings - compound_path = mpath.Path.make_compound_path( - mpath.Path(exterior, closed=True), - *[mpath.Path(ring, closed=True) for ring in interiors], - ) - patches.append(mpatches.PathPatch(compound_path)) - - return patches - - def _mpl_ax_contains_elements(ax: Axes) -> bool: """Check if any objects have been plotted on the axes object. @@ -2319,2180 +763,223 @@ def _rasterize_if_necessary( # Rasterize when the source image is substantially larger than what the # current figure DPI × size requires. The +100 margin avoids rasterizing - # when the image is only slightly larger than the target. - do_rasterization = y_dims > target_y_dims + 100 or x_dims > target_x_dims + 100 - - if do_rasterization: - logger.info("Rasterizing image for faster rendering.") - # ``rasterize`` interprets ``target_unit_to_pixels`` in world units, not - # intrinsic pixels. Dividing by world extent keeps the result correct - # for any transformation (translation, scale, etc.). - world_x = float(extent["x"][1]) - float(extent["x"][0]) - world_y = float(extent["y"][1]) - float(extent["y"][0]) - target_unit_to_pixels = min(target_y_dims / world_y, target_x_dims / world_x) - image = rasterize( - image, - ("y", "x"), - [extent["y"][0], extent["x"][0]], - [extent["y"][1], extent["x"][1]], - coordinate_system, - target_unit_to_pixels=target_unit_to_pixels, - ) - if hasattr(image.data, "compute"): - # rasterize is lazy; downstream reads the result once per channel (NaN check, - # compositing, draw), so materialize once instead of re-running the warp each time. - image = image.copy(data=image.data.compute()) - - return image - - -def _rasterize_if_necessary_datashader( - image: DataArray, - dpi: float, - width: float, - height: float, - coordinate_system: str, - extent: dict[str, tuple[float, float]], - downsample_method: str, -) -> DataArray: - """Downsample to canvas resolution with a configurable datashader reduction. - - Used by ``render_images(method='datashader')`` so sparse images (mostly - zeros, rare non-zero pixels) survive the downsample step instead of - being averaged away by the default mean aggregation. - """ - has_c_dim = len(image.shape) == 3 - y_dims, x_dims = (image.shape[1], image.shape[2]) if has_c_dim else image.shape - - target_y_dims = int(dpi * height) - target_x_dims = int(dpi * width) - - if y_dims <= target_y_dims and x_dims <= target_x_dims: - return image - - # spatialdata.rasterize is invoked solely to inherit the output coords and - # spatial transformation; its mean-aggregated values are overwritten below. - # TODO: this wastes a full per-channel resample pass. A future refactor can - # construct the target DataArray + transformation directly once spatialdata - # exposes a public geometry-only helper. - world_x = float(extent["x"][1]) - float(extent["x"][0]) - world_y = float(extent["y"][1]) - float(extent["y"][0]) - target_unit_to_pixels = min(target_y_dims / world_y, target_x_dims / world_x) - base = rasterize( - image, - ("y", "x"), - [extent["y"][0], extent["x"][0]], - [extent["y"][1], extent["x"][1]], - coordinate_system, - target_unit_to_pixels=target_unit_to_pixels, - ) - - out_y, out_x = (base.shape[1], base.shape[2]) if has_c_dim else base.shape - # Materialize once: per-chunk reductions across channels would otherwise - # trigger repeated dask graph evaluations on the same source array. - src = image.compute() if hasattr(image.data, "compute") else image - cvs = ds.Canvas( - plot_width=out_x, - plot_height=out_y, - x_range=(float(extent["x"][0]), float(extent["x"][1])), - y_range=(float(extent["y"][0]), float(extent["y"][1])), - ) - base.values = np.asarray(cvs.raster(src, downsample_method=downsample_method).values).astype(base.dtype, copy=False) - return base - - -def _multiscale_to_spatial_image( - multiscale_image: DataTree, - dpi: float, - width: float, - height: float, - scale: str | None = None, - is_label: bool = False, -) -> DataArray: - """Extract the DataArray to be rendered from a multiscale image. - - From the `DataTree`, the scale that fits the given image size and dpi most is selected - and returned. In case the lowest resolution is still too high, a rasterization step is added. - - Parameters - ---------- - multiscale_image - `DataTree` that should be rendered - dpi - dpi of the target image - width - width of the target image in inches - height - height of the target image in inches - scale - specific scale that the user chose, if None the heuristic is used - is_label - When True, the multiscale image contains labels which don't contain the `c` dimension - - Returns - ------- - DataArray - To be rendered, extracted from the DataTree respecting the dpi and size of the target image. - """ - scales = [leaf.name for leaf in multiscale_image.leaves] - x_dims = [multiscale_image[scale].dims["x"] for scale in scales] - y_dims = [multiscale_image[scale].dims["y"] for scale in scales] - - if isinstance(scale, str): - if scale not in scales and scale != "full": - raise ValueError(f'Scale {scale} does not exist. Please select one of {scales} or set scale = "full"!') - optimal_scale = scale - if scale == "full": - # use scale with highest resolution - optimal_scale = scales[np.argmax(x_dims)] - else: - # sort scales ascending by x resolution - order = np.argsort(x_dims) - scales = [scales[i] for i in order] - x_dims = [x_dims[i] for i in order] - y_dims = [y_dims[i] for i in order] - - optimal_x = width * dpi - optimal_y = height * dpi - - # Pick the lowest-resolution scale where both x and y are >= the - # target pixel count. Falls back to highest available resolution. - optimal_scale = scales[-1] - for i, (xd, yd) in enumerate(zip(x_dims, y_dims, strict=True)): - if xd >= optimal_x and yd >= optimal_y: - optimal_scale = scales[i] - break - - # NOTE: problematic if there are cases with > 1 data variable - data_var_keys = list(multiscale_image[optimal_scale].data_vars) - image = multiscale_image[optimal_scale][data_var_keys[0]] - - return Labels2DModel.parse(image) if is_label else Image2DModel.parse(image, c_coords=image.coords["c"].values) - - -def _get_elements_to_be_rendered( - render_cmds: list[ - tuple[ - str, - ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams | GraphRenderParams, - ] - ], - cs_index: pd.DataFrame, - cs: str, -) -> list[str]: - """ - Get the names of the elements to be rendered in the plot. - - Parameters - ---------- - render_cmds - List of tuples containing the commands and their respective parameters. - cs_index - The cs_contents dataframe indexed by the "cs" column. - cs - The name of the coordinate system to query cs_index for. - - Returns - ------- - List of names of the SpatialElements to be rendered in the plot. - """ - elements_to_be_rendered: list[str] = [] - - cs_row = cs_index.loc[cs] if cs in cs_index.index else None - - for cmd, params in render_cmds: - if cmd == "render_graph": - # Graph doesn't have its own CS flag; include its element so - # _get_valid_cs keeps the coordinate system alive. - elements_to_be_rendered.append(params.element) - else: - key = _RENDER_CMD_TO_CS_FLAG.get(cmd) - if key and cs_row is not None and cs_row[key]: - elements_to_be_rendered.append(params.element) - - return elements_to_be_rendered - - -def _validate_show_parameters( - coordinate_systems: list[str] | str | None, - legend_fontsize: int | float | _FontSize | None, - legend_fontweight: int | _FontWeight, - legend_loc: str | None, - legend_fontoutline: int | None, - na_in_legend: bool, - colorbar: bool, - colorbar_params: dict[str, object] | None, - wspace: float | None, - hspace: float, - ncols: int, - frameon: bool | None, - figsize: tuple[float, float] | None, - dpi: int | None, - fig: Figure | None, - title: list[str] | str | None, - pad_extent: int | float, - ax: list[Axes] | Axes | None, - return_ax: bool, - save: str | Path | None, - show: bool | None, - scalebar_dx: float | None, - scalebar_units: str, - scalebar_params: dict[str, Any] | None, - legend_params: dict[str, Any] | None, -) -> None: - if coordinate_systems is not None and not isinstance(coordinate_systems, list | str): - raise TypeError("Parameter 'coordinate_systems' must be a string or a list of strings.") - - font_weights = ["light", "normal", "medium", "semibold", "bold", "heavy", "black"] - if legend_fontweight is not None and ( - not isinstance(legend_fontweight, int | str) - or (isinstance(legend_fontweight, str) and legend_fontweight not in font_weights) - ): - readable_font_weights = ", ".join(font_weights[:-1]) + ", or " + font_weights[-1] - raise TypeError( - "Parameter 'legend_fontweight' must be an integer or one of", - f"the following strings: {readable_font_weights}.", - ) - - font_sizes = [ - "xx-small", - "x-small", - "small", - "medium", - "large", - "x-large", - "xx-large", - ] - - if legend_fontsize is not None and ( - not isinstance(legend_fontsize, int | float | str) - or (isinstance(legend_fontsize, str) and legend_fontsize not in font_sizes) - ): - readable_font_sizes = ", ".join(font_sizes[:-1]) + ", or " + font_sizes[-1] - raise TypeError( - "Parameter 'legend_fontsize' must be an integer, a float, or ", - f"one of the following strings: {readable_font_sizes}.", - ) - - if legend_loc is not None and not isinstance(legend_loc, str): - raise TypeError("Parameter 'legend_loc' must be a string.") - - if legend_fontoutline is not None and not isinstance(legend_fontoutline, int): - raise TypeError("Parameter 'legend_fontoutline' must be an integer.") - - if not isinstance(na_in_legend, bool): - raise TypeError("Parameter 'na_in_legend' must be a boolean.") - - if not isinstance(colorbar, bool): - raise TypeError("Parameter 'colorbar' must be a boolean.") - - if colorbar_params is not None and not isinstance(colorbar_params, dict): - raise TypeError("Parameter 'colorbar_params' must be a dictionary or None.") - - if wspace is not None and not isinstance(wspace, float): - raise TypeError("Parameter 'wspace' must be a float.") - - if not isinstance(hspace, float): - raise TypeError("Parameter 'hspace' must be a float.") - - if not isinstance(ncols, int): - raise TypeError("Parameter 'ncols' must be an integer.") - - if frameon is not None and not isinstance(frameon, bool): - raise TypeError("Parameter 'frameon' must be a boolean.") - - if figsize is not None and ( - not isinstance(figsize, tuple | list | np.ndarray) - or len(figsize) != 2 - or not all(isinstance(x, int | float) and not isinstance(x, bool) for x in figsize) - ): - raise TypeError("Parameter 'figsize' must be a tuple, list, or numpy array of two numbers.") - - if dpi is not None and not isinstance(dpi, int): - raise TypeError("Parameter 'dpi' must be an integer.") - - if fig is not None and not isinstance(fig, Figure): - raise TypeError("Parameter 'fig' must be a matplotlib.figure.Figure.") - - if title is not None and not isinstance(title, list | str): - raise TypeError("Parameter 'title' must be a string or a list of strings.") - - if not isinstance(pad_extent, int | float): - raise TypeError("Parameter 'pad_extent' must be numeric.") - - if ax is not None and not isinstance(ax, Axes | list): - raise TypeError("Parameter 'ax' must be a matplotlib.axes.Axes or a list of Axes.") - - if not isinstance(return_ax, bool): - raise TypeError("Parameter 'return_ax' must be a boolean.") - - if save is not None and not isinstance(save, str | Path): - raise TypeError("Parameter 'save' must be a string or a pathlib.Path.") - - if show is not None and not isinstance(show, bool): - raise TypeError("Parameter 'show' must be a boolean or None.") - - if scalebar_dx is not None: - if not isinstance(scalebar_dx, int | float) or isinstance(scalebar_dx, bool): - raise TypeError("Parameter 'scalebar_dx' must be a number or None.") - if scalebar_dx <= 0: - raise ValueError("Parameter 'scalebar_dx' must be > 0.") - if not isinstance(scalebar_units, str): - raise TypeError("Parameter 'scalebar_units' must be a string.") - - if scalebar_params is not None and not isinstance(scalebar_params, dict): - raise TypeError("Parameter 'scalebar_params' must be a dictionary or None.") - - if legend_params is not None: - if not isinstance(legend_params, dict): - raise TypeError("Parameter 'legend_params' must be a dictionary or None.") - # `loc` is matplotlib.Legend's native key; `location` aligns with colorbar_params / scalebar_params. - allowed_legend_keys = {"loc", "location", "fontsize", "fontweight", "fontoutline", "na_in_legend"} - unknown = set(legend_params) - allowed_legend_keys - if unknown: - raise ValueError( - f"Unknown legend_params key(s): {sorted(unknown)}. Allowed keys: {sorted(allowed_legend_keys)}." - ) - - -def _check_color_column_collision( - sdata: SpatialData, - elements: list[str], - color: str, - element_type: str, -) -> None: - """Raise if ``color`` is a color-like string that also names a column in the element or its tables.""" - matches: list[str] = [] - for el in elements: - if element_type in {"shapes", "points"}: - try: - el_cols = sdata[el].columns - except (KeyError, AttributeError): - el_cols = () - if color in el_cols: - matches.append(f"element '{el}'") - continue - try: - tables = get_element_annotators(sdata, el) - except (KeyError, ValueError): - tables = set() - for t in tables: - adata = sdata[t] - if color in adata.obs.columns or color in adata.var_names: - matches.append(f"table '{t}' (annotating '{el}')") - break - if matches: - locations = ", ".join(matches) - raise ValueError( - f"`color={color!r}` is ambiguous: it is a valid matplotlib color name AND a column " - f"name in {locations}. Disambiguate by either passing an unambiguous color form " - f"(hex string like '#ffa500' or an RGB(A) tuple), or by renaming the column." - ) - - -def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[str, Any]: - colorbar = param_dict.get("colorbar", "auto") - if colorbar not in {True, False, None, "auto"}: - raise TypeError("Parameter 'colorbar' must be one of True, False or 'auto'.") - - colorbar_params = param_dict.get("colorbar_params") - if colorbar_params is not None and not isinstance(colorbar_params, dict): - raise TypeError("Parameter 'colorbar_params' must be a dictionary or None.") - - element = param_dict.get("element") - if element is not None and not isinstance(element, str): - raise ValueError( - "Parameter 'element' must be a string. If you want to display more elements, pass `element` " - "as `None` or chain pl.render(...).pl.render(...).pl.show()" - ) - if element_type == "images": - param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].images.keys()) - elif element_type == "labels": - param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].labels.keys()) - elif element_type == "points": - param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].points.keys()) - elif element_type == "shapes": - param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].shapes.keys()) - - channel = param_dict.get("channel") - if channel is not None and not isinstance(channel, list | str | int): - raise TypeError("Parameter 'channel' must be a string, an integer, or a list of strings or integers.") - if isinstance(channel, list): - if not all(isinstance(c, str | int) for c in channel): - raise TypeError("Each item in 'channel' list must be a string or an integer.") - if not all(isinstance(c, type(channel[0])) for c in channel): - raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.") - - elif "channel" in param_dict: - param_dict["channel"] = [channel] if channel is not None else None - - contour_px = param_dict.get("contour_px") - if contour_px and not isinstance(contour_px, int): - raise TypeError("Parameter 'contour_px' must be an integer.") - - color = param_dict.get("color") - if color and element_type in { - "shapes", - "points", - "labels", - "graph", - }: - if not isinstance(color, str | tuple | list): - raise TypeError("Parameter 'color' must be a string or a tuple/list of floats.") - if _is_color_like(color): - if isinstance(color, str): - _check_color_column_collision(param_dict["sdata"], param_dict["element"], color, element_type) - param_dict["col_for_color"] = None - param_dict["color"] = Color(color) - if param_dict["color"].alpha_is_user_defined(): - if element_type == "points" and param_dict.get("alpha") is None: - param_dict["alpha"] = param_dict["color"].get_alpha_as_float() - elif element_type in {"shapes", "labels"} and param_dict.get("fill_alpha") is None: - param_dict["fill_alpha"] = param_dict["color"].get_alpha_as_float() - else: - logger.info( - f"Alpha implied by color '{color}' is ignored since the parameter 'alpha' or 'fill_alpha' " - "is set and its value takes precedence." - ) - elif isinstance(color, str): - param_dict["col_for_color"] = color - param_dict["color"] = None - else: - raise ValueError(f"{color} is not a valid RGB(A) array and therefore can't be used as 'color' value.") - elif "color" in param_dict and element_type != "images": - param_dict["col_for_color"] = None - - outline_width = param_dict.get("outline_width") - if outline_width: - # outline_width only exists for shapes at the moment - if isinstance(outline_width, tuple): - for ow in outline_width: - if isinstance(ow, float | int): - if ow < 0: - raise ValueError("Parameter 'outline_width' cannot contain negative values.") - else: - raise TypeError("Parameter 'outline_width' must contain only numerics when it is a tuple.") - elif not isinstance(outline_width, float | int): - raise TypeError("Parameter 'outline_width' must be numeric or a tuple of two numerics.") - if isinstance(outline_width, float | int) and outline_width < 0: - raise ValueError("Parameter 'outline_width' cannot be negative.") - - outline_alpha = param_dict.get("outline_alpha") - if outline_alpha: - if isinstance(outline_alpha, tuple): - if element_type != "shapes": - raise ValueError("Parameter 'outline_alpha' must be a single numeric.") - if len(outline_alpha) == 1: - if not isinstance(outline_alpha[0], float | int) or not 0 <= outline_alpha[0] <= 1: - raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.") - param_dict["outline_alpha"] = outline_alpha[0] - elif len(outline_alpha) < 1: - raise ValueError("Empty tuple is not supported as input for outline_alpha!") - else: - if len(outline_alpha) > 2: - logger.warning( - f"Tuple of length {len(outline_alpha)} was passed for outline_alpha, only first two positions " - "are used since more than 2 outlines are not supported!" - ) - if ( - not isinstance(outline_alpha[0], float | int) - or not isinstance(outline_alpha[1], float | int) - or not 0 <= outline_alpha[0] <= 1 - or not 0 <= outline_alpha[1] <= 1 - ): - raise TypeError("Parameter 'outline_alpha' must contain numeric values between 0 and 1.") - param_dict["outline_alpha"] = (outline_alpha[0], outline_alpha[1]) - elif not isinstance(outline_alpha, float | int) or not 0 <= outline_alpha <= 1: - raise TypeError("Parameter 'outline_alpha' must be numeric and between 0 and 1.") - - outline_color = param_dict.get("outline_color") - if "outline_color" in param_dict and element_type in {"shapes", "labels"}: - param_dict["col_for_outline_color"] = None - if outline_color: - if not isinstance(outline_color, str | tuple | list): - raise TypeError("Parameter 'outline_color' must be a string or a tuple/list of floats or colors.") - if isinstance(outline_color, tuple | list): - if len(outline_color) < 1: - raise ValueError("Empty tuple is not supported as input for outline_color!") - if len(outline_color) == 1: - param_dict["outline_color"] = Color(outline_color[0]) - elif len(outline_color) == 2: - # assuming the case of 2 outlines - param_dict["outline_color"] = (Color(outline_color[0]), Color(outline_color[1])) - elif len(outline_color) in [3, 4]: - # assuming RGB(A) array - param_dict["outline_color"] = Color(outline_color) - else: - raise ValueError( - f"Tuple/List of length {len(outline_color)} was passed for outline_color. Valid options would be: " - "tuple of 2 colors (for 2 outlines) or an RGB(A) array, aka a list/tuple of 3-4 floats." - ) - elif isinstance(outline_color, str) and element_type in {"shapes", "labels"}: - if _is_color_like(outline_color): - _check_color_column_collision(param_dict["sdata"], param_dict["element"], outline_color, element_type) - param_dict["outline_color"] = Color(outline_color) - else: - if isinstance(param_dict.get("outline_width"), tuple): - raise ValueError( - "Coloring outlines by a column is not supported with two outlines. " - "Pass a scalar `outline_width` or a literal color for `outline_color`." - ) - param_dict["col_for_outline_color"] = outline_color - param_dict["outline_color"] = None - else: - param_dict["outline_color"] = Color(outline_color) - - if contour_px is not None and contour_px < 2: - raise ValueError( - "Parameter 'contour_px' must be >= 2; values below 2 produce no visible outline " - "(a 1x1 erosion is the identity transformation)." - ) - - alpha = param_dict.get("alpha") - if alpha is not None: - if not isinstance(alpha, float | int): - raise TypeError("Parameter 'alpha' must be numeric.") - if not 0 <= alpha <= 1: - raise ValueError("Parameter 'alpha' must be between 0 and 1.") - elif element_type == "points": - # set default alpha for points if not given by user explicitly or implicitly (as part of color) - param_dict["alpha"] = 1.0 - - fill_alpha = param_dict.get("fill_alpha") - if fill_alpha is not None: - if not isinstance(fill_alpha, float | int): - raise TypeError("Parameter 'fill_alpha' must be numeric.") - if fill_alpha < 0: - raise ValueError("Parameter 'fill_alpha' cannot be negative.") - elif element_type == "shapes": - # set default fill_alpha for shapes if not given by user explicitly or implicitly (as part of color) - param_dict["fill_alpha"] = 1.0 - elif element_type == "labels": - # set default fill_alpha for labels if not given by user explicitly or implicitly (as part of color) - param_dict["fill_alpha"] = 0.4 - - cmap = param_dict.get("cmap") - palette = param_dict.get("palette") - if cmap is not None and palette is not None: - raise ValueError("Both `palette` and `cmap` are specified. Please specify only one of them.") - param_dict["cmap"] = cmap - - groups = param_dict.get("groups") - if groups is not None: - if not isinstance(groups, list | str): - raise TypeError("Parameter 'groups' must be a string or a list of strings.") - if isinstance(groups, str): - param_dict["groups"] = [groups] - elif not all(isinstance(g, str) for g in groups): - raise TypeError("Each item in 'groups' must be a string.") - - palette = param_dict["palette"] - - # dict palettes (e.g. from make_palette_from_data) bypass groups validation - if isinstance(palette, dict): - from matplotlib.colors import is_color_like - - invalid = [f"'{k}': '{v}'" for k, v in palette.items() if not is_color_like(v)] - if invalid: - raise ValueError(f"Dict palette contains invalid color values: {', '.join(invalid)}.") - elif isinstance(palette, list): - if not all(isinstance(p, str) for p in palette): - raise ValueError("If specified, parameter 'palette' must contain only strings.") - elif isinstance(palette, str | type(None)) and "palette" in param_dict and element_type != "graph": - param_dict["palette"] = [palette] if palette is not None else None - - palette_group = param_dict.get("palette") - if element_type in ["shapes", "points", "labels"] and palette_group is not None and not isinstance(palette, dict): - groups = param_dict.get("groups") - if groups is not None and len(groups) != len(palette_group): - raise ValueError( - f"The length of 'palette' and 'groups' must be the same, length is {len(palette_group)} and" - f"{len(groups)} respectively." - ) - - if isinstance(cmap, list): - if not all(isinstance(c, Colormap | str) for c in cmap): - raise TypeError("Each item in 'cmap' list must be a string or a Colormap.") - elif isinstance(cmap, Colormap | str | type(None)): - if "cmap" in param_dict and element_type != "graph": - param_dict["cmap"] = [cmap] if cmap is not None else None - else: - raise TypeError("Parameter 'cmap' must be a string, a Colormap, or a list of these types.") - - # validation happens within Color constructor (images don't use na_color) - if "na_color" in param_dict: - param_dict["na_color"] = Color(param_dict.get("na_color")) - - norm = param_dict.get("norm") - if norm is not None: - if element_type == "images": - if isinstance(norm, list): - if not norm: - raise ValueError("Parameter 'norm' list must not be empty.") - if not all(isinstance(n, Normalize) for n in norm): - raise TypeError("Every item in 'norm' list must be a Normalize instance.") - elif not isinstance(norm, Normalize): - raise TypeError("Parameter 'norm' must be a Normalize or a list of Normalize instances.") - elif element_type == "labels" and not isinstance(norm, Normalize): - raise TypeError("Parameter 'norm' must be of type Normalize.") - if element_type in {"shapes", "points"} and not isinstance(norm, bool | Normalize): - raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.") - if element_type == "graph" and not isinstance(norm, Normalize): - raise TypeError("Parameter 'norm' must be a Normalize instance.") - - scale = param_dict.get("scale") - if scale is not None: - if element_type in {"images", "labels"} and not isinstance(scale, str): - raise TypeError("Parameter 'scale' must be a string if specified.") - if element_type == "shapes": - if not isinstance(scale, float | int): - raise TypeError("Parameter 'scale' must be numeric.") - if scale < 0: - raise ValueError("Parameter 'scale' must be a positive number.") - - size = param_dict.get("size") - if size: - if not isinstance(size, float | int): - raise TypeError("Parameter 'size' must be numeric.") - if size < 0: - raise ValueError("Parameter 'size' must be a positive number.") - - shape = param_dict.get("shape") - if element_type == "shapes" and shape is not None: - valid_shapes = {"circle", "hex", "visium_hex", "square"} - if not isinstance(shape, str): - raise TypeError(f"Parameter 'shape' must be a String from {valid_shapes} if not None.") - if shape not in valid_shapes: - raise ValueError(f"'{shape}' is not supported for 'shape', please choose from {valid_shapes}.") - - table_name = param_dict.get("table_name") - table_layer = param_dict.get("table_layer") - if table_name and not isinstance(param_dict["table_name"], str): - raise TypeError("Parameter 'table_name' must be a string.") - - if table_layer and not isinstance(param_dict["table_layer"], str): - raise TypeError("Parameter 'table_layer' must be a string.") - - def _ensure_table_and_layer_exist_in_sdata( - sdata: SpatialData, table_name: str | None, table_layer: str | None - ) -> bool: - """Ensure that table_name and table_layer are valid; throw error if not.""" - if table_name: - if table_layer: - if table_layer in sdata.tables[table_name].layers: - return True - raise ValueError(f"Layer '{table_layer}' not found in table '{table_name}'.") - return True # using sdata.tables[table_name].X - - if table_layer: - # user specified a layer but we have no tables => invalid - if len(sdata.tables) == 0: - raise ValueError("Trying to use 'table_layer' but no tables are present in the SpatialData object.") - if len(sdata.tables) == 1: - single_table_name = list(sdata.tables.keys())[0] - if table_layer in sdata.tables[single_table_name].layers: - return True - raise ValueError(f"Layer '{table_layer}' not found in table '{single_table_name}'.") - # more than one tables, try to find which one has the given layer - found_table = False - for tname in sdata.tables: - if table_layer in sdata.tables[tname].layers: - if found_table: - raise ValueError( - "Trying to guess 'table_name' based on 'table_layer', but found multiple matches." - ) - found_table = True - - if found_table: - return True - - raise ValueError(f"Layer '{table_layer}' not found in any table.") - - return True # not using any table - - _ensure_table_and_layer_exist_in_sdata(param_dict.get("sdata"), table_name, table_layer) - - method = param_dict.get("method") - if method not in ["matplotlib", "datashader", None]: - raise ValueError("If specified, parameter 'method' must be either 'matplotlib' or 'datashader'.") - - valid_ds_reduction_methods = [ - "sum", - "mean", - "any", - "count", - # "m2", -> not intended to be used alone (see https://datashader.org/api.html#datashader.reductions.m2) - # "mode", -> not supported for points (see https://datashader.org/api.html#datashader.reductions.mode) - "std", - "var", - "max", - "min", - ] - ds_reduction = param_dict.get("ds_reduction") - if ds_reduction and (ds_reduction not in valid_ds_reduction_methods): - raise ValueError(f"Parameter 'ds_reduction' must be one of the following: {valid_ds_reduction_methods}.") - - if element_type == "graph": - for key in ("connectivity_key",): - val = param_dict.get(key) - if val is not None and not isinstance(val, str): - raise TypeError(f"Parameter '{key}' must be a string.") - - for key in ("obsp_key", "weight_key", "group_key"): - val = param_dict.get(key) - if val is not None and not isinstance(val, str): - raise TypeError(f"Parameter '{key}' must be a string or None.") - - for key in ("edge_width", "edge_alpha"): - val = param_dict.get(key) - if val == "weight": - continue - if not isinstance(val, float | int): - raise TypeError(f"Parameter '{key}' must be numeric or the literal string 'weight'.") - if val < 0: - raise ValueError(f"Parameter '{key}' cannot be negative.") - - linestyle = param_dict.get("linestyle") - if linestyle is not None and not isinstance(linestyle, str | list | tuple): - raise TypeError("Parameter 'linestyle' must be a string or a sequence of strings.") - - for key in ("include_self_loops", "rasterize"): - val = param_dict.get(key) - if val is not None and not isinstance(val, bool): - raise TypeError(f"Parameter '{key}' must be a boolean.") - - return param_dict - - -def _resolve_color_panels(color: Any) -> tuple[Any, list[str] | None]: - """Split a ``color`` argument into a scalar color and an optional multi-panel key list. - - Returns ``(scalar_color, panel_keys)``. When ``panel_keys`` is ``None`` the call is a - normal single-color render and ``scalar_color`` is the (unchanged) color to use. When - ``panel_keys`` is a list, the render must be expanded into one panel per key. - - A list of all-strings is treated as multi-panel keys; a length-1 list normalizes to a - scalar color; an all-numeric list stays a single RGB(A) color. Empty, duplicate, or - mixed str/number lists raise ``ValueError``. - """ - if not isinstance(color, list): - return color, None - if all(isinstance(c, str) for c in color): - if len(color) == 0: - raise ValueError("`color` was given an empty list; provide at least one column/key name.") - duplicate_keys = sorted(k for k, n in Counter(color).items() if n > 1) - if duplicate_keys: - raise ValueError(f"`color` contains duplicate keys {duplicate_keys}; each multi-panel key must be unique.") - if len(color) == 1: - return color[0], None - return None, list(color) - if any(isinstance(c, str) for c in color): - raise ValueError( - "`color` list must be either all column/key names (str) for a multi-panel plot, " - "or 3-4 floats for a single RGB(A) color, not a mix of both." - ) - return color, None - - -def _expand_color_panels( - sdata: SpatialData, - color: Any, - render_fn_name: str, - validate: Callable[[Any], dict[str, Any]], -) -> list[tuple[str | None, dict[str, Any]]]: - """Resolve ``color`` into validated per-panel render params for the multi-panel ``color=[...]`` feature. - - ``validate`` is a callback that runs the render function's own parameter validation for a single - color value and returns its per-element ``params_dict``. Returns a list of ``(panel_key, params_dict)`` - pairs: a single ``(None, params_dict)`` for the scalar case, or one entry per key for a key list. - - Enforces that only one ``render_*`` call per figure may pass a color list, and aggregates per-key - validation errors into a single message. Used by ``render_shapes`` and ``render_labels``. - """ - color, panel_keys = _resolve_color_panels(color) - if panel_keys is not None and any( - getattr(params, "panel_key", None) is not None for params in getattr(sdata, "plotting_tree", {}).values() - ): - raise ValueError( - "Only one `render_*` call may use a list of color keys per figure. Other chained render " - "calls must use a single (scalar) color; they are drawn into every panel as a shared layer." - ) - - color_specs = [(None, color)] if panel_keys is None else [(key, key) for key in panel_keys] - panel_param_dicts: list[tuple[str | None, dict[str, Any]]] = [] - key_errors: dict[str, str] = {} - for panel_key, color_value in color_specs: - try: - params_dict = validate(color_value) - except (KeyError, ValueError) as e: - if panel_keys is None: - raise - key_errors[panel_key] = str(e) # type: ignore[index] - continue - panel_param_dicts.append((panel_key, params_dict)) - if key_errors: - details = "\n".join(f" - {key!r}: {msg}" for key, msg in key_errors.items()) - raise ValueError(f"Invalid color key(s) for multi-panel `{render_fn_name}`:\n{details}") - return panel_param_dicts - - -def _validate_as_points_size(size: float) -> None: - """Validate the centroid marker `size` used by ``render_shapes``/``render_labels`` with ``as_points=True``.""" - if isinstance(size, bool) or not isinstance(size, (int, float)): - raise TypeError("Parameter 'size' must be numeric.") - if size <= 0: - raise ValueError("Parameter 'size' must be a positive number.") - - -def _validate_label_render_params( - sdata: sd.SpatialData, - element: str | None, - cmap: list[Colormap | str] | Colormap | str | None, - color: ColorLike | None, - fill_alpha: float | int | None, - contour_px: int | None, - groups: list[str] | str | None, - palette: dict[str, str] | list[str] | str | None, - na_color: ColorLike | None, - norm: Normalize | None, - outline_alpha: float | int, - outline_color: ColorLike | None, - scale: str | None, - table_name: str | None, - table_layer: str | None, - colorbar: bool | str | None, - colorbar_params: dict[str, object] | None, - gene_symbols: str | None = None, -) -> dict[str, dict[str, Any]]: - param_dict: dict[str, Any] = { - "sdata": sdata, - "element": element, - "fill_alpha": fill_alpha, - "contour_px": contour_px, - "groups": groups, - "palette": palette, - "color": color, - "na_color": na_color, - "outline_alpha": outline_alpha, - "outline_color": outline_color, - "cmap": cmap, - "norm": norm, - "scale": scale, - "table_name": table_name, - "table_layer": table_layer, - "colorbar": colorbar, - "colorbar_params": colorbar_params, - } - param_dict = _type_check_params(param_dict, "labels") - - element_params: dict[str, dict[str, Any]] = {} - for el in param_dict["element"]: - # ensure that the element exists in the SpatialData object - _ = param_dict["sdata"][el] - - element_params[el] = {} - element_params[el]["na_color"] = param_dict["na_color"] - element_params[el]["cmap"] = param_dict["cmap"] - element_params[el]["norm"] = param_dict["norm"] - element_params[el]["fill_alpha"] = param_dict["fill_alpha"] - element_params[el]["scale"] = param_dict["scale"] - element_params[el]["outline_alpha"] = param_dict["outline_alpha"] - element_params[el]["outline_color"] = param_dict["outline_color"] - element_params[el]["contour_px"] = param_dict["contour_px"] - element_params[el]["table_layer"] = param_dict["table_layer"] - - element_params[el]["table_name"] = None - element_params[el]["color"] = param_dict["color"] # literal Color or None - element_params[el]["col_for_color"] = None - if (col_for_color := param_dict["col_for_color"]) is not None: - col_for_color, table_name = _validate_col_for_column_table( - sdata, el, col_for_color, param_dict["table_name"], labels=True, gene_symbols=gene_symbols - ) - element_params[el]["table_name"] = table_name - element_params[el]["col_for_color"] = col_for_color - - element_params[el]["col_for_outline_color"] = None - element_params[el]["outline_table_name"] = None - if (col_for_outline_color := param_dict.get("col_for_outline_color")) is not None: - col_for_outline_color, outline_table_name = _validate_col_for_column_table( - sdata, - el, - col_for_outline_color, - param_dict["table_name"], - labels=True, - gene_symbols=gene_symbols, - ) - element_params[el]["col_for_outline_color"] = col_for_outline_color - element_params[el]["outline_table_name"] = outline_table_name - - _gate_palette_and_groups(element_params[el], param_dict) - element_params[el]["colorbar"] = param_dict["colorbar"] - element_params[el]["colorbar_params"] = param_dict["colorbar_params"] - - return element_params - - -def _validate_points_render_params( - sdata: sd.SpatialData, - element: str | None, - alpha: float | int | None, - color: ColorLike | None, - groups: list[str] | str | None, - palette: dict[str, str] | list[str] | str | None, - na_color: ColorLike | None, - cmap: list[Colormap | str] | Colormap | str | None, - norm: Normalize | None, - size: float | int, - table_name: str | None, - table_layer: str | None, - ds_reduction: str | None, - colorbar: bool | str | None, - colorbar_params: dict[str, object] | None, - gene_symbols: str | None = None, - density: bool = False, - density_how: Literal["linear", "log", "cbrt", "eq_hist"] = "linear", - transfunc: Callable[[float], float] | None = None, - method: str | None = None, -) -> dict[str, dict[str, Any]]: - if not isinstance(density, bool): - raise TypeError("Parameter 'density' must be a bool.") - allowed_how = ("linear", "log", "cbrt", "eq_hist") - if density_how not in allowed_how: - raise ValueError(f"Parameter 'density_how' must be one of {allowed_how}; got {density_how!r}.") - - param_dict: dict[str, Any] = { - "sdata": sdata, - "element": element, - "alpha": alpha, - "color": color, - "groups": groups, - "palette": palette, - "na_color": na_color, - "cmap": cmap, - "norm": norm, - "size": size, - "table_name": table_name, - "table_layer": table_layer, - "ds_reduction": ds_reduction, - "colorbar": colorbar, - "colorbar_params": colorbar_params, - } - param_dict = _type_check_params(param_dict, "points") - - if density: - if method == "matplotlib": - raise ValueError( - "density=True requires the datashader backend; got method='matplotlib'. " - "Either drop method= or set method='datashader'." - ) - # Literal color (resolved into param_dict["color"] as a Color instance, with - # col_for_color set to None) is ambiguous with density: it could mean a - # single-hue cmap or a one-entry palette. Force the user to choose. - if param_dict["color"] is not None and param_dict["col_for_color"] is None: - raise ValueError( - "density=True with a literal color is ambiguous. Pass cmap= to recolor the " - "density, or palette= to assign a categorical color, but not color=." - ) - # Warn-and-ignore: these parameters do not interact meaningfully with a - # count-based density and are silently dropped to keep the API consistent. - if size != 1.0: - warnings.warn( - "size is ignored when density=True; spreading would distort the count signal.", - UserWarning, - stacklevel=3, - ) - if transfunc is not None: - warnings.warn( - "transfunc is ignored when density=True (no continuous color vector to transform).", - UserWarning, - stacklevel=3, - ) - if isinstance(norm, Normalize) and (norm.vmin is not None or norm.vmax is not None): - warnings.warn( - "norm.vmin/vmax are ignored when density=True; use density_how= to control intensity mapping.", - UserWarning, - stacklevel=3, - ) - if ds_reduction is not None: - warnings.warn( - "datashader_reduction is ignored when density=True; counts are forced.", - UserWarning, - stacklevel=3, - ) - - element_params: dict[str, dict[str, Any]] = {} - for el in param_dict["element"]: - # ensure that the element exists in the SpatialData object - _ = param_dict["sdata"][el] - - element_params[el] = {} - element_params[el]["na_color"] = param_dict["na_color"] - element_params[el]["cmap"] = param_dict["cmap"] - element_params[el]["norm"] = param_dict["norm"] - element_params[el]["color"] = param_dict["color"] - element_params[el]["size"] = param_dict["size"] - element_params[el]["alpha"] = param_dict["alpha"] - element_params[el]["table_layer"] = param_dict["table_layer"] - - element_params[el]["table_name"] = None - element_params[el]["col_for_color"] = None - col_for_color = param_dict["col_for_color"] - if col_for_color is not None: - col_for_color, table_name = _validate_col_for_column_table( - sdata, el, col_for_color, param_dict["table_name"], gene_symbols=gene_symbols - ) - element_params[el]["table_name"] = table_name - element_params[el]["col_for_color"] = col_for_color - - _gate_palette_and_groups(element_params[el], param_dict) - element_params[el]["ds_reduction"] = param_dict["ds_reduction"] - element_params[el]["colorbar"] = param_dict["colorbar"] - element_params[el]["colorbar_params"] = param_dict["colorbar_params"] - - return element_params - - -def _validate_shape_render_params( - sdata: sd.SpatialData, - element: str | None, - fill_alpha: float | int | None, - groups: list[str] | str | None, - palette: dict[str, str] | list[str] | str | None, - color: ColorLike | None, - na_color: ColorLike | None, - outline_width: float | int | tuple[float | int, float | int] | None, - outline_color: ColorLike | tuple[ColorLike] | None, - outline_alpha: float | int | tuple[float | int, float | int] | None, - cmap: list[Colormap | str] | Colormap | str | None, - norm: Normalize | None, - scale: float | int, - table_name: str | None, - table_layer: str | None, - shape: Literal["circle", "hex", "visium_hex", "square"] | None, - method: str | None, - ds_reduction: str | None, - colorbar: bool | str | None, - colorbar_params: dict[str, object] | None, - gene_symbols: str | None = None, -) -> dict[str, dict[str, Any]]: - param_dict: dict[str, Any] = { - "sdata": sdata, - "element": element, - "fill_alpha": fill_alpha, - "groups": groups, - "palette": palette, - "color": color, - "na_color": na_color, - "outline_width": outline_width, - "outline_color": outline_color, - "outline_alpha": outline_alpha, - "cmap": cmap, - "norm": norm, - "scale": scale, - "table_name": table_name, - "table_layer": table_layer, - "shape": shape, - "method": method, - "ds_reduction": ds_reduction, - "colorbar": colorbar, - "colorbar_params": colorbar_params, - } - param_dict = _type_check_params(param_dict, "shapes") - - element_params: dict[str, dict[str, Any]] = {} - for el in param_dict["element"]: - # ensure that the element exists in the SpatialData object - _ = param_dict["sdata"][el] - - element_params[el] = {} - element_params[el]["fill_alpha"] = param_dict["fill_alpha"] - element_params[el]["na_color"] = param_dict["na_color"] - element_params[el]["outline_width"] = param_dict["outline_width"] - element_params[el]["outline_color"] = param_dict["outline_color"] - element_params[el]["outline_alpha"] = param_dict["outline_alpha"] - element_params[el]["cmap"] = param_dict["cmap"] - element_params[el]["norm"] = param_dict["norm"] - element_params[el]["scale"] = param_dict["scale"] - element_params[el]["table_layer"] = param_dict["table_layer"] - element_params[el]["shape"] = param_dict["shape"] - - element_params[el]["color"] = param_dict["color"] - - element_params[el]["table_name"] = None - element_params[el]["col_for_color"] = None - col_for_color = param_dict["col_for_color"] - if col_for_color is not None: - col_for_color, table_name = _validate_col_for_column_table( - sdata, el, col_for_color, param_dict["table_name"], gene_symbols=gene_symbols - ) - element_params[el]["table_name"] = table_name - element_params[el]["col_for_color"] = col_for_color - - element_params[el]["col_for_outline_color"] = None - element_params[el]["outline_table_name"] = None - col_for_outline_color = param_dict.get("col_for_outline_color") - if col_for_outline_color is not None: - col_for_outline_color, outline_table_name = _validate_col_for_column_table( - sdata, el, col_for_outline_color, param_dict["table_name"], gene_symbols=gene_symbols - ) - element_params[el]["col_for_outline_color"] = col_for_outline_color - element_params[el]["outline_table_name"] = outline_table_name - - _gate_palette_and_groups(element_params[el], param_dict) - element_params[el]["method"] = param_dict["method"] - element_params[el]["ds_reduction"] = param_dict["ds_reduction"] - element_params[el]["colorbar"] = param_dict["colorbar"] - element_params[el]["colorbar_params"] = param_dict["colorbar_params"] - - return element_params - - -def _resolve_gene_symbols( - adata: AnnData, - col_for_color: str, - gene_symbols: str, -) -> str: - """Resolve a gene symbol to its var_name using an alternate var column. - - Mimics scanpy's ``gene_symbols`` behaviour: look up *col_for_color* in - ``adata.var[gene_symbols]`` and return the corresponding ``var_name`` - (i.e. the var index value). - """ - if gene_symbols not in adata.var.columns: - raise KeyError(f"Column '{gene_symbols}' not found in `adata.var`. Cannot use it as `gene_symbols` lookup.") - mask = adata.var[gene_symbols] == col_for_color - if not mask.any(): - raise KeyError(f"'{col_for_color}' not found in `adata.var['{gene_symbols}']`.") - n_matches = mask.sum() - if n_matches > 1: - logger.warning( - f"Gene symbol '{col_for_color}' maps to {n_matches} var_names in column '{gene_symbols}'. " - f"Using the first match: '{adata.var.index[mask][0]}'." - ) - return str(adata.var.index[mask][0]) - - -def _validate_graph_render_params( - sdata: SpatialData, - element: str | None, - connectivity_key: str, - table_name: str | None, - color: ColorLike | None, - edge_width: float | Literal["weight"], - edge_alpha: float | Literal["weight"], - groups: list[str] | str | None, - group_key: str | None, - obsp_key: str | None = None, - weight_key: str | None = None, - palette: dict[str, str] | list[str] | str | None = None, - na_color: ColorLike | None = "default", - cmap: Colormap | str | None = None, - norm: Normalize | None = None, - linestyle: str | Sequence[str] = "solid", - include_self_loops: bool = False, - rasterize: bool = True, -) -> dict[str, Any]: - """Validate and resolve parameters for render_graph.""" - param_dict: dict[str, Any] = { - "sdata": sdata, - "element": element, - "color": color, - "groups": groups, - "palette": palette, - "na_color": na_color, - "cmap": cmap, - "norm": norm if norm is not None else Normalize(clip=False), - "table_name": table_name, - "connectivity_key": connectivity_key, - "obsp_key": obsp_key, - "weight_key": weight_key, - "group_key": group_key, - "edge_width": edge_width, - "edge_alpha": edge_alpha, - "linestyle": linestyle, - "include_self_loops": include_self_loops, - "rasterize": rasterize, - } - param_dict = _type_check_params(param_dict, "graph") - - if param_dict["table_name"] is None: - candidates = [tname for tname in sdata.tables if _resolve_obsp_key(sdata[tname], connectivity_key) is not None] - if len(candidates) == 0: - raise ValueError( - f"No table found with connectivity key '{connectivity_key}' in obsp. " - f"Available tables: {list(sdata.tables.keys())}." - ) - if len(candidates) > 1: - raise ValueError( - f"Multiple tables contain connectivity key '{connectivity_key}': {candidates}. " - "Please specify `table_name` explicitly." - ) - param_dict["table_name"] = candidates[0] - - if param_dict["table_name"] not in sdata.tables: - raise KeyError(f"Table '{param_dict['table_name']}' not found. Available: {list(sdata.tables.keys())}.") - - table = sdata[param_dict["table_name"]] - connectivity_obsp_key = _require_obsp_key(table, connectivity_key, param_name="connectivity_key") - - _, region_key, _ = get_table_keys(table) - if region_key is None: - raise ValueError( - f"Table '{param_dict['table_name']}' has no `region_key`; cannot associate its observations " - "with a spatial element. Re-parse the table with `TableModel.parse(..., region_key=...)`." - ) - - if param_dict["element"] is None: - regions = table.obs[region_key].unique().tolist() - spatial_regions = [r for r in regions if r in sdata.shapes or r in sdata.points or r in sdata.labels] - if len(spatial_regions) == 0: - raise ValueError( - f"Table '{param_dict['table_name']}' does not annotate any spatial element. Region values: {regions}." - ) - if len(spatial_regions) > 1: - raise ValueError( - f"Table '{param_dict['table_name']}' annotates multiple spatial elements: {spatial_regions}. " - "Please specify `element` explicitly." - ) - param_dict["element"] = spatial_regions[0] - elif not ( - param_dict["element"] in sdata.shapes - or param_dict["element"] in sdata.points - or param_dict["element"] in sdata.labels - ): - raise KeyError( - f"Element '{param_dict['element']}' not found in shapes, points, or labels. " - f"Available: shapes={list(sdata.shapes.keys())}, " - f"points={list(sdata.points.keys())}, labels={list(sdata.labels.keys())}." - ) - - # _type_check_params normalised string groups → list; renormalise the working set here. - if param_dict["groups"] is not None and param_dict["group_key"] is None: - raise ValueError("`groups` requires `group_key` to be specified.") - if param_dict["group_key"] is not None and param_dict["group_key"] not in table.obs.columns: - raise KeyError( - f"`group_key='{param_dict['group_key']}'` not found in table obs columns. " - f"Available: {list(table.obs.columns)}." - ) - if param_dict["groups"] is not None and param_dict["group_key"] is not None: - groups_set: set[Any] = set(param_dict["groups"]) - available_groups = set(table.obs[param_dict["group_key"]].dropna().unique()) - missing_groups = groups_set - available_groups - if missing_groups: - try: - missing_str = str(sorted(missing_groups)) - except TypeError: - missing_str = str(list(missing_groups)) - if missing_groups == groups_set: - logger.warning( - f"None of the requested groups {missing_str} were found in column " - f"'{param_dict['group_key']}'. Resulting plot will contain no edges." - ) - else: - logger.warning( - f"Groups {missing_str} not found in column '{param_dict['group_key']}' and will be ignored." - ) - - # After _type_check_params: col_for_color is the non-color string user passed via `color=`; - # color is either a Color (user gave a real color) or None (user gave a column name or nothing). - col_for_color = param_dict.get("col_for_color") - if col_for_color is not None and col_for_color not in table.obs.columns: - raise ValueError( - f"`color='{col_for_color}'` is not a matplotlib color and was not found in " - f"`table.obs` columns. Available obs columns: {list(table.obs.columns)}." - ) - - color_is_obs_col = col_for_color is not None - if obsp_key is not None and color_is_obs_col: - raise ValueError( - "Cannot set both `color` (as an obs column) and `obsp_key` for edge coloring. " - "Pick one source: scalar color, obs-column color, or obsp-matrix color." - ) - if obsp_key is not None and param_dict["color"] is not None: - raise ValueError( - "Cannot set both `color` and `obsp_key` for edge coloring. " - "Use `obsp_key` for matrix-driven coloring with `cmap`/`norm`, " - "or `color` for a scalar / obs-column-driven coloring." - ) - - color_obsp_key: str | None = None - obs_col: str | None = None - color_source: str = "scalar" - cmap_params: CmapParams | None = None - palette_map: dict[str, str] | None = None - - if obsp_key is not None: - color_obsp_key = _require_obsp_key(table, obsp_key, param_name="obsp_key") - color_source = "obsp" - cmap_params = _prepare_cmap_norm(cmap=cmap, norm=param_dict["norm"]) - elif color_is_obs_col: - obs_col = col_for_color - obs_values = table.obs[obs_col] - if isinstance(obs_values.dtype, pd.CategoricalDtype) or obs_values.dtype == object: - color_source = "obs_categorical" - categories = ( - obs_values.cat.categories.tolist() - if isinstance(obs_values.dtype, pd.CategoricalDtype) - else sorted(obs_values.dropna().unique().tolist()) - ) - if isinstance(palette, dict): - missing = [c for c in categories if c not in palette] - if missing: - raise KeyError( - f"Palette dict is missing entries for categories: {missing}. " - f"Available categories: {categories}." - ) - palette_map = {c: palette[c] for c in categories} - else: - cat_colors = _get_colors_for_categorical_obs(categories=categories, palette=palette) - palette_map = dict(zip(categories, cat_colors, strict=True)) - else: - color_source = "obs_continuous" - cmap_params = _prepare_cmap_norm(cmap=cmap, norm=param_dict["norm"]) - - # When edge_width/edge_alpha="weight" but weight_key isn't given, fall back to the - # connectivity matrix so binary graphs still produce a per-edge array. - resolved_weight_key: str | None = None - if edge_width == "weight" or edge_alpha == "weight": - resolved_weight_key = _require_obsp_key( - table, weight_key if weight_key is not None else connectivity_key, param_name="weight_key" - ) - - edge_color = param_dict["color"] if param_dict["color"] is not None else Color("grey") - parsed_na_color = param_dict["na_color"] - - return { - "element": param_dict["element"], - "connectivity_key": connectivity_key, - "connectivity_obsp_key": connectivity_obsp_key, - "obsp_key": color_obsp_key, - "obs_col": obs_col, - "cmap_params": cmap_params, - "palette_map": palette_map, - "na_color": parsed_na_color, - "color_source": color_source, - "table_name": param_dict["table_name"], - "weight_key": resolved_weight_key, - "color": edge_color, - "edge_width": edge_width, - "edge_alpha": edge_alpha, - "groups": param_dict["groups"], - "group_key": param_dict["group_key"], - } - - -def _resolve_obsp_key(table: AnnData, connectivity_key: str) -> str | None: - """Resolve connectivity_key to an actual obsp key. Accepts full key or prefix.""" - if connectivity_key in table.obsp: - return connectivity_key - suffixed = f"{connectivity_key}_connectivities" - if suffixed in table.obsp: - return suffixed - return None - - -def _require_obsp_key(table: AnnData, key: str, *, param_name: str) -> str: - """Resolve key (with prefix fallback) or raise KeyError.""" - resolved = _resolve_obsp_key(table, key) - if resolved is None: - raise KeyError( - f"`{param_name}='{key}'` not found in `table.obsp`. " - f"Tried '{key}' and '{key}_connectivities'. " - f"Available obsp keys: {list(table.obsp.keys())}." - ) - return resolved - - -def _validate_col_for_column_table( - sdata: SpatialData, - element_name: str, - col_for_color: str | None, - table_name: str | None, - labels: bool = False, - gene_symbols: str | None = None, -) -> tuple[str | None, str | None]: - if col_for_color is None: - return None, None - - if not labels and col_for_color in sdata[element_name].columns and table_name is None: - return col_for_color, None - if table_name is not None: - tables = get_element_annotators(sdata, element_name) - if table_name not in tables: - logger.warning(f"Table '{table_name}' does not annotate element '{element_name}'.") - raise KeyError(f"Table '{table_name}' does not annotate element '{element_name}'.") - if col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names: - if gene_symbols is not None: - col_for_color = _resolve_gene_symbols(sdata[table_name], col_for_color, gene_symbols) - else: - raise KeyError( - f"Column '{col_for_color}' not found in obs/var of table '{table_name}' " - f"for element '{element_name}'." - ) - else: - tables = get_element_annotators(sdata, element_name) - if len(tables) == 0: - raise KeyError( - f"Element '{element_name}' has no annotating tables. " - f"Cannot use column '{col_for_color}' for coloring. " - "Please ensure the element is annotated by at least one table." - ) - # Now check which tables contain the column - resolved_var_name: str | None = None - if gene_symbols is not None and not any(gene_symbols in sdata[t].var.columns for t in tables): - available = sorted({c for t in tables for c in sdata[t].var.columns}) - raise KeyError( - f"Column '{gene_symbols}' specified in `gene_symbols=` was not found in " - f"`adata.var` of any table annotating element '{element_name}'. " - f"Available var columns: {available}" - ) - for annotates in tables.copy(): - if col_for_color not in sdata[annotates].obs.columns and col_for_color not in sdata[annotates].var_names: - if gene_symbols is not None: - try: - resolved_var_name = _resolve_gene_symbols(sdata[annotates], col_for_color, gene_symbols) - except KeyError: - tables.remove(annotates) - else: - tables.remove(annotates) - if len(tables) == 0: - raise KeyError( - f"Unable to locate color key '{col_for_color}' for element '{element_name}'. " - "Please ensure the key exists in a table annotating this element." - ) - table_name = next(iter(tables)) - if len(tables) > 1: - logger.warning(f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.") - if resolved_var_name is not None: - col_for_color = resolved_var_name - return col_for_color, table_name - - -def _validate_image_render_params( - sdata: sd.SpatialData, - element: str | None, - channel: list[str] | list[int] | str | int | None, - alpha: float | int | None, - palette: list[str] | str | None, - cmap: list[Colormap | str] | Colormap | str | None, - norm: list[Normalize] | Normalize | None, - scale: str | None, - colorbar: bool | str | None, - colorbar_params: dict[str, object] | None, -) -> dict[str, dict[str, Any]]: - param_dict: dict[str, Any] = { - "sdata": sdata, - "element": element, - "channel": channel, - "alpha": alpha, - "palette": palette, - "cmap": cmap, - "norm": norm, - "scale": scale, - "colorbar": colorbar, - "colorbar_params": colorbar_params, - } - param_dict = _type_check_params(param_dict, "images") - - element_params: dict[str, dict[str, Any]] = {} - for el in param_dict["element"]: - element_params[el] = {} - spatial_element = param_dict["sdata"][el] - - # robustly get channel names from image or multiscale image - spatial_element_ch = ( - spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values - ) - channel = param_dict["channel"] - if channel is not None: - # Normalize channel to always be a list of str or a list of int - if isinstance(channel, str): - channel = [channel] - - if isinstance(channel, int): - channel = [channel] - - # If channel is a list, ensure all elements are the same type - if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)): - raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.") - - invalid = [c for c in channel if c not in spatial_element_ch] - if invalid: - raise ValueError( - f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}" - ) - element_params[el]["channel"] = channel - else: - element_params[el]["channel"] = None - - element_params[el]["alpha"] = param_dict["alpha"] - - palette = param_dict["palette"] - assert isinstance(palette, list | type(None)) # if present, was converted to list, just to make sure - - if isinstance(palette, list): - # case A: single palette for all channels - if len(palette) == 1: - palette_length = len(channel) if channel is not None else len(spatial_element_ch) - palette = palette * palette_length - # case B: one palette per channel (either given or derived from channel length) - channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel - if channels_to_use is not None and len(palette) != len(channels_to_use): - raise ValueError( - f"Palette length ({len(palette)}) does not match channel length " - f"({', '.join(str(c) for c in channels_to_use)})." - ) - element_params[el]["palette"] = palette - - expected_len = len(channel) if channel is not None else len(spatial_element_ch) - - cmap = param_dict["cmap"] - if cmap is not None: - if len(cmap) == 1: - cmap = cmap * expected_len - if len(cmap) != expected_len: - raise ValueError( - f"Length of 'cmap' list ({len(cmap)}) must match the number of channels ({expected_len})." - ) - element_params[el]["cmap"] = cmap - - norm = param_dict["norm"] - if isinstance(norm, list) and len(norm) > 1 and len(norm) != expected_len: - raise ValueError(f"Length of 'norm' list ({len(norm)}) must match the number of channels ({expected_len}).") - element_params[el]["norm"] = norm - scale = param_dict["scale"] - if scale and isinstance(param_dict["sdata"][el], DataTree): - valid_scales = list(param_dict["sdata"][el].keys()) - if scale not in valid_scales and scale != "full": - raise ValueError( - f"Scale '{scale}' does not exist in image '{el}'. Valid scales: {valid_scales + ['full']}." - ) - element_params[el]["scale"] = scale - element_params[el]["colorbar"] = param_dict["colorbar"] - element_params[el]["colorbar_params"] = param_dict["colorbar_params"] - - return element_params - - -def _get_wanted_render_elements( - sdata: SpatialData, - sdata_wanted_elements: list[str], - params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams, - cs: str, - element_type: Literal["images", "labels", "points", "shapes"], -) -> tuple[list[str], list[str], bool]: - wants_elements = True - if element_type in [ - "images", - "labels", - "points", - "shapes", - ]: # Prevents eval security risk - wanted_elements: list[str] = [params.element] - wanted_elements_on_cs = [ - element for element in wanted_elements if cs in set(get_transformation(sdata[element], get_all=True).keys()) - ] - - sdata_wanted_elements.extend(wanted_elements_on_cs) - return sdata_wanted_elements, wanted_elements_on_cs, wants_elements - - raise ValueError(f"Unknown element type {element_type}") - - -def _ax_show_and_transform( - array: MaskedArray[tuple[int, ...], Any] | npt.NDArray[Any], - trans_data: CompositeGenericTransform, - ax: Axes, - alpha: float | None = None, - cmap: ListedColormap | LinearSegmentedColormap | None = None, - zorder: int = 0, - norm: Normalize | None = None, - interpolation: str | None = None, -) -> matplotlib.image.AxesImage: - # ``extent`` uses mpl's pixel-grid convention; world placement happens via - # ``set_transform(trans_data)`` afterwards. - image_extent = (-0.5, array.shape[1] - 0.5, array.shape[0] - 0.5, -0.5) - # ``alpha`` is applied only when no cmap is set, so RGBA arrays already - # carrying per-pixel alpha (e.g. datashader output) are not double-attenuated. - imshow_kwargs: dict[str, Any] = {"zorder": zorder, "extent": image_extent, "norm": norm} - if not cmap and alpha is not None: - imshow_kwargs["alpha"] = alpha - else: - imshow_kwargs["cmap"] = cmap - if interpolation is not None: - imshow_kwargs["interpolation"] = interpolation - im = ax.imshow(array, **imshow_kwargs) - im.set_transform(trans_data) - return im - - -def set_zero_in_cmap_to_transparent(cmap: Colormap | str, steps: int | None = None) -> ListedColormap: - """ - Modify colormap so that 0s are transparent. - - Parameters - ---------- - cmap (Colormap | str): A matplotlib Colormap instance or a colormap name string. - steps (int): The number of steps in the colormap. - - Returns - ------- - ListedColormap: A new colormap instance with modified alpha values. - """ - if isinstance(cmap, str): - cmap = plt.get_cmap(cmap) - - colors = cmap(np.arange(steps or cmap.N)) - colors[0, :] = [1.0, 1.0, 1.0, 0.0] + # when the image is only slightly larger than the target. + do_rasterization = y_dims > target_y_dims + 100 or x_dims > target_x_dims + 100 - return ListedColormap(colors) + if do_rasterization: + logger.info("Rasterizing image for faster rendering.") + # ``rasterize`` interprets ``target_unit_to_pixels`` in world units, not + # intrinsic pixels. Dividing by world extent keeps the result correct + # for any transformation (translation, scale, etc.). + world_x = float(extent["x"][1]) - float(extent["x"][0]) + world_y = float(extent["y"][1]) - float(extent["y"][0]) + target_unit_to_pixels = min(target_y_dims / world_y, target_x_dims / world_x) + image = rasterize( + image, + ("y", "x"), + [extent["y"][0], extent["x"][0]], + [extent["y"][1], extent["x"][1]], + coordinate_system, + target_unit_to_pixels=target_unit_to_pixels, + ) + if hasattr(image.data, "compute"): + # rasterize is lazy; downstream reads the result once per channel (NaN check, + # compositing, draw), so materialize once instead of re-running the warp each time. + image = image.copy(data=image.data.compute()) + return image -def _compute_datashader_canvas_params( - x_ext: list[Any], - y_ext: list[Any], - fig_params: FigParams, -) -> tuple[Any, Any, list[Any], list[Any], Any]: - """Compute datashader canvas dimensions from spatial extents. - Shared logic used by both the dask-based and pandas-based entry points. - """ - # Compute canvas size in pixels, capped at the figure's display resolution. - # Using np.max ensures the canvas never exceeds display pixels on either axis, - # preventing pixel-based operations (spread, line_width) from being downscaled - # to sub-pixel size when the data aspect ratio differs from the figure's. - plot_width = x_ext[1] - x_ext[0] - plot_height = y_ext[1] - y_ext[0] - plot_width_px = int(round(fig_params.fig.get_size_inches()[0] * fig_params.fig.dpi)) - plot_height_px = int(round(fig_params.fig.get_size_inches()[1] * fig_params.fig.dpi)) - factor: float - factor = np.max([plot_width / plot_width_px, plot_height / plot_height_px]) - plot_width = int(np.round(plot_width / factor)) - plot_height = int(np.round(plot_height / factor)) - - return plot_width, plot_height, x_ext, y_ext, factor - - -def _get_extent_and_range_for_datashader_canvas( - spatial_element: SpatialElement, +def _rasterize_if_necessary_datashader( + image: DataArray, + dpi: float, + width: float, + height: float, coordinate_system: str, - fig_params: FigParams, -) -> tuple[Any, Any, list[Any], list[Any], Any]: - extent = _fast_extent(spatial_element, coordinate_system) - x_ext = [float(extent["x"][0]), float(extent["x"][1])] - y_ext = [float(extent["y"][0]), float(extent["y"][1])] - return _compute_datashader_canvas_params(x_ext, y_ext, fig_params) - - -def _datashader_canvas_from_dataframe( - df: pd.DataFrame, - fig_params: FigParams, -) -> tuple[Any, Any, list[Any], list[Any], Any]: - """Compute datashader canvas params directly from a pandas DataFrame. + extent: dict[str, tuple[float, float]], + downsample_method: str, +) -> DataArray: + """Downsample to canvas resolution with a configurable datashader reduction. - Avoids the overhead of ``get_extent()`` (which requires a dask-backed - SpatialElement) by reading min/max from the already-materialised data. + Used by ``render_images(method='datashader')`` so sparse images (mostly + zeros, rare non-zero pixels) survive the downsample step instead of + being averaged away by the default mean aggregation. """ - if len(df) == 0: - # Empty input (e.g., a bounding_box_query with no overlap) — caller - # should short-circuit; return zero-sized canvas params as a sentinel. - return 0, 0, [0.0, 0.0], [0.0, 0.0], 1.0 - x_ext = [float(df["x"].min()), float(df["x"].max())] - y_ext = [float(df["y"].min()), float(df["y"].max())] - return _compute_datashader_canvas_params(x_ext, y_ext, fig_params) - - -def _create_image_from_datashader_result( - ds_result: ds.transfer_functions.Image | np.ndarray[Any, np.dtype[np.uint8]], - factor: float, - ax: Axes, - x_min: float = 0.0, - y_min: float = 0.0, -) -> tuple[MaskedArray[tuple[int, ...], Any], matplotlib.transforms.Transform]: - # create SpatialImage from datashader output to get it back to original size - rgba_image_data = ds_result.copy() if isinstance(ds_result, np.ndarray) else ds_result.to_numpy().base - rgba_image_data = np.transpose(rgba_image_data, (2, 0, 1)) - transformation: Scale | TransformSequence = Scale([1, factor, factor], ("c", "y", "x")) - if x_min != 0.0 or y_min != 0.0: - # Canvas pixel (0, 0) corresponds to world (x_min, y_min). Without this - # translation the rgba would render at the world origin instead of at - # the element's actual position. - transformation = TransformSequence([transformation, Translation([x_min, y_min], ("x", "y"))]) - rgba_image = Image2DModel.parse( - rgba_image_data, - dims=("c", "y", "x"), - transformations={"global": transformation}, - ) - - _, trans_data = _prepare_transformation(rgba_image, "global", ax) + has_c_dim = len(image.shape) == 3 + y_dims, x_dims = (image.shape[1], image.shape[2]) if has_c_dim else image.shape - rgba_image = np.transpose(rgba_image.data.compute(), (1, 2, 0)) # type: ignore[attr-defined] - rgba_image = ma.masked_array(rgba_image) # type conversion for mypy + target_y_dims = int(dpi * height) + target_x_dims = int(dpi * width) - return rgba_image, trans_data + if y_dims <= target_y_dims and x_dims <= target_x_dims: + return image + # spatialdata.rasterize is invoked solely to inherit the output coords and + # spatial transformation; its mean-aggregated values are overwritten below. + # TODO: this wastes a full per-channel resample pass. A future refactor can + # construct the target DataArray + transformation directly once spatialdata + # exposes a public geometry-only helper. + world_x = float(extent["x"][1]) - float(extent["x"][0]) + world_y = float(extent["y"][1]) - float(extent["y"][0]) + target_unit_to_pixels = min(target_y_dims / world_y, target_x_dims / world_x) + base = rasterize( + image, + ("y", "x"), + [extent["y"][0], extent["x"][0]], + [extent["y"][1], extent["x"][1]], + coordinate_system, + target_unit_to_pixels=target_unit_to_pixels, + ) -_DS_REDUCTION_FUNCS: dict[str, Any] = { - "sum": ds.sum, - "mean": ds.mean, - "any": ds.any, - "count": ds.count, - "std": ds.std, - "var": ds.var, - "max": ds.max, - "min": ds.min, -} + out_y, out_x = (base.shape[1], base.shape[2]) if has_c_dim else base.shape + # Materialize once: per-chunk reductions across channels would otherwise + # trigger repeated dask graph evaluations on the same source array. + src = image.compute() if hasattr(image.data, "compute") else image + cvs = ds.Canvas( + plot_width=out_x, + plot_height=out_y, + x_range=(float(extent["x"][0]), float(extent["x"][1])), + y_range=(float(extent["y"][0]), float(extent["y"][1])), + ) + base.values = np.asarray(cvs.raster(src, downsample_method=downsample_method).values).astype(base.dtype, copy=False) + return base -def _datashader_aggregate_with_function( - reduction: _DsReduction | None, - cvs: Canvas, - spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame, - col_for_color: str | None, - element_type: Literal["points", "shapes"], +def _multiscale_to_spatial_image( + multiscale_image: DataTree, + dpi: float, + width: float, + height: float, + scale: str | None = None, + is_label: bool = False, ) -> DataArray: - """ - When shapes or points are colored by a continuous value during rendering with datashader. + """Extract the DataArray to be rendered from a multiscale image. - This function performs the aggregation using the user-specified reduction method. + From the `DataTree`, the scale that fits the given image size and dpi most is selected + and returned. In case the lowest resolution is still too high, a rasterization step is added. Parameters ---------- - reduction: String specifying the datashader reduction method to be used. - If None, "sum" is used as default. - cvs: Canvas object previously created with ds.Canvas() - spatial_element: geo or dask dataframe with the shapes or points to render - col_for_color: name of the column containing the values by which to color - element_type: tells us if this function is called from _render_shapes() or _render_points() - """ - if reduction is None: - reduction = "sum" - - try: - reduction_function = _DS_REDUCTION_FUNCS[reduction](column=col_for_color) - except KeyError as e: - raise ValueError( - f"Reduction '{reduction}' is not supported. Please use one of: {', '.join(_DS_REDUCTION_FUNCS.keys())}." - ) from e - - element_function_map = { - "points": cvs.points, - "shapes": cvs.polygons, - } - - try: - element_function = element_function_map[element_type] - except KeyError as e: - raise ValueError(f"Element type '{element_type}' is not supported. Use 'points' or 'shapes'.") from e - - if element_type == "points": - points_aggregate = element_function(spatial_element, "x", "y", agg=reduction_function) - if reduction == "any": - # replace False/True by nan/1 - points_aggregate = points_aggregate.astype(int) - points_aggregate = points_aggregate.where(points_aggregate > 0) - return points_aggregate - - # is shapes - return element_function(spatial_element, geometry="geometry", agg=reduction_function) - - -def _datshader_get_how_kw_for_spread( - reduction: _DsReduction | None, -) -> str: - # Get the best input for the how argument of ds.tf.spread(), needed for numerical values - reduction = reduction or "sum" - - reduction_to_how_map = { - "sum": "add", - "mean": "source", - "any": "source", - "count": "add", - "std": "source", - "var": "source", - "max": "max", - "min": "min", - } - - if reduction not in reduction_to_how_map: - raise ValueError( - f"Reduction {reduction} is not supported, please use one of the following: sum, mean, any, count" - ", std, var, max, min." - ) - - return reduction_to_how_map[reduction] - + multiscale_image + `DataTree` that should be rendered + dpi + dpi of the target image + width + width of the target image in inches + height + height of the target image in inches + scale + specific scale that the user chose, if None the heuristic is used + is_label + When True, the multiscale image contains labels which don't contain the `c` dimension -def _prepare_transformation( - element: DataArray | GeoDataFrame | dask.dataframe.core.DataFrame, - coordinate_system: str, - ax: Axes | None = None, -) -> tuple[ - matplotlib.transforms.Affine2D, - matplotlib.transforms.CompositeGenericTransform | None, -]: - trans = get_transformation(element, get_all=True)[coordinate_system] - affine_trans = trans.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y")) - trans = mtransforms.Affine2D(matrix=affine_trans) - trans_data = trans + ax.transData if ax is not None else None - - return trans, trans_data - - -def _apply_cmap_alpha_to_datashader_result( - result: Any, - agg: DataArray, - cmap: str | list[str] | Colormap, - span: list[float] | tuple[float, float] | None, -) -> Any: - """Apply the colormap's alpha channel to a datashader RGBA result. - - Datashader ignores the per-entry alpha channel of matplotlib colormaps, - so pixels that the cmap marks as transparent (alpha=0) are rendered - opaque. This function post-processes the shaded RGBA output to restore - the cmap's intended transparency. See :issue:`376`. + Returns + ------- + DataArray + To be rendered, extracted from the DataTree respecting the dpi and size of the target image. """ - if not isinstance(cmap, Colormap): - return result - - # Quick check: does this cmap have any transparent entries? - test_vals = np.linspace(0, 1, min(cmap.N, 256)) - cmap_alphas = cmap(test_vals)[:, 3] - if np.all(cmap_alphas >= 1.0): - return result - - # Get or ensure we have an (H, W, 4) uint8 array - if hasattr(result, "values"): - # datashader Image — uint32 packed, convert via to_numpy() - rgba = result.to_numpy().base - if rgba is None: - return result - else: - rgba = result - - if rgba.ndim != 3 or rgba.shape[2] != 4: - return result - - # Normalise aggregate values to [0, 1] using the same span datashader used - agg_vals = agg.values.astype(np.float64) - valid = np.isfinite(agg_vals) - if not valid.any(): - return result + scales = [leaf.name for leaf in multiscale_image.leaves] + x_dims = [multiscale_image[scale].dims["x"] for scale in scales] + y_dims = [multiscale_image[scale].dims["y"] for scale in scales] - if span is not None: - lo, hi = float(span[0]), float(span[1]) + if isinstance(scale, str): + if scale not in scales and scale != "full": + raise ValueError(f'Scale {scale} does not exist. Please select one of {scales} or set scale = "full"!') + optimal_scale = scale + if scale == "full": + # use scale with highest resolution + optimal_scale = scales[np.argmax(x_dims)] else: - lo = float(np.nanmin(agg_vals)) - hi = float(np.nanmax(agg_vals)) - - if hi <= lo or not np.isfinite(lo) or not np.isfinite(hi): - return result - - normed = np.clip((agg_vals - lo) / (hi - lo), 0.0, 1.0) - - # Look up cmap alpha for each pixel - desired_alpha = cmap(normed)[:, :, 3] - - # Zero out pixels where the cmap wants transparency - transparent = valid & (desired_alpha < 1.0) - if transparent.any(): - # Scale the existing alpha by the cmap's alpha - rgba[transparent, 3] = (rgba[transparent, 3].astype(np.float32) * desired_alpha[transparent]).astype(np.uint8) - - return result - - -def _datashader_map_aggregate_to_color( - agg: DataArray, - cmap: str | list[str] | ListedColormap, - color_key: list[str] | dict[str, str] | None = None, - min_alpha: float = 40, - span: None | list[float] = None, - clip: bool = True, - how: str = "linear", -) -> ds.tf.Image | np.ndarray[Any, np.dtype[np.uint8]]: - """ds.tf.shade() part, ensuring correct clipping behavior. - - If necessary (norm.clip=False), split shading in 3 parts and in the end, stack results. - This ensures the correct clipping behavior, because else datashader would always automatically clip. + # sort scales ascending by x resolution + order = np.argsort(x_dims) + scales = [scales[i] for i in order] + x_dims = [x_dims[i] for i in order] + y_dims = [y_dims[i] for i in order] - ``how`` controls the count-to-color mapping passed to :func:`datashader.transfer_functions.shade` - (``"linear"`` by default; ``"log"``/``"cbrt"``/``"eq_hist"`` compress dynamic range). The split-shade - branch used for ``norm.clip=False`` always uses ``"linear"`` since per-segment shading would otherwise - interact poorly with rank-based mappings. - """ - if not clip and isinstance(cmap, Colormap) and span is not None: - # in case we use datashader together with a Normalize object where clip=False - # why we need this is documented in https://github.com/scverse/spatialdata-plot/issues/372 - agg_in = agg.where((agg >= span[0]) & (agg <= span[1])) - img_in = ds.tf.shade( - agg_in, - cmap=cmap, - span=(span[0], span[1]), - how="linear", - color_key=color_key, - min_alpha=min_alpha, - ) + optimal_x = width * dpi + optimal_y = height * dpi - agg_under = agg.where(agg < span[0]) - img_under = ds.tf.shade( - agg_under, - cmap=[to_hex(cmap.get_under())[:7]], - min_alpha=min_alpha, - color_key=color_key, - ) + # Pick the lowest-resolution scale where both x and y are >= the + # target pixel count. Falls back to highest available resolution. + optimal_scale = scales[-1] + for i, (xd, yd) in enumerate(zip(x_dims, y_dims, strict=True)): + if xd >= optimal_x and yd >= optimal_y: + optimal_scale = scales[i] + break - agg_over = agg.where(agg > span[1]) - img_over = ds.tf.shade( - agg_over, - cmap=[to_hex(cmap.get_over())[:7]], - min_alpha=min_alpha, - color_key=color_key, - ) + # NOTE: problematic if there are cases with > 1 data variable + data_var_keys = list(multiscale_image[optimal_scale].data_vars) + image = multiscale_image[optimal_scale][data_var_keys[0]] - # stack the 3 arrays manually: go from under, through in to over and always overlay the values where alpha=0 - stack = img_under.to_numpy().base - if stack is None: - stack = img_in.to_numpy().base - else: - stack[stack[:, :, 3] == 0] = img_in.to_numpy().base[stack[:, :, 3] == 0] - img_over = img_over.to_numpy().base - if img_over is not None: - stack[stack[:, :, 3] == 0] = img_over[stack[:, :, 3] == 0] - - return _apply_cmap_alpha_to_datashader_result(stack, agg, cmap, span) - - result = ds.tf.shade( - agg, - cmap=cmap, - color_key=color_key, - min_alpha=min_alpha, - span=span, - how=how, - ) - return _apply_cmap_alpha_to_datashader_result(result, agg, cmap, span) + return Labels2DModel.parse(image) if is_label else Image2DModel.parse(image, c_coords=image.coords["c"].values) -def _hex_no_alpha(hex: str) -> str: +def _get_elements_to_be_rendered( + render_cmds: list[ + tuple[ + str, + ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams | GraphRenderParams, + ] + ], + cs_index: pd.DataFrame, + cs: str, +) -> list[str]: """ - Return a hex color string without an alpha component. + Get the names of the elements to be rendered in the plot. Parameters ---------- - hex : str - The input hex color string. Must be in one of the following formats: - - "#RRGGBB": a hex color without an alpha channel. - - "#RRGGBBAA": a hex color with an alpha channel that will be removed. + render_cmds + List of tuples containing the commands and their respective parameters. + cs_index + The cs_contents dataframe indexed by the "cs" column. + cs + The name of the coordinate system to query cs_index for. Returns ------- - str - The hex color string in "#RRGGBB" format. + List of names of the SpatialElements to be rendered in the plot. """ - if not isinstance(hex, str): - raise TypeError("Input must be a string") - if not hex.startswith("#"): - raise ValueError("Invalid hex color: must start with '#'") - - hex_digits = hex[1:] - length = len(hex_digits) - - if length == 6: - if not all(c in "0123456789abcdefABCDEF" for c in hex_digits): - raise ValueError("Invalid hex color: contains non-hex characters") - return hex # Already in #RRGGBB format. - - if length == 8: - if not all(c in "0123456789abcdefABCDEF" for c in hex_digits): - raise ValueError("Invalid hex color: contains non-hex characters") - # Return only the first 6 characters, stripping the alpha. - return "#" + hex_digits[:6] - - raise ValueError("Invalid hex color length: must be either '#RRGGBB' or '#RRGGBBAA'") - - -def _convert_shapes( - shapes: GeoDataFrame, - target_shape: str, - max_extent: float, - warn_above_extent_fraction: float = 0.5, -) -> GeoDataFrame: - """Convert shapes in a GeoDataFrame to the target_shape, using positional indexing.""" - if warn_above_extent_fraction < 0.0 or warn_above_extent_fraction > 1.0: - warn_above_extent_fraction = 0.5 - warn_shape_size = False - - # work on a copy with a clean positional index - shapes = shapes.reset_index(drop=True).copy() - - def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: - verts = [ - ( - center.x + radius * math.cos(math.radians(a)), - center.y + radius * math.sin(math.radians(a)), - ) - for a in range(30, 390, 60) - ] - return shapely.Polygon(verts), None - - def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: - verts = [ - ( - center.x + radius * math.cos(math.radians(a)), - center.y + radius * math.sin(math.radians(a)), - ) - for a in range(45, 360, 90) - ] - return shapely.Polygon(verts), None - - def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]: - return center, radius - - def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]: - coords = np.array(polygon.exterior.coords) - hull_pts = coords[ConvexHull(coords).vertices] - center = np.mean(hull_pts, axis=0) - radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1))) - nonlocal warn_shape_size - if 2 * radius > max_extent * warn_above_extent_fraction: - warn_shape_size = True - return shapely.Point(center), radius - - def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: - c, r = _polygon_to_circle(polygon) - return _circle_to_hexagon(c, r) - - def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: - c, r = _polygon_to_circle(polygon) - return _circle_to_square(c, r) - - def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]: - pts = [] - for poly in multipolygon.geoms: - pts.extend(poly.exterior.coords) - pts_array = np.array(pts) - hull_pts = pts_array[ConvexHull(pts_array).vertices] - center = np.mean(hull_pts, axis=0) - radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1))) - nonlocal warn_shape_size - if 2 * radius > max_extent * warn_above_extent_fraction: - warn_shape_size = True - return shapely.Point(center), radius - - def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: - c, r = _multipolygon_to_circle(multipolygon) - return _circle_to_hexagon(c, r) - - def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: - c, r = _multipolygon_to_circle(multipolygon) - return _circle_to_square(c, r) - - # choose conversion methods - conversion_methods: dict[str, Any] - if target_shape == "circle": - conversion_methods = { - "Point": _circle_to_circle, - "Polygon": _polygon_to_circle, - "MultiPolygon": _multipolygon_to_circle, - } - elif target_shape == "hex": - conversion_methods = { - "Point": _circle_to_hexagon, - "Polygon": _polygon_to_hexagon, - "MultiPolygon": _multipolygon_to_hexagon, - } - elif target_shape == "visium_hex": - # estimate hex radius from point spacing when possible - point_centers = [] - non_point_count = 0 - for geom in shapes.geometry: - if geom.geom_type == "Point": - point_centers.append((geom.x, geom.y)) - else: - non_point_count += 1 - if non_point_count > 0: - logger.warning("visium_hex supports Points best. Non-Point geometries will use regular hex conversion.") - if len(point_centers) >= 2: - centers = np.array(point_centers, dtype=float) - # pairwise min distance - dmin = np.inf - for i in range(len(centers)): - diffs = centers[i + 1 :] - centers[i] - if diffs.size: - d = np.min(np.linalg.norm(diffs, axis=1)) - dmin = min(dmin, d) - if not np.isfinite(dmin) or dmin <= 0: - # fallback - conversion_methods = { - "Point": _circle_to_hexagon, - "Polygon": _polygon_to_hexagon, - "MultiPolygon": _multipolygon_to_hexagon, - } - else: - hex_radius = dmin / math.sqrt(3.0) + elements_to_be_rendered: list[str] = [] - def _circle_to_visium_hex(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: - return _circle_to_hexagon(center, hex_radius) + cs_row = cs_index.loc[cs] if cs in cs_index.index else None - def _polygon_to_visium_hex(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: - return _polygon_to_hexagon(polygon) + for cmd, params in render_cmds: + if cmd == "render_graph": + # Graph doesn't have its own CS flag; include its element so + # _get_valid_cs keeps the coordinate system alive. + elements_to_be_rendered.append(params.element) + else: + key = _RENDER_CMD_TO_CS_FLAG.get(cmd) + if key and cs_row is not None and cs_row[key]: + elements_to_be_rendered.append(params.element) - def _multipolygon_to_visium_hex(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: - return _multipolygon_to_hexagon(multipolygon) + return elements_to_be_rendered - conversion_methods = { - "Point": _circle_to_visium_hex, - "Polygon": _polygon_to_visium_hex, - "MultiPolygon": _multipolygon_to_visium_hex, - } - else: - conversion_methods = { - "Point": _circle_to_hexagon, - "Polygon": _polygon_to_hexagon, - "MultiPolygon": _multipolygon_to_hexagon, - } - else: - conversion_methods = { - "Point": _circle_to_square, - "Polygon": _polygon_to_square, - "MultiPolygon": _multipolygon_to_square, - } - - # ensure radius column exists if needed - if "radius" not in shapes.columns: - shapes["radius"] = np.nan - - # convert all geometries using positional indexing - for i in range(len(shapes)): - geom = shapes.geometry.iloc[i] - gtype = geom.geom_type - if gtype == "Point": - r = shapes["radius"].iloc[i] - r = float(r) if np.isfinite(r) else 0.0 - converted, radius = conversion_methods["Point"](geom, r) # type: ignore[arg-type] - elif gtype == "Polygon": - converted, radius = conversion_methods["Polygon"](geom) # type: ignore[arg-type] - elif gtype == "MultiPolygon": - converted, radius = conversion_methods["MultiPolygon"](geom) # type: ignore[arg-type] - else: - raise ValueError(f"Converting shape {gtype} to {target_shape} is not supported.") - shapes.at[i, "geometry"] = converted - if radius is not None: - shapes.at[i, "radius"] = radius - - if warn_shape_size: - logger.info( - f"At least one converted shape spans >= {warn_above_extent_fraction * 100:.0f}% of the " - "original total bound. Results may be suboptimal." - ) - return shapes +def _get_wanted_render_elements( + sdata: SpatialData, + sdata_wanted_elements: list[str], + params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams, + cs: str, + element_type: Literal["images", "labels", "points", "shapes"], +) -> tuple[list[str], list[str], bool]: + wants_elements = True + if element_type in [ + "images", + "labels", + "points", + "shapes", + ]: # Prevents eval security risk + wanted_elements: list[str] = [params.element] + wanted_elements_on_cs = [ + element for element in wanted_elements if cs in set(get_transformation(sdata[element], get_all=True).keys()) + ] + sdata_wanted_elements.extend(wanted_elements_on_cs) + return sdata_wanted_elements, wanted_elements_on_cs, wants_elements -def _convert_alpha_to_datashader_range(alpha: float) -> float: - """Convert alpha from the range [0, 1] to the range [0, 255] used in datashader.""" - # prevent a value of 255, bc that led to fully colored test plots instead of just colored points/shapes - return min([254, alpha * 255]) + raise ValueError(f"Unknown element type {element_type}") # --- Per-cell measurements into the annotating table (centroid / area / equivalent diameter) --- @@ -4723,6 +1210,28 @@ def _resolve_measure_table(sdata: SpatialData, element_name: str, table_name: st return str(annotators[0]) +def set_zero_in_cmap_to_transparent(cmap: Colormap | str, steps: int | None = None) -> ListedColormap: + """ + Modify colormap so that 0s are transparent. + + Parameters + ---------- + cmap (Colormap | str): A matplotlib Colormap instance or a colormap name string. + steps (int): The number of steps in the colormap. + + Returns + ------- + ListedColormap: A new colormap instance with modified alpha values. + """ + if isinstance(cmap, str): + cmap = plt.get_cmap(cmap) + + colors = cmap(np.arange(steps or cmap.N)) + colors[0, :] = [1.0, 1.0, 1.0, 0.0] + + return ListedColormap(colors) + + def measure_obs( sdata: SpatialData, element: str | None = None, diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 7c0a1286..4f9e0578 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -1106,8 +1106,8 @@ def test_datashader_canvas_preserves_resolution_under_bbox_query(): # units per canvas pixel) at offset must match the no-offset baseline. from spatialdata import bounding_box_query + from spatialdata_plot.pl._datashader import _datashader_canvas_from_dataframe from spatialdata_plot.pl.render_params import FigParams - from spatialdata_plot.pl.utils import _datashader_canvas_from_dataframe baseline_sdata = _make_offset_points_sdata(offset=(0.0, 0.0)) offset_sdata = _make_offset_points_sdata(offset=(10000.0, 18000.0)) @@ -1176,8 +1176,8 @@ def test_datashader_canvas_from_empty_dataframe_does_not_crash(): # crash with ``ValueError: cannot convert float NaN to integer`` when fed # an empty DataFrame (NaN min()/max() → int cast). The helper now returns # a zero-sized sentinel so callers can short-circuit cleanly. + from spatialdata_plot.pl._datashader import _datashader_canvas_from_dataframe from spatialdata_plot.pl.render_params import FigParams - from spatialdata_plot.pl.utils import _datashader_canvas_from_dataframe empty_df = pd.DataFrame({"x": pd.Series(dtype=float), "y": pd.Series(dtype=float)}) fig = plt.figure(figsize=(6, 6), dpi=100) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index c0fa36f5..3d7a9db8 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -442,16 +442,14 @@ def test_render_shapes_raises_for_missing_column_in_table(self, sdata_blobs_shap def test_plot_shapes_unannotated_by_table_render_with_na_color(self, sdata_blobs_shapes_annotated: SpatialData): # Regression for #710: blobs_polygons instance 0 has no row in the table (instance_id starts # at 1), so coloring by a table column must render it with na_color, not drop it. - sdata_blobs_shapes_annotated.pl.render_shapes( - "blobs_polygons", color="channel_0_sum", na_color="red" - ).pl.show() + sdata_blobs_shapes_annotated.pl.render_shapes("blobs_polygons", color="channel_0_sum", na_color="red").pl.show() - def test_plot_shapes_unannotated_by_table_hidden_with_na_color_none(self, sdata_blobs_shapes_annotated: SpatialData): + def test_plot_shapes_unannotated_by_table_hidden_with_na_color_none( + self, sdata_blobs_shapes_annotated: SpatialData + ): # Counterpart to the test above: na_color=None makes the unannotated polygon (instance 0) # transparent, opting back into hiding it. - sdata_blobs_shapes_annotated.pl.render_shapes( - "blobs_polygons", color="channel_0_sum", na_color=None - ).pl.show() + sdata_blobs_shapes_annotated.pl.render_shapes("blobs_polygons", color="channel_0_sum", na_color=None).pl.show() def test_plot_can_plot_shapes_after_spatial_query(self, sdata_blobs: SpatialData): # subset to only shapes, should be unnecessary after rasterizeation of multiscale images is included @@ -1226,7 +1224,7 @@ def test_groups_filtering_preserves_transformation(sdata_blobs: SpatialData): re-assign to sdata_filt -> GeoDataFrame re-wrap — then asserts that ``_prepare_transformation`` can still retrieve the correct transformation. """ - from spatialdata_plot.pl.utils import _prepare_transformation + from spatialdata_plot.pl._datashader import _prepare_transformation scale_factor = 2.5 cs = "not_global" @@ -1753,9 +1751,7 @@ def test_render_shapes_as_points_applies_non_identity_transform(sdata_blobs: Spa set_transformation(sdata_blobs["blobs_circles"], Scale([3.0, 5.0], axes=("x", "y")), "scaled") fig, ax = plt.subplots() - sdata_blobs.pl.render_shapes("blobs_circles", as_points=True, size=50).pl.show( - ax=ax, coordinate_systems="scaled" - ) + sdata_blobs.pl.render_shapes("blobs_circles", as_points=True, size=50).pl.show(ax=ax, coordinate_systems="scaled") coll = ax.collections[0] dots = coll.get_offset_transform().transform(np.asarray(coll.get_offsets())) cs = sd.get_centroids(sdata_blobs["blobs_circles"], coordinate_system="scaled").compute()[["x", "y"]].to_numpy() diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index 33c07739..a5ec0306 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -13,13 +13,13 @@ import spatialdata_plot from spatialdata_plot.pl import measure_obs -from spatialdata_plot.pl.render_params import Color, ColorLike -from spatialdata_plot.pl.utils import ( +from spatialdata_plot.pl._datashader import ( _apply_cmap_alpha_to_datashader_result, _datashader_map_aggregate_to_color, - _set_outline, - set_zero_in_cmap_to_transparent, ) +from spatialdata_plot.pl.render_params import Color, ColorLike +from spatialdata_plot.pl._color import _set_outline +from spatialdata_plot.pl.utils import set_zero_in_cmap_to_transparent from tests.conftest import DPI, PlotTester, PlotTesterMeta sc.pl.set_rcParams_defaults() @@ -162,7 +162,7 @@ def test_plot_transparent_cmap_shapes_clip_false(self, sdata_blobs: SpatialData) def test_is_color_like(color_result: tuple[ColorLike, bool]): color, result = color_result - assert spatialdata_plot.pl.utils._is_color_like(color) == result + assert spatialdata_plot.pl._color._is_color_like(color) == result @pytest.mark.parametrize( @@ -505,9 +505,7 @@ def _add_shapes_table(sdata: SpatialData, element: str = "blobs_polygons", name: adata = AnnData(np.zeros((len(gdf), 1), dtype=np.float32)) adata.obs["instance_id"] = list(gdf.index) adata.obs["region"] = element - sdata[name] = TableModel.parse( - adata, region_key="region", instance_key="instance_id", region=element - ) + sdata[name] = TableModel.parse(adata, region_key="region", instance_key="instance_id", region=element) return sdata @@ -547,9 +545,7 @@ def test_writes_centroid_area_diameter_for_labels(self, sdata_blobs: SpatialData # area is the pixel count (positive integers); diameter = 2*sqrt(area/pi) area = table.obs["area"].to_numpy() assert (area > 0).all() - np.testing.assert_allclose( - table.obs["equivalent_diameter"].to_numpy(), 2.0 * np.sqrt(area / np.pi), rtol=1e-12 - ) + np.testing.assert_allclose(table.obs["equivalent_diameter"].to_numpy(), 2.0 * np.sqrt(area / np.pi), rtol=1e-12) def test_writes_for_shapes(self, sdata_blobs: SpatialData) -> None: _add_shapes_table(sdata_blobs, "blobs_polygons") @@ -635,7 +631,7 @@ def test_existing_nonnumeric_column_raises_before_any_write(self, sdata_blobs: S def test_sparse_high_label_ids(self, sdata_blobs: SpatialData) -> None: # #5: sparse/high label ids (max id >> n_labels) are measured correctly (dense relabelling). arr = np.asarray(sdata_blobs["blobs_labels"].data) - hi = (arr.astype(np.int64) * 1000) # ids become 1000, 2000, ... ; max id is huge, few labels + hi = arr.astype(np.int64) * 1000 # ids become 1000, 2000, ... ; max id is huge, few labels measure_obs(sd_hi := _labels_sdata(hi), "lab", table_name="t") measure_obs(sd_lo := _labels_sdata(arr.astype(np.int64)), "lab", table_name="t") # relabelling values does not move pixels -> identical centroid set @@ -762,7 +758,7 @@ def _annotated_shapes(n: int = 30, *, shuffle: bool = False, drop: int = 0, seed def test_matches_get_values(self, key: str, origin: str): from spatialdata import get_values - from spatialdata_plot.pl.utils import _extract_color_column + from spatialdata_plot.pl._color import _extract_color_column sdata = self._annotated_shapes() old = pd.Series(get_values(value_key=key, sdata=sdata, element_name="shapes", table_name="table")[key]) @@ -777,7 +773,7 @@ def test_matches_get_values(self, key: str, origin: str): def test_shuffled_table_order_realigns(self): from spatialdata import get_values - from spatialdata_plot.pl.utils import _extract_color_column + from spatialdata_plot.pl._color import _extract_color_column sdata = self._annotated_shapes(shuffle=True) old = pd.Series(get_values(value_key="g0", sdata=sdata, element_name="shapes", table_name="table")["g0"]) @@ -785,7 +781,7 @@ def test_shuffled_table_order_realigns(self): np.testing.assert_allclose(old.to_numpy(float), new.to_numpy(float)) def test_missing_instances_become_nan(self): - from spatialdata_plot.pl.utils import _extract_color_column + from spatialdata_plot.pl._color import _extract_color_column sdata = self._annotated_shapes(drop=5) # 5 shapes have no annotating table row new = _extract_color_column(sdata["table"], "g0", origin="var", element=sdata["shapes"], element_name="shapes")