From 26f12e6b57fbd1724e4ec03bc2956c6caa048368 Mon Sep 17 00:00:00 2001 From: anon Date: Thu, 28 May 2026 22:48:36 +0200 Subject: [PATCH 1/8] Stain normalization: Macenko/Vahadane decomposition + estimator Fills the decomposition branches the Reinhard PR left as dispatch stubs; no public signature from that PR changes. - _validation.py: StainFittingError (carries image_key for cohort fitting), validate_stain_matrix, reorder_to_canonical, complement_third_column. - _background.py: estimate_background_intensity (per-channel high-percentile white point; auto-used by the decomposition fits when not supplied). - _mask.py: absorbance_foreground_mask (OD-space tissue mask) alongside the luminosity variant. - _decomposition.py: MacenkoParams/VahadaneParams, Macenko (SVD angular extremes) and Vahadane (sklearn sparse NMF) stain-matrix fits, fit/apply. Apply maps source->reference absorbance via one (3,3) operator and stays lazy; the source matrix is fit on a coarse level, never the full image. - _reference.py: max_concentrations field (decomposition only) so the target concentration scale travels with the reference. - _normalize.py: method-keyed param resolver, filled fit/apply branches, and decompose_stains -> (hematoxylin, eosin, residual) concentration channels. Decomposition correctness is gated by synthetic-recovery tests (planted matrix recovered within an angle tolerance); macenko/vahadane also fit, apply and decompose end-to-end on the Visium H&E image. Mandate a detect_tissue mask for the fits (contract change). The module's own absorbance/luminosity threshold masks degenerate on real data - on the Visium H&E the absorbance mask kept 99.1% of pixels (including the dark fiducial ring), which fed the fit garbage and was the real cause of Macenko's >45deg validation failure on that image. Squidpy already ships a tested detect_tissue; the stain fits now consume it instead of thresholding their own mask. - fit_stain_reference / apply_stain_normalization / decompose_stains gain a tissue_mask_key argument. A tissue mask is required: the sdata-level functions resolve tissue_mask_key (or f"{image_key}_tissue") via resolve_tissue_mask(auto_create=False) and raise an actionable error asking the caller to run detect_tissue if none exists. This changes the Reinhard contract (was: auto luminosity mask). - The DataArray-layer primitives (_tissue_od/fit/apply, fit/apply_reinhard) take an optional tissue_mask and fall back to the threshold mask when it is None, so the synthetic-image algorithm tests are unchanged. - apply_reinhard now reduces its source statistics on a coarse fit_rgb (like apply_decomposition), so the mask and stats stay small on whole slides. - Reuse: _choose_label_scale_for_image moves from _make_tiles to experimental/im/_utils.py; resolve_tissue_mask gains auto_create. With a real detect_tissue mask (33% of the H&E, fiducial ring dropped) Macenko fits cleanly and agrees with Vahadane - the earlier "Macenko fails on this image" was a masking artifact - so the H&E smoke test exercises both methods again. Keep the background white: fixed I_0 default + output composite. Two parity fixes so normalization no longer tints non-tissue/white pixels (matching HistomicsTK): - Default I_0 (background_intensity) is now a fixed full-white [255, 255, 255] (DEFAULT_BACKGROUND_INTENSITY), not an image-derived high percentile. The percentile estimate returned ~130 on a dim scan, so true-white pixels got negative absorbance and could only reconstruct as far as 130 (grey, then tinted). estimate_background_intensity stays as an opt-in helper for slides with a known non-white background. - apply_stain_normalization gains preserve_background (default True): the global colour map would recolour every non-I_0 pixel, so non-tissue pixels are composited back from the source verbatim (HistomicsTK's mask_out). The composite stays lazy via an output-resolution tissue mask. Set preserve_background=False for full-frame normalization. Verified: a colour-cast query slide normalized through the public API keeps its background byte-identical to the input ([160.5,101.1,141.1] vs [161.9,101.4,141.2]) while the tissue is recoloured. decompose_stains: store each stain as its own image; optional residual. Rather than one 3-channel image, decompose_stains now writes a separate single-channel image per stain (image_key_added as a prefix -> f"{prefix}_hematoxylin", f"{prefix}_eosin", f"{prefix}_residual") and returns a dict of named (y, x) maps when not writing. include_residual (default True) drops the residual - a decomposition-quality diagnostic (absorbance not explained by H or E: extra chromogen, artifacts, or a poor fit), not a biological stain. Provenance (method, stain matrix, white point) lives on the StainReference, not on the element (custom element attrs don't survive the zarr round-trip). Co-Authored-By: Claude Opus 4.8 --- docs/api.md | 4 + src/squidpy/experimental/im/__init__.py | 8 + src/squidpy/experimental/im/_make_tiles.py | 22 +- .../experimental/im/_stain/__init__.py | 30 +- .../experimental/im/_stain/_background.py | 67 ++++ .../experimental/im/_stain/_decomposition.py | 277 +++++++++++++++++ src/squidpy/experimental/im/_stain/_mask.py | 57 +++- .../experimental/im/_stain/_normalize.py | 292 +++++++++++++++--- .../experimental/im/_stain/_reference.py | 37 ++- .../experimental/im/_stain/_reinhard.py | 48 ++- .../experimental/im/_stain/_validation.py | 124 ++++++++ src/squidpy/experimental/im/_utils.py | 39 ++- tests/experimental/test_stain_background.py | 39 +++ .../test_stain_decompose_public.py | 149 +++++++++ .../experimental/test_stain_decomposition.py | 125 ++++++++ tests/experimental/test_stain_mask.py | 27 +- tests/experimental/test_stain_normalize.py | 87 +++++- tests/experimental/test_stain_reference.py | 12 + tests/experimental/test_stain_validation.py | 77 +++++ 19 files changed, 1425 insertions(+), 96 deletions(-) create mode 100644 src/squidpy/experimental/im/_stain/_background.py create mode 100644 src/squidpy/experimental/im/_stain/_decomposition.py create mode 100644 src/squidpy/experimental/im/_stain/_validation.py create mode 100644 tests/experimental/test_stain_background.py create mode 100644 tests/experimental/test_stain_decompose_public.py create mode 100644 tests/experimental/test_stain_decomposition.py create mode 100644 tests/experimental/test_stain_validation.py diff --git a/docs/api.md b/docs/api.md index 1bf4fdf92..a2df60a37 100644 --- a/docs/api.md +++ b/docs/api.md @@ -152,6 +152,10 @@ See the {doc}`extensibility guide ` for how to implement a custo experimental.pl.tiling_qc experimental.im.fit_stain_reference experimental.im.apply_stain_normalization + experimental.im.decompose_stains + experimental.im.estimate_background_intensity experimental.im.StainReference experimental.im.ReinhardParams + experimental.im.MacenkoParams + experimental.im.VahadaneParams ``` diff --git a/src/squidpy/experimental/im/__init__.py b/src/squidpy/experimental/im/__init__.py index 1a661b53a..12bbb2da5 100644 --- a/src/squidpy/experimental/im/__init__.py +++ b/src/squidpy/experimental/im/__init__.py @@ -10,21 +10,29 @@ from ._qc_image import qc_image from ._qc_metrics import QCMetric from ._stain import ( + MacenkoParams, ReinhardParams, StainReference, + VahadaneParams, apply_stain_normalization, + decompose_stains, + estimate_background_intensity, fit_stain_reference, ) __all__ = [ "BackgroundDetectionParams", "FelzenszwalbParams", + "MacenkoParams", "QCMetric", "ReinhardParams", "StainReference", + "VahadaneParams", "WekaParams", "apply_stain_normalization", + "decompose_stains", "detect_tissue", + "estimate_background_intensity", "fit_stain_reference", "make_tiles", "make_tiles_from_spots", diff --git a/src/squidpy/experimental/im/_make_tiles.py b/src/squidpy/experimental/im/_make_tiles.py index 6323c48ad..87f93d33b 100644 --- a/src/squidpy/experimental/im/_make_tiles.py +++ b/src/squidpy/experimental/im/_make_tiles.py @@ -8,14 +8,14 @@ from dask.base import is_dask_collection from shapely.geometry import Polygon from spatialdata._logging import logger -from spatialdata.models import Labels2DModel, ShapesModel +from spatialdata.models import ShapesModel from spatialdata.models._utils import SpatialElement from spatialdata.transformations import get_transformation, set_transformation -from squidpy._utils import _yx_from_shape from squidpy._validators import assert_in_range, assert_key_in_sdata, assert_positive from squidpy.experimental.im._utils import ( TileGrid, + _choose_label_scale_for_image, get_element_data, get_mask_materialized, save_tile_grid_to_shapes, @@ -99,24 +99,6 @@ def _get_largest_scale_dimensions( return int(img_da.shape[-2]), int(img_da.shape[-1]) -def _choose_label_scale_for_image(label_node: Labels2DModel, target_hw: tuple[int, int]) -> str: - """Pick the label scale closest to the target image height/width.""" - if not hasattr(label_node, "keys"): - return "scale0" # single-scale labels default to their only scale - target_h, target_w = target_hw - best = None - best_diff = float("inf") - for k in label_node.keys(): - y, x = _yx_from_shape(label_node[k].image.shape) - diff = abs(y - target_h) + abs(x - target_w) - if diff == 0: - return k - if diff < best_diff: - best_diff = diff - best = k - return best or "scale0" - - def _save_tiles_to_shapes( sdata: sd.SpatialData, tg: TileGrid, diff --git a/src/squidpy/experimental/im/_stain/__init__.py b/src/squidpy/experimental/im/_stain/__init__.py index 86024e4f8..bc608bc0a 100644 --- a/src/squidpy/experimental/im/_stain/__init__.py +++ b/src/squidpy/experimental/im/_stain/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from squidpy.experimental.im._stain._background import estimate_background_intensity from squidpy.experimental.im._stain._constants import ( DEFAULT_LUMINOSITY_THRESHOLD, RUDERMAN_LAB_TO_LMS, @@ -15,9 +16,19 @@ rgb_to_sda, sda_to_rgb, ) -from squidpy.experimental.im._stain._mask import luminosity_foreground_mask +from squidpy.experimental.im._stain._decomposition import ( + MacenkoParams, + VahadaneParams, + apply_decomposition, + fit_decomposition, +) +from squidpy.experimental.im._stain._mask import ( + absorbance_foreground_mask, + luminosity_foreground_mask, +) from squidpy.experimental.im._stain._normalize import ( apply_stain_normalization, + decompose_stains, fit_stain_reference, ) from squidpy.experimental.im._stain._reference import StainMethod, StainReference @@ -26,6 +37,12 @@ apply_reinhard, fit_reinhard, ) +from squidpy.experimental.im._stain._validation import ( + StainFittingError, + complement_third_column, + reorder_to_canonical, + validate_stain_matrix, +) __all__ = [ "DEFAULT_LUMINOSITY_THRESHOLD", @@ -35,16 +52,27 @@ "RUDERMAN_RGB_TO_LMS", "RUIFROK_HE", "SDA_SCALE", + "MacenkoParams", "ReinhardParams", + "StainFittingError", "StainMethod", "StainReference", + "VahadaneParams", + "absorbance_foreground_mask", + "apply_decomposition", "apply_reinhard", "apply_stain_normalization", + "complement_third_column", + "decompose_stains", + "estimate_background_intensity", + "fit_decomposition", "fit_reinhard", "fit_stain_reference", "lab_ruderman_to_rgb", "luminosity_foreground_mask", + "reorder_to_canonical", "rgb_to_lab_ruderman", "rgb_to_sda", "sda_to_rgb", + "validate_stain_matrix", ] diff --git a/src/squidpy/experimental/im/_stain/_background.py b/src/squidpy/experimental/im/_stain/_background.py new file mode 100644 index 000000000..23ff47e67 --- /dev/null +++ b/src/squidpy/experimental/im/_stain/_background.py @@ -0,0 +1,67 @@ +"""Background (white-point) intensity estimation for absorbance methods. + +The decomposition methods convert RGB to absorbance against a per-channel +white point ``I_0``. Rather than assume pure white (255), estimate it from the +brightest pixels of the slide, which are the unstained background. +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr + +from squidpy.experimental.im._stain._conversion import _check_channel_dim +from squidpy.experimental.im._stain._validation import StainFittingError + +#: Default per-channel white point ``I_0`` for the absorbance methods. A fixed +#: full-white reference (8-bit), matching HistomicsTK (255/256) and the Macenko +#: literature (240). The absorbance origin must be at least as bright as the +#: slide background, otherwise unstained pixels get a non-zero absorbance and +#: cannot round-trip back to white. Estimate from the image (see +#: ``estimate_background_intensity``) only when the slide has a genuinely +#: non-white background you want to anchor to. +DEFAULT_BACKGROUND_INTENSITY: np.ndarray = np.array([255.0, 255.0, 255.0]) + + +def estimate_background_intensity(rgb: xr.DataArray, *, percentile: float = 99.0) -> np.ndarray: + """Estimate the per-channel white point from the brightest pixels. + + Parameters + ---------- + rgb + Image with a ``"c"`` dimension of length 3. Numpy- or dask-backed. + percentile + Per-channel intensity percentile to take as the white point. The + default (99) picks near-saturated background while ignoring the few + truly-saturated outlier pixels. + + Returns + ------- + Shape-``(3,)`` float64 white point, suitable as ``background_intensity`` + for :func:`~squidpy.experimental.im._stain._conversion.rgb_to_sda`. + + Notes + ----- + The exact percentile is computed eagerly (the input is materialised), so + the result is identical for numpy- and dask-backed inputs and independent + of chunking - important for reproducible references across a cohort. Pass + a coarse pyramid level for whole-slide images. + + Raises + ------ + StainFittingError + If the estimate is not strictly positive in every channel (e.g. a + blank/black image with no bright background). + """ + if not 0.0 < percentile <= 100.0: + raise ValueError(f"`percentile` must be in (0, 100], got {percentile}.") + _check_channel_dim(rgb) + flat = np.asarray(rgb.transpose("c", "y", "x").data, dtype=np.float64).reshape(3, -1) + bg = np.percentile(flat, percentile, axis=1) + + if np.any(bg <= 0): + raise StainFittingError( + "estimated background intensity is non-positive; the image may be blank or all-tissue. " + "Pass an explicit `background_intensity` if this is expected." + ) + return bg diff --git a/src/squidpy/experimental/im/_stain/_decomposition.py b/src/squidpy/experimental/im/_stain/_decomposition.py new file mode 100644 index 000000000..072f9c871 --- /dev/null +++ b/src/squidpy/experimental/im/_stain/_decomposition.py @@ -0,0 +1,277 @@ +"""Macenko and Vahadane stain decomposition (fit + apply). + +Pure DataArray/numpy layer: no ``sdata``, no public export. The stain-matrix +fits run on tissue pixels (a bounded reduction at the chosen scale); the apply +transform is a single per-pixel matmul and stays lazy. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, fields +from typing import Any + +import numpy as np +import xarray as xr + +from squidpy.experimental.im._stain._conversion import ( + _apply_along_channel, + _check_channel_dim, + _working_dtype, + rgb_to_sda, + sda_to_rgb, +) +from squidpy.experimental.im._stain._mask import as_spatial_mask, foreground_mask_from_sda +from squidpy.experimental.im._stain._reference import StainMethod, StainReference +from squidpy.experimental.im._stain._validation import ( + StainFittingError, + _unit_columns, + complement_third_column, + reorder_to_canonical, + validate_stain_matrix, +) + +_MAXC_PERCENTILE = 99.0 +_MAXC_FLOOR = 1e-6 + + +@dataclass(slots=True, frozen=True) +class MacenkoParams: + """Tuning knobs for Macenko stain-matrix fitting.""" + + alpha: float = 1.0 + """Angular percentile (deg) for the two stain directions; the extremes are taken at ``alpha`` / ``100 - alpha``.""" + + beta: float = 0.15 + """Mean-absorbance cutoff selecting tissue pixels (optical-density space).""" + + def __post_init__(self) -> None: + object.__setattr__(self, "alpha", float(self.alpha)) + object.__setattr__(self, "beta", float(self.beta)) + if not 0.0 < self.alpha < 50.0: + raise ValueError(f"`alpha` must be in (0, 50), got {self.alpha}.") + if self.beta < 0.0: + raise ValueError(f"`beta` must be >= 0, got {self.beta}.") + + +@dataclass(slots=True, frozen=True) +class VahadaneParams: + """Tuning knobs for Vahadane (sparse-NMF) stain-matrix fitting.""" + + beta: float = 0.15 + """Mean-absorbance cutoff selecting tissue pixels (optical-density space).""" + + lambda1: float = 0.1 + """L1 sparsity regularisation on the concentration factor of the NMF.""" + + n_iter: int = 200 + """Maximum NMF iterations.""" + + random_state: int | None = 0 + """Seed for NMF initialisation tie-breaking; fixed for reproducible fits.""" + + def __post_init__(self) -> None: + object.__setattr__(self, "beta", float(self.beta)) + object.__setattr__(self, "lambda1", float(self.lambda1)) + object.__setattr__(self, "n_iter", int(self.n_iter)) + if self.beta < 0.0: + raise ValueError(f"`beta` must be >= 0, got {self.beta}.") + if self.lambda1 < 0.0: + raise ValueError(f"`lambda1` must be >= 0, got {self.lambda1}.") + if self.n_iter < 1: + raise ValueError(f"`n_iter` must be >= 1, got {self.n_iter}.") + + +_MACENKO_DEFAULTS = MacenkoParams() +_VAHADANE_DEFAULTS = VahadaneParams() +_MACENKO_FIELDS = frozenset(f.name for f in fields(MacenkoParams)) +_VAHADANE_FIELDS = frozenset(f.name for f in fields(VahadaneParams)) + + +def _resolve_params(params: Any, cls: type, defaults: Any, valid_fields: frozenset[str]) -> Any: + if params is None: + return defaults + if isinstance(params, cls): + return params + if isinstance(params, Mapping): + unknown = set(params) - valid_fields + if unknown: + raise ValueError( + f"Unknown `method_params` field(s): {sorted(unknown)}; expected from {sorted(valid_fields)}." + ) + return cls(**params) + raise TypeError(f"`method_params` must be {cls.__name__}, Mapping, or None; got {type(params).__name__}.") + + +def _resolve_macenko_params(params: MacenkoParams | Mapping[str, Any] | None) -> MacenkoParams: + return _resolve_params(params, MacenkoParams, _MACENKO_DEFAULTS, _MACENKO_FIELDS) + + +def _resolve_vahadane_params(params: VahadaneParams | Mapping[str, Any] | None) -> VahadaneParams: + return _resolve_params(params, VahadaneParams, _VAHADANE_DEFAULTS, _VAHADANE_FIELDS) + + +def _tissue_od( + image_rgb: xr.DataArray, + background_intensity: np.ndarray, + beta: float, + *, + tissue_mask: np.ndarray | None = None, + image_key: str | None, +) -> np.ndarray: + """Flatten tissue pixels to an ``(N, 3)`` optical-density matrix. + + Reduces over the chosen scale only (bounded); the stain fits need the + tissue pixels resident for SVD/NMF, so this is the one materialising step. + When ``tissue_mask`` (a ``(y, x)`` boolean aligned to ``image_rgb``) is + given it selects the tissue pixels; otherwise the absorbance threshold + ``beta`` is used. + """ + sda = rgb_to_sda(image_rgb, background_intensity) + mask = as_spatial_mask(tissue_mask, sda) if tissue_mask is not None else foreground_mask_from_sda(sda, beta) + od = np.asarray(sda.where(mask).transpose("c", "y", "x").data.reshape(3, -1)).T + od = od[np.all(np.isfinite(od), axis=1)] + if od.shape[0] == 0: + raise StainFittingError("no tissue pixels for stain fitting; the mask is empty.", image_key=image_key) + # Keep signed OD: pixels brighter than the estimated background carry + # negative absorbance that Macenko's SVD legitimately uses. Only Vahadane's + # NMF requires non-negativity, and clips locally. + return od + + +def _macenko_stain_matrix(od: np.ndarray, alpha: float) -> np.ndarray: + """Recover a ``(3, 2)`` H/E matrix via Macenko's angular-extreme method.""" + # right singular vectors of OD = principal absorbance directions through 0 + _, _, vh = np.linalg.svd(od, full_matrices=False) + plane = vh[:2].T # (3, 2) + # SVD sign is arbitrary; orient the basis into the data so the projected + # angles cluster around 0 instead of straddling the atan2 branch cut at + # +-180 deg (which would collapse the angular percentiles). + signs = np.sign(od.mean(axis=0) @ plane) + signs[signs == 0] = 1.0 + plane = plane * signs + proj = od @ plane # (N, 2) + phi = np.arctan2(proj[:, 1], proj[:, 0]) + lo, hi = np.percentile(phi, [alpha, 100.0 - alpha]) + extremes = np.stack( + [plane @ np.array([np.cos(lo), np.sin(lo)]), plane @ np.array([np.cos(hi), np.sin(hi)])], + axis=1, + ) + return _unit_columns(extremes) + + +def _vahadane_stain_matrix(od: np.ndarray, params: VahadaneParams) -> np.ndarray: + """Recover a ``(3, 2)`` H/E matrix via sparse NMF (Vahadane).""" + from sklearn.decomposition import NMF + + nmf = NMF( + n_components=2, + init="nndsvda", + random_state=params.random_state, + alpha_W=params.lambda1, + l1_ratio=1.0, + max_iter=params.n_iter, + ) + nmf.fit(np.clip(od, 0.0, None)) # NMF requires non-negative absorbance + stains = nmf.components_.T # (3, 2) + if np.any(np.linalg.norm(stains, axis=0) < 1e-8): + raise StainFittingError("Vahadane NMF produced a zero-norm stain vector.") + return _unit_columns(stains) + + +def _stain_matrix(od: np.ndarray, method: StainMethod, params: Any, *, image_key: str | None) -> np.ndarray: + """Fit, canonicalise, complete and validate a ``(3, 3)`` stain matrix.""" + raw = _macenko_stain_matrix(od, params.alpha) if method == "macenko" else _vahadane_stain_matrix(od, params) + matrix = complement_third_column(reorder_to_canonical(raw)) + validate_stain_matrix(matrix, image_key=image_key) + return matrix + + +def _concentrations(od: np.ndarray, stain_matrix: np.ndarray) -> np.ndarray: + """Per-pixel stain concentrations ``(N, 3)`` from optical density.""" + return od @ np.linalg.pinv(stain_matrix).T + + +def _max_concentrations(concentrations: np.ndarray) -> np.ndarray: + """Robust per-stain (H, E) maximum concentrations ``(2,)`` from an ``(N, 3)`` array.""" + return np.maximum(np.percentile(concentrations[:, :2], _MAXC_PERCENTILE, axis=0), _MAXC_FLOOR) + + +def fit_decomposition( + image_rgb: xr.DataArray, + method: StainMethod, + params: Any, + background_intensity: np.ndarray, + *, + tissue_mask: np.ndarray | None = None, + image_key: str | None = None, +) -> StainReference: + """Fit a decomposition :class:`StainReference` (stain matrix + max concentrations).""" + od = _tissue_od(image_rgb, background_intensity, params.beta, tissue_mask=tissue_mask, image_key=image_key) + matrix = _stain_matrix(od, method, params, image_key=image_key) + return StainReference( + method=method, + stain_matrix=matrix, + background_intensity=np.asarray(background_intensity, dtype=np.float64), + max_concentrations=_max_concentrations(_concentrations(od, matrix)), + ) + + +def _matmul_kernel(x: np.ndarray, *, matrix: np.ndarray, dtype: np.dtype) -> np.ndarray: + return (x.astype(dtype, copy=False) @ matrix.T).astype(dtype, copy=False) + + +def apply_decomposition( + image_rgb: xr.DataArray, + reference: StainReference, + params: Any, + *, + fit_rgb: xr.DataArray | None = None, + tissue_mask: np.ndarray | None = None, +) -> xr.DataArray: + """Normalize a source image to a decomposition reference. + + Fits the *source's* own stain matrix and concentration scale, then maps + source absorbance to reference absorbance via a single ``(3, 3)`` linear + operator so the per-pixel transform stays lazy. + + The source matrix is a colour property, so it is fit on ``fit_rgb`` (a + coarse level) when given, while ``image_rgb`` (which may be full + resolution) is only ever touched by the lazy operator - never + materialised to fit a matrix. + """ + _check_channel_dim(image_rgb) + if reference.max_concentrations is None: + raise ValueError("reference is missing max_concentrations; refit it with fit_stain_reference.") + bg = reference.background_intensity + + od_src = _tissue_od( + fit_rgb if fit_rgb is not None else image_rgb, bg, params.beta, tissue_mask=tissue_mask, image_key=None + ) + w_src = _stain_matrix(od_src, reference.method, params, image_key=None) + pinv_src = np.linalg.pinv(w_src) # reused for the source concentrations and the operator + maxc_src = _max_concentrations(od_src @ pinv_src.T) + + scale = np.ones(3) + scale[:2] = reference.max_concentrations / maxc_src + operator = reference.stain_matrix @ np.diag(scale) @ pinv_src + + sda = rgb_to_sda(image_rgb, bg) + dtype = _working_dtype(sda) + sda_out = _apply_along_channel(sda, _matmul_kernel, out_dtype=dtype, matrix=operator.astype(dtype), dtype=dtype) + return sda_to_rgb(sda_out, bg) + + +def decompose_to_concentrations( + image_rgb: xr.DataArray, stain_matrix: np.ndarray, background_intensity: np.ndarray +) -> xr.DataArray: + """Project an image onto a stain matrix, returning a 3-channel concentration image. + + Channels are ``(hematoxylin, eosin, residual)``; the residual is the + concentration along the complement vector and is a diagnostic, not a stain. + """ + _check_channel_dim(image_rgb) + sda = rgb_to_sda(image_rgb, background_intensity) + dtype = _working_dtype(sda) + pinv = np.linalg.pinv(stain_matrix) + return _apply_along_channel(sda, _matmul_kernel, out_dtype=dtype, matrix=pinv.astype(dtype), dtype=dtype) diff --git a/src/squidpy/experimental/im/_stain/_mask.py b/src/squidpy/experimental/im/_stain/_mask.py index 87b12e960..3e35cee16 100644 --- a/src/squidpy/experimental/im/_stain/_mask.py +++ b/src/squidpy/experimental/im/_stain/_mask.py @@ -1,11 +1,10 @@ """Foreground (tissue) masking for stain fitting. Method-agnostic on purpose: Reinhard fits its channel statistics over tissue -pixels only, and the Macenko/Vahadane fits added later need the same kind of -mask. The luminosity variant lives here; the absorbance variant -(``absorbance_foreground_mask``) is added beside it when decomposition lands, -both returning the same ``(y, x)`` boolean contract so downstream statistics -code stays mask-source-agnostic. +pixels only, and the Macenko/Vahadane fits need the same kind of mask. Two +variants live here - luminosity (Reinhard, intensity space) and absorbance +(decomposition, optical-density space) - both returning the same ``(y, x)`` +boolean contract so downstream statistics code stays mask-source-agnostic. """ from __future__ import annotations @@ -20,6 +19,7 @@ from squidpy.experimental.im._stain._conversion import ( _check_channel_dim, rgb_to_lab_ruderman, + rgb_to_sda, ) @@ -74,3 +74,50 @@ def luminosity_foreground_mask(rgb: xr.DataArray, threshold: float) -> xr.DataAr """ _check_channel_dim(rgb) return foreground_mask_from_lab(rgb_to_lab_ruderman(rgb), threshold) + + +def as_spatial_mask(mask: np.ndarray, like: xr.DataArray) -> xr.DataArray: + """Wrap a ``(y, x)`` boolean array as a DataArray aligned to ``like``'s y/x. + + Copies ``like``'s ``y``/``x`` coords (when present) so ``like.where(...)`` + aligns by coordinate rather than silently broadcasting. ``mask`` must match + ``like`` in the spatial dims. + """ + coords = {d: like.coords[d] for d in ("y", "x") if d in like.coords} + return xr.DataArray(np.asarray(mask, dtype=bool), dims=("y", "x"), coords=coords) + + +def foreground_mask_from_sda(sda: xr.DataArray, beta: float = 0.15) -> xr.DataArray: + """Tissue mask from an already-computed optical-density (SDA) image. + + The absorbance-space sibling of :func:`foreground_mask_from_lab`: lets the + decomposition fit derive the mask from the same lazy ``sda`` graph it + already needs for the optical densities, so dask materialises the + RGB->SDA conversion once. ``True`` = tissue (mean absorbance ``> beta``). + """ + return sda.mean(dim="c") > beta + + +def absorbance_foreground_mask(rgb: xr.DataArray, background_intensity: np.ndarray, beta: float = 0.15) -> xr.DataArray: + """Boolean tissue mask in optical-density (absorbance) space. + + The convention the Macenko/Vahadane fits expect: a pixel is tissue if its + mean absorbance across channels exceeds ``beta``. Near-white background + has near-zero absorbance and is excluded. + + Parameters + ---------- + rgb + Image with a ``"c"`` dimension of length 3. Numpy- or dask-backed. + background_intensity + Per-channel white point ``I_0`` (shape ``(3,)``), as used by + :func:`~squidpy.experimental.im._stain._conversion.rgb_to_sda`. + beta + Mean-absorbance cutoff. Pixels with mean SDA ``> beta`` are tissue. + + Returns + ------- + Boolean ``(y, x)`` DataArray: ``True`` = tissue. Lazy if ``rgb`` was lazy. + """ + _check_channel_dim(rgb) + return foreground_mask_from_sda(rgb_to_sda(rgb, background_intensity), beta) diff --git a/src/squidpy/experimental/im/_stain/_normalize.py b/src/squidpy/experimental/im/_stain/_normalize.py index 3178985b6..76c619755 100644 --- a/src/squidpy/experimental/im/_stain/_normalize.py +++ b/src/squidpy/experimental/im/_stain/_normalize.py @@ -5,9 +5,9 @@ re-exported publicly. Everything it calls is a pure DataArray-layer primitive (:mod:`._reinhard`, :mod:`._mask`, :mod:`._conversion`). -Both entry points dispatch on the fitting ``method``. Only ``"reinhard"`` is -implemented here; ``"macenko"``/``"vahadane"`` raise ``NotImplementedError`` -and are filled in without changing these signatures. +Both entry points dispatch on the fitting ``method`` (``"reinhard"`` colour +transfer, or ``"macenko"``/``"vahadane"`` absorbance decomposition); a third +entry, :func:`decompose_stains`, projects an image onto its stain matrix. """ from __future__ import annotations @@ -15,13 +15,24 @@ from collections.abc import Mapping from typing import Any, Literal +import numpy as np import spatialdata as sd import xarray as xr from spatialdata.models import Image2DModel from spatialdata.transformations import get_transformation from squidpy._utils import _get_scale_factors +from squidpy.experimental.im._stain._background import DEFAULT_BACKGROUND_INTENSITY from squidpy.experimental.im._stain._conversion import _check_channel_dim +from squidpy.experimental.im._stain._decomposition import ( + MacenkoParams, + VahadaneParams, + _resolve_macenko_params, + _resolve_vahadane_params, + apply_decomposition, + decompose_to_concentrations, + fit_decomposition, +) from squidpy.experimental.im._stain._reference import StainMethod, StainReference from squidpy.experimental.im._stain._reinhard import ( ReinhardParams, @@ -29,9 +40,19 @@ apply_reinhard, fit_reinhard, ) -from squidpy.experimental.im._utils import get_element_data +from squidpy.experimental.im._utils import ( + _choose_label_scale_for_image, + get_element_data, + get_mask_materialized, + resolve_tissue_mask, +) -_DECOMPOSITION_NOT_IMPLEMENTED = "macenko/vahadane decomposition is not yet implemented" +_VALID_METHODS = ("reinhard", "macenko", "vahadane") +_DECOMPOSITION_METHODS = ("macenko", "vahadane") +_CONCENTRATION_CHANNELS = ["hematoxylin", "eosin", "residual"] + +# Public union accepted by the method_params argument of the dispatchers. +MethodParams = ReinhardParams | MacenkoParams | VahadaneParams | Mapping[str, Any] | None def _resolve_image( @@ -49,13 +70,102 @@ def _resolve_image( return da +def _resolve_tissue_bool_mask( + sdata: sd.SpatialData, image_key: str, fit_da: xr.DataArray, tissue_mask_key: str | None +) -> np.ndarray: + """Return a ``(y, x)`` boolean tissue mask aligned to ``fit_da``. + + Consumes a :func:`~squidpy.experimental.im.detect_tissue` labels element + (mandatory - raises if none exists), picks the label scale closest to + ``fit_da``, materialises it, and nearest-resizes to ``fit_da``'s ``(y, x)`` + when the resolutions differ. The stain fits run on a coarse level, so the + mask stays small. + """ + mask_key = resolve_tissue_mask(sdata, image_key, "auto", tissue_mask_key, auto_create=False) + target_hw = (int(fit_da.sizes["y"]), int(fit_da.sizes["x"])) + label_scale = _choose_label_scale_for_image(sdata.labels[mask_key], target_hw) + mask = get_mask_materialized(sdata, mask_key, label_scale) > 0 + if mask.shape != target_hw: + from skimage.transform import resize + + mask = resize(mask, target_hw, order=0, preserve_range=True) > 0.5 + return mask + + +def _resolve_output_tissue_mask( + sdata: sd.SpatialData, image_key: str, target_da: xr.DataArray, tissue_mask_key: str | None +) -> xr.DataArray: + """Return a lazy ``(y, x)`` boolean tissue mask aligned to ``target_da``. + + Like :func:`_resolve_tissue_bool_mask` but kept lazy and at the (full-res) + output resolution, for compositing the original background back into the + normalized image without materialising the full frame. The label pyramid + shares the image's scale factors, so the matching level usually lines up + exactly; only a residual size mismatch forces a (small) eager resize. + """ + mask_key = resolve_tissue_mask(sdata, image_key, "auto", tissue_mask_key, auto_create=False) + target_hw = (int(target_da.sizes["y"]), int(target_da.sizes["x"])) + label_scale = _choose_label_scale_for_image(sdata.labels[mask_key], target_hw) + coords = {d: target_da.coords[d] for d in ("y", "x") if d in target_da.coords} + mask = get_element_data(sdata.labels[mask_key], label_scale, "label", mask_key).squeeze() > 0 + if (int(mask.sizes["y"]), int(mask.sizes["x"])) == target_hw: + return mask.assign_coords(coords) + from skimage.transform import resize + + resized = resize(np.asarray(mask.data) > 0, target_hw, order=0, preserve_range=True) > 0.5 + return xr.DataArray(resized, dims=("y", "x"), coords=coords) + + +def _resolve_method_params(method: str, method_params: MethodParams) -> Any: + """Pick the right Params dataclass for ``method`` and resolve a mapping/instance/None.""" + if method == "reinhard": + return _resolve_reinhard_params(method_params) + if method == "macenko": + return _resolve_macenko_params(method_params) + if method == "vahadane": + return _resolve_vahadane_params(method_params) + raise ValueError(f"Unknown method {method!r}; expected one of {list(_VALID_METHODS)}.") + + +def _write_image( + sdata: sd.SpatialData, + source_node: Any, + image_key_added: str, + data_array: xr.DataArray, + *, + c_coords: list[Any] | None = None, +) -> None: + """Write a derived image element, preserving the source's transforms/pyramid. + + Reconstructs the element from the bare array (a derived DataArray would + carry the source's ``transform`` attr and collide with the transformations + we pass) plus the dims/channel-coords/transforms to preserve. The same + idiom as detect_tissue. ``_get_scale_factors`` returns ``[]`` for a + single-scale source; parse needs ``None`` there (an empty list builds a + degenerate single-level pyramid). + """ + if image_key_added in sdata.images: + raise ValueError(f"image_key_added={image_key_added!r} already exists in sdata.images.") + if c_coords is None: + c_coords = data_array.coords["c"].values.tolist() if "c" in data_array.coords else None + sdata.images[image_key_added] = Image2DModel.parse( + data_array.data, + dims=data_array.dims, + c_coords=c_coords, + transformations=get_transformation(source_node, get_all=True), + scale_factors=_get_scale_factors(source_node) or None, + ) + + def fit_stain_reference( sdata: sd.SpatialData, image_key: str, *, method: StainMethod = "reinhard", scale: str | Literal["auto"] = "auto", - method_params: ReinhardParams | Mapping[str, Any] | None = None, + method_params: MethodParams = None, + background_intensity: np.ndarray | None = None, + tissue_mask_key: str | None = None, ) -> StainReference: """Fit a stain reference from an image in a :class:`~spatialdata.SpatialData` object. @@ -66,25 +176,45 @@ def fit_stain_reference( image_key Key of the RGB image in ``sdata.images`` to fit on. method - Fitting method. Only ``"reinhard"`` is implemented; ``"macenko"`` and - ``"vahadane"`` raise :class:`NotImplementedError`. + Fitting method: ``"reinhard"`` (colour transfer), ``"macenko"`` or + ``"vahadane"`` (stain-matrix decomposition). scale Scale level to fit on. ``"auto"`` (default) uses the coarsest level, which is cheap and sufficient for colour statistics. method_params - :class:`ReinhardParams` instance, a mapping of its fields, or ``None`` - for defaults. + A :class:`ReinhardParams`/:class:`MacenkoParams`/:class:`VahadaneParams` + instance, a mapping of its fields, or ``None`` for defaults. Must match + ``method``. + background_intensity + Per-channel white point ``I_0`` ``(3,)`` for the decomposition methods. + If ``None``, a fixed full-white ``[255, 255, 255]`` is used (the + HistomicsTK/Macenko convention), so unstained pixels round-trip to + white. Pass :func:`estimate_background_intensity` only for slides with a + known non-white background. Ignored by Reinhard. + tissue_mask_key + Key of a tissue-label element in ``sdata.labels`` (as produced by + :func:`~squidpy.experimental.im.detect_tissue`) restricting the fit to + tissue pixels. If ``None``, ``f"{image_key}_tissue"`` is used. A tissue + mask is **required**: if neither exists, a :class:`KeyError` asks you to + run :func:`~squidpy.experimental.im.detect_tissue` first. Returns ------- The fitted :class:`StainReference`. Nothing is written to ``sdata``. """ + if method not in _VALID_METHODS: + raise ValueError(f"Unknown method {method!r}; expected one of {list(_VALID_METHODS)}.") da = _resolve_image(sdata, image_key, scale, prefer="coarsest") + params = _resolve_method_params(method, method_params) + tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, da, tissue_mask_key) if method == "reinhard": - return fit_reinhard(da, _resolve_reinhard_params(method_params)) - if method in {"macenko", "vahadane"}: - raise NotImplementedError(_DECOMPOSITION_NOT_IMPLEMENTED) - raise ValueError(f"Unknown method {method!r}; expected one of ['macenko', 'reinhard', 'vahadane'].") + return fit_reinhard(da, params, tissue_mask=tissue_mask) + bg = ( + DEFAULT_BACKGROUND_INTENSITY.copy() + if background_intensity is None + else np.asarray(background_intensity, np.float64) + ) + return fit_decomposition(da, method, params, bg, tissue_mask=tissue_mask, image_key=image_key) def apply_stain_normalization( @@ -93,8 +223,10 @@ def apply_stain_normalization( reference: StainReference, *, scale: str | Literal["auto"] = "auto", - method_params: ReinhardParams | Mapping[str, Any] | None = None, + method_params: MethodParams = None, image_key_added: str | None = None, + tissue_mask_key: str | None = None, + preserve_background: bool = True, ) -> xr.DataArray | None: """Normalize an image to a fitted stain reference. @@ -112,14 +244,23 @@ def apply_stain_normalization( so the result is not downsampled; source statistics are reduced lazily so memory stays bounded. method_params - :class:`ReinhardParams` instance, a mapping of its fields, or ``None`` - for defaults. + Params matching ``reference.method`` (instance, mapping, or ``None``). image_key_added If ``None`` (default), return the lazy normalized DataArray and leave ``sdata`` untouched. If given, write the result to ``sdata.images[image_key_added]`` (rebuilding the pyramid for multiscale sources, preserving transforms) and return ``None``. Raises if the key already exists. + tissue_mask_key + Key of a tissue-label element in ``sdata.labels`` restricting the + *source* statistics to tissue pixels. As for + :func:`fit_stain_reference`, a tissue mask is required (defaults to + ``f"{image_key}_tissue"``; raises if missing). + preserve_background + If ``True`` (default), non-tissue (background) pixels are passed through + unchanged from the source image, so the normalization recolours only + tissue. The colour map is a global linear transform that would otherwise + tint background/white pixels. Set ``False`` for full-frame normalization. Returns ------- @@ -127,32 +268,103 @@ def apply_stain_normalization( ``None``, otherwise ``None``. """ da = _resolve_image(sdata, image_key, scale, prefer="finest") + params = _resolve_method_params(reference.method, method_params) + # Source statistics (Reinhard mu/sigma or the decomposition source matrix) + # are reduced on a coarse level with a tissue mask; the lazy transform is + # then applied to the full-resolution `da`. + fit_rgb = _resolve_image(sdata, image_key, scale, prefer="coarsest") + tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, fit_rgb, tissue_mask_key) if reference.method == "reinhard": - normalized = apply_reinhard(da, reference, _resolve_reinhard_params(method_params)) - elif reference.method in {"macenko", "vahadane"}: - raise NotImplementedError(_DECOMPOSITION_NOT_IMPLEMENTED) - else: # pragma: no cover - StainReference validates method on construction - raise ValueError(f"Unknown reference method {reference.method!r}.") + normalized = apply_reinhard(da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask) + else: + normalized = apply_decomposition(da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask) + + if preserve_background: + # Keep non-tissue pixels byte-identical to the source: the global colour + # map would otherwise recolour background/white pixels (HistomicsTK's + # `mask_out`). Stays lazy - the mask aligns to `da` without materialising. + keep = _resolve_output_tissue_mask(sdata, image_key, da, tissue_mask_key) + normalized = normalized.where(keep, da) if image_key_added is None: return normalized - if image_key_added in sdata.images: - raise ValueError(f"image_key_added={image_key_added!r} already exists in sdata.images.") + _write_image(sdata, sdata.images[image_key], image_key_added, normalized) + return None - node = sdata.images[image_key] - # Reconstruct the element explicitly from the underlying array: parse a - # DataArray would carry over the source's transform attr and collide with - # the transformations we pass, so we hand it the bare array plus the - # dims/channel coords/transforms we want to preserve (the same idiom as - # detect_tissue). `_get_scale_factors` returns [] for a single-scale - # source; parse needs None there (an empty list builds a degenerate - # single-level pyramid). - c_coords = normalized.coords["c"].values.tolist() if "c" in normalized.coords else None - sdata.images[image_key_added] = Image2DModel.parse( - normalized.data, - dims=normalized.dims, - c_coords=c_coords, - transformations=get_transformation(node, get_all=True), - scale_factors=_get_scale_factors(node) or None, - ) + +def decompose_stains( + sdata: sd.SpatialData, + image_key: str, + reference_or_method: StainReference | Literal["macenko", "vahadane"], + *, + scale: str | Literal["auto"] = "auto", + method_params: MethodParams = None, + background_intensity: np.ndarray | None = None, + image_key_added: str | None = None, + tissue_mask_key: str | None = None, + include_residual: bool = True, +) -> dict[str, xr.DataArray] | None: + """Decompose an image into separate per-stain concentration maps. + + Parameters + ---------- + sdata, image_key + The SpatialData object and the RGB image key to decompose. + reference_or_method + Either a decomposition :class:`StainReference` (its stain matrix and + white point are used) or a method name (``"macenko"``/``"vahadane"``) + to fit on this image first. The reference is the provenance record of + how the maps were produced (method, stain matrix, white point). + scale, method_params, background_intensity, tissue_mask_key + As for :func:`fit_stain_reference` (only used when a method name is + given; a reference is projected as-is and needs no tissue mask). + image_key_added + If ``None`` (default), return the concentration maps as a dict. If + given, used as a key *prefix*: each stain is written as its own + single-channel image ``sdata.images[f"{image_key_added}_{stain}"]`` + (e.g. ``f"{image_key_added}_hematoxylin"``), and ``None`` is returned. + Raises if any target key already exists. + include_residual + If ``True`` (default), also produce the ``"residual"`` map. The residual + is the absorbance along the complement direction - a diagnostic of + decomposition quality (extra chromogen, artifacts, or a poor fit), not a + biological stain. Set ``False`` to keep only ``hematoxylin``/``eosin``. + + Returns + ------- + If ``image_key_added`` is ``None``, a ``dict`` mapping each stain name to + its ``(y, x)`` concentration :class:`~xarray.DataArray` + (``"hematoxylin"``, ``"eosin"``, and ``"residual"`` unless dropped). + Otherwise ``None`` (the maps are written as separate images). + """ + da = _resolve_image(sdata, image_key, scale, prefer="finest") + if isinstance(reference_or_method, StainReference): + reference = reference_or_method + if reference.method not in _DECOMPOSITION_METHODS or reference.stain_matrix is None: + raise ValueError("decompose_stains requires a macenko/vahadane reference with a stain matrix.") + stain_matrix, bg = reference.stain_matrix, reference.background_intensity + else: + if reference_or_method not in _DECOMPOSITION_METHODS: + raise ValueError(f"method must be one of {list(_DECOMPOSITION_METHODS)}; got {reference_or_method!r}.") + reference = fit_stain_reference( + sdata, + image_key, + method=reference_or_method, + scale=scale, + method_params=method_params, + background_intensity=background_intensity, + tissue_mask_key=tissue_mask_key, + ) + stain_matrix, bg = reference.stain_matrix, reference.background_intensity + + concentrations = decompose_to_concentrations(da, stain_matrix, bg).assign_coords(c=_CONCENTRATION_CHANNELS) + names = ["hematoxylin", "eosin"] + (["residual"] if include_residual else []) + + if image_key_added is None: + return {name: concentrations.sel(c=name) for name in names} + + source = sdata.images[image_key] + for name in names: + # keep the c dim (length 1) so Image2DModel.parse accepts it + _write_image(sdata, source, f"{image_key_added}_{name}", concentrations.sel(c=[name]), c_coords=[name]) return None diff --git a/src/squidpy/experimental/im/_stain/_reference.py b/src/squidpy/experimental/im/_stain/_reference.py index 4b1a7fa06..80687fe61 100644 --- a/src/squidpy/experimental/im/_stain/_reference.py +++ b/src/squidpy/experimental/im/_stain/_reference.py @@ -27,7 +27,7 @@ def _coerce_finite(arr: Any, *, shape: tuple[int, ...], name: str) -> np.ndarray return out -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class StainReference: """Container for a fitted stain reference. @@ -48,7 +48,14 @@ class StainReference: decomposition methods (apply consumes it). Forbidden for Reinhard because Reinhard's color transfer operates in Ruderman Lab and does not model absorbance. There is no universal default; pass an - estimate from your data (PR 3 ships the estimator). + estimate from your data (see ``estimate_background_intensity``). + max_concentrations + Shape ``(2,)`` reference per-stain (H, E) maximum concentrations. + Decomposition only. Stored because at apply time the reference image + is gone, so the target concentration scale must travel with the + reference. Optional (Reinhard references and externally-built + decomposition references without it remain valid); forbidden for + Reinhard. """ method: StainMethod @@ -56,6 +63,25 @@ class StainReference: mu: np.ndarray | None = None sigma: np.ndarray | None = None background_intensity: np.ndarray | None = None + max_concentrations: np.ndarray | None = None + + def __eq__(self, other: object) -> bool: + # The numpy-array fields make the dataclass-generated __eq__ raise + # ("truth value of an array is ambiguous"), so compare explicitly: + # equal method plus element-wise-equal arrays. + if not isinstance(other, StainReference): + return NotImplemented + if self.method != other.method: + return False + return all( + np.array_equal(getattr(self, name), getattr(other, name)) + for name in ("stain_matrix", "mu", "sigma", "background_intensity", "max_concentrations") + ) + + # eq=False keeps the default identity-based __hash__ (the array fields are + # unhashable, so a value-based hash is impossible); references remain usable + # as set members / dict keys by identity. + __hash__ = object.__hash__ def __post_init__(self) -> None: if self.method not in _VALID_METHODS: @@ -77,6 +103,11 @@ def __post_init__(self) -> None: if np.any(bg <= 0): raise ValueError("background_intensity must be strictly positive.") object.__setattr__(self, "background_intensity", bg) + if self.max_concentrations is not None: + maxc = _coerce_finite(self.max_concentrations, shape=(2,), name="max_concentrations") + if np.any(maxc <= 0): + raise ValueError("max_concentrations must be strictly positive.") + object.__setattr__(self, "max_concentrations", maxc) else: if self.mu is None or self.sigma is None: raise ValueError("method='reinhard' requires both mu and sigma.") @@ -87,6 +118,8 @@ def __post_init__(self) -> None: "method='reinhard' forbids background_intensity; Reinhard's color " "transfer is in Ruderman Lab and does not use a white point." ) + if self.max_concentrations is not None: + raise ValueError("method='reinhard' forbids max_concentrations.") mu = _coerce_finite(self.mu, shape=(3,), name="mu") sigma = _coerce_finite(self.sigma, shape=(3,), name="sigma") if np.any(sigma <= 0): diff --git a/src/squidpy/experimental/im/_stain/_reinhard.py b/src/squidpy/experimental/im/_stain/_reinhard.py index a3f425c4a..cae89faed 100644 --- a/src/squidpy/experimental/im/_stain/_reinhard.py +++ b/src/squidpy/experimental/im/_stain/_reinhard.py @@ -22,7 +22,7 @@ lab_ruderman_to_rgb, rgb_to_lab_ruderman, ) -from squidpy.experimental.im._stain._mask import foreground_mask_from_lab +from squidpy.experimental.im._stain._mask import as_spatial_mask, foreground_mask_from_lab from squidpy.experimental.im._stain._reference import StainReference # Numerical safeguard against divide-by-zero on flat (constant-colour) @@ -112,34 +112,58 @@ def _transfer_kernel( return ((x - mu_src) / sigma_src * sigma_ref + mu_ref).astype(dtype, copy=False) -def fit_reinhard(image_rgb: xr.DataArray, params: ReinhardParams) -> StainReference: +def _reinhard_mask(lab: xr.DataArray, params: ReinhardParams, tissue_mask: np.ndarray | None) -> xr.DataArray | None: + """Resolve the tissue mask for the Reinhard stats: external mask wins, else + the param-driven luminosity mask (or ``None`` for vanilla Reinhard).""" + if tissue_mask is not None: + return as_spatial_mask(tissue_mask, lab) + if params.mask_background: + return foreground_mask_from_lab(lab, params.luminosity_threshold) + return None + + +def fit_reinhard( + image_rgb: xr.DataArray, params: ReinhardParams, *, tissue_mask: np.ndarray | None = None +) -> StainReference: """Fit Reinhard channel statistics on a reference image. Converts to Ruderman Lab, computes per-channel ``mu``/``sigma`` over - tissue pixels (or all pixels when ``mask_background=False``), and packs - them into a ``StainReference(method="reinhard")``. + tissue pixels, and packs them into a ``StainReference(method="reinhard")``. + ``tissue_mask`` (a ``(y, x)`` boolean aligned to ``image_rgb``) selects the + tissue pixels when given; otherwise the ``mask_background`` / + ``luminosity_threshold`` params drive the mask. """ _check_channel_dim(image_rgb) lab = rgb_to_lab_ruderman(image_rgb) - mask = foreground_mask_from_lab(lab, params.luminosity_threshold) if params.mask_background else None - mu, sigma = _masked_channel_stats(lab, mask) + mu, sigma = _masked_channel_stats(lab, _reinhard_mask(lab, params, tissue_mask)) return StainReference(method="reinhard", mu=mu, sigma=sigma) -def apply_reinhard(image_rgb: xr.DataArray, reference: StainReference, params: ReinhardParams) -> xr.DataArray: +def apply_reinhard( + image_rgb: xr.DataArray, + reference: StainReference, + params: ReinhardParams, + *, + fit_rgb: xr.DataArray | None = None, + tissue_mask: np.ndarray | None = None, +) -> xr.DataArray: """Apply a Reinhard reference to a source image. Standardises by the source's own tissue statistics, rescales to the reference statistics, and converts back to RGB. The transform is applied - to every pixel (the map is global); only the statistics that define it - are tissue-only. Lazy if and only if the input is lazy. + to every pixel of ``image_rgb`` (the map is global); the defining + statistics are reduced on ``fit_rgb`` (a coarse level) when given, so the + full-resolution image is never materialised to compute them. + ``tissue_mask`` (aligned to ``fit_rgb``) selects the source tissue pixels. + Lazy if and only if ``image_rgb`` is lazy. """ _check_channel_dim(image_rgb) - lab = rgb_to_lab_ruderman(image_rgb) - mask = foreground_mask_from_lab(lab, params.luminosity_threshold) if params.mask_background else None - mu_src, sigma_src = _masked_channel_stats(lab, mask) + fit_lab = rgb_to_lab_ruderman(fit_rgb if fit_rgb is not None else image_rgb) + mu_src, sigma_src = _masked_channel_stats(fit_lab, _reinhard_mask(fit_lab, params, tissue_mask)) sigma_src = np.maximum(sigma_src, _SIGMA_FLOOR) + lab = rgb_to_lab_ruderman(image_rgb) + dtype = _working_dtype(lab) lab_out = _apply_along_channel( lab, diff --git a/src/squidpy/experimental/im/_stain/_validation.py b/src/squidpy/experimental/im/_stain/_validation.py new file mode 100644 index 000000000..3c2cd8554 --- /dev/null +++ b/src/squidpy/experimental/im/_stain/_validation.py @@ -0,0 +1,124 @@ +"""Stain-matrix validation and canonicalisation primitives. + +Pure numpy, no ``sdata``, no public export. Shared by the Macenko and +Vahadane fits so both produce a canonical ``(H, E, complement)`` matrix that +downstream apply/decompose code can treat method-agnostically. +""" + +from __future__ import annotations + +import numpy as np + +from squidpy.experimental.im._stain._constants import RUIFROK_HE + + +class StainFittingError(RuntimeError): + """A stain-matrix fit produced an invalid or degenerate result. + + Carries ``image_key`` so cohort fitting (a later PR) can attribute a + failure to a specific slide and skip or flag it by name. + """ + + def __init__(self, reason: str, *, image_key: str | None = None) -> None: + self.reason = reason + self.image_key = image_key + prefix = f"[{image_key}] " if image_key is not None else "" + super().__init__(f"{prefix}{reason}") + + +def _canonical_he(reference: dict[str, np.ndarray]) -> np.ndarray: + """Stack the reference H and E unit vectors as columns of a ``(3, 2)``.""" + return np.stack([reference["hematoxylin"], reference["eosin"]], axis=1) + + +def angle_between_deg(u: np.ndarray, v: np.ndarray) -> float: + """Unsigned angle in degrees between two vectors (sign-agnostic).""" + cos = abs(float(u @ v)) / (np.linalg.norm(u) * np.linalg.norm(v)) + return float(np.degrees(np.arccos(np.clip(cos, -1.0, 1.0)))) + + +def _unit_columns(matrix: np.ndarray) -> np.ndarray: + """Scale each column of ``matrix`` to unit L2 norm.""" + return matrix / np.linalg.norm(matrix, axis=0, keepdims=True) + + +def reorder_to_canonical(matrix: np.ndarray, reference: dict[str, np.ndarray] = RUIFROK_HE) -> np.ndarray: + """Order a ``(3, 2)`` stain matrix to ``(H, E)`` and fix column signs. + + Macenko's SVD and Vahadane's NMF recover the two stain directions in an + arbitrary order and sign. We assign each recovered column to whichever of + the canonical Ruifrok H/E vectors it is most colinear with, then flip its + sign so it points the same way as that reference (absorbance is positive). + """ + w = np.asarray(matrix, dtype=np.float64) + if w.shape != (3, 2): + raise ValueError(f"stain matrix to reorder must have shape (3, 2); got {w.shape}.") + canonical = _canonical_he(reference) # (3, 2): [H, E] + cols = _unit_columns(w) + + # cosine of each recovered column against each canonical vector + sim = cols.T @ canonical # (2 recovered, 2 canonical) + # assign recovered column 0/1 to H if it favours H more than column 1 does + h_idx = int(np.argmax(np.abs(sim[:, 0]))) + e_idx = 1 - h_idx + ordered = np.stack([w[:, h_idx], w[:, e_idx]], axis=1) + + # flip signs so each column points along its canonical reference + for j in range(2): + if ordered[:, j] @ canonical[:, j] < 0: + ordered[:, j] = -ordered[:, j] + return ordered + + +def complement_third_column(matrix: np.ndarray) -> np.ndarray: + """Extend a ``(3, 2)`` H/E matrix to ``(3, 3)`` with a complement column. + + The third column is the unit cross product of the H and E columns: the + residual direction orthogonal to both, used to capture absorbance not + explained by either stain. + """ + w = np.asarray(matrix, dtype=np.float64) + if w.shape != (3, 2): + raise ValueError(f"stain matrix to complement must have shape (3, 2); got {w.shape}.") + third = np.cross(w[:, 0], w[:, 1]) + norm = np.linalg.norm(third) + if norm < 1e-8: + raise StainFittingError("H and E stain vectors are colinear; cannot form a complement.") + third = third / norm + return np.column_stack([w, third]) + + +def validate_stain_matrix( + matrix: np.ndarray, + *, + reference: dict[str, np.ndarray] = RUIFROK_HE, + max_angle_deg: float = 45.0, + image_key: str | None = None, +) -> None: + """Raise :class:`StainFittingError` if a ``(3, 3)`` matrix is implausible. + + Guards against the failure modes of an unsupervised stain fit: a column + collapsed to zero, a rank-deficient (single-stain) matrix, or an H/E + direction rotated far from its Ruifrok canonical (a sign the fit latched + onto noise or a non-H&E chromogen). + """ + w = np.asarray(matrix, dtype=np.float64) + if w.shape != (3, 3): + raise StainFittingError(f"stain matrix must have shape (3, 3); got {w.shape}.", image_key=image_key) + if not np.all(np.isfinite(w)): + raise StainFittingError("stain matrix contains non-finite values.", image_key=image_key) + + norms = np.linalg.norm(w, axis=0) + if np.any(norms < 1e-8): + raise StainFittingError("stain matrix has a zero-norm column.", image_key=image_key) + if np.linalg.matrix_rank(w, tol=1e-6) < 3: + raise StainFittingError("stain matrix is rank-deficient (stains are not separable).", image_key=image_key) + + canonical = _canonical_he(reference) + for name, j in (("hematoxylin", 0), ("eosin", 1)): + angle = angle_between_deg(w[:, j], canonical[:, j]) + if angle > max_angle_deg: + raise StainFittingError( + f"{name} stain vector deviates {angle:.1f} deg from its canonical (max {max_angle_deg}).", + image_key=image_key, + ) diff --git a/src/squidpy/experimental/im/_utils.py b/src/squidpy/experimental/im/_utils.py index b58dea1bf..a18770741 100644 --- a/src/squidpy/experimental/im/_utils.py +++ b/src/squidpy/experimental/im/_utils.py @@ -9,9 +9,11 @@ from shapely import box from spatialdata import SpatialData from spatialdata._logging import logger -from spatialdata.models import ShapesModel +from spatialdata.models import Labels2DModel, ShapesModel from spatialdata.transformations import get_transformation, set_transformation +from squidpy._utils import _yx_from_shape + class TileGrid: """Immutable tile grid definition with cached bounds and centroids.""" @@ -252,22 +254,45 @@ def get_mask_materialized(sdata: SpatialData, mask_key: str, scale: str) -> np.n return np.asarray(arr.compute()) +def _choose_label_scale_for_image(label_node: Labels2DModel, target_hw: tuple[int, int]) -> str: + """Pick the label scale closest to the target image height/width.""" + if not hasattr(label_node, "keys"): + return "scale0" # single-scale labels default to their only scale + target_h, target_w = target_hw + best = None + best_diff = float("inf") + for k in label_node.keys(): + y, x = _yx_from_shape(label_node[k].image.shape) + diff = abs(y - target_h) + abs(x - target_w) + if diff == 0: + return k + if diff < best_diff: + best_diff = diff + best = k + return best or "scale0" + + def resolve_tissue_mask( sdata: SpatialData, image_key: str, scale: str, tissue_mask_key: str | None = None, + *, + auto_create: bool = True, ) -> str: """Return the key of a tissue mask in ``sdata.labels``, creating one if needed. If *tissue_mask_key* is given and exists, it is returned as-is. - Otherwise falls back to ``f"{image_key}_tissue"``, running ``detect_tissue`` - to create it when missing. + Otherwise falls back to ``f"{image_key}_tissue"``. When that key is missing, + the behaviour depends on *auto_create*: if ``True`` (default) ``detect_tissue`` + is run to create it; if ``False`` a :class:`KeyError` is raised asking the + caller to run ``detect_tissue`` first. Raises ------ KeyError - If *tissue_mask_key* is given but not found in ``sdata.labels``. + If *tissue_mask_key* is given but not found in ``sdata.labels``, or if + *auto_create* is ``False`` and no tissue mask exists. Exception Any exception raised by ``detect_tissue`` if auto-creation fails. Callers needing graceful fallback should wrap this call in try/except. @@ -279,6 +304,12 @@ def resolve_tissue_mask( mask_key = f"{image_key}_tissue" if mask_key not in sdata.labels: + if not auto_create: + raise KeyError( + f"No tissue mask found in sdata.labels (looked for {mask_key!r}). Run " + f"`squidpy.experimental.im.detect_tissue(sdata, {image_key!r})` first, " + "or pass an explicit `tissue_mask_key`." + ) from squidpy.experimental.im._detect_tissue import detect_tissue detect_tissue(sdata=sdata, image_key=image_key, scale=scale, inplace=True, new_labels_key=mask_key) diff --git a/tests/experimental/test_stain_background.py b/tests/experimental/test_stain_background.py new file mode 100644 index 000000000..9043cc91a --- /dev/null +++ b/tests/experimental/test_stain_background.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import dask.array as da +import numpy as np +import pytest +import xarray as xr + +from squidpy.experimental.im._stain._background import estimate_background_intensity +from squidpy.experimental.im._stain._validation import StainFittingError + + +def _da(values: np.ndarray, *, chunked: bool) -> xr.DataArray: + data = da.from_array(values, chunks=(3, 8, 8)) if chunked else values + return xr.DataArray(data, dims=("c", "y", "x")) + + +@pytest.mark.parametrize("chunked", [False, True]) +def test_recovers_white_point(chunked: bool) -> None: + rng = np.random.default_rng(0) + # mostly bright background near (240, 245, 250), a darker tissue blob + values = np.empty((3, 32, 32)) + values[0] = 240.0 + values[1] = 245.0 + values[2] = 250.0 + values[:, :8, :8] = rng.uniform(20.0, 60.0, size=(3, 8, 8)) # tissue + bg = estimate_background_intensity(_da(values, chunked=chunked)) + assert bg.shape == (3,) + np.testing.assert_allclose(bg, [240.0, 245.0, 250.0], atol=1.0) + + +def test_blank_image_raises() -> None: + black = np.zeros((3, 16, 16)) + with pytest.raises(StainFittingError, match="non-positive"): + estimate_background_intensity(_da(black, chunked=False)) + + +def test_bad_percentile_raises() -> None: + with pytest.raises(ValueError, match="percentile"): + estimate_background_intensity(_da(np.ones((3, 8, 8)), chunked=False), percentile=0.0) diff --git a/tests/experimental/test_stain_decompose_public.py b/tests/experimental/test_stain_decompose_public.py new file mode 100644 index 000000000..58b461d03 --- /dev/null +++ b/tests/experimental/test_stain_decompose_public.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import numpy as np +import pytest +import spatialdata as sd +import xarray as xr +from spatialdata.models import Image2DModel, Labels2DModel +from spatialdata.transformations import get_transformation + +import squidpy as sq +from squidpy.experimental.im import ( + StainReference, + apply_stain_normalization, + decompose_stains, + estimate_background_intensity, + fit_stain_reference, +) +from squidpy.experimental.im._stain._constants import RUIFROK_HE +from squidpy.experimental.im._stain._conversion import sda_to_rgb +from squidpy.experimental.im._stain._validation import complement_third_column, reorder_to_canonical + +_WHITE = np.array([255.0, 255.0, 255.0]) + + +def _synthetic_rgb(seed: int = 0, n_side: int = 48, white: np.ndarray = _WHITE) -> np.ndarray: + w = complement_third_column( + reorder_to_canonical(np.stack([RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]], axis=1)) + ) + rng = np.random.default_rng(seed) + n = n_side * n_side + conc = rng.uniform(0.0, 70.0, size=(n, 2)) + third = n // 3 + conc[:third, 1] = 0.0 + conc[third : 2 * third, 0] = 0.0 + od = (conc @ w[:, :2].T).T.reshape(3, n_side, n_side) + rgb = sda_to_rgb(xr.DataArray(od, dims=("c", "y", "x")), white) + return np.asarray(rgb.data) + + +def _make_sdata(values: np.ndarray, *, with_tissue: bool = True) -> sd.SpatialData: + sdata = sd.SpatialData(images={"img": Image2DModel.parse(values, dims=("c", "y", "x"))}) + if with_tissue: + h, w = values.shape[-2], values.shape[-1] + sdata.labels["img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) + return sdata + + +@pytest.mark.parametrize("method", ["macenko", "vahadane"]) +class TestDecompositionThroughDispatchers: + def test_fit_and_apply_end_to_end(self, method: str) -> None: + sdata = _make_sdata(_synthetic_rgb(seed=1)) + ref = fit_stain_reference(sdata, "img", method=method, background_intensity=_WHITE) + assert ref.method == method + assert ref.stain_matrix.shape == (3, 3) + assert ref.max_concentrations.shape == (2,) + + out = apply_stain_normalization(sdata, "img", ref) + assert isinstance(out, xr.DataArray) + assert out.sizes["c"] == 3 + + def test_apply_writes_back(self, method: str) -> None: + sdata = _make_sdata(_synthetic_rgb(seed=2)) + ref = fit_stain_reference(sdata, "img", method=method, background_intensity=_WHITE) + result = apply_stain_normalization(sdata, "img", ref, image_key_added="norm") + assert result is None + assert get_transformation(sdata.images["norm"], get_all=True).keys() == ( + get_transformation(sdata.images["img"], get_all=True).keys() + ) + + +class TestDecomposeStains: + def test_returns_named_concentration_maps(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + conc = decompose_stains(sdata, "img", "macenko", background_intensity=_WHITE) + assert set(conc) == {"hematoxylin", "eosin", "residual"} + assert all(set(c.dims) == {"y", "x"} for c in conc.values()) # one (y, x) map per stain + + def test_drop_residual(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + conc = decompose_stains(sdata, "img", "macenko", background_intensity=_WHITE, include_residual=False) + assert set(conc) == {"hematoxylin", "eosin"} + + def test_with_reference_writes_separate_images(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + ref = fit_stain_reference(sdata, "img", method="macenko", background_intensity=_WHITE) + out = decompose_stains(sdata, "img", ref, image_key_added="conc") + assert out is None + for stain in ("hematoxylin", "eosin", "residual"): + assert f"conc_{stain}" in sdata.images + assert list(sdata.images[f"conc_{stain}"].coords["c"].values) == [stain] + + def test_reinhard_reference_rejected(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + reinhard_ref = fit_stain_reference(sdata, "img", method="reinhard") + with pytest.raises(ValueError, match="macenko/vahadane reference"): + decompose_stains(sdata, "img", reinhard_ref) + + def test_bad_method_rejected(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + with pytest.raises(ValueError, match="method must be"): + decompose_stains(sdata, "img", "reinhard") + + +class TestBackgroundDefault: + def test_fit_defaults_to_white_when_absent(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + ref = fit_stain_reference(sdata, "img", method="macenko") + # default I_0 is a fixed full-white point, not an image-derived estimate + np.testing.assert_array_equal(ref.background_intensity, [255.0, 255.0, 255.0]) + + def test_explicit_background_is_used(self) -> None: + I0 = np.array([240.0, 245.0, 250.0]) + # build the synthetic image against this white point so the fit is consistent + sdata = _make_sdata(_synthetic_rgb(white=I0)) + ref = fit_stain_reference(sdata, "img", method="vahadane", background_intensity=I0) + np.testing.assert_array_equal(ref.background_intensity, I0) + + def test_estimate_background_public(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + bg = estimate_background_intensity(sdata.images["img"]) + assert bg.shape == (3,) + + +class TestUnknownMethod: + def test_fit_unknown_method_raises(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + with pytest.raises(ValueError, match="Unknown method"): + fit_stain_reference(sdata, "img", method="bogus") + + +class TestDecompositionOnHnE: + # Correctness is gated by the synthetic-recovery tests above (per the arc + # decision); these are real-data smoke checks that the pipeline fits a + # valid matrix, applies lazily, and decomposes - fast because the source + # matrix is fit on the coarse level, not the full-resolution image. + # Both methods are exercised: once the fit consumes a real detect_tissue + # mask (dropping the fiducial ring + dim background), Macenko fits this + # low-contrast Visium H&E cleanly and agrees with Vahadane. + @pytest.mark.parametrize("method", ["macenko", "vahadane"]) + def test_fit_apply_decompose_smoke(self, sdata_hne, method: str) -> None: + image_key = next(iter(sdata_hne.images)) + sq.experimental.im.detect_tissue(sdata_hne, image_key) + ref = sq.experimental.im.fit_stain_reference(sdata_hne, image_key, method=method) + assert isinstance(ref, StainReference) + assert ref.stain_matrix.shape == (3, 3) + normalized = sq.experimental.im.apply_stain_normalization(sdata_hne, image_key, ref) + assert normalized.sizes["c"] == 3 + conc = sq.experimental.im.decompose_stains(sdata_hne, image_key, ref) + assert set(conc) == {"hematoxylin", "eosin", "residual"} diff --git a/tests/experimental/test_stain_decomposition.py b/tests/experimental/test_stain_decomposition.py new file mode 100644 index 000000000..686c0daa2 --- /dev/null +++ b/tests/experimental/test_stain_decomposition.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import dask.array as da +import numpy as np +import pytest +import xarray as xr + +from squidpy.experimental.im._stain._constants import RUIFROK_HE +from squidpy.experimental.im._stain._conversion import sda_to_rgb +from squidpy.experimental.im._stain._decomposition import ( + MacenkoParams, + VahadaneParams, + _resolve_macenko_params, + _resolve_vahadane_params, + apply_decomposition, + fit_decomposition, +) +from squidpy.experimental.im._stain._validation import ( + StainFittingError, + angle_between_deg, + complement_third_column, + reorder_to_canonical, +) + +_WHITE = np.array([255.0, 255.0, 255.0]) + + +def _canonical(h: np.ndarray, e: np.ndarray) -> np.ndarray: + return complement_third_column(reorder_to_canonical(np.stack([h, e], axis=1))) + + +def _synthetic_he(stain_matrix: np.ndarray, *, n_side: int = 48, seed: int = 0, chunked: bool = False) -> xr.DataArray: + """Build an RGB image from known H/E concentrations and a stain matrix.""" + rng = np.random.default_rng(seed) + n = n_side * n_side + conc = rng.uniform(0.0, 70.0, size=(n, 2)) + # dense pure-H / pure-E populations so the angular extremes are well sampled + # (real H&E slides have many near-pure pixels; a uniform mix under-samples them) + third = n // 3 + conc[:third, 1] = 0.0 # pure-H pixels (one angular extreme) + conc[third : 2 * third, 0] = 0.0 # pure-E pixels (the other extreme) + od = (conc @ stain_matrix[:, :2].T).T.reshape(3, n_side, n_side) + data = da.from_array(od, chunks=(3, 16, 16)) if chunked else od + return sda_to_rgb(xr.DataArray(data, dims=("c", "y", "x")), _WHITE) + + +class TestMacenko: + @pytest.mark.parametrize("chunked", [False, True]) + def test_recovers_planted_matrix(self, chunked: bool) -> None: + truth = _canonical(RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]) + img = _synthetic_he(truth, chunked=chunked) + ref = fit_decomposition(img, "macenko", MacenkoParams(), _WHITE) + assert angle_between_deg(ref.stain_matrix[:, 0], truth[:, 0]) < 12.0 + assert angle_between_deg(ref.stain_matrix[:, 1], truth[:, 1]) < 12.0 + assert ref.max_concentrations.shape == (2,) + assert np.all(ref.max_concentrations > 0) + + +class TestVahadane: + def test_recovers_planted_matrix(self) -> None: + truth = _canonical(RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]) + img = _synthetic_he(truth) + ref = fit_decomposition(img, "vahadane", VahadaneParams(), _WHITE) + assert angle_between_deg(ref.stain_matrix[:, 0], truth[:, 0]) < 20.0 + assert angle_between_deg(ref.stain_matrix[:, 1], truth[:, 1]) < 20.0 + + +class TestApplyDecomposition: + def test_transfer_matches_reference_matrix(self) -> None: + truth_a = _canonical(RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]) + # a slightly rotated source staining + e_shift = RUIFROK_HE["eosin"] + 0.15 * RUIFROK_HE["hematoxylin"] + truth_b = _canonical(RUIFROK_HE["hematoxylin"], e_shift / np.linalg.norm(e_shift)) + + img_a = _synthetic_he(truth_a, seed=1) + img_b = _synthetic_he(truth_b, seed=2) + ref_a = fit_decomposition(img_a, "macenko", MacenkoParams(), _WHITE) + + normalized = apply_decomposition(img_b, ref_a, MacenkoParams()) + refit = fit_decomposition(normalized, "macenko", MacenkoParams(), _WHITE) + assert angle_between_deg(refit.stain_matrix[:, 0], ref_a.stain_matrix[:, 0]) < 12.0 + assert angle_between_deg(refit.stain_matrix[:, 1], ref_a.stain_matrix[:, 1]) < 12.0 + + def test_lazy_in_lazy_out(self) -> None: + truth = _canonical(RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]) + ref = fit_decomposition(_synthetic_he(truth), "macenko", MacenkoParams(), _WHITE) + out = apply_decomposition(_synthetic_he(truth, chunked=True), ref, MacenkoParams()) + assert isinstance(out.data, da.Array) + + def test_missing_max_concentrations_raises(self) -> None: + from squidpy.experimental.im._stain._reference import StainReference + + ref = StainReference( + method="macenko", + stain_matrix=_canonical(RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]), + background_intensity=_WHITE, + ) + img = _synthetic_he(ref.stain_matrix) + with pytest.raises(ValueError, match="max_concentrations"): + apply_decomposition(img, ref, MacenkoParams()) + + +class TestDegenerate: + def test_empty_tissue_raises(self) -> None: + white = xr.DataArray(np.full((3, 16, 16), 255.0), dims=("c", "y", "x")) + with pytest.raises(StainFittingError, match="mask is empty"): + fit_decomposition(white, "macenko", MacenkoParams(), _WHITE) + + +class TestResolvers: + def test_macenko_mapping_and_unknown(self) -> None: + assert _resolve_macenko_params({"alpha": 2.0}).alpha == 2.0 + with pytest.raises(ValueError, match="Unknown"): + _resolve_macenko_params({"nope": 1}) + + def test_vahadane_instance_and_badtype(self) -> None: + p = VahadaneParams(lambda1=0.2) + assert _resolve_vahadane_params(p) is p + with pytest.raises(TypeError, match="VahadaneParams"): + _resolve_vahadane_params(5) + + @pytest.mark.parametrize("bad", [0.0, 50.0, -1.0]) + def test_macenko_alpha_bounds(self, bad: float) -> None: + with pytest.raises(ValueError, match="alpha"): + MacenkoParams(alpha=bad) diff --git a/tests/experimental/test_stain_mask.py b/tests/experimental/test_stain_mask.py index 3f87116f0..0e4110b3a 100644 --- a/tests/experimental/test_stain_mask.py +++ b/tests/experimental/test_stain_mask.py @@ -5,7 +5,12 @@ import pytest import xarray as xr -from squidpy.experimental.im._stain._mask import luminosity_foreground_mask +from squidpy.experimental.im._stain._mask import ( + absorbance_foreground_mask, + luminosity_foreground_mask, +) + +_WHITE = np.array([255.0, 255.0, 255.0]) def _rgb_dataarray(values: np.ndarray, *, chunked: bool) -> xr.DataArray: @@ -43,3 +48,23 @@ def test_non_three_channel_raises(self) -> None: values = np.zeros((2, 8, 8)) with pytest.raises(ValueError, match="length 3"): luminosity_foreground_mask(xr.DataArray(values, dims=("c", "y", "x")), 0.8) + + +class TestAbsorbanceForegroundMask: + def test_white_is_background_dark_is_tissue(self) -> None: + values = np.full((3, 8, 16), 255.0) + values[:, :, 8:] = 30.0 # dark right half = high absorbance = tissue + mask = absorbance_foreground_mask(_rgb_dataarray(values, chunked=False), _WHITE) + assert mask.dims == ("y", "x") + assert not bool(mask.values[:, :8].any()) + assert bool(mask.values[:, 8:].all()) + + def test_lazy_in_lazy_out(self) -> None: + values = np.full((3, 16, 16), 50.0) + mask = absorbance_foreground_mask(_rgb_dataarray(values, chunked=True), _WHITE) + assert isinstance(mask.data, da.Array) + + def test_non_three_channel_raises(self) -> None: + values = np.zeros((2, 8, 8)) + with pytest.raises(ValueError, match="length 3"): + absorbance_foreground_mask(xr.DataArray(values, dims=("c", "y", "x")), _WHITE) diff --git a/tests/experimental/test_stain_normalize.py b/tests/experimental/test_stain_normalize.py index e8aaa6e55..27e8d9340 100644 --- a/tests/experimental/test_stain_normalize.py +++ b/tests/experimental/test_stain_normalize.py @@ -7,7 +7,7 @@ import spatialdata as sd import spatialdata_plot as sdp import xarray as xr -from spatialdata.models import Image2DModel +from spatialdata.models import Image2DModel, Labels2DModel from spatialdata.transformations import Scale, get_transformation, set_transformation import squidpy as sq @@ -23,9 +23,15 @@ _ = sdp # registers the `.pl` spatialdata accessor -def _make_sdata(values: np.ndarray, *, scale_factors: list[int] | None = None) -> sd.SpatialData: +def _make_sdata( + values: np.ndarray, *, scale_factors: list[int] | None = None, with_tissue: bool = True +) -> sd.SpatialData: img = Image2DModel.parse(values, dims=("c", "y", "x"), scale_factors=scale_factors) - return sd.SpatialData(images={"img": img}) + sdata = sd.SpatialData(images={"img": img}) + if with_tissue: + h, w = values.shape[-2], values.shape[-1] + sdata.labels["img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) + return sdata @pytest.fixture @@ -46,11 +52,6 @@ def test_missing_image_key_raises(self, rgb_values: np.ndarray) -> None: with pytest.raises(ValueError, match="not found, valid keys"): fit_stain_reference(sdata, "nope") - def test_macenko_not_implemented(self, rgb_values: np.ndarray) -> None: - sdata = _make_sdata(rgb_values) - with pytest.raises(NotImplementedError, match="decomposition is not yet implemented"): - fit_stain_reference(sdata, "img", method="macenko") - def test_unknown_method_raises(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) with pytest.raises(ValueError, match="Unknown method"): @@ -93,6 +94,8 @@ def test_preserves_channel_coords_and_nonidentity_transform(self, rgb_values: np img = Image2DModel.parse(rgb_values, dims=("c", "y", "x"), c_coords=["r", "g", "b"]) set_transformation(img, Scale([2.0, 2.0], axes=("y", "x")), to_coordinate_system="global") sdata = sd.SpatialData(images={"img": img}) + h, w = rgb_values.shape[-2], rgb_values.shape[-1] + sdata.labels["img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) ref = fit_stain_reference(sdata, "img") apply_stain_normalization(sdata, "img", ref, image_key_added="norm") out = sdata.images["norm"] @@ -105,14 +108,14 @@ def test_existing_key_raises(self, rgb_values: np.ndarray) -> None: with pytest.raises(ValueError, match="already exists"): apply_stain_normalization(sdata, "img", ref, image_key_added="img") - def test_decomposition_reference_not_implemented(self, rgb_values: np.ndarray) -> None: + def test_decomposition_reference_without_max_concentrations_raises(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) ref = StainReference( method="macenko", stain_matrix=np.eye(3), background_intensity=np.array([255.0, 255.0, 255.0]), ) - with pytest.raises(NotImplementedError, match="decomposition is not yet implemented"): + with pytest.raises(ValueError, match="max_concentrations"): apply_stain_normalization(sdata, "img", ref) def test_method_params_mapping(self, rgb_values: np.ndarray) -> None: @@ -122,9 +125,67 @@ def test_method_params_mapping(self, rgb_values: np.ndarray) -> None: assert isinstance(out, xr.DataArray) +class TestTissueMaskMandate: + def test_fit_requires_tissue_mask(self, rgb_values: np.ndarray) -> None: + sdata = _make_sdata(rgb_values, with_tissue=False) + with pytest.raises(KeyError, match="detect_tissue"): + fit_stain_reference(sdata, "img") + + def test_apply_requires_tissue_mask(self, rgb_values: np.ndarray) -> None: + sdata = _make_sdata(rgb_values) # has a mask -> fit works + ref = fit_stain_reference(sdata, "img") + del sdata.labels["img_tissue"] # ... but now the source has none + with pytest.raises(KeyError, match="detect_tissue"): + apply_stain_normalization(sdata, "img", ref) + + def test_explicit_missing_key_raises(self, rgb_values: np.ndarray) -> None: + sdata = _make_sdata(rgb_values) + with pytest.raises(KeyError, match="not found in sdata.labels"): + fit_stain_reference(sdata, "img", tissue_mask_key="nope") + + def test_mask_is_used_in_the_fit(self, rgb_values: np.ndarray) -> None: + # A different tissue region yields different channel statistics, proving + # the mask actually drives the fit (not silently ignored). + ref_full = fit_stain_reference(_make_sdata(rgb_values), "img") + + sdata_part = _make_sdata(rgb_values, with_tissue=False) + h, w = rgb_values.shape[-2], rgb_values.shape[-1] + partial = np.zeros((h, w), dtype=np.uint32) + partial[: h // 2] = 1 # only the top half is tissue + sdata_part.labels["img_tissue"] = Labels2DModel.parse(partial, dims=("y", "x")) + ref_part = fit_stain_reference(sdata_part, "img") + + assert not np.allclose(ref_full.mu, ref_part.mu) + + +class TestPreserveBackground: + def test_background_passthrough_vs_full_frame(self, rgb_values: np.ndarray) -> None: + # tissue = top half only; bottom half is background + h, w = rgb_values.shape[-2], rgb_values.shape[-1] + sdata = _make_sdata(rgb_values, with_tissue=False) + partial = np.zeros((h, w), dtype=np.uint32) + partial[: h // 2] = 1 + sdata.labels["img_tissue"] = Labels2DModel.parse(partial, dims=("y", "x")) + + # a differently-coloured reference so the transform is non-trivial + shifted = np.clip(rgb_values * np.array([1.3, 0.8, 1.1])[:, None, None], 0, 255).astype(np.float32) + sdata.images["ref_img"] = Image2DModel.parse(shifted, dims=("c", "y", "x")) + sdata.labels["ref_img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) + ref = fit_stain_reference(sdata, "ref_img") + + original = get_element_data(sdata.images["img"], "auto", "image", "img").values + kept = apply_stain_normalization(sdata, "img", ref).values # preserve_background=True (default) + full = apply_stain_normalization(sdata, "img", ref, preserve_background=False).values + + bg = slice(h // 2, None) + np.testing.assert_allclose(kept[:, bg], original[:, bg]) # background untouched + assert not np.allclose(full[:, bg], original[:, bg]) # full-frame recolours it + + class TestStainNormalizationOnHnE: def test_fit_apply_smoke(self, sdata_hne) -> None: image_key = next(iter(sdata_hne.images)) + sq.experimental.im.detect_tissue(sdata_hne, image_key) ref = sq.experimental.im.fit_stain_reference(sdata_hne, image_key) assert ref.method == "reinhard" out = sq.experimental.im.apply_stain_normalization(sdata_hne, image_key, ref) @@ -136,6 +197,7 @@ class TestStainNormalizationVisual(PlotTester, metaclass=PlotTesterMeta): def test_plot_reinhard_before_after(self, sdata_hne) -> None: """Visual: a re-stained source (left) normalized back to the H&E reference (right).""" image_key = next(iter(sdata_hne.images)) + sq.experimental.im.detect_tissue(sdata_hne, image_key) reference = fit_stain_reference(sdata_hne, image_key) # Deterministically warm/cool the channels to simulate a different @@ -145,7 +207,10 @@ def test_plot_reinhard_before_after(self, sdata_hne) -> None: shifted = (da_rgb * weights).clip(0, 255) sdata_hne.images["hne_shifted"] = Image2DModel.parse(shifted.data, dims=shifted.dims) - apply_stain_normalization(sdata_hne, "hne_shifted", reference, image_key_added="hne_normalized") + # `hne_shifted` shares geometry with `image_key`; reuse its tissue mask. + apply_stain_normalization( + sdata_hne, "hne_shifted", reference, image_key_added="hne_normalized", tissue_mask_key=f"{image_key}_tissue" + ) _, axes = plt.subplots(1, 2, figsize=(8, 4)) sdata_hne.pl.render_images("hne_shifted").pl.show(ax=axes[0], title="before") diff --git a/tests/experimental/test_stain_reference.py b/tests/experimental/test_stain_reference.py index 5bd5cfd0e..e8f5999d2 100644 --- a/tests/experimental/test_stain_reference.py +++ b/tests/experimental/test_stain_reference.py @@ -117,3 +117,15 @@ def test_rejects_non_finite() -> None: mu=np.array([np.nan, 0.0, 0.0]), sigma=np.ones(3), ) + + +def test_equality_is_array_aware_and_hashable() -> None: + # distinct-but-equal references compare equal (array-aware __eq__), and + # references remain hashable (identity) despite the numpy-array fields. + a = StainReference(method="reinhard", mu=np.array([1.0, 2.0, 3.0]), sigma=np.ones(3)) + b = StainReference(method="reinhard", mu=np.array([1.0, 2.0, 3.0]), sigma=np.ones(3)) + c = StainReference(method="reinhard", mu=np.array([9.0, 2.0, 3.0]), sigma=np.ones(3)) + assert a == b + assert a != c + assert len({a, b, c}) == 3 # identity-hashed, no TypeError + assert a != "not a reference" diff --git a/tests/experimental/test_stain_validation.py b/tests/experimental/test_stain_validation.py new file mode 100644 index 000000000..cb548ab54 --- /dev/null +++ b/tests/experimental/test_stain_validation.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from squidpy.experimental.im._stain._constants import RUIFROK_HE +from squidpy.experimental.im._stain._validation import ( + StainFittingError, + complement_third_column, + reorder_to_canonical, + validate_stain_matrix, +) + + +def _he_matrix() -> np.ndarray: + return np.stack([RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]], axis=1) + + +class TestReorderToCanonical: + def test_swapped_columns_restored(self) -> None: + he = _he_matrix() + swapped = he[:, ::-1] + out = reorder_to_canonical(swapped) + np.testing.assert_allclose(out, he, atol=1e-8) + + def test_sign_flips_corrected(self) -> None: + he = _he_matrix() + flipped = he * np.array([-1.0, 1.0]) + out = reorder_to_canonical(flipped) + np.testing.assert_allclose(out, he, atol=1e-8) + + def test_bad_shape_raises(self) -> None: + with pytest.raises(ValueError, match="shape"): + reorder_to_canonical(np.eye(3)) + + +class TestComplementThirdColumn: + def test_unit_orthogonal(self) -> None: + w = complement_third_column(_he_matrix()) + assert w.shape == (3, 3) + np.testing.assert_allclose(np.linalg.norm(w[:, 2]), 1.0, atol=1e-8) + assert abs(w[:, 2] @ w[:, 0]) < 1e-8 + assert abs(w[:, 2] @ w[:, 1]) < 1e-8 + + def test_colinear_raises(self) -> None: + v = RUIFROK_HE["hematoxylin"] + with pytest.raises(StainFittingError, match="colinear"): + complement_third_column(np.stack([v, v], axis=1)) + + +class TestValidateStainMatrix: + def test_canonical_passes(self) -> None: + validate_stain_matrix(complement_third_column(_he_matrix())) + + def test_rank_deficient_raises_with_image_key(self) -> None: + w = complement_third_column(_he_matrix()) + w[:, 1] = w[:, 0] # collapse E onto H + with pytest.raises(StainFittingError) as exc: + validate_stain_matrix(w, image_key="slideA") + assert exc.value.image_key == "slideA" + + def test_rotated_he_raises(self) -> None: + # rotate H far toward an unrelated direction + w = complement_third_column(_he_matrix()) + w[:, 0] = np.array([1.0, 0.0, 0.0]) + with pytest.raises(StainFittingError, match="deviates"): + validate_stain_matrix(w) + + def test_zero_column_raises(self) -> None: + w = complement_third_column(_he_matrix()) + w[:, 0] = 0.0 + with pytest.raises(StainFittingError, match="zero-norm"): + validate_stain_matrix(w) + + def test_bad_shape_raises(self) -> None: + with pytest.raises(StainFittingError, match="shape"): + validate_stain_matrix(np.eye(2)) From a31a65dff31354db22bd712ca09436dfbaefe584 Mon Sep 17 00:00:00 2001 From: anon Date: Mon, 1 Jun 2026 22:13:09 +0200 Subject: [PATCH 2/8] Stain normalization: UX rename pass (step 1/5) Behaviour-preserving renames from the UX-sharpening pass (plans/ stain-pr3-decomposition.md): - apply_stain_normalization -> normalize_stains (verb, not noun-phrase) - background_intensity -> white_point (the I_0 reference, not the image's background region) across params, the StainReference field, rgb_to_sda, and docstrings - estimate_background_intensity -> estimate_white_point; DEFAULT_BACKGROUND_INTENSITY -> DEFAULT_WHITE_POINT - _stain/_background.py -> _stain/_white_point.py (+ its test) No behaviour change; 121 stain tests pass unchanged. Co-Authored-By: Claude Opus 4.8 --- docs/api.md | 4 +-- src/squidpy/experimental/im/__init__.py | 8 +++--- .../experimental/im/_stain/__init__.py | 8 +++--- .../experimental/im/_stain/_conversion.py | 12 ++++---- .../experimental/im/_stain/_decomposition.py | 16 +++++------ src/squidpy/experimental/im/_stain/_mask.py | 6 ++-- .../experimental/im/_stain/_normalize.py | 26 ++++++++--------- .../experimental/im/_stain/_reference.py | 22 +++++++-------- .../{_background.py => _white_point.py} | 10 +++---- .../test_stain_decompose_public.py | 28 +++++++++---------- .../experimental/test_stain_decomposition.py | 2 +- tests/experimental/test_stain_normalize.py | 28 +++++++++---------- tests/experimental/test_stain_reference.py | 28 +++++++++---------- ...ackground.py => test_stain_white_point.py} | 8 +++--- 14 files changed, 101 insertions(+), 105 deletions(-) rename src/squidpy/experimental/im/_stain/{_background.py => _white_point.py} (85%) rename tests/experimental/{test_stain_background.py => test_stain_white_point.py} (77%) diff --git a/docs/api.md b/docs/api.md index a2df60a37..3b582ab60 100644 --- a/docs/api.md +++ b/docs/api.md @@ -151,9 +151,9 @@ See the {doc}`extensibility guide ` for how to implement a custo experimental.tl.TilingQCParams experimental.pl.tiling_qc experimental.im.fit_stain_reference - experimental.im.apply_stain_normalization + experimental.im.normalize_stains experimental.im.decompose_stains - experimental.im.estimate_background_intensity + experimental.im.estimate_white_point experimental.im.StainReference experimental.im.ReinhardParams experimental.im.MacenkoParams diff --git a/src/squidpy/experimental/im/__init__.py b/src/squidpy/experimental/im/__init__.py index 12bbb2da5..7adea9d74 100644 --- a/src/squidpy/experimental/im/__init__.py +++ b/src/squidpy/experimental/im/__init__.py @@ -14,10 +14,10 @@ ReinhardParams, StainReference, VahadaneParams, - apply_stain_normalization, decompose_stains, - estimate_background_intensity, + estimate_white_point, fit_stain_reference, + normalize_stains, ) __all__ = [ @@ -29,10 +29,10 @@ "StainReference", "VahadaneParams", "WekaParams", - "apply_stain_normalization", + "normalize_stains", "decompose_stains", "detect_tissue", - "estimate_background_intensity", + "estimate_white_point", "fit_stain_reference", "make_tiles", "make_tiles_from_spots", diff --git a/src/squidpy/experimental/im/_stain/__init__.py b/src/squidpy/experimental/im/_stain/__init__.py index bc608bc0a..ffaad6a33 100644 --- a/src/squidpy/experimental/im/_stain/__init__.py +++ b/src/squidpy/experimental/im/_stain/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations -from squidpy.experimental.im._stain._background import estimate_background_intensity from squidpy.experimental.im._stain._constants import ( DEFAULT_LUMINOSITY_THRESHOLD, RUDERMAN_LAB_TO_LMS, @@ -27,9 +26,9 @@ luminosity_foreground_mask, ) from squidpy.experimental.im._stain._normalize import ( - apply_stain_normalization, decompose_stains, fit_stain_reference, + normalize_stains, ) from squidpy.experimental.im._stain._reference import StainMethod, StainReference from squidpy.experimental.im._stain._reinhard import ( @@ -43,6 +42,7 @@ reorder_to_canonical, validate_stain_matrix, ) +from squidpy.experimental.im._stain._white_point import estimate_white_point __all__ = [ "DEFAULT_LUMINOSITY_THRESHOLD", @@ -61,10 +61,10 @@ "absorbance_foreground_mask", "apply_decomposition", "apply_reinhard", - "apply_stain_normalization", + "normalize_stains", "complement_third_column", "decompose_stains", - "estimate_background_intensity", + "estimate_white_point", "fit_decomposition", "fit_reinhard", "fit_stain_reference", diff --git a/src/squidpy/experimental/im/_stain/_conversion.py b/src/squidpy/experimental/im/_stain/_conversion.py index b71935ea5..00fda7500 100644 --- a/src/squidpy/experimental/im/_stain/_conversion.py +++ b/src/squidpy/experimental/im/_stain/_conversion.py @@ -94,7 +94,7 @@ def _lab_to_rgb_kernel(x: np.ndarray, *, dtype: np.dtype) -> np.ndarray: def rgb_to_sda( rgb: xr.DataArray, - background_intensity: np.ndarray, + white_point: np.ndarray, ) -> xr.DataArray: """Convert RGB intensities to standard deviation per absorbance (SDA). @@ -112,7 +112,7 @@ def rgb_to_sda( rgb Image with a ``"c"`` dimension of length 3. May be numpy- or dask-backed; the operation is purely elementwise and stays lazy. - background_intensity + white_point Per-channel white-point ``I_0`` as a shape-``(3,)`` numpy array. Required: no scanner produces a pure-white background, so the caller must supply either an estimate (PR 3 will ship the @@ -125,23 +125,23 @@ def rgb_to_sda( """ _check_channel_dim(rgb) dtype = _working_dtype(rgb) - bg = np.asarray(background_intensity, dtype=dtype) + bg = np.asarray(white_point, dtype=dtype) return _apply_along_channel(rgb, _rgb_to_sda_kernel, out_dtype=dtype, bg=bg, dtype=dtype) def sda_to_rgb( sda: xr.DataArray, - background_intensity: np.ndarray, + white_point: np.ndarray, ) -> xr.DataArray: """Convert SDA back to RGB intensities in ``[0, 255]``. - Inverse of :func:`rgb_to_sda`. Pass the same ``background_intensity`` + Inverse of :func:`rgb_to_sda`. Pass the same ``white_point`` used at encode time. The result is clipped to ``[0, 255]`` but kept in float dtype; uint8 conversion is the caller's choice. """ _check_channel_dim(sda) dtype = _working_dtype(sda) - bg = np.asarray(background_intensity, dtype=dtype) + bg = np.asarray(white_point, dtype=dtype) return _apply_along_channel(sda, _sda_to_rgb_kernel, out_dtype=dtype, bg=bg, dtype=dtype) diff --git a/src/squidpy/experimental/im/_stain/_decomposition.py b/src/squidpy/experimental/im/_stain/_decomposition.py index 072f9c871..a17b53e73 100644 --- a/src/squidpy/experimental/im/_stain/_decomposition.py +++ b/src/squidpy/experimental/im/_stain/_decomposition.py @@ -113,7 +113,7 @@ def _resolve_vahadane_params(params: VahadaneParams | Mapping[str, Any] | None) def _tissue_od( image_rgb: xr.DataArray, - background_intensity: np.ndarray, + white_point: np.ndarray, beta: float, *, tissue_mask: np.ndarray | None = None, @@ -127,7 +127,7 @@ def _tissue_od( given it selects the tissue pixels; otherwise the absorbance threshold ``beta`` is used. """ - sda = rgb_to_sda(image_rgb, background_intensity) + sda = rgb_to_sda(image_rgb, white_point) mask = as_spatial_mask(tissue_mask, sda) if tissue_mask is not None else foreground_mask_from_sda(sda, beta) od = np.asarray(sda.where(mask).transpose("c", "y", "x").data.reshape(3, -1)).T od = od[np.all(np.isfinite(od), axis=1)] @@ -201,18 +201,18 @@ def fit_decomposition( image_rgb: xr.DataArray, method: StainMethod, params: Any, - background_intensity: np.ndarray, + white_point: np.ndarray, *, tissue_mask: np.ndarray | None = None, image_key: str | None = None, ) -> StainReference: """Fit a decomposition :class:`StainReference` (stain matrix + max concentrations).""" - od = _tissue_od(image_rgb, background_intensity, params.beta, tissue_mask=tissue_mask, image_key=image_key) + od = _tissue_od(image_rgb, white_point, params.beta, tissue_mask=tissue_mask, image_key=image_key) matrix = _stain_matrix(od, method, params, image_key=image_key) return StainReference( method=method, stain_matrix=matrix, - background_intensity=np.asarray(background_intensity, dtype=np.float64), + white_point=np.asarray(white_point, dtype=np.float64), max_concentrations=_max_concentrations(_concentrations(od, matrix)), ) @@ -243,7 +243,7 @@ def apply_decomposition( _check_channel_dim(image_rgb) if reference.max_concentrations is None: raise ValueError("reference is missing max_concentrations; refit it with fit_stain_reference.") - bg = reference.background_intensity + bg = reference.white_point od_src = _tissue_od( fit_rgb if fit_rgb is not None else image_rgb, bg, params.beta, tissue_mask=tissue_mask, image_key=None @@ -263,7 +263,7 @@ def apply_decomposition( def decompose_to_concentrations( - image_rgb: xr.DataArray, stain_matrix: np.ndarray, background_intensity: np.ndarray + image_rgb: xr.DataArray, stain_matrix: np.ndarray, white_point: np.ndarray ) -> xr.DataArray: """Project an image onto a stain matrix, returning a 3-channel concentration image. @@ -271,7 +271,7 @@ def decompose_to_concentrations( concentration along the complement vector and is a diagnostic, not a stain. """ _check_channel_dim(image_rgb) - sda = rgb_to_sda(image_rgb, background_intensity) + sda = rgb_to_sda(image_rgb, white_point) dtype = _working_dtype(sda) pinv = np.linalg.pinv(stain_matrix) return _apply_along_channel(sda, _matmul_kernel, out_dtype=dtype, matrix=pinv.astype(dtype), dtype=dtype) diff --git a/src/squidpy/experimental/im/_stain/_mask.py b/src/squidpy/experimental/im/_stain/_mask.py index 3e35cee16..833e3da18 100644 --- a/src/squidpy/experimental/im/_stain/_mask.py +++ b/src/squidpy/experimental/im/_stain/_mask.py @@ -98,7 +98,7 @@ def foreground_mask_from_sda(sda: xr.DataArray, beta: float = 0.15) -> xr.DataAr return sda.mean(dim="c") > beta -def absorbance_foreground_mask(rgb: xr.DataArray, background_intensity: np.ndarray, beta: float = 0.15) -> xr.DataArray: +def absorbance_foreground_mask(rgb: xr.DataArray, white_point: np.ndarray, beta: float = 0.15) -> xr.DataArray: """Boolean tissue mask in optical-density (absorbance) space. The convention the Macenko/Vahadane fits expect: a pixel is tissue if its @@ -109,7 +109,7 @@ def absorbance_foreground_mask(rgb: xr.DataArray, background_intensity: np.ndarr ---------- rgb Image with a ``"c"`` dimension of length 3. Numpy- or dask-backed. - background_intensity + white_point Per-channel white point ``I_0`` (shape ``(3,)``), as used by :func:`~squidpy.experimental.im._stain._conversion.rgb_to_sda`. beta @@ -120,4 +120,4 @@ def absorbance_foreground_mask(rgb: xr.DataArray, background_intensity: np.ndarr Boolean ``(y, x)`` DataArray: ``True`` = tissue. Lazy if ``rgb`` was lazy. """ _check_channel_dim(rgb) - return foreground_mask_from_sda(rgb_to_sda(rgb, background_intensity), beta) + return foreground_mask_from_sda(rgb_to_sda(rgb, white_point), beta) diff --git a/src/squidpy/experimental/im/_stain/_normalize.py b/src/squidpy/experimental/im/_stain/_normalize.py index 76c619755..695bb554d 100644 --- a/src/squidpy/experimental/im/_stain/_normalize.py +++ b/src/squidpy/experimental/im/_stain/_normalize.py @@ -22,7 +22,6 @@ from spatialdata.transformations import get_transformation from squidpy._utils import _get_scale_factors -from squidpy.experimental.im._stain._background import DEFAULT_BACKGROUND_INTENSITY from squidpy.experimental.im._stain._conversion import _check_channel_dim from squidpy.experimental.im._stain._decomposition import ( MacenkoParams, @@ -40,6 +39,7 @@ apply_reinhard, fit_reinhard, ) +from squidpy.experimental.im._stain._white_point import DEFAULT_WHITE_POINT from squidpy.experimental.im._utils import ( _choose_label_scale_for_image, get_element_data, @@ -164,7 +164,7 @@ def fit_stain_reference( method: StainMethod = "reinhard", scale: str | Literal["auto"] = "auto", method_params: MethodParams = None, - background_intensity: np.ndarray | None = None, + white_point: np.ndarray | None = None, tissue_mask_key: str | None = None, ) -> StainReference: """Fit a stain reference from an image in a :class:`~spatialdata.SpatialData` object. @@ -185,11 +185,11 @@ def fit_stain_reference( A :class:`ReinhardParams`/:class:`MacenkoParams`/:class:`VahadaneParams` instance, a mapping of its fields, or ``None`` for defaults. Must match ``method``. - background_intensity + white_point Per-channel white point ``I_0`` ``(3,)`` for the decomposition methods. If ``None``, a fixed full-white ``[255, 255, 255]`` is used (the HistomicsTK/Macenko convention), so unstained pixels round-trip to - white. Pass :func:`estimate_background_intensity` only for slides with a + white. Pass :func:`estimate_white_point` only for slides with a known non-white background. Ignored by Reinhard. tissue_mask_key Key of a tissue-label element in ``sdata.labels`` (as produced by @@ -209,15 +209,11 @@ def fit_stain_reference( tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, da, tissue_mask_key) if method == "reinhard": return fit_reinhard(da, params, tissue_mask=tissue_mask) - bg = ( - DEFAULT_BACKGROUND_INTENSITY.copy() - if background_intensity is None - else np.asarray(background_intensity, np.float64) - ) + bg = DEFAULT_WHITE_POINT.copy() if white_point is None else np.asarray(white_point, np.float64) return fit_decomposition(da, method, params, bg, tissue_mask=tissue_mask, image_key=image_key) -def apply_stain_normalization( +def normalize_stains( sdata: sd.SpatialData, image_key: str, reference: StainReference, @@ -299,7 +295,7 @@ def decompose_stains( *, scale: str | Literal["auto"] = "auto", method_params: MethodParams = None, - background_intensity: np.ndarray | None = None, + white_point: np.ndarray | None = None, image_key_added: str | None = None, tissue_mask_key: str | None = None, include_residual: bool = True, @@ -315,7 +311,7 @@ def decompose_stains( white point are used) or a method name (``"macenko"``/``"vahadane"``) to fit on this image first. The reference is the provenance record of how the maps were produced (method, stain matrix, white point). - scale, method_params, background_intensity, tissue_mask_key + scale, method_params, white_point, tissue_mask_key As for :func:`fit_stain_reference` (only used when a method name is given; a reference is projected as-is and needs no tissue mask). image_key_added @@ -342,7 +338,7 @@ def decompose_stains( reference = reference_or_method if reference.method not in _DECOMPOSITION_METHODS or reference.stain_matrix is None: raise ValueError("decompose_stains requires a macenko/vahadane reference with a stain matrix.") - stain_matrix, bg = reference.stain_matrix, reference.background_intensity + stain_matrix, bg = reference.stain_matrix, reference.white_point else: if reference_or_method not in _DECOMPOSITION_METHODS: raise ValueError(f"method must be one of {list(_DECOMPOSITION_METHODS)}; got {reference_or_method!r}.") @@ -352,10 +348,10 @@ def decompose_stains( method=reference_or_method, scale=scale, method_params=method_params, - background_intensity=background_intensity, + white_point=white_point, tissue_mask_key=tissue_mask_key, ) - stain_matrix, bg = reference.stain_matrix, reference.background_intensity + stain_matrix, bg = reference.stain_matrix, reference.white_point concentrations = decompose_to_concentrations(da, stain_matrix, bg).assign_coords(c=_CONCENTRATION_CHANNELS) names = ["hematoxylin", "eosin"] + (["residual"] if include_residual else []) diff --git a/src/squidpy/experimental/im/_stain/_reference.py b/src/squidpy/experimental/im/_stain/_reference.py index 80687fe61..bf4e464b4 100644 --- a/src/squidpy/experimental/im/_stain/_reference.py +++ b/src/squidpy/experimental/im/_stain/_reference.py @@ -43,12 +43,12 @@ class StainReference: sigma Shape ``(3,)`` Ruderman Lab channel standard deviations. Reinhard only. - background_intensity + white_point Shape ``(3,)`` per-channel white-point estimate. Required for decomposition methods (apply consumes it). Forbidden for Reinhard because Reinhard's color transfer operates in Ruderman Lab and does not model absorbance. There is no universal default; pass an - estimate from your data (see ``estimate_background_intensity``). + estimate from your data (see ``estimate_white_point``). max_concentrations Shape ``(2,)`` reference per-stain (H, E) maximum concentrations. Decomposition only. Stored because at apply time the reference image @@ -62,7 +62,7 @@ class StainReference: stain_matrix: np.ndarray | None = None mu: np.ndarray | None = None sigma: np.ndarray | None = None - background_intensity: np.ndarray | None = None + white_point: np.ndarray | None = None max_concentrations: np.ndarray | None = None def __eq__(self, other: object) -> bool: @@ -75,7 +75,7 @@ def __eq__(self, other: object) -> bool: return False return all( np.array_equal(getattr(self, name), getattr(other, name)) - for name in ("stain_matrix", "mu", "sigma", "background_intensity", "max_concentrations") + for name in ("stain_matrix", "mu", "sigma", "white_point", "max_concentrations") ) # eq=False keeps the default identity-based __hash__ (the array fields are @@ -92,17 +92,17 @@ def __post_init__(self) -> None: raise ValueError(f"method={self.method!r} requires stain_matrix.") if self.mu is not None or self.sigma is not None: raise ValueError(f"method={self.method!r} forbids mu/sigma; pass them only for Reinhard.") - if self.background_intensity is None: - raise ValueError(f"method={self.method!r} requires background_intensity.") + if self.white_point is None: + raise ValueError(f"method={self.method!r} requires white_point.") object.__setattr__( self, "stain_matrix", _coerce_finite(self.stain_matrix, shape=(3, 3), name="stain_matrix"), ) - bg = _coerce_finite(self.background_intensity, shape=(3,), name="background_intensity") + bg = _coerce_finite(self.white_point, shape=(3,), name="white_point") if np.any(bg <= 0): - raise ValueError("background_intensity must be strictly positive.") - object.__setattr__(self, "background_intensity", bg) + raise ValueError("white_point must be strictly positive.") + object.__setattr__(self, "white_point", bg) if self.max_concentrations is not None: maxc = _coerce_finite(self.max_concentrations, shape=(2,), name="max_concentrations") if np.any(maxc <= 0): @@ -113,9 +113,9 @@ def __post_init__(self) -> None: raise ValueError("method='reinhard' requires both mu and sigma.") if self.stain_matrix is not None: raise ValueError("method='reinhard' forbids stain_matrix.") - if self.background_intensity is not None: + if self.white_point is not None: raise ValueError( - "method='reinhard' forbids background_intensity; Reinhard's color " + "method='reinhard' forbids white_point; Reinhard's color " "transfer is in Ruderman Lab and does not use a white point." ) if self.max_concentrations is not None: diff --git a/src/squidpy/experimental/im/_stain/_background.py b/src/squidpy/experimental/im/_stain/_white_point.py similarity index 85% rename from src/squidpy/experimental/im/_stain/_background.py rename to src/squidpy/experimental/im/_stain/_white_point.py index 23ff47e67..d4d97da71 100644 --- a/src/squidpy/experimental/im/_stain/_background.py +++ b/src/squidpy/experimental/im/_stain/_white_point.py @@ -18,12 +18,12 @@ #: literature (240). The absorbance origin must be at least as bright as the #: slide background, otherwise unstained pixels get a non-zero absorbance and #: cannot round-trip back to white. Estimate from the image (see -#: ``estimate_background_intensity``) only when the slide has a genuinely +#: ``estimate_white_point``) only when the slide has a genuinely #: non-white background you want to anchor to. -DEFAULT_BACKGROUND_INTENSITY: np.ndarray = np.array([255.0, 255.0, 255.0]) +DEFAULT_WHITE_POINT: np.ndarray = np.array([255.0, 255.0, 255.0]) -def estimate_background_intensity(rgb: xr.DataArray, *, percentile: float = 99.0) -> np.ndarray: +def estimate_white_point(rgb: xr.DataArray, *, percentile: float = 99.0) -> np.ndarray: """Estimate the per-channel white point from the brightest pixels. Parameters @@ -37,7 +37,7 @@ def estimate_background_intensity(rgb: xr.DataArray, *, percentile: float = 99.0 Returns ------- - Shape-``(3,)`` float64 white point, suitable as ``background_intensity`` + Shape-``(3,)`` float64 white point, suitable as ``white_point`` for :func:`~squidpy.experimental.im._stain._conversion.rgb_to_sda`. Notes @@ -62,6 +62,6 @@ def estimate_background_intensity(rgb: xr.DataArray, *, percentile: float = 99.0 if np.any(bg <= 0): raise StainFittingError( "estimated background intensity is non-positive; the image may be blank or all-tissue. " - "Pass an explicit `background_intensity` if this is expected." + "Pass an explicit `white_point` if this is expected." ) return bg diff --git a/tests/experimental/test_stain_decompose_public.py b/tests/experimental/test_stain_decompose_public.py index 58b461d03..ab0431a70 100644 --- a/tests/experimental/test_stain_decompose_public.py +++ b/tests/experimental/test_stain_decompose_public.py @@ -10,10 +10,10 @@ import squidpy as sq from squidpy.experimental.im import ( StainReference, - apply_stain_normalization, decompose_stains, - estimate_background_intensity, + estimate_white_point, fit_stain_reference, + normalize_stains, ) from squidpy.experimental.im._stain._constants import RUIFROK_HE from squidpy.experimental.im._stain._conversion import sda_to_rgb @@ -49,19 +49,19 @@ def _make_sdata(values: np.ndarray, *, with_tissue: bool = True) -> sd.SpatialDa class TestDecompositionThroughDispatchers: def test_fit_and_apply_end_to_end(self, method: str) -> None: sdata = _make_sdata(_synthetic_rgb(seed=1)) - ref = fit_stain_reference(sdata, "img", method=method, background_intensity=_WHITE) + ref = fit_stain_reference(sdata, "img", method=method, white_point=_WHITE) assert ref.method == method assert ref.stain_matrix.shape == (3, 3) assert ref.max_concentrations.shape == (2,) - out = apply_stain_normalization(sdata, "img", ref) + out = normalize_stains(sdata, "img", ref) assert isinstance(out, xr.DataArray) assert out.sizes["c"] == 3 def test_apply_writes_back(self, method: str) -> None: sdata = _make_sdata(_synthetic_rgb(seed=2)) - ref = fit_stain_reference(sdata, "img", method=method, background_intensity=_WHITE) - result = apply_stain_normalization(sdata, "img", ref, image_key_added="norm") + ref = fit_stain_reference(sdata, "img", method=method, white_point=_WHITE) + result = normalize_stains(sdata, "img", ref, image_key_added="norm") assert result is None assert get_transformation(sdata.images["norm"], get_all=True).keys() == ( get_transformation(sdata.images["img"], get_all=True).keys() @@ -71,18 +71,18 @@ def test_apply_writes_back(self, method: str) -> None: class TestDecomposeStains: def test_returns_named_concentration_maps(self) -> None: sdata = _make_sdata(_synthetic_rgb()) - conc = decompose_stains(sdata, "img", "macenko", background_intensity=_WHITE) + conc = decompose_stains(sdata, "img", "macenko", white_point=_WHITE) assert set(conc) == {"hematoxylin", "eosin", "residual"} assert all(set(c.dims) == {"y", "x"} for c in conc.values()) # one (y, x) map per stain def test_drop_residual(self) -> None: sdata = _make_sdata(_synthetic_rgb()) - conc = decompose_stains(sdata, "img", "macenko", background_intensity=_WHITE, include_residual=False) + conc = decompose_stains(sdata, "img", "macenko", white_point=_WHITE, include_residual=False) assert set(conc) == {"hematoxylin", "eosin"} def test_with_reference_writes_separate_images(self) -> None: sdata = _make_sdata(_synthetic_rgb()) - ref = fit_stain_reference(sdata, "img", method="macenko", background_intensity=_WHITE) + ref = fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE) out = decompose_stains(sdata, "img", ref, image_key_added="conc") assert out is None for stain in ("hematoxylin", "eosin", "residual"): @@ -106,18 +106,18 @@ def test_fit_defaults_to_white_when_absent(self) -> None: sdata = _make_sdata(_synthetic_rgb()) ref = fit_stain_reference(sdata, "img", method="macenko") # default I_0 is a fixed full-white point, not an image-derived estimate - np.testing.assert_array_equal(ref.background_intensity, [255.0, 255.0, 255.0]) + np.testing.assert_array_equal(ref.white_point, [255.0, 255.0, 255.0]) def test_explicit_background_is_used(self) -> None: I0 = np.array([240.0, 245.0, 250.0]) # build the synthetic image against this white point so the fit is consistent sdata = _make_sdata(_synthetic_rgb(white=I0)) - ref = fit_stain_reference(sdata, "img", method="vahadane", background_intensity=I0) - np.testing.assert_array_equal(ref.background_intensity, I0) + ref = fit_stain_reference(sdata, "img", method="vahadane", white_point=I0) + np.testing.assert_array_equal(ref.white_point, I0) def test_estimate_background_public(self) -> None: sdata = _make_sdata(_synthetic_rgb()) - bg = estimate_background_intensity(sdata.images["img"]) + bg = estimate_white_point(sdata.images["img"]) assert bg.shape == (3,) @@ -143,7 +143,7 @@ def test_fit_apply_decompose_smoke(self, sdata_hne, method: str) -> None: ref = sq.experimental.im.fit_stain_reference(sdata_hne, image_key, method=method) assert isinstance(ref, StainReference) assert ref.stain_matrix.shape == (3, 3) - normalized = sq.experimental.im.apply_stain_normalization(sdata_hne, image_key, ref) + normalized = sq.experimental.im.normalize_stains(sdata_hne, image_key, ref) assert normalized.sizes["c"] == 3 conc = sq.experimental.im.decompose_stains(sdata_hne, image_key, ref) assert set(conc) == {"hematoxylin", "eosin", "residual"} diff --git a/tests/experimental/test_stain_decomposition.py b/tests/experimental/test_stain_decomposition.py index 686c0daa2..04f060b51 100644 --- a/tests/experimental/test_stain_decomposition.py +++ b/tests/experimental/test_stain_decomposition.py @@ -93,7 +93,7 @@ def test_missing_max_concentrations_raises(self) -> None: ref = StainReference( method="macenko", stain_matrix=_canonical(RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]), - background_intensity=_WHITE, + white_point=_WHITE, ) img = _synthetic_he(ref.stain_matrix) with pytest.raises(ValueError, match="max_concentrations"): diff --git a/tests/experimental/test_stain_normalize.py b/tests/experimental/test_stain_normalize.py index 27e8d9340..d25d60e91 100644 --- a/tests/experimental/test_stain_normalize.py +++ b/tests/experimental/test_stain_normalize.py @@ -14,8 +14,8 @@ from squidpy.experimental.im import ( ReinhardParams, StainReference, - apply_stain_normalization, fit_stain_reference, + normalize_stains, ) from squidpy.experimental.im._utils import get_element_data from tests.conftest import PlotTester, PlotTesterMeta @@ -62,7 +62,7 @@ class TestApplyStainNormalization: def test_returns_lazy_and_leaves_sdata_untouched(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) ref = fit_stain_reference(sdata, "img") - out = apply_stain_normalization(sdata, "img", ref) + out = normalize_stains(sdata, "img", ref) assert isinstance(out, xr.DataArray) assert isinstance(out.data, da.Array) assert list(sdata.images.keys()) == ["img"] @@ -70,7 +70,7 @@ def test_returns_lazy_and_leaves_sdata_untouched(self, rgb_values: np.ndarray) - def test_writes_and_preserves_transform_and_dims(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) ref = fit_stain_reference(sdata, "img") - result = apply_stain_normalization(sdata, "img", ref, image_key_added="norm") + result = normalize_stains(sdata, "img", ref, image_key_added="norm") assert result is None assert "norm" in sdata.images out = sdata.images["norm"] @@ -83,7 +83,7 @@ def test_writes_and_preserves_transform_and_dims(self, rgb_values: np.ndarray) - def test_multiscale_rebuilds_pyramid(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values, scale_factors=[2]) ref = fit_stain_reference(sdata, "img") - apply_stain_normalization(sdata, "img", ref, image_key_added="norm") + normalize_stains(sdata, "img", ref, image_key_added="norm") src, out = sdata.images["img"], sdata.images["norm"] assert hasattr(out, "keys") src_shapes = [src[k].image.shape for k in src] @@ -97,7 +97,7 @@ def test_preserves_channel_coords_and_nonidentity_transform(self, rgb_values: np h, w = rgb_values.shape[-2], rgb_values.shape[-1] sdata.labels["img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) ref = fit_stain_reference(sdata, "img") - apply_stain_normalization(sdata, "img", ref, image_key_added="norm") + normalize_stains(sdata, "img", ref, image_key_added="norm") out = sdata.images["norm"] assert list(out.coords["c"].values) == ["r", "g", "b"] assert get_transformation(out, get_all=True) == get_transformation(img, get_all=True) @@ -106,22 +106,22 @@ def test_existing_key_raises(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) ref = fit_stain_reference(sdata, "img") with pytest.raises(ValueError, match="already exists"): - apply_stain_normalization(sdata, "img", ref, image_key_added="img") + normalize_stains(sdata, "img", ref, image_key_added="img") def test_decomposition_reference_without_max_concentrations_raises(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) ref = StainReference( method="macenko", stain_matrix=np.eye(3), - background_intensity=np.array([255.0, 255.0, 255.0]), + white_point=np.array([255.0, 255.0, 255.0]), ) with pytest.raises(ValueError, match="max_concentrations"): - apply_stain_normalization(sdata, "img", ref) + normalize_stains(sdata, "img", ref) def test_method_params_mapping(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) ref = fit_stain_reference(sdata, "img", method_params={"mask_background": False}) - out = apply_stain_normalization(sdata, "img", ref, method_params=ReinhardParams(mask_background=False)) + out = normalize_stains(sdata, "img", ref, method_params=ReinhardParams(mask_background=False)) assert isinstance(out, xr.DataArray) @@ -136,7 +136,7 @@ def test_apply_requires_tissue_mask(self, rgb_values: np.ndarray) -> None: ref = fit_stain_reference(sdata, "img") del sdata.labels["img_tissue"] # ... but now the source has none with pytest.raises(KeyError, match="detect_tissue"): - apply_stain_normalization(sdata, "img", ref) + normalize_stains(sdata, "img", ref) def test_explicit_missing_key_raises(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) @@ -174,8 +174,8 @@ def test_background_passthrough_vs_full_frame(self, rgb_values: np.ndarray) -> N ref = fit_stain_reference(sdata, "ref_img") original = get_element_data(sdata.images["img"], "auto", "image", "img").values - kept = apply_stain_normalization(sdata, "img", ref).values # preserve_background=True (default) - full = apply_stain_normalization(sdata, "img", ref, preserve_background=False).values + kept = normalize_stains(sdata, "img", ref).values # preserve_background=True (default) + full = normalize_stains(sdata, "img", ref, preserve_background=False).values bg = slice(h // 2, None) np.testing.assert_allclose(kept[:, bg], original[:, bg]) # background untouched @@ -188,7 +188,7 @@ def test_fit_apply_smoke(self, sdata_hne) -> None: sq.experimental.im.detect_tissue(sdata_hne, image_key) ref = sq.experimental.im.fit_stain_reference(sdata_hne, image_key) assert ref.method == "reinhard" - out = sq.experimental.im.apply_stain_normalization(sdata_hne, image_key, ref) + out = sq.experimental.im.normalize_stains(sdata_hne, image_key, ref) assert "c" in out.dims assert out.sizes["c"] == 3 @@ -208,7 +208,7 @@ def test_plot_reinhard_before_after(self, sdata_hne) -> None: sdata_hne.images["hne_shifted"] = Image2DModel.parse(shifted.data, dims=shifted.dims) # `hne_shifted` shares geometry with `image_key`; reuse its tissue mask. - apply_stain_normalization( + normalize_stains( sdata_hne, "hne_shifted", reference, image_key_added="hne_normalized", tissue_mask_key=f"{image_key}_tissue" ) diff --git a/tests/experimental/test_stain_reference.py b/tests/experimental/test_stain_reference.py index e8f5999d2..93783c276 100644 --- a/tests/experimental/test_stain_reference.py +++ b/tests/experimental/test_stain_reference.py @@ -21,19 +21,19 @@ def test_macenko_basic() -> None: ref = StainReference( method="macenko", stain_matrix=_ruifrok_matrix(), - background_intensity=_TEST_BACKGROUND, + white_point=_TEST_BACKGROUND, ) assert ref.method == "macenko" assert ref.stain_matrix.shape == (3, 3) assert ref.mu is None and ref.sigma is None - np.testing.assert_array_equal(ref.background_intensity, _TEST_BACKGROUND) + np.testing.assert_array_equal(ref.white_point, _TEST_BACKGROUND) def test_reinhard_basic() -> None: ref = StainReference(method="reinhard", mu=np.array([1.0, 0.5, -0.2]), sigma=np.array([0.1, 0.1, 0.1])) assert ref.method == "reinhard" assert ref.stain_matrix is None - assert ref.background_intensity is None + assert ref.white_point is None def test_unknown_method_raises() -> None: @@ -43,11 +43,11 @@ def test_unknown_method_raises() -> None: def test_decomposition_requires_stain_matrix() -> None: with pytest.raises(ValueError, match="requires stain_matrix"): - StainReference(method="macenko", background_intensity=_TEST_BACKGROUND) + StainReference(method="macenko", white_point=_TEST_BACKGROUND) -def test_decomposition_requires_background_intensity() -> None: - with pytest.raises(ValueError, match="requires background_intensity"): +def test_decomposition_requires_white_point() -> None: + with pytest.raises(ValueError, match="requires white_point"): StainReference(method="macenko", stain_matrix=_ruifrok_matrix()) @@ -56,7 +56,7 @@ def test_decomposition_forbids_mu_sigma() -> None: StainReference( method="macenko", stain_matrix=_ruifrok_matrix(), - background_intensity=_TEST_BACKGROUND, + white_point=_TEST_BACKGROUND, mu=np.zeros(3), sigma=np.ones(3), ) @@ -82,22 +82,22 @@ def test_reinhard_forbids_stain_matrix() -> None: ) -def test_reinhard_forbids_background_intensity() -> None: - with pytest.raises(ValueError, match="forbids background_intensity"): +def test_reinhard_forbids_white_point() -> None: + with pytest.raises(ValueError, match="forbids white_point"): StainReference( method="reinhard", mu=np.zeros(3), sigma=np.ones(3), - background_intensity=_TEST_BACKGROUND, + white_point=_TEST_BACKGROUND, ) -def test_bad_background_intensity() -> None: - with pytest.raises(ValueError, match="background_intensity"): +def test_bad_white_point() -> None: + with pytest.raises(ValueError, match="white_point"): StainReference( method="macenko", stain_matrix=_ruifrok_matrix(), - background_intensity=np.array([255.0, -1.0, 255.0]), + white_point=np.array([255.0, -1.0, 255.0]), ) @@ -106,7 +106,7 @@ def test_rejects_bad_shape() -> None: StainReference( method="macenko", stain_matrix=np.zeros((2, 3)), - background_intensity=_TEST_BACKGROUND, + white_point=_TEST_BACKGROUND, ) diff --git a/tests/experimental/test_stain_background.py b/tests/experimental/test_stain_white_point.py similarity index 77% rename from tests/experimental/test_stain_background.py rename to tests/experimental/test_stain_white_point.py index 9043cc91a..0cdec8133 100644 --- a/tests/experimental/test_stain_background.py +++ b/tests/experimental/test_stain_white_point.py @@ -5,8 +5,8 @@ import pytest import xarray as xr -from squidpy.experimental.im._stain._background import estimate_background_intensity from squidpy.experimental.im._stain._validation import StainFittingError +from squidpy.experimental.im._stain._white_point import estimate_white_point def _da(values: np.ndarray, *, chunked: bool) -> xr.DataArray: @@ -23,7 +23,7 @@ def test_recovers_white_point(chunked: bool) -> None: values[1] = 245.0 values[2] = 250.0 values[:, :8, :8] = rng.uniform(20.0, 60.0, size=(3, 8, 8)) # tissue - bg = estimate_background_intensity(_da(values, chunked=chunked)) + bg = estimate_white_point(_da(values, chunked=chunked)) assert bg.shape == (3,) np.testing.assert_allclose(bg, [240.0, 245.0, 250.0], atol=1.0) @@ -31,9 +31,9 @@ def test_recovers_white_point(chunked: bool) -> None: def test_blank_image_raises() -> None: black = np.zeros((3, 16, 16)) with pytest.raises(StainFittingError, match="non-positive"): - estimate_background_intensity(_da(black, chunked=False)) + estimate_white_point(_da(black, chunked=False)) def test_bad_percentile_raises() -> None: with pytest.raises(ValueError, match="percentile"): - estimate_background_intensity(_da(np.ones((3, 8, 8)), chunked=False), percentile=0.0) + estimate_white_point(_da(np.ones((3, 8, 8)), chunked=False), percentile=0.0) From b1ceec485f22fb8363a69d941755b0aed9de5987 Mon Sep 17 00:00:00 2001 From: anon Date: Mon, 1 Jun 2026 22:41:28 +0200 Subject: [PATCH 3/8] Stain normalization: white point + bit-depth handling (step 2/5) - Default white point is now dtype-aware: dtype_max() gives the full-white value (255 / 65535 / 1.0) and default_white_point() raises with guidance when the data clearly doesn't match its dtype's range (8-bit-in-uint16, 0-255 float). - Bit-depth-agnostic reconstruction: sda_to_rgb / lab_ruderman_to_rgb take an out_dtype and clip to that dtype's valid range (dtype_max) rather than a hardcoded 255 - threaded from the source image dtype through apply. Per review, the dtype (not a derived max_value float) is the threaded parameter. - estimate_white_point is now sdata-level and samples the per-channel MEDIAN over non-tissue (background) pixels via the tissue mask (HistomicsTK semantics), replacing the whole-image percentile that under-estimated on dim scans. - Renamed _background.py docstring/semantics to white-point; test fixtures are uint8 (real H&E) rather than float-0-255 (which the [0,1]-float convention would flag). 125 stain tests pass. Co-Authored-By: Claude Opus 4.8 --- .../experimental/im/_stain/__init__.py | 2 +- .../experimental/im/_stain/_conversion.py | 41 +++++--- .../experimental/im/_stain/_decomposition.py | 3 +- .../experimental/im/_stain/_normalize.py | 45 ++++++++- .../experimental/im/_stain/_reinhard.py | 3 +- .../experimental/im/_stain/_white_point.py | 94 ++++++++++--------- .../test_stain_decompose_public.py | 8 +- tests/experimental/test_stain_normalize.py | 2 +- tests/experimental/test_stain_white_point.py | 86 ++++++++++++----- 9 files changed, 188 insertions(+), 96 deletions(-) diff --git a/src/squidpy/experimental/im/_stain/__init__.py b/src/squidpy/experimental/im/_stain/__init__.py index ffaad6a33..df7d354d0 100644 --- a/src/squidpy/experimental/im/_stain/__init__.py +++ b/src/squidpy/experimental/im/_stain/__init__.py @@ -27,6 +27,7 @@ ) from squidpy.experimental.im._stain._normalize import ( decompose_stains, + estimate_white_point, fit_stain_reference, normalize_stains, ) @@ -42,7 +43,6 @@ reorder_to_canonical, validate_stain_matrix, ) -from squidpy.experimental.im._stain._white_point import estimate_white_point __all__ = [ "DEFAULT_LUMINOSITY_THRESHOLD", diff --git a/src/squidpy/experimental/im/_stain/_conversion.py b/src/squidpy/experimental/im/_stain/_conversion.py index 00fda7500..98bca5e5f 100644 --- a/src/squidpy/experimental/im/_stain/_conversion.py +++ b/src/squidpy/experimental/im/_stain/_conversion.py @@ -24,6 +24,16 @@ _CHANNEL_DIM = "c" +def dtype_max(dtype: np.dtype | type) -> float: + """Valid-intensity upper bound for an image dtype (255 / 65535 / 1.0). + + Integer dtypes use their full range; float RGB is assumed to live in + ``[0, 1]``. + """ + dt = np.dtype(dtype) + return float(np.iinfo(dt).max) if np.issubdtype(dt, np.integer) else 1.0 + + def _check_channel_dim(arr: xr.DataArray) -> None: if _CHANNEL_DIM not in arr.dims: raise ValueError(f"Input must have a dimension named {_CHANNEL_DIM!r}; got dims {arr.dims}.") @@ -68,9 +78,9 @@ def _rgb_to_sda_kernel(x: np.ndarray, *, bg: np.ndarray, dtype: np.dtype) -> np. return (-np.log((x + 1.0) / (bg + 1.0)) * SDA_SCALE).astype(dtype, copy=False) -def _sda_to_rgb_kernel(x: np.ndarray, *, bg: np.ndarray, dtype: np.dtype) -> np.ndarray: +def _sda_to_rgb_kernel(x: np.ndarray, *, bg: np.ndarray, max_value: float, dtype: np.dtype) -> np.ndarray: rgb = (bg + 1.0) * np.exp(-x.astype(dtype, copy=False) / SDA_SCALE) - 1.0 - np.clip(rgb, 0.0, 255.0, out=rgb) + np.clip(rgb, 0.0, max_value, out=rgb) return rgb.astype(dtype, copy=False) @@ -81,14 +91,14 @@ def _rgb_to_lab_kernel(x: np.ndarray, *, dtype: np.dtype) -> np.ndarray: return (lms @ RUDERMAN_LMS_TO_LAB.T.astype(dtype, copy=False)).astype(dtype, copy=False) -def _lab_to_rgb_kernel(x: np.ndarray, *, dtype: np.dtype) -> np.ndarray: +def _lab_to_rgb_kernel(x: np.ndarray, *, max_value: float, dtype: np.dtype) -> np.ndarray: x = x.astype(dtype, copy=False) log_lms = x @ RUDERMAN_LAB_TO_LMS.T.astype(dtype, copy=False) # The +1.0 / -1.0 pair is paired with the matching offset in # `_rgb_to_lab_kernel` so the round trip remains exact for valid RGB. lms = np.exp(log_lms) - 1.0 rgb = lms @ RUDERMAN_LMS_TO_RGB.T.astype(dtype, copy=False) - np.clip(rgb, 0.0, 255.0, out=rgb) + np.clip(rgb, 0.0, max_value, out=rgb) return rgb.astype(dtype, copy=False) @@ -132,17 +142,22 @@ def rgb_to_sda( def sda_to_rgb( sda: xr.DataArray, white_point: np.ndarray, + *, + out_dtype: np.dtype | type = np.uint8, ) -> xr.DataArray: - """Convert SDA back to RGB intensities in ``[0, 255]``. + """Convert SDA back to RGB, clipped to ``out_dtype``'s valid range. - Inverse of :func:`rgb_to_sda`. Pass the same ``white_point`` - used at encode time. The result is clipped to ``[0, 255]`` but kept in - float dtype; uint8 conversion is the caller's choice. + Inverse of :func:`rgb_to_sda`. Pass the same ``white_point`` used at encode + time. ``out_dtype`` is the eventual image dtype: the reconstruction is + clipped to that dtype's valid range (``dtype_max`` = 255 / 65535 / 1.0) but + kept in float; the final cast to ``out_dtype`` is the caller's choice. """ _check_channel_dim(sda) dtype = _working_dtype(sda) bg = np.asarray(white_point, dtype=dtype) - return _apply_along_channel(sda, _sda_to_rgb_kernel, out_dtype=dtype, bg=bg, dtype=dtype) + return _apply_along_channel( + sda, _sda_to_rgb_kernel, out_dtype=dtype, bg=bg, max_value=dtype_max(out_dtype), dtype=dtype + ) def rgb_to_lab_ruderman(rgb: xr.DataArray) -> xr.DataArray: @@ -161,12 +176,12 @@ def rgb_to_lab_ruderman(rgb: xr.DataArray) -> xr.DataArray: return _apply_along_channel(rgb, _rgb_to_lab_kernel, out_dtype=dtype, dtype=dtype) -def lab_ruderman_to_rgb(lab: xr.DataArray) -> xr.DataArray: +def lab_ruderman_to_rgb(lab: xr.DataArray, *, out_dtype: np.dtype | type = np.uint8) -> xr.DataArray: """Inverse of :func:`rgb_to_lab_ruderman`. - Returns RGB clipped to ``[0, 255]`` in float dtype; uint8 conversion is - the caller's choice. + Clips the reconstruction to ``out_dtype``'s valid range (the eventual image + dtype) but keeps it in float; the final cast is the caller's choice. """ _check_channel_dim(lab) dtype = _working_dtype(lab) - return _apply_along_channel(lab, _lab_to_rgb_kernel, out_dtype=dtype, dtype=dtype) + return _apply_along_channel(lab, _lab_to_rgb_kernel, out_dtype=dtype, max_value=dtype_max(out_dtype), dtype=dtype) diff --git a/src/squidpy/experimental/im/_stain/_decomposition.py b/src/squidpy/experimental/im/_stain/_decomposition.py index a17b53e73..c9e1a6edc 100644 --- a/src/squidpy/experimental/im/_stain/_decomposition.py +++ b/src/squidpy/experimental/im/_stain/_decomposition.py @@ -228,6 +228,7 @@ def apply_decomposition( *, fit_rgb: xr.DataArray | None = None, tissue_mask: np.ndarray | None = None, + out_dtype: np.dtype | type = np.uint8, ) -> xr.DataArray: """Normalize a source image to a decomposition reference. @@ -259,7 +260,7 @@ def apply_decomposition( sda = rgb_to_sda(image_rgb, bg) dtype = _working_dtype(sda) sda_out = _apply_along_channel(sda, _matmul_kernel, out_dtype=dtype, matrix=operator.astype(dtype), dtype=dtype) - return sda_to_rgb(sda_out, bg) + return sda_to_rgb(sda_out, bg, out_dtype=out_dtype) def decompose_to_concentrations( diff --git a/src/squidpy/experimental/im/_stain/_normalize.py b/src/squidpy/experimental/im/_stain/_normalize.py index 695bb554d..272e06242 100644 --- a/src/squidpy/experimental/im/_stain/_normalize.py +++ b/src/squidpy/experimental/im/_stain/_normalize.py @@ -39,7 +39,7 @@ apply_reinhard, fit_reinhard, ) -from squidpy.experimental.im._stain._white_point import DEFAULT_WHITE_POINT +from squidpy.experimental.im._stain._white_point import default_white_point, white_point_from_background from squidpy.experimental.im._utils import ( _choose_label_scale_for_image, get_element_data, @@ -157,6 +157,40 @@ def _write_image( ) +def estimate_white_point( + sdata: sd.SpatialData, + image_key: str, + *, + tissue_mask_key: str | None = None, + scale: str | Literal["auto"] = "auto", +) -> np.ndarray: + """Estimate the white point ``I_0`` from a slide's background (non-tissue median). + + Opt-in alternative to the fixed dtype-aware default white point, for a slide + whose unstained background is genuinely not full white. Samples the + per-channel median over **non-tissue** pixels (background = the complement of + the :func:`~squidpy.experimental.im.detect_tissue` mask). + + Parameters + ---------- + sdata, image_key + The SpatialData object and the RGB image key. + tissue_mask_key + Tissue-label element key (defaults to ``f"{image_key}_tissue"``); a + tissue mask is required, as for :func:`fit_stain_reference`. + scale + Scale level to sample on. ``"auto"`` (default) uses the coarsest level. + + Returns + ------- + Shape-``(3,)`` white point; pass it as ``white_point`` to + :func:`fit_stain_reference` / :func:`decompose_stains`. + """ + da = _resolve_image(sdata, image_key, scale, prefer="coarsest") + tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, da, tissue_mask_key) + return white_point_from_background(da, ~tissue_mask) + + def fit_stain_reference( sdata: sd.SpatialData, image_key: str, @@ -209,7 +243,7 @@ def fit_stain_reference( tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, da, tissue_mask_key) if method == "reinhard": return fit_reinhard(da, params, tissue_mask=tissue_mask) - bg = DEFAULT_WHITE_POINT.copy() if white_point is None else np.asarray(white_point, np.float64) + bg = default_white_point(da) if white_point is None else np.asarray(white_point, np.float64) return fit_decomposition(da, method, params, bg, tissue_mask=tissue_mask, image_key=image_key) @@ -270,10 +304,13 @@ def normalize_stains( # then applied to the full-resolution `da`. fit_rgb = _resolve_image(sdata, image_key, scale, prefer="coarsest") tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, fit_rgb, tissue_mask_key) + out_dtype = da.dtype # reconstruct into the source image's dtype (clip + cast happen together) if reference.method == "reinhard": - normalized = apply_reinhard(da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask) + normalized = apply_reinhard(da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask, out_dtype=out_dtype) else: - normalized = apply_decomposition(da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask) + normalized = apply_decomposition( + da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask, out_dtype=out_dtype + ) if preserve_background: # Keep non-tissue pixels byte-identical to the source: the global colour diff --git a/src/squidpy/experimental/im/_stain/_reinhard.py b/src/squidpy/experimental/im/_stain/_reinhard.py index cae89faed..5c0021c5d 100644 --- a/src/squidpy/experimental/im/_stain/_reinhard.py +++ b/src/squidpy/experimental/im/_stain/_reinhard.py @@ -146,6 +146,7 @@ def apply_reinhard( *, fit_rgb: xr.DataArray | None = None, tissue_mask: np.ndarray | None = None, + out_dtype: np.dtype | type = np.uint8, ) -> xr.DataArray: """Apply a Reinhard reference to a source image. @@ -175,4 +176,4 @@ def apply_reinhard( sigma_ref=np.asarray(reference.sigma, dtype=dtype), dtype=dtype, ) - return lab_ruderman_to_rgb(lab_out) + return lab_ruderman_to_rgb(lab_out, out_dtype=out_dtype) diff --git a/src/squidpy/experimental/im/_stain/_white_point.py b/src/squidpy/experimental/im/_stain/_white_point.py index d4d97da71..142e5d3e0 100644 --- a/src/squidpy/experimental/im/_stain/_white_point.py +++ b/src/squidpy/experimental/im/_stain/_white_point.py @@ -1,8 +1,13 @@ -"""Background (white-point) intensity estimation for absorbance methods. +"""White-point (``I_0``) handling for the absorbance methods. -The decomposition methods convert RGB to absorbance against a per-channel -white point ``I_0``. Rather than assume pure white (255), estimate it from the -brightest pixels of the slide, which are the unstained background. +The decomposition methods measure absorbance against a per-channel white point +``I_0`` (the intensity that counts as fully unstained). The default is a fixed +full-white reference at the image's bit depth (255 for uint8, 65535 for uint16, +1.0 for float), matching HistomicsTK (255/256) and the Macenko literature (240): +``I_0`` must be at least as bright as the slide background, otherwise unstained +pixels get a non-zero absorbance and cannot round-trip back to white. Use +:func:`estimate_white_point` only for a slide with a genuinely non-white +background you want to anchor to. """ from __future__ import annotations @@ -10,58 +15,59 @@ import numpy as np import xarray as xr -from squidpy.experimental.im._stain._conversion import _check_channel_dim +from squidpy.experimental.im._stain._conversion import _check_channel_dim, dtype_max from squidpy.experimental.im._stain._validation import StainFittingError -#: Default per-channel white point ``I_0`` for the absorbance methods. A fixed -#: full-white reference (8-bit), matching HistomicsTK (255/256) and the Macenko -#: literature (240). The absorbance origin must be at least as bright as the -#: slide background, otherwise unstained pixels get a non-zero absorbance and -#: cannot round-trip back to white. Estimate from the image (see -#: ``estimate_white_point``) only when the slide has a genuinely -#: non-white background you want to anchor to. -DEFAULT_WHITE_POINT: np.ndarray = np.array([255.0, 255.0, 255.0]) +def default_white_point(rgb: xr.DataArray) -> np.ndarray: + """Dtype-aware default white point ``I_0`` (full white), with a range check. -def estimate_white_point(rgb: xr.DataArray, *, percentile: float = 99.0) -> np.ndarray: - """Estimate the per-channel white point from the brightest pixels. + Returns ``(3,)`` filled with the dtype's full-white value. Raises with + guidance when the data clearly does not match its dtype's range (e.g. 8-bit + values stored in a uint16 container, or 0-255 values stored as float), since + that would silently mis-scale the absorbance. + """ + m = dtype_max(rgb.dtype) + data_max = float(np.asarray(rgb.max())) + if np.issubdtype(rgb.dtype, np.integer): + if m >= 256 and data_max <= 255: + raise ValueError( + f"{rgb.dtype} image but the maximum value is {data_max:.0f} (<= 255) - this looks like " + f"8-bit data stored in a {rgb.dtype} container. Convert to uint8, or pass `white_point`." + ) + elif data_max > 1.5: + raise ValueError( + f"float image but the maximum value is {data_max:.1f} (> 1) - this looks like 0-255 data " + "stored as float. Rescale to [0, 1], or pass `white_point`." + ) + return np.full(3, m, dtype=np.float64) - Parameters - ---------- - rgb - Image with a ``"c"`` dimension of length 3. Numpy- or dask-backed. - percentile - Per-channel intensity percentile to take as the white point. The - default (99) picks near-saturated background while ignoring the few - truly-saturated outlier pixels. - Returns - ------- - Shape-``(3,)`` float64 white point, suitable as ``white_point`` - for :func:`~squidpy.experimental.im._stain._conversion.rgb_to_sda`. +def white_point_from_background(rgb: xr.DataArray, background_mask: np.ndarray) -> np.ndarray: + """Per-channel median intensity over background pixels -> ``(3,)`` white point. - Notes - ----- - The exact percentile is computed eagerly (the input is materialised), so - the result is identical for numpy- and dask-backed inputs and independent - of chunking - important for reproducible references across a cohort. Pass - a coarse pyramid level for whole-slide images. + ``background_mask`` is a ``(y, x)`` boolean, ``True`` over non-tissue + (background) pixels. Sampling the *median* of true background (rather than a + whole-image percentile) anchors ``I_0`` to the actual unstained intensity, + matching HistomicsTK's ``background_intensity`` semantics. Raises ------ StainFittingError - If the estimate is not strictly positive in every channel (e.g. a - blank/black image with no bright background). + If the mask selects no background pixels, or the median is non-positive + (e.g. a black background). """ - if not 0.0 < percentile <= 100.0: - raise ValueError(f"`percentile` must be in (0, 100], got {percentile}.") _check_channel_dim(rgb) - flat = np.asarray(rgb.transpose("c", "y", "x").data, dtype=np.float64).reshape(3, -1) - bg = np.percentile(flat, percentile, axis=1) - - if np.any(bg <= 0): + flat = np.asarray(rgb.transpose("c", "y", "x").data, dtype=np.float64) # (3, y, x) + bg_pixels = flat[:, np.asarray(background_mask, dtype=bool)] # (3, N_background) + if bg_pixels.shape[1] == 0: + raise StainFittingError( + "no background pixels to estimate the white point; the tissue mask covers the whole image. " + "Pass an explicit `white_point`." + ) + wp = np.median(bg_pixels, axis=1) + if np.any(wp <= 0): raise StainFittingError( - "estimated background intensity is non-positive; the image may be blank or all-tissue. " - "Pass an explicit `white_point` if this is expected." + "estimated white point is non-positive; the background may be black. Pass an explicit `white_point`." ) - return bg + return wp diff --git a/tests/experimental/test_stain_decompose_public.py b/tests/experimental/test_stain_decompose_public.py index ab0431a70..92b345701 100644 --- a/tests/experimental/test_stain_decompose_public.py +++ b/tests/experimental/test_stain_decompose_public.py @@ -11,7 +11,6 @@ from squidpy.experimental.im import ( StainReference, decompose_stains, - estimate_white_point, fit_stain_reference, normalize_stains, ) @@ -34,7 +33,7 @@ def _synthetic_rgb(seed: int = 0, n_side: int = 48, white: np.ndarray = _WHITE) conc[third : 2 * third, 0] = 0.0 od = (conc @ w[:, :2].T).T.reshape(3, n_side, n_side) rgb = sda_to_rgb(xr.DataArray(od, dims=("c", "y", "x")), white) - return np.asarray(rgb.data) + return np.asarray(rgb.data).astype(np.uint8) def _make_sdata(values: np.ndarray, *, with_tissue: bool = True) -> sd.SpatialData: @@ -115,11 +114,6 @@ def test_explicit_background_is_used(self) -> None: ref = fit_stain_reference(sdata, "img", method="vahadane", white_point=I0) np.testing.assert_array_equal(ref.white_point, I0) - def test_estimate_background_public(self) -> None: - sdata = _make_sdata(_synthetic_rgb()) - bg = estimate_white_point(sdata.images["img"]) - assert bg.shape == (3,) - class TestUnknownMethod: def test_fit_unknown_method_raises(self) -> None: diff --git a/tests/experimental/test_stain_normalize.py b/tests/experimental/test_stain_normalize.py index d25d60e91..c1deb7de3 100644 --- a/tests/experimental/test_stain_normalize.py +++ b/tests/experimental/test_stain_normalize.py @@ -37,7 +37,7 @@ def _make_sdata( @pytest.fixture def rgb_values() -> np.ndarray: rng = np.random.default_rng(3) - return rng.uniform(40.0, 200.0, size=(3, 64, 64)).astype(np.float32) + return rng.uniform(40.0, 200.0, size=(3, 64, 64)).astype(np.uint8) class TestFitStainReference: diff --git a/tests/experimental/test_stain_white_point.py b/tests/experimental/test_stain_white_point.py index 0cdec8133..248059965 100644 --- a/tests/experimental/test_stain_white_point.py +++ b/tests/experimental/test_stain_white_point.py @@ -1,39 +1,77 @@ from __future__ import annotations -import dask.array as da import numpy as np import pytest +import spatialdata as sd import xarray as xr +from spatialdata.models import Image2DModel, Labels2DModel +from squidpy.experimental.im import estimate_white_point +from squidpy.experimental.im._stain._conversion import dtype_max from squidpy.experimental.im._stain._validation import StainFittingError -from squidpy.experimental.im._stain._white_point import estimate_white_point +from squidpy.experimental.im._stain._white_point import default_white_point -def _da(values: np.ndarray, *, chunked: bool) -> xr.DataArray: - data = da.from_array(values, chunks=(3, 8, 8)) if chunked else values - return xr.DataArray(data, dims=("c", "y", "x")) +def _rgb(values: np.ndarray) -> xr.DataArray: + return xr.DataArray(values, dims=("c", "y", "x")) -@pytest.mark.parametrize("chunked", [False, True]) -def test_recovers_white_point(chunked: bool) -> None: - rng = np.random.default_rng(0) - # mostly bright background near (240, 245, 250), a darker tissue blob - values = np.empty((3, 32, 32)) - values[0] = 240.0 - values[1] = 245.0 - values[2] = 250.0 - values[:, :8, :8] = rng.uniform(20.0, 60.0, size=(3, 8, 8)) # tissue - bg = estimate_white_point(_da(values, chunked=chunked)) - assert bg.shape == (3,) - np.testing.assert_allclose(bg, [240.0, 245.0, 250.0], atol=1.0) +class TestDtypeMax: + def test_known_dtypes(self) -> None: + assert dtype_max(np.uint8) == 255.0 + assert dtype_max(np.uint16) == 65535.0 + assert dtype_max(np.float32) == 1.0 -def test_blank_image_raises() -> None: - black = np.zeros((3, 16, 16)) - with pytest.raises(StainFittingError, match="non-positive"): - estimate_white_point(_da(black, chunked=False)) +class TestDefaultWhitePoint: + def test_uint8(self) -> None: + rgb = _rgb(np.full((3, 8, 8), 200, dtype=np.uint8)) + np.testing.assert_array_equal(default_white_point(rgb), [255.0, 255.0, 255.0]) + def test_uint16(self) -> None: + rgb = _rgb(np.full((3, 8, 8), 5000, dtype=np.uint16)) + np.testing.assert_array_equal(default_white_point(rgb), [65535.0] * 3) -def test_bad_percentile_raises() -> None: - with pytest.raises(ValueError, match="percentile"): - estimate_white_point(_da(np.ones((3, 8, 8)), chunked=False), percentile=0.0) + def test_float_unit_range(self) -> None: + rgb = _rgb(np.full((3, 8, 8), 0.8, dtype=np.float32)) + np.testing.assert_array_equal(default_white_point(rgb), [1.0, 1.0, 1.0]) + + def test_raises_on_8bit_in_uint16(self) -> None: + rgb = _rgb(np.full((3, 8, 8), 200, dtype=np.uint16)) # uint16 container, 8-bit values + with pytest.raises(ValueError, match="8-bit data stored in"): + default_white_point(rgb) + + def test_raises_on_0_255_float(self) -> None: + rgb = _rgb(np.full((3, 8, 8), 200.0, dtype=np.float32)) # float, but 0-255 valued + with pytest.raises(ValueError, match="stored as float"): + default_white_point(rgb) + + +class TestEstimateWhitePoint: + def _sdata(self, *, all_tissue: bool = False) -> sd.SpatialData: + rng = np.random.default_rng(0) + values = np.empty((3, 32, 32), dtype=np.uint8) + values[0], values[1], values[2] = 240, 245, 250 # background + values[:, :8, :8] = rng.integers(20, 60, size=(3, 8, 8)) # a darker tissue blob + mask = np.zeros((32, 32), dtype=np.uint32) + mask[:8, :8] = 1 # tissue = the blob + if all_tissue: + mask[:] = 1 + sdata = sd.SpatialData(images={"img": Image2DModel.parse(values, dims=("c", "y", "x"))}) + sdata.labels["img_tissue"] = Labels2DModel.parse(mask, dims=("y", "x")) + return sdata + + def test_recovers_background_median(self) -> None: + wp = estimate_white_point(self._sdata(), "img") + assert wp.shape == (3,) + np.testing.assert_allclose(wp, [240.0, 245.0, 250.0], atol=1.0) + + def test_raises_when_tissue_covers_all(self) -> None: + with pytest.raises(StainFittingError, match="covers the whole image"): + estimate_white_point(self._sdata(all_tissue=True), "img") + + def test_requires_a_tissue_mask(self) -> None: + values = np.full((3, 16, 16), 240, dtype=np.uint8) + sdata = sd.SpatialData(images={"img": Image2DModel.parse(values, dims=("c", "y", "x"))}) + with pytest.raises(KeyError, match="detect_tissue"): + estimate_white_point(sdata, "img") From 1555179e536b32ea893277b66904a3d92bd53a2a Mon Sep 17 00:00:00 2001 From: anon Date: Tue, 2 Jun 2026 03:58:46 +0200 Subject: [PATCH 4/8] Stain normalization: address review of step 2 - FIX (correctness): the bit-depth range check now runs on the APPLY/estimate paths too, not only fit. A float image holding 0-255 values previously slipped through normalize_stains and clipped its reconstruction to [0,1] (dtype_max(float)=1.0), silently destroying the output. Extracted the check into validate_rgb_range() and call it from fit / normalize_stains / estimate_white_point. + regression test. - default_white_point() is now a pure defaulter (no max() reduction, no raise) - validation lives in validate_rgb_range(), separating the two concerns. - Extracted _resolve_mask_key_and_scale() shared by the two tissue-mask consumers (dedup). - Documented that estimate_white_point materialises its level (keep it coarse). 128 stain tests pass. Co-Authored-By: Claude Opus 4.8 --- .../experimental/im/_stain/_normalize.py | 44 +++++++++++++------ .../experimental/im/_stain/_white_point.py | 21 ++++++--- tests/experimental/test_stain_normalize.py | 17 ++++++- tests/experimental/test_stain_white_point.py | 16 ++++--- 4 files changed, 71 insertions(+), 27 deletions(-) diff --git a/src/squidpy/experimental/im/_stain/_normalize.py b/src/squidpy/experimental/im/_stain/_normalize.py index 272e06242..0dde72737 100644 --- a/src/squidpy/experimental/im/_stain/_normalize.py +++ b/src/squidpy/experimental/im/_stain/_normalize.py @@ -39,7 +39,11 @@ apply_reinhard, fit_reinhard, ) -from squidpy.experimental.im._stain._white_point import default_white_point, white_point_from_background +from squidpy.experimental.im._stain._white_point import ( + default_white_point, + validate_rgb_range, + white_point_from_background, +) from squidpy.experimental.im._utils import ( _choose_label_scale_for_image, get_element_data, @@ -70,20 +74,31 @@ def _resolve_image( return da +def _resolve_mask_key_and_scale( + sdata: sd.SpatialData, image_key: str, target_da: xr.DataArray, tissue_mask_key: str | None +) -> tuple[str, str, tuple[int, int]]: + """Resolve the (mandatory) tissue-mask key and the label scale closest to ``target_da``. + + Shared by the two mask consumers below. Consumes a + :func:`~squidpy.experimental.im.detect_tissue` labels element - raises if + none exists. + """ + mask_key = resolve_tissue_mask(sdata, image_key, "auto", tissue_mask_key, auto_create=False) + target_hw = (int(target_da.sizes["y"]), int(target_da.sizes["x"])) + label_scale = _choose_label_scale_for_image(sdata.labels[mask_key], target_hw) + return mask_key, label_scale, target_hw + + def _resolve_tissue_bool_mask( sdata: sd.SpatialData, image_key: str, fit_da: xr.DataArray, tissue_mask_key: str | None ) -> np.ndarray: - """Return a ``(y, x)`` boolean tissue mask aligned to ``fit_da``. + """Return a materialised ``(y, x)`` boolean tissue mask aligned to ``fit_da``. - Consumes a :func:`~squidpy.experimental.im.detect_tissue` labels element - (mandatory - raises if none exists), picks the label scale closest to - ``fit_da``, materialises it, and nearest-resizes to ``fit_da``'s ``(y, x)`` - when the resolutions differ. The stain fits run on a coarse level, so the - mask stays small. + For the (coarse) fit: nearest-resizes to ``fit_da``'s ``(y, x)`` when the + closest label scale differs. The fits run on a coarse level, so the mask + stays small. """ - mask_key = resolve_tissue_mask(sdata, image_key, "auto", tissue_mask_key, auto_create=False) - target_hw = (int(fit_da.sizes["y"]), int(fit_da.sizes["x"])) - label_scale = _choose_label_scale_for_image(sdata.labels[mask_key], target_hw) + mask_key, label_scale, target_hw = _resolve_mask_key_and_scale(sdata, image_key, fit_da, tissue_mask_key) mask = get_mask_materialized(sdata, mask_key, label_scale) > 0 if mask.shape != target_hw: from skimage.transform import resize @@ -103,9 +118,7 @@ def _resolve_output_tissue_mask( shares the image's scale factors, so the matching level usually lines up exactly; only a residual size mismatch forces a (small) eager resize. """ - mask_key = resolve_tissue_mask(sdata, image_key, "auto", tissue_mask_key, auto_create=False) - target_hw = (int(target_da.sizes["y"]), int(target_da.sizes["x"])) - label_scale = _choose_label_scale_for_image(sdata.labels[mask_key], target_hw) + mask_key, label_scale, target_hw = _resolve_mask_key_and_scale(sdata, image_key, target_da, tissue_mask_key) coords = {d: target_da.coords[d] for d in ("y", "x") if d in target_da.coords} mask = get_element_data(sdata.labels[mask_key], label_scale, "label", mask_key).squeeze() > 0 if (int(mask.sizes["y"]), int(mask.sizes["x"])) == target_hw: @@ -180,6 +193,8 @@ def estimate_white_point( tissue mask is required, as for :func:`fit_stain_reference`. scale Scale level to sample on. ``"auto"`` (default) uses the coarsest level. + The sampled level is materialised to take the median, so keep this + coarse - do not pass a fine level on a whole-slide image. Returns ------- @@ -187,6 +202,7 @@ def estimate_white_point( :func:`fit_stain_reference` / :func:`decompose_stains`. """ da = _resolve_image(sdata, image_key, scale, prefer="coarsest") + validate_rgb_range(da) tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, da, tissue_mask_key) return white_point_from_background(da, ~tissue_mask) @@ -239,6 +255,7 @@ def fit_stain_reference( if method not in _VALID_METHODS: raise ValueError(f"Unknown method {method!r}; expected one of {list(_VALID_METHODS)}.") da = _resolve_image(sdata, image_key, scale, prefer="coarsest") + validate_rgb_range(da) params = _resolve_method_params(method, method_params) tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, da, tissue_mask_key) if method == "reinhard": @@ -303,6 +320,7 @@ def normalize_stains( # are reduced on a coarse level with a tissue mask; the lazy transform is # then applied to the full-resolution `da`. fit_rgb = _resolve_image(sdata, image_key, scale, prefer="coarsest") + validate_rgb_range(fit_rgb) # reject mis-typed source (e.g. 0-255 float) before the dtype-clipped reconstruction tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, fit_rgb, tissue_mask_key) out_dtype = da.dtype # reconstruct into the source image's dtype (clip + cast happen together) if reference.method == "reinhard": diff --git a/src/squidpy/experimental/im/_stain/_white_point.py b/src/squidpy/experimental/im/_stain/_white_point.py index 142e5d3e0..5fbfd031f 100644 --- a/src/squidpy/experimental/im/_stain/_white_point.py +++ b/src/squidpy/experimental/im/_stain/_white_point.py @@ -20,12 +20,20 @@ def default_white_point(rgb: xr.DataArray) -> np.ndarray: - """Dtype-aware default white point ``I_0`` (full white), with a range check. + """Dtype-aware default white point ``I_0`` (full white) - pure, no validation. - Returns ``(3,)`` filled with the dtype's full-white value. Raises with - guidance when the data clearly does not match its dtype's range (e.g. 8-bit - values stored in a uint16 container, or 0-255 values stored as float), since - that would silently mis-scale the absorbance. + Returns ``(3,)`` filled with the dtype's full-white value (255 / 65535 / 1.0). + Call :func:`validate_rgb_range` separately to reject mis-typed data. + """ + return np.full(3, dtype_max(rgb.dtype), dtype=np.float64) + + +def validate_rgb_range(rgb: xr.DataArray) -> None: + """Raise if the image's values clearly don't match its dtype's range. + + Guards every absorbance entry point (fit / apply / estimate) against silently + mis-scaling or clipping: 8-bit values in a wider integer container, or 0-255 + values stored as float. The escape is to pass an explicit ``white_point``. """ m = dtype_max(rgb.dtype) data_max = float(np.asarray(rgb.max())) @@ -38,9 +46,8 @@ def default_white_point(rgb: xr.DataArray) -> np.ndarray: elif data_max > 1.5: raise ValueError( f"float image but the maximum value is {data_max:.1f} (> 1) - this looks like 0-255 data " - "stored as float. Rescale to [0, 1], or pass `white_point`." + "stored as float. Rescale to [0, 1] or store as uint8." ) - return np.full(3, m, dtype=np.float64) def white_point_from_background(rgb: xr.DataArray, background_mask: np.ndarray) -> np.ndarray: diff --git a/tests/experimental/test_stain_normalize.py b/tests/experimental/test_stain_normalize.py index c1deb7de3..f6d19bd87 100644 --- a/tests/experimental/test_stain_normalize.py +++ b/tests/experimental/test_stain_normalize.py @@ -143,6 +143,19 @@ def test_explicit_missing_key_raises(self, rgb_values: np.ndarray) -> None: with pytest.raises(KeyError, match="not found in sdata.labels"): fit_stain_reference(sdata, "img", tissue_mask_key="nope") + def test_float_0_255_source_rejected_on_apply(self, rgb_values: np.ndarray) -> None: + # A float image holding 0-255 values would otherwise clip to [0, 1] in the + # reconstruction (dtype_max(float)=1.0); apply must reject it, not silently destroy it. + sdata = _make_sdata(rgb_values) # uint8 + ref = fit_stain_reference(sdata, "img") + floaty = rgb_values.astype(np.float32) + sdata.images["floaty"] = Image2DModel.parse(floaty, dims=("c", "y", "x")) + sdata.labels["floaty_tissue"] = Labels2DModel.parse( + np.ones(floaty.shape[-2:], dtype=np.uint32), dims=("y", "x") + ) + with pytest.raises(ValueError, match="stored as float"): + normalize_stains(sdata, "floaty", ref) + def test_mask_is_used_in_the_fit(self, rgb_values: np.ndarray) -> None: # A different tissue region yields different channel statistics, proving # the mask actually drives the fit (not silently ignored). @@ -168,7 +181,7 @@ def test_background_passthrough_vs_full_frame(self, rgb_values: np.ndarray) -> N sdata.labels["img_tissue"] = Labels2DModel.parse(partial, dims=("y", "x")) # a differently-coloured reference so the transform is non-trivial - shifted = np.clip(rgb_values * np.array([1.3, 0.8, 1.1])[:, None, None], 0, 255).astype(np.float32) + shifted = np.clip(rgb_values * np.array([1.3, 0.8, 1.1])[:, None, None], 0, 255).astype(np.uint8) sdata.images["ref_img"] = Image2DModel.parse(shifted, dims=("c", "y", "x")) sdata.labels["ref_img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) ref = fit_stain_reference(sdata, "ref_img") @@ -204,7 +217,7 @@ def test_plot_reinhard_before_after(self, sdata_hne) -> None: # staining batch, so the before/after panels are visibly distinct. da_rgb = get_element_data(sdata_hne.images[image_key], "auto", "image", image_key).astype("float32") weights = xr.DataArray([1.4, 1.0, 0.6], dims="c", coords={"c": da_rgb.coords["c"]}) - shifted = (da_rgb * weights).clip(0, 255) + shifted = (da_rgb * weights).clip(0, 255).astype("uint8") sdata_hne.images["hne_shifted"] = Image2DModel.parse(shifted.data, dims=shifted.dims) # `hne_shifted` shares geometry with `image_key`; reuse its tissue mask. diff --git a/tests/experimental/test_stain_white_point.py b/tests/experimental/test_stain_white_point.py index 248059965..da5fb3595 100644 --- a/tests/experimental/test_stain_white_point.py +++ b/tests/experimental/test_stain_white_point.py @@ -9,7 +9,7 @@ from squidpy.experimental.im import estimate_white_point from squidpy.experimental.im._stain._conversion import dtype_max from squidpy.experimental.im._stain._validation import StainFittingError -from squidpy.experimental.im._stain._white_point import default_white_point +from squidpy.experimental.im._stain._white_point import default_white_point, validate_rgb_range def _rgb(values: np.ndarray) -> xr.DataArray: @@ -36,15 +36,21 @@ def test_float_unit_range(self) -> None: rgb = _rgb(np.full((3, 8, 8), 0.8, dtype=np.float32)) np.testing.assert_array_equal(default_white_point(rgb), [1.0, 1.0, 1.0]) + +class TestValidateRgbRange: + def test_passes_on_uint8(self) -> None: + validate_rgb_range(_rgb(np.full((3, 8, 8), 200, dtype=np.uint8))) # no raise + + def test_passes_on_float_unit_range(self) -> None: + validate_rgb_range(_rgb(np.full((3, 8, 8), 0.8, dtype=np.float32))) # no raise + def test_raises_on_8bit_in_uint16(self) -> None: - rgb = _rgb(np.full((3, 8, 8), 200, dtype=np.uint16)) # uint16 container, 8-bit values with pytest.raises(ValueError, match="8-bit data stored in"): - default_white_point(rgb) + validate_rgb_range(_rgb(np.full((3, 8, 8), 200, dtype=np.uint16))) def test_raises_on_0_255_float(self) -> None: - rgb = _rgb(np.full((3, 8, 8), 200.0, dtype=np.float32)) # float, but 0-255 valued with pytest.raises(ValueError, match="stored as float"): - default_white_point(rgb) + validate_rgb_range(_rgb(np.full((3, 8, 8), 200.0, dtype=np.float32))) class TestEstimateWhitePoint: From 2ae8dfc1dde7529903d436fe1ab3829d2a4d1d88 Mon Sep 17 00:00:00 2001 From: anon Date: Tue, 2 Jun 2026 04:22:37 +0200 Subject: [PATCH 5/8] Stain normalization: inplace writes + output_dtype + atomic decompose (Step 3) Align the stain entry points with the detect_tissue inplace+key idiom and finish the bit-depth cast that Step 2 deferred. normalize_stains: - inplace=True (default) writes the result to sdata.images[image_key_added], image_key_added defaulting to f"{image_key}_normalized"; inplace=False returns the lazy DataArray and leaves sdata untouched. - output_dtype (default = source dtype) is the clip range and the final cast. - cast-at-boundary: the reconstruction stayed in float (clipped to range); it is now rounded (integer dtypes) and cast at the write boundary, so the stored image is the requested dtype and integer background is byte-identical. decompose_stains: - inplace=True (default) writes each stain as a single-channel image under the image_key_added prefix (default = image_key); the write is atomic - all target keys are validated free before any is written. - output_dtype (default float16; float32 for strict quantification). _conversion.cast_to_image_dtype performs the deferred rounding+cast, kept lazy. Tests updated to the inplace=True default; added coverage for the derived-key defaults, output_dtype overrides, and the atomic-abort path. Co-Authored-By: Claude Opus 4.8 --- .../experimental/im/_stain/_conversion.py | 12 +++ .../experimental/im/_stain/_normalize.py | 89 +++++++++++++------ .../test_stain_decompose_public.py | 35 ++++++-- tests/experimental/test_stain_normalize.py | 27 ++++-- 4 files changed, 128 insertions(+), 35 deletions(-) diff --git a/src/squidpy/experimental/im/_stain/_conversion.py b/src/squidpy/experimental/im/_stain/_conversion.py index 98bca5e5f..ef16c5b88 100644 --- a/src/squidpy/experimental/im/_stain/_conversion.py +++ b/src/squidpy/experimental/im/_stain/_conversion.py @@ -34,6 +34,18 @@ def dtype_max(dtype: np.dtype | type) -> float: return float(np.iinfo(dt).max) if np.issubdtype(dt, np.integer) else 1.0 +def cast_to_image_dtype(arr: xr.DataArray, out_dtype: np.dtype | type) -> xr.DataArray: + """Cast a clipped working-float image to its final dtype at the write boundary. + + The reconstruction kernels (:func:`sda_to_rgb`, :func:`lab_ruderman_to_rgb`) + clip to ``out_dtype``'s valid range but stay in float; this performs the + deferred cast. Integer targets are **rounded** (so ``254.6 -> 255``, not + ``254``); float targets cast directly. Stays lazy on dask-backed input. + """ + dt = np.dtype(out_dtype) + return arr.round().astype(dt) if np.issubdtype(dt, np.integer) else arr.astype(dt) + + def _check_channel_dim(arr: xr.DataArray) -> None: if _CHANNEL_DIM not in arr.dims: raise ValueError(f"Input must have a dimension named {_CHANNEL_DIM!r}; got dims {arr.dims}.") diff --git a/src/squidpy/experimental/im/_stain/_normalize.py b/src/squidpy/experimental/im/_stain/_normalize.py index 0dde72737..c8f4bfdc2 100644 --- a/src/squidpy/experimental/im/_stain/_normalize.py +++ b/src/squidpy/experimental/im/_stain/_normalize.py @@ -18,11 +18,12 @@ import numpy as np import spatialdata as sd import xarray as xr +from numpy.typing import DTypeLike from spatialdata.models import Image2DModel from spatialdata.transformations import get_transformation from squidpy._utils import _get_scale_factors -from squidpy.experimental.im._stain._conversion import _check_channel_dim +from squidpy.experimental.im._stain._conversion import _check_channel_dim, cast_to_image_dtype from squidpy.experimental.im._stain._decomposition import ( MacenkoParams, VahadaneParams, @@ -272,6 +273,8 @@ def normalize_stains( scale: str | Literal["auto"] = "auto", method_params: MethodParams = None, image_key_added: str | None = None, + inplace: bool = True, + output_dtype: DTypeLike | None = None, tissue_mask_key: str | None = None, preserve_background: bool = True, ) -> xr.DataArray | None: @@ -293,11 +296,18 @@ def normalize_stains( method_params Params matching ``reference.method`` (instance, mapping, or ``None``). image_key_added - If ``None`` (default), return the lazy normalized DataArray and leave - ``sdata`` untouched. If given, write the result to - ``sdata.images[image_key_added]`` (rebuilding the pyramid for - multiscale sources, preserving transforms) and return ``None``. - Raises if the key already exists. + Key for the written image when ``inplace=True``. If ``None`` (default), + ``f"{image_key}_normalized"`` is used. Ignored when ``inplace=False``. + inplace + If ``True`` (default), write the normalized image to + ``sdata.images[image_key_added]`` (rebuilding the pyramid for multiscale + sources, preserving transforms) and return ``None``; raises if the key + already exists. If ``False``, leave ``sdata`` untouched and return the + lazy normalized :class:`~xarray.DataArray`. + output_dtype + Dtype of the result. If ``None`` (default), the source image's dtype is + used. The reconstruction is clipped to that dtype's valid range and + rounded (for integer dtypes) at the write boundary. tissue_mask_key Key of a tissue-label element in ``sdata.labels`` restricting the *source* statistics to tissue pixels. As for @@ -311,10 +321,13 @@ def normalize_stains( Returns ------- - The lazy normalized :class:`xarray.DataArray` if ``image_key_added`` is - ``None``, otherwise ``None``. + ``None`` if ``inplace=True`` (the image is written), otherwise the lazy + normalized :class:`xarray.DataArray`. """ da = _resolve_image(sdata, image_key, scale, prefer="finest") + target_key = image_key_added if image_key_added is not None else f"{image_key}_normalized" + if inplace and target_key in sdata.images: + raise ValueError(f"image_key_added={target_key!r} already exists in sdata.images.") params = _resolve_method_params(reference.method, method_params) # Source statistics (Reinhard mu/sigma or the decomposition source matrix) # are reduced on a coarse level with a tissue mask; the lazy transform is @@ -322,9 +335,11 @@ def normalize_stains( fit_rgb = _resolve_image(sdata, image_key, scale, prefer="coarsest") validate_rgb_range(fit_rgb) # reject mis-typed source (e.g. 0-255 float) before the dtype-clipped reconstruction tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, fit_rgb, tissue_mask_key) - out_dtype = da.dtype # reconstruct into the source image's dtype (clip + cast happen together) + out_dtype = da.dtype if output_dtype is None else np.dtype(output_dtype) # clip range + final cast if reference.method == "reinhard": - normalized = apply_reinhard(da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask, out_dtype=out_dtype) + normalized = apply_reinhard( + da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask, out_dtype=out_dtype + ) else: normalized = apply_decomposition( da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask, out_dtype=out_dtype @@ -337,9 +352,14 @@ def normalize_stains( keep = _resolve_output_tissue_mask(sdata, image_key, da, tissue_mask_key) normalized = normalized.where(keep, da) - if image_key_added is None: + # Deferred cast at the write boundary: the reconstruction was kept in float + # (clipped to `out_dtype`'s range); round + cast here so the stored image is + # the requested dtype and integer background stays byte-identical. + normalized = cast_to_image_dtype(normalized, out_dtype) + + if not inplace: return normalized - _write_image(sdata, sdata.images[image_key], image_key_added, normalized) + _write_image(sdata, sdata.images[image_key], target_key, normalized) return None @@ -352,6 +372,8 @@ def decompose_stains( method_params: MethodParams = None, white_point: np.ndarray | None = None, image_key_added: str | None = None, + inplace: bool = True, + output_dtype: DTypeLike = np.float16, tissue_mask_key: str | None = None, include_residual: bool = True, ) -> dict[str, xr.DataArray] | None: @@ -370,11 +392,20 @@ def decompose_stains( As for :func:`fit_stain_reference` (only used when a method name is given; a reference is projected as-is and needs no tissue mask). image_key_added - If ``None`` (default), return the concentration maps as a dict. If - given, used as a key *prefix*: each stain is written as its own - single-channel image ``sdata.images[f"{image_key_added}_{stain}"]`` - (e.g. ``f"{image_key_added}_hematoxylin"``), and ``None`` is returned. - Raises if any target key already exists. + Key *prefix* for the written images when ``inplace=True``. If ``None`` + (default), ``image_key`` is used, so each stain is written as its own + single-channel image ``sdata.images[f"{image_key}_{stain}"]`` (e.g. + ``f"{image_key}_hematoxylin"``). Ignored when ``inplace=False``. + inplace + If ``True`` (default), write each stain as a separate single-channel + image under the ``image_key_added`` prefix and return ``None``; the + write is atomic (all target keys are validated free before any is + written). If ``False``, leave ``sdata`` untouched and return the maps + as a dict. + output_dtype + Dtype of the concentration maps. Defaults to ``float16`` (half the + storage; ~3 significant figures, adequate for concentrations); pass + ``float32`` for strict quantification. include_residual If ``True`` (default), also produce the ``"residual"`` map. The residual is the absorbance along the complement direction - a diagnostic of @@ -383,10 +414,10 @@ def decompose_stains( Returns ------- - If ``image_key_added`` is ``None``, a ``dict`` mapping each stain name to - its ``(y, x)`` concentration :class:`~xarray.DataArray` - (``"hematoxylin"``, ``"eosin"``, and ``"residual"`` unless dropped). - Otherwise ``None`` (the maps are written as separate images). + ``None`` if ``inplace=True`` (the maps are written as separate images), + otherwise a ``dict`` mapping each stain name to its ``(y, x)`` concentration + :class:`~xarray.DataArray` (``"hematoxylin"``, ``"eosin"``, and + ``"residual"`` unless dropped). """ da = _resolve_image(sdata, image_key, scale, prefer="finest") if isinstance(reference_or_method, StainReference): @@ -408,14 +439,22 @@ def decompose_stains( ) stain_matrix, bg = reference.stain_matrix, reference.white_point - concentrations = decompose_to_concentrations(da, stain_matrix, bg).assign_coords(c=_CONCENTRATION_CHANNELS) names = ["hematoxylin", "eosin"] + (["residual"] if include_residual else []) + prefix = image_key_added if image_key_added is not None else image_key + target_keys = [f"{prefix}_{name}" for name in names] + if inplace: # validate all keys free up front, so a partial write can't leave a half-decomposed sdata + clashes = [k for k in target_keys if k in sdata.images] + if clashes: + raise ValueError(f"decompose_stains would overwrite existing image(s): {clashes}.") + + concentrations = decompose_to_concentrations(da, stain_matrix, bg).assign_coords(c=_CONCENTRATION_CHANNELS) + concentrations = concentrations.astype(np.dtype(output_dtype)) - if image_key_added is None: + if not inplace: return {name: concentrations.sel(c=name) for name in names} source = sdata.images[image_key] - for name in names: + for name, key in zip(names, target_keys, strict=True): # keep the c dim (length 1) so Image2DModel.parse accepts it - _write_image(sdata, source, f"{image_key_added}_{name}", concentrations.sel(c=[name]), c_coords=[name]) + _write_image(sdata, source, key, concentrations.sel(c=[name]), c_coords=[name]) return None diff --git a/tests/experimental/test_stain_decompose_public.py b/tests/experimental/test_stain_decompose_public.py index 92b345701..d7da08335 100644 --- a/tests/experimental/test_stain_decompose_public.py +++ b/tests/experimental/test_stain_decompose_public.py @@ -53,7 +53,7 @@ def test_fit_and_apply_end_to_end(self, method: str) -> None: assert ref.stain_matrix.shape == (3, 3) assert ref.max_concentrations.shape == (2,) - out = normalize_stains(sdata, "img", ref) + out = normalize_stains(sdata, "img", ref, inplace=False) assert isinstance(out, xr.DataArray) assert out.sizes["c"] == 3 @@ -70,15 +70,29 @@ def test_apply_writes_back(self, method: str) -> None: class TestDecomposeStains: def test_returns_named_concentration_maps(self) -> None: sdata = _make_sdata(_synthetic_rgb()) - conc = decompose_stains(sdata, "img", "macenko", white_point=_WHITE) + conc = decompose_stains(sdata, "img", "macenko", white_point=_WHITE, inplace=False) assert set(conc) == {"hematoxylin", "eosin", "residual"} assert all(set(c.dims) == {"y", "x"} for c in conc.values()) # one (y, x) map per stain + assert all(c.dtype == np.float16 for c in conc.values()) # default output_dtype def test_drop_residual(self) -> None: sdata = _make_sdata(_synthetic_rgb()) - conc = decompose_stains(sdata, "img", "macenko", white_point=_WHITE, include_residual=False) + conc = decompose_stains(sdata, "img", "macenko", white_point=_WHITE, include_residual=False, inplace=False) assert set(conc) == {"hematoxylin", "eosin"} + def test_output_dtype_override(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + conc = decompose_stains(sdata, "img", "macenko", white_point=_WHITE, output_dtype=np.float32, inplace=False) + assert all(c.dtype == np.float32 for c in conc.values()) + + def test_inplace_default_writes_derived_keys(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + ref = fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE) + out = decompose_stains(sdata, "img", ref) # inplace=True, prefix defaults to image_key + assert out is None + for stain in ("hematoxylin", "eosin", "residual"): + assert f"img_{stain}" in sdata.images + def test_with_reference_writes_separate_images(self) -> None: sdata = _make_sdata(_synthetic_rgb()) ref = fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE) @@ -88,6 +102,17 @@ def test_with_reference_writes_separate_images(self) -> None: assert f"conc_{stain}" in sdata.images assert list(sdata.images[f"conc_{stain}"].coords["c"].values) == [stain] + def test_atomic_write_aborts_on_any_existing_key(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + ref = fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE) + # pre-occupy only the *eosin* target; the whole write must abort, leaving + # no half-written hematoxylin/residual behind. + sdata.images["conc_eosin"] = sdata.images["img"] + with pytest.raises(ValueError, match="would overwrite"): + decompose_stains(sdata, "img", ref, image_key_added="conc") + assert "conc_hematoxylin" not in sdata.images + assert "conc_residual" not in sdata.images + def test_reinhard_reference_rejected(self) -> None: sdata = _make_sdata(_synthetic_rgb()) reinhard_ref = fit_stain_reference(sdata, "img", method="reinhard") @@ -137,7 +162,7 @@ def test_fit_apply_decompose_smoke(self, sdata_hne, method: str) -> None: ref = sq.experimental.im.fit_stain_reference(sdata_hne, image_key, method=method) assert isinstance(ref, StainReference) assert ref.stain_matrix.shape == (3, 3) - normalized = sq.experimental.im.normalize_stains(sdata_hne, image_key, ref) + normalized = sq.experimental.im.normalize_stains(sdata_hne, image_key, ref, inplace=False) assert normalized.sizes["c"] == 3 - conc = sq.experimental.im.decompose_stains(sdata_hne, image_key, ref) + conc = sq.experimental.im.decompose_stains(sdata_hne, image_key, ref, inplace=False) assert set(conc) == {"hematoxylin", "eosin", "residual"} diff --git a/tests/experimental/test_stain_normalize.py b/tests/experimental/test_stain_normalize.py index f6d19bd87..731b280d6 100644 --- a/tests/experimental/test_stain_normalize.py +++ b/tests/experimental/test_stain_normalize.py @@ -62,11 +62,27 @@ class TestApplyStainNormalization: def test_returns_lazy_and_leaves_sdata_untouched(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) ref = fit_stain_reference(sdata, "img") - out = normalize_stains(sdata, "img", ref) + out = normalize_stains(sdata, "img", ref, inplace=False) assert isinstance(out, xr.DataArray) assert isinstance(out.data, da.Array) assert list(sdata.images.keys()) == ["img"] + def test_inplace_default_writes_derived_key(self, rgb_values: np.ndarray) -> None: + sdata = _make_sdata(rgb_values) + ref = fit_stain_reference(sdata, "img") + result = normalize_stains(sdata, "img", ref) # inplace=True, image_key_added defaults to f"{key}_normalized" + assert result is None + assert "img_normalized" in sdata.images + out = sdata.images["img_normalized"] + assert out.dtype == rgb_values.dtype # cast back to the source dtype at the write boundary + assert out.shape == rgb_values.shape + + def test_output_dtype_override(self, rgb_values: np.ndarray) -> None: + sdata = _make_sdata(rgb_values) + ref = fit_stain_reference(sdata, "img") + out = normalize_stains(sdata, "img", ref, inplace=False, output_dtype=np.uint16) + assert out.dtype == np.uint16 + def test_writes_and_preserves_transform_and_dims(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) ref = fit_stain_reference(sdata, "img") @@ -76,6 +92,7 @@ def test_writes_and_preserves_transform_and_dims(self, rgb_values: np.ndarray) - out = sdata.images["norm"] assert out.dims == ("c", "y", "x") assert out.shape == rgb_values.shape + assert out.dtype == rgb_values.dtype assert ( get_transformation(out, get_all=True).keys() == get_transformation(sdata.images["img"], get_all=True).keys() ) @@ -121,7 +138,7 @@ def test_decomposition_reference_without_max_concentrations_raises(self, rgb_val def test_method_params_mapping(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) ref = fit_stain_reference(sdata, "img", method_params={"mask_background": False}) - out = normalize_stains(sdata, "img", ref, method_params=ReinhardParams(mask_background=False)) + out = normalize_stains(sdata, "img", ref, method_params=ReinhardParams(mask_background=False), inplace=False) assert isinstance(out, xr.DataArray) @@ -187,8 +204,8 @@ def test_background_passthrough_vs_full_frame(self, rgb_values: np.ndarray) -> N ref = fit_stain_reference(sdata, "ref_img") original = get_element_data(sdata.images["img"], "auto", "image", "img").values - kept = normalize_stains(sdata, "img", ref).values # preserve_background=True (default) - full = normalize_stains(sdata, "img", ref, preserve_background=False).values + kept = normalize_stains(sdata, "img", ref, inplace=False).values # preserve_background=True (default) + full = normalize_stains(sdata, "img", ref, preserve_background=False, inplace=False).values bg = slice(h // 2, None) np.testing.assert_allclose(kept[:, bg], original[:, bg]) # background untouched @@ -201,7 +218,7 @@ def test_fit_apply_smoke(self, sdata_hne) -> None: sq.experimental.im.detect_tissue(sdata_hne, image_key) ref = sq.experimental.im.fit_stain_reference(sdata_hne, image_key) assert ref.method == "reinhard" - out = sq.experimental.im.normalize_stains(sdata_hne, image_key, ref) + out = sq.experimental.im.normalize_stains(sdata_hne, image_key, ref, inplace=False) assert "c" in out.dims assert out.sizes["c"] == 3 From e7abe1d0e0295aa5a34f3ad2f89a7680d16f3c0e Mon Sep 17 00:00:00 2001 From: anon Date: Tue, 2 Jun 2026 04:37:56 +0200 Subject: [PATCH 6/8] docs: suppress unresolved detect_tissue cross-reference (fix RTD -W build) The tissue-mask mandate added `:func:` cross-references to detect_tissue in the now-documented fit_stain_reference / estimate_white_point docstrings. detect_tissue is not in docs/api.md (documenting it would cascade into its FelzenszwalbParams / WekaParams / BackgroundDetectionParams / DetectTissueMethod surface - out of scope, deferred to the docs PR), so the references resolved to nothing and the `-W` docs build failed on 3 warnings. Suppress the cross-reference with the `!` prefix; the name still renders, just without a (dead) link. Co-Authored-By: Claude Opus 4.8 --- src/squidpy/experimental/im/_stain/_normalize.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/squidpy/experimental/im/_stain/_normalize.py b/src/squidpy/experimental/im/_stain/_normalize.py index c8f4bfdc2..b26e9033e 100644 --- a/src/squidpy/experimental/im/_stain/_normalize.py +++ b/src/squidpy/experimental/im/_stain/_normalize.py @@ -81,7 +81,7 @@ def _resolve_mask_key_and_scale( """Resolve the (mandatory) tissue-mask key and the label scale closest to ``target_da``. Shared by the two mask consumers below. Consumes a - :func:`~squidpy.experimental.im.detect_tissue` labels element - raises if + :func:`!detect_tissue` labels element - raises if none exists. """ mask_key = resolve_tissue_mask(sdata, image_key, "auto", tissue_mask_key, auto_create=False) @@ -183,7 +183,7 @@ def estimate_white_point( Opt-in alternative to the fixed dtype-aware default white point, for a slide whose unstained background is genuinely not full white. Samples the per-channel median over **non-tissue** pixels (background = the complement of - the :func:`~squidpy.experimental.im.detect_tissue` mask). + the :func:`!detect_tissue` mask). Parameters ---------- @@ -244,10 +244,10 @@ def fit_stain_reference( known non-white background. Ignored by Reinhard. tissue_mask_key Key of a tissue-label element in ``sdata.labels`` (as produced by - :func:`~squidpy.experimental.im.detect_tissue`) restricting the fit to + :func:`!detect_tissue`) restricting the fit to tissue pixels. If ``None``, ``f"{image_key}_tissue"`` is used. A tissue mask is **required**: if neither exists, a :class:`KeyError` asks you to - run :func:`~squidpy.experimental.im.detect_tissue` first. + run :func:`!detect_tissue` first. Returns ------- From 85ca27f063db9776a34a538ea22ce73eb5d8739f Mon Sep 17 00:00:00 2001 From: anon Date: Tue, 2 Jun 2026 15:24:26 +0200 Subject: [PATCH 7/8] Stain normalization: strict RGB guard, H/E gate kwargs, macenko default (Step 4) Three guard/UX changes to fit_stain_reference and the channel check: - Default `method` flips reinhard -> **macenko**: the no-choice path now supports both normalize and decompose, and macenko's one documented weakness (artifact pixels) is exactly what the mandatory tissue mask removes. reinhard stays the explicit fast colour-transfer opt-out. - Expose the H/E sanity gate: `max_angle_deg` (deviation tolerance) and `canonical_reference` (the Ruifrok H/E vectors) are now documented kwargs on fit_stain_reference, threaded into reorder_to_canonical + validate_stain_matrix for the decomposition methods. Defaults unchanged (45 deg / Ruifrok). - Strict 3-channel RGB: the channel-dim check now raises a clear, RGB-specific message naming the RGBA/multi-channel case instead of a generic "length 3". Tests: pin method="reinhard" on the reinhard-oriented cases (random fixture / reinhard smoke+visual), add default-is-macenko, max_angle_deg-too-strict, canonical_reference passthrough, and an RGBA-rejection test; update the two mask/conversion channel-length assertions to the new message. Co-Authored-By: Claude Opus 4.8 --- .../experimental/im/_stain/_conversion.py | 6 ++- .../experimental/im/_stain/_decomposition.py | 25 ++++++++--- .../experimental/im/_stain/_normalize.py | 35 ++++++++++++++-- tests/experimental/test_stain_conversion.py | 2 +- .../test_stain_decompose_public.py | 26 +++++++++++- tests/experimental/test_stain_mask.py | 4 +- tests/experimental/test_stain_normalize.py | 41 +++++++++++-------- 7 files changed, 108 insertions(+), 31 deletions(-) diff --git a/src/squidpy/experimental/im/_stain/_conversion.py b/src/squidpy/experimental/im/_stain/_conversion.py index ef16c5b88..cacabe46b 100644 --- a/src/squidpy/experimental/im/_stain/_conversion.py +++ b/src/squidpy/experimental/im/_stain/_conversion.py @@ -50,7 +50,11 @@ def _check_channel_dim(arr: xr.DataArray) -> None: if _CHANNEL_DIM not in arr.dims: raise ValueError(f"Input must have a dimension named {_CHANNEL_DIM!r}; got dims {arr.dims}.") if arr.sizes[_CHANNEL_DIM] != 3: - raise ValueError(f"Channel dimension {_CHANNEL_DIM!r} must have length 3; got {arr.sizes[_CHANNEL_DIM]}.") + raise ValueError( + f"stain normalization expects a 3-channel RGB image, but the {_CHANNEL_DIM!r} dimension has " + f"length {arr.sizes[_CHANNEL_DIM]}. RGBA (4-channel) and multi-channel images are not supported - " + "drop the alpha or extra channels first (e.g. keep the first 3 channels)." + ) def _working_dtype(arr: xr.DataArray) -> np.dtype: diff --git a/src/squidpy/experimental/im/_stain/_decomposition.py b/src/squidpy/experimental/im/_stain/_decomposition.py index c9e1a6edc..b1893c510 100644 --- a/src/squidpy/experimental/im/_stain/_decomposition.py +++ b/src/squidpy/experimental/im/_stain/_decomposition.py @@ -14,6 +14,7 @@ import numpy as np import xarray as xr +from squidpy.experimental.im._stain._constants import RUIFROK_HE from squidpy.experimental.im._stain._conversion import ( _apply_along_channel, _check_channel_dim, @@ -179,11 +180,23 @@ def _vahadane_stain_matrix(od: np.ndarray, params: VahadaneParams) -> np.ndarray return _unit_columns(stains) -def _stain_matrix(od: np.ndarray, method: StainMethod, params: Any, *, image_key: str | None) -> np.ndarray: - """Fit, canonicalise, complete and validate a ``(3, 3)`` stain matrix.""" +def _stain_matrix( + od: np.ndarray, + method: StainMethod, + params: Any, + *, + image_key: str | None, + reference: dict[str, np.ndarray] = RUIFROK_HE, + max_angle_deg: float = 45.0, +) -> np.ndarray: + """Fit, canonicalise, complete and validate a ``(3, 3)`` stain matrix. + + ``reference`` (the canonical H/E vectors) drives both the column ordering and + the deviation gate; ``max_angle_deg`` is the gate tolerance. + """ raw = _macenko_stain_matrix(od, params.alpha) if method == "macenko" else _vahadane_stain_matrix(od, params) - matrix = complement_third_column(reorder_to_canonical(raw)) - validate_stain_matrix(matrix, image_key=image_key) + matrix = complement_third_column(reorder_to_canonical(raw, reference)) + validate_stain_matrix(matrix, reference=reference, max_angle_deg=max_angle_deg, image_key=image_key) return matrix @@ -205,10 +218,12 @@ def fit_decomposition( *, tissue_mask: np.ndarray | None = None, image_key: str | None = None, + reference: dict[str, np.ndarray] = RUIFROK_HE, + max_angle_deg: float = 45.0, ) -> StainReference: """Fit a decomposition :class:`StainReference` (stain matrix + max concentrations).""" od = _tissue_od(image_rgb, white_point, params.beta, tissue_mask=tissue_mask, image_key=image_key) - matrix = _stain_matrix(od, method, params, image_key=image_key) + matrix = _stain_matrix(od, method, params, image_key=image_key, reference=reference, max_angle_deg=max_angle_deg) return StainReference( method=method, stain_matrix=matrix, diff --git a/src/squidpy/experimental/im/_stain/_normalize.py b/src/squidpy/experimental/im/_stain/_normalize.py index b26e9033e..f63a17e88 100644 --- a/src/squidpy/experimental/im/_stain/_normalize.py +++ b/src/squidpy/experimental/im/_stain/_normalize.py @@ -23,6 +23,7 @@ from spatialdata.transformations import get_transformation from squidpy._utils import _get_scale_factors +from squidpy.experimental.im._stain._constants import RUIFROK_HE from squidpy.experimental.im._stain._conversion import _check_channel_dim, cast_to_image_dtype from squidpy.experimental.im._stain._decomposition import ( MacenkoParams, @@ -212,11 +213,13 @@ def fit_stain_reference( sdata: sd.SpatialData, image_key: str, *, - method: StainMethod = "reinhard", + method: StainMethod = "macenko", scale: str | Literal["auto"] = "auto", method_params: MethodParams = None, white_point: np.ndarray | None = None, tissue_mask_key: str | None = None, + max_angle_deg: float = 45.0, + canonical_reference: Mapping[str, np.ndarray] | None = None, ) -> StainReference: """Fit a stain reference from an image in a :class:`~spatialdata.SpatialData` object. @@ -227,8 +230,12 @@ def fit_stain_reference( image_key Key of the RGB image in ``sdata.images`` to fit on. method - Fitting method: ``"reinhard"`` (colour transfer), ``"macenko"`` or - ``"vahadane"`` (stain-matrix decomposition). + Fitting method: ``"macenko"`` (default) or ``"vahadane"`` (physical + stain-matrix decomposition, usable by both :func:`normalize_stains` and + :func:`decompose_stains`), or ``"reinhard"`` (faster statistical colour + transfer, no stain separation). Macenko is the default because its one + documented weakness - artifact pixels contaminating the fit - is removed + by the mandatory tissue mask. scale Scale level to fit on. ``"auto"`` (default) uses the coarsest level, which is cheap and sufficient for colour statistics. @@ -248,6 +255,16 @@ def fit_stain_reference( tissue pixels. If ``None``, ``f"{image_key}_tissue"`` is used. A tissue mask is **required**: if neither exists, a :class:`KeyError` asks you to run :func:`!detect_tissue` first. + max_angle_deg + Tolerance of the H/E sanity gate for the decomposition methods: the fit + raises :class:`StainFittingError` if either recovered stain vector + deviates more than this many degrees from its canonical reference. + Default ``45``. Ignored by Reinhard. + canonical_reference + Canonical H/E reference for the decomposition methods, a mapping with + ``"hematoxylin"`` and ``"eosin"`` keys to ``(3,)`` RGB optical-density + unit vectors. Drives both the H/E column ordering and the deviation + gate. If ``None``, the Ruifrok H&E vectors are used. Ignored by Reinhard. Returns ------- @@ -262,7 +279,17 @@ def fit_stain_reference( if method == "reinhard": return fit_reinhard(da, params, tissue_mask=tissue_mask) bg = default_white_point(da) if white_point is None else np.asarray(white_point, np.float64) - return fit_decomposition(da, method, params, bg, tissue_mask=tissue_mask, image_key=image_key) + reference = RUIFROK_HE if canonical_reference is None else dict(canonical_reference) + return fit_decomposition( + da, + method, + params, + bg, + tissue_mask=tissue_mask, + image_key=image_key, + reference=reference, + max_angle_deg=max_angle_deg, + ) def normalize_stains( diff --git a/tests/experimental/test_stain_conversion.py b/tests/experimental/test_stain_conversion.py index bd8f68d63..99733c585 100644 --- a/tests/experimental/test_stain_conversion.py +++ b/tests/experimental/test_stain_conversion.py @@ -103,5 +103,5 @@ def test_missing_channel_dim_raises(self) -> None: def test_wrong_channel_length_raises(self) -> None: arr = xr.DataArray(np.zeros((4, 4, 4)), dims=("y", "x", "c")) - with pytest.raises(ValueError, match="length 3"): + with pytest.raises(ValueError, match="3-channel RGB"): rgb_to_sda(arr, _TEST_WHITE) diff --git a/tests/experimental/test_stain_decompose_public.py b/tests/experimental/test_stain_decompose_public.py index d7da08335..abe43199a 100644 --- a/tests/experimental/test_stain_decompose_public.py +++ b/tests/experimental/test_stain_decompose_public.py @@ -16,7 +16,11 @@ ) from squidpy.experimental.im._stain._constants import RUIFROK_HE from squidpy.experimental.im._stain._conversion import sda_to_rgb -from squidpy.experimental.im._stain._validation import complement_third_column, reorder_to_canonical +from squidpy.experimental.im._stain._validation import ( + StainFittingError, + complement_third_column, + reorder_to_canonical, +) _WHITE = np.array([255.0, 255.0, 255.0]) @@ -147,6 +151,26 @@ def test_fit_unknown_method_raises(self) -> None: fit_stain_reference(sdata, "img", method="bogus") +class TestDefaultMethodAndGate: + def test_default_method_is_macenko(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + ref = fit_stain_reference(sdata, "img") # no method -> default + assert ref.method == "macenko" + + def test_max_angle_deg_gate_too_strict_raises(self) -> None: + # an impossibly tight tolerance trips the H/E sanity gate + sdata = _make_sdata(_synthetic_rgb()) + with pytest.raises(StainFittingError, match="deviates"): + fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE, max_angle_deg=0.01) + + def test_canonical_reference_passthrough(self) -> None: + # passing the Ruifrok canonical explicitly reproduces the default fit + sdata = _make_sdata(_synthetic_rgb()) + default = fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE) + custom = fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE, canonical_reference=RUIFROK_HE) + np.testing.assert_allclose(default.stain_matrix, custom.stain_matrix) + + class TestDecompositionOnHnE: # Correctness is gated by the synthetic-recovery tests above (per the arc # decision); these are real-data smoke checks that the pipeline fits a diff --git a/tests/experimental/test_stain_mask.py b/tests/experimental/test_stain_mask.py index 0e4110b3a..c0e70f23b 100644 --- a/tests/experimental/test_stain_mask.py +++ b/tests/experimental/test_stain_mask.py @@ -46,7 +46,7 @@ def test_lazy_in_lazy_out(self) -> None: def test_non_three_channel_raises(self) -> None: values = np.zeros((2, 8, 8)) - with pytest.raises(ValueError, match="length 3"): + with pytest.raises(ValueError, match="3-channel RGB"): luminosity_foreground_mask(xr.DataArray(values, dims=("c", "y", "x")), 0.8) @@ -66,5 +66,5 @@ def test_lazy_in_lazy_out(self) -> None: def test_non_three_channel_raises(self) -> None: values = np.zeros((2, 8, 8)) - with pytest.raises(ValueError, match="length 3"): + with pytest.raises(ValueError, match="3-channel RGB"): absorbance_foreground_mask(xr.DataArray(values, dims=("c", "y", "x")), _WHITE) diff --git a/tests/experimental/test_stain_normalize.py b/tests/experimental/test_stain_normalize.py index 731b280d6..478459782 100644 --- a/tests/experimental/test_stain_normalize.py +++ b/tests/experimental/test_stain_normalize.py @@ -43,7 +43,7 @@ def rgb_values() -> np.ndarray: class TestFitStainReference: def test_end_to_end(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") assert isinstance(ref, StainReference) assert ref.method == "reinhard" @@ -57,11 +57,18 @@ def test_unknown_method_raises(self, rgb_values: np.ndarray) -> None: with pytest.raises(ValueError, match="Unknown method"): fit_stain_reference(sdata, "img", method="bogus") + def test_rgba_image_rejected(self) -> None: + rng = np.random.default_rng(0) + rgba = rng.integers(0, 256, size=(4, 16, 16), dtype=np.uint8) # 4-channel, not RGB + sdata = sd.SpatialData(images={"img": Image2DModel.parse(rgba, dims=("c", "y", "x"))}) + with pytest.raises(ValueError, match="3-channel RGB"): + fit_stain_reference(sdata, "img", method="reinhard") + class TestApplyStainNormalization: def test_returns_lazy_and_leaves_sdata_untouched(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") out = normalize_stains(sdata, "img", ref, inplace=False) assert isinstance(out, xr.DataArray) assert isinstance(out.data, da.Array) @@ -69,7 +76,7 @@ def test_returns_lazy_and_leaves_sdata_untouched(self, rgb_values: np.ndarray) - def test_inplace_default_writes_derived_key(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") result = normalize_stains(sdata, "img", ref) # inplace=True, image_key_added defaults to f"{key}_normalized" assert result is None assert "img_normalized" in sdata.images @@ -79,13 +86,13 @@ def test_inplace_default_writes_derived_key(self, rgb_values: np.ndarray) -> Non def test_output_dtype_override(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") out = normalize_stains(sdata, "img", ref, inplace=False, output_dtype=np.uint16) assert out.dtype == np.uint16 def test_writes_and_preserves_transform_and_dims(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") result = normalize_stains(sdata, "img", ref, image_key_added="norm") assert result is None assert "norm" in sdata.images @@ -99,7 +106,7 @@ def test_writes_and_preserves_transform_and_dims(self, rgb_values: np.ndarray) - def test_multiscale_rebuilds_pyramid(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values, scale_factors=[2]) - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") normalize_stains(sdata, "img", ref, image_key_added="norm") src, out = sdata.images["img"], sdata.images["norm"] assert hasattr(out, "keys") @@ -113,7 +120,7 @@ def test_preserves_channel_coords_and_nonidentity_transform(self, rgb_values: np sdata = sd.SpatialData(images={"img": img}) h, w = rgb_values.shape[-2], rgb_values.shape[-1] sdata.labels["img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") normalize_stains(sdata, "img", ref, image_key_added="norm") out = sdata.images["norm"] assert list(out.coords["c"].values) == ["r", "g", "b"] @@ -121,7 +128,7 @@ def test_preserves_channel_coords_and_nonidentity_transform(self, rgb_values: np def test_existing_key_raises(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") with pytest.raises(ValueError, match="already exists"): normalize_stains(sdata, "img", ref, image_key_added="img") @@ -137,7 +144,7 @@ def test_decomposition_reference_without_max_concentrations_raises(self, rgb_val def test_method_params_mapping(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img", method_params={"mask_background": False}) + ref = fit_stain_reference(sdata, "img", method="reinhard", method_params={"mask_background": False}) out = normalize_stains(sdata, "img", ref, method_params=ReinhardParams(mask_background=False), inplace=False) assert isinstance(out, xr.DataArray) @@ -146,11 +153,11 @@ class TestTissueMaskMandate: def test_fit_requires_tissue_mask(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values, with_tissue=False) with pytest.raises(KeyError, match="detect_tissue"): - fit_stain_reference(sdata, "img") + fit_stain_reference(sdata, "img", method="reinhard") def test_apply_requires_tissue_mask(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) # has a mask -> fit works - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") del sdata.labels["img_tissue"] # ... but now the source has none with pytest.raises(KeyError, match="detect_tissue"): normalize_stains(sdata, "img", ref) @@ -164,7 +171,7 @@ def test_float_0_255_source_rejected_on_apply(self, rgb_values: np.ndarray) -> N # A float image holding 0-255 values would otherwise clip to [0, 1] in the # reconstruction (dtype_max(float)=1.0); apply must reject it, not silently destroy it. sdata = _make_sdata(rgb_values) # uint8 - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") floaty = rgb_values.astype(np.float32) sdata.images["floaty"] = Image2DModel.parse(floaty, dims=("c", "y", "x")) sdata.labels["floaty_tissue"] = Labels2DModel.parse( @@ -176,14 +183,14 @@ def test_float_0_255_source_rejected_on_apply(self, rgb_values: np.ndarray) -> N def test_mask_is_used_in_the_fit(self, rgb_values: np.ndarray) -> None: # A different tissue region yields different channel statistics, proving # the mask actually drives the fit (not silently ignored). - ref_full = fit_stain_reference(_make_sdata(rgb_values), "img") + ref_full = fit_stain_reference(_make_sdata(rgb_values), "img", method="reinhard") sdata_part = _make_sdata(rgb_values, with_tissue=False) h, w = rgb_values.shape[-2], rgb_values.shape[-1] partial = np.zeros((h, w), dtype=np.uint32) partial[: h // 2] = 1 # only the top half is tissue sdata_part.labels["img_tissue"] = Labels2DModel.parse(partial, dims=("y", "x")) - ref_part = fit_stain_reference(sdata_part, "img") + ref_part = fit_stain_reference(sdata_part, "img", method="reinhard") assert not np.allclose(ref_full.mu, ref_part.mu) @@ -201,7 +208,7 @@ def test_background_passthrough_vs_full_frame(self, rgb_values: np.ndarray) -> N shifted = np.clip(rgb_values * np.array([1.3, 0.8, 1.1])[:, None, None], 0, 255).astype(np.uint8) sdata.images["ref_img"] = Image2DModel.parse(shifted, dims=("c", "y", "x")) sdata.labels["ref_img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) - ref = fit_stain_reference(sdata, "ref_img") + ref = fit_stain_reference(sdata, "ref_img", method="reinhard") original = get_element_data(sdata.images["img"], "auto", "image", "img").values kept = normalize_stains(sdata, "img", ref, inplace=False).values # preserve_background=True (default) @@ -216,7 +223,7 @@ class TestStainNormalizationOnHnE: def test_fit_apply_smoke(self, sdata_hne) -> None: image_key = next(iter(sdata_hne.images)) sq.experimental.im.detect_tissue(sdata_hne, image_key) - ref = sq.experimental.im.fit_stain_reference(sdata_hne, image_key) + ref = sq.experimental.im.fit_stain_reference(sdata_hne, image_key, method="reinhard") assert ref.method == "reinhard" out = sq.experimental.im.normalize_stains(sdata_hne, image_key, ref, inplace=False) assert "c" in out.dims @@ -228,7 +235,7 @@ def test_plot_reinhard_before_after(self, sdata_hne) -> None: """Visual: a re-stained source (left) normalized back to the H&E reference (right).""" image_key = next(iter(sdata_hne.images)) sq.experimental.im.detect_tissue(sdata_hne, image_key) - reference = fit_stain_reference(sdata_hne, image_key) + reference = fit_stain_reference(sdata_hne, image_key, method="reinhard") # Deterministically warm/cool the channels to simulate a different # staining batch, so the before/after panels are visibly distinct. From 1f59452d95c87609be9e518d759cd11df0764768 Mon Sep 17 00:00:00 2001 From: anon Date: Tue, 2 Jun 2026 15:48:23 +0200 Subject: [PATCH 8/8] docs: suppress unresolved StainFittingError cross-reference (fix RTD -W build) Step 4 added a `:class:`StainFittingError`` reference to the fit_stain_reference docstring. StainFittingError is exported from the _stain package but not surfaced at the public squidpy.experimental.im level, so it has no autosummary target and the `-W` docs build failed. Surfacing it publicly is a deliberate API decision for the docs PR; for now suppress the cross-reference with `!`, consistent with the detect_tissue refs. Co-Authored-By: Claude Opus 4.8 --- src/squidpy/experimental/im/_stain/_normalize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/experimental/im/_stain/_normalize.py b/src/squidpy/experimental/im/_stain/_normalize.py index f63a17e88..b1f15950e 100644 --- a/src/squidpy/experimental/im/_stain/_normalize.py +++ b/src/squidpy/experimental/im/_stain/_normalize.py @@ -257,7 +257,7 @@ def fit_stain_reference( run :func:`!detect_tissue` first. max_angle_deg Tolerance of the H/E sanity gate for the decomposition methods: the fit - raises :class:`StainFittingError` if either recovered stain vector + raises :class:`!StainFittingError` if either recovered stain vector deviates more than this many degrees from its canonical reference. Default ``45``. Ignored by Reinhard. canonical_reference