diff --git a/docs/api.md b/docs/api.md index 1bf4fdf92..3b582ab60 100644 --- a/docs/api.md +++ b/docs/api.md @@ -151,7 +151,11 @@ See the {doc}`extensibility guide ` for how to implement a custo experimental.tl.TilingQCParams experimental.pl.tiling_qc experimental.im.fit_stain_reference - experimental.im.apply_stain_normalization + experimental.im.normalize_stains + experimental.im.decompose_stains + experimental.im.estimate_white_point experimental.im.StainReference experimental.im.ReinhardParams + experimental.im.MacenkoParams + experimental.im.VahadaneParams ``` diff --git a/src/squidpy/experimental/im/__init__.py b/src/squidpy/experimental/im/__init__.py index 1a661b53a..7adea9d74 100644 --- a/src/squidpy/experimental/im/__init__.py +++ b/src/squidpy/experimental/im/__init__.py @@ -10,21 +10,29 @@ from ._qc_image import qc_image from ._qc_metrics import QCMetric from ._stain import ( + MacenkoParams, ReinhardParams, StainReference, - apply_stain_normalization, + VahadaneParams, + decompose_stains, + estimate_white_point, fit_stain_reference, + normalize_stains, ) __all__ = [ "BackgroundDetectionParams", "FelzenszwalbParams", + "MacenkoParams", "QCMetric", "ReinhardParams", "StainReference", + "VahadaneParams", "WekaParams", - "apply_stain_normalization", + "normalize_stains", + "decompose_stains", "detect_tissue", + "estimate_white_point", "fit_stain_reference", "make_tiles", "make_tiles_from_spots", diff --git a/src/squidpy/experimental/im/_make_tiles.py b/src/squidpy/experimental/im/_make_tiles.py index 6323c48ad..87f93d33b 100644 --- a/src/squidpy/experimental/im/_make_tiles.py +++ b/src/squidpy/experimental/im/_make_tiles.py @@ -8,14 +8,14 @@ from dask.base import is_dask_collection from shapely.geometry import Polygon from spatialdata._logging import logger -from spatialdata.models import Labels2DModel, ShapesModel +from spatialdata.models import ShapesModel from spatialdata.models._utils import SpatialElement from spatialdata.transformations import get_transformation, set_transformation -from squidpy._utils import _yx_from_shape from squidpy._validators import assert_in_range, assert_key_in_sdata, assert_positive from squidpy.experimental.im._utils import ( TileGrid, + _choose_label_scale_for_image, get_element_data, get_mask_materialized, save_tile_grid_to_shapes, @@ -99,24 +99,6 @@ def _get_largest_scale_dimensions( return int(img_da.shape[-2]), int(img_da.shape[-1]) -def _choose_label_scale_for_image(label_node: Labels2DModel, target_hw: tuple[int, int]) -> str: - """Pick the label scale closest to the target image height/width.""" - if not hasattr(label_node, "keys"): - return "scale0" # single-scale labels default to their only scale - target_h, target_w = target_hw - best = None - best_diff = float("inf") - for k in label_node.keys(): - y, x = _yx_from_shape(label_node[k].image.shape) - diff = abs(y - target_h) + abs(x - target_w) - if diff == 0: - return k - if diff < best_diff: - best_diff = diff - best = k - return best or "scale0" - - def _save_tiles_to_shapes( sdata: sd.SpatialData, tg: TileGrid, diff --git a/src/squidpy/experimental/im/_stain/__init__.py b/src/squidpy/experimental/im/_stain/__init__.py index 86024e4f8..df7d354d0 100644 --- a/src/squidpy/experimental/im/_stain/__init__.py +++ b/src/squidpy/experimental/im/_stain/__init__.py @@ -15,10 +15,21 @@ rgb_to_sda, sda_to_rgb, ) -from squidpy.experimental.im._stain._mask import luminosity_foreground_mask +from squidpy.experimental.im._stain._decomposition import ( + MacenkoParams, + VahadaneParams, + apply_decomposition, + fit_decomposition, +) +from squidpy.experimental.im._stain._mask import ( + absorbance_foreground_mask, + luminosity_foreground_mask, +) from squidpy.experimental.im._stain._normalize import ( - apply_stain_normalization, + decompose_stains, + estimate_white_point, fit_stain_reference, + normalize_stains, ) from squidpy.experimental.im._stain._reference import StainMethod, StainReference from squidpy.experimental.im._stain._reinhard import ( @@ -26,6 +37,12 @@ apply_reinhard, fit_reinhard, ) +from squidpy.experimental.im._stain._validation import ( + StainFittingError, + complement_third_column, + reorder_to_canonical, + validate_stain_matrix, +) __all__ = [ "DEFAULT_LUMINOSITY_THRESHOLD", @@ -35,16 +52,27 @@ "RUDERMAN_RGB_TO_LMS", "RUIFROK_HE", "SDA_SCALE", + "MacenkoParams", "ReinhardParams", + "StainFittingError", "StainMethod", "StainReference", + "VahadaneParams", + "absorbance_foreground_mask", + "apply_decomposition", "apply_reinhard", - "apply_stain_normalization", + "normalize_stains", + "complement_third_column", + "decompose_stains", + "estimate_white_point", + "fit_decomposition", "fit_reinhard", "fit_stain_reference", "lab_ruderman_to_rgb", "luminosity_foreground_mask", + "reorder_to_canonical", "rgb_to_lab_ruderman", "rgb_to_sda", "sda_to_rgb", + "validate_stain_matrix", ] diff --git a/src/squidpy/experimental/im/_stain/_conversion.py b/src/squidpy/experimental/im/_stain/_conversion.py index b71935ea5..cacabe46b 100644 --- a/src/squidpy/experimental/im/_stain/_conversion.py +++ b/src/squidpy/experimental/im/_stain/_conversion.py @@ -24,11 +24,37 @@ _CHANNEL_DIM = "c" +def dtype_max(dtype: np.dtype | type) -> float: + """Valid-intensity upper bound for an image dtype (255 / 65535 / 1.0). + + Integer dtypes use their full range; float RGB is assumed to live in + ``[0, 1]``. + """ + dt = np.dtype(dtype) + return float(np.iinfo(dt).max) if np.issubdtype(dt, np.integer) else 1.0 + + +def cast_to_image_dtype(arr: xr.DataArray, out_dtype: np.dtype | type) -> xr.DataArray: + """Cast a clipped working-float image to its final dtype at the write boundary. + + The reconstruction kernels (:func:`sda_to_rgb`, :func:`lab_ruderman_to_rgb`) + clip to ``out_dtype``'s valid range but stay in float; this performs the + deferred cast. Integer targets are **rounded** (so ``254.6 -> 255``, not + ``254``); float targets cast directly. Stays lazy on dask-backed input. + """ + dt = np.dtype(out_dtype) + return arr.round().astype(dt) if np.issubdtype(dt, np.integer) else arr.astype(dt) + + def _check_channel_dim(arr: xr.DataArray) -> None: if _CHANNEL_DIM not in arr.dims: raise ValueError(f"Input must have a dimension named {_CHANNEL_DIM!r}; got dims {arr.dims}.") if arr.sizes[_CHANNEL_DIM] != 3: - raise ValueError(f"Channel dimension {_CHANNEL_DIM!r} must have length 3; got {arr.sizes[_CHANNEL_DIM]}.") + raise ValueError( + f"stain normalization expects a 3-channel RGB image, but the {_CHANNEL_DIM!r} dimension has " + f"length {arr.sizes[_CHANNEL_DIM]}. RGBA (4-channel) and multi-channel images are not supported - " + "drop the alpha or extra channels first (e.g. keep the first 3 channels)." + ) def _working_dtype(arr: xr.DataArray) -> np.dtype: @@ -68,9 +94,9 @@ def _rgb_to_sda_kernel(x: np.ndarray, *, bg: np.ndarray, dtype: np.dtype) -> np. return (-np.log((x + 1.0) / (bg + 1.0)) * SDA_SCALE).astype(dtype, copy=False) -def _sda_to_rgb_kernel(x: np.ndarray, *, bg: np.ndarray, dtype: np.dtype) -> np.ndarray: +def _sda_to_rgb_kernel(x: np.ndarray, *, bg: np.ndarray, max_value: float, dtype: np.dtype) -> np.ndarray: rgb = (bg + 1.0) * np.exp(-x.astype(dtype, copy=False) / SDA_SCALE) - 1.0 - np.clip(rgb, 0.0, 255.0, out=rgb) + np.clip(rgb, 0.0, max_value, out=rgb) return rgb.astype(dtype, copy=False) @@ -81,20 +107,20 @@ def _rgb_to_lab_kernel(x: np.ndarray, *, dtype: np.dtype) -> np.ndarray: return (lms @ RUDERMAN_LMS_TO_LAB.T.astype(dtype, copy=False)).astype(dtype, copy=False) -def _lab_to_rgb_kernel(x: np.ndarray, *, dtype: np.dtype) -> np.ndarray: +def _lab_to_rgb_kernel(x: np.ndarray, *, max_value: float, dtype: np.dtype) -> np.ndarray: x = x.astype(dtype, copy=False) log_lms = x @ RUDERMAN_LAB_TO_LMS.T.astype(dtype, copy=False) # The +1.0 / -1.0 pair is paired with the matching offset in # `_rgb_to_lab_kernel` so the round trip remains exact for valid RGB. lms = np.exp(log_lms) - 1.0 rgb = lms @ RUDERMAN_LMS_TO_RGB.T.astype(dtype, copy=False) - np.clip(rgb, 0.0, 255.0, out=rgb) + np.clip(rgb, 0.0, max_value, out=rgb) return rgb.astype(dtype, copy=False) def rgb_to_sda( rgb: xr.DataArray, - background_intensity: np.ndarray, + white_point: np.ndarray, ) -> xr.DataArray: """Convert RGB intensities to standard deviation per absorbance (SDA). @@ -112,7 +138,7 @@ def rgb_to_sda( rgb Image with a ``"c"`` dimension of length 3. May be numpy- or dask-backed; the operation is purely elementwise and stays lazy. - background_intensity + white_point Per-channel white-point ``I_0`` as a shape-``(3,)`` numpy array. Required: no scanner produces a pure-white background, so the caller must supply either an estimate (PR 3 will ship the @@ -125,24 +151,29 @@ def rgb_to_sda( """ _check_channel_dim(rgb) dtype = _working_dtype(rgb) - bg = np.asarray(background_intensity, dtype=dtype) + bg = np.asarray(white_point, dtype=dtype) return _apply_along_channel(rgb, _rgb_to_sda_kernel, out_dtype=dtype, bg=bg, dtype=dtype) def sda_to_rgb( sda: xr.DataArray, - background_intensity: np.ndarray, + white_point: np.ndarray, + *, + out_dtype: np.dtype | type = np.uint8, ) -> xr.DataArray: - """Convert SDA back to RGB intensities in ``[0, 255]``. + """Convert SDA back to RGB, clipped to ``out_dtype``'s valid range. - Inverse of :func:`rgb_to_sda`. Pass the same ``background_intensity`` - used at encode time. The result is clipped to ``[0, 255]`` but kept in - float dtype; uint8 conversion is the caller's choice. + Inverse of :func:`rgb_to_sda`. Pass the same ``white_point`` used at encode + time. ``out_dtype`` is the eventual image dtype: the reconstruction is + clipped to that dtype's valid range (``dtype_max`` = 255 / 65535 / 1.0) but + kept in float; the final cast to ``out_dtype`` is the caller's choice. """ _check_channel_dim(sda) dtype = _working_dtype(sda) - bg = np.asarray(background_intensity, dtype=dtype) - return _apply_along_channel(sda, _sda_to_rgb_kernel, out_dtype=dtype, bg=bg, dtype=dtype) + bg = np.asarray(white_point, dtype=dtype) + return _apply_along_channel( + sda, _sda_to_rgb_kernel, out_dtype=dtype, bg=bg, max_value=dtype_max(out_dtype), dtype=dtype + ) def rgb_to_lab_ruderman(rgb: xr.DataArray) -> xr.DataArray: @@ -161,12 +192,12 @@ def rgb_to_lab_ruderman(rgb: xr.DataArray) -> xr.DataArray: return _apply_along_channel(rgb, _rgb_to_lab_kernel, out_dtype=dtype, dtype=dtype) -def lab_ruderman_to_rgb(lab: xr.DataArray) -> xr.DataArray: +def lab_ruderman_to_rgb(lab: xr.DataArray, *, out_dtype: np.dtype | type = np.uint8) -> xr.DataArray: """Inverse of :func:`rgb_to_lab_ruderman`. - Returns RGB clipped to ``[0, 255]`` in float dtype; uint8 conversion is - the caller's choice. + Clips the reconstruction to ``out_dtype``'s valid range (the eventual image + dtype) but keeps it in float; the final cast is the caller's choice. """ _check_channel_dim(lab) dtype = _working_dtype(lab) - return _apply_along_channel(lab, _lab_to_rgb_kernel, out_dtype=dtype, dtype=dtype) + return _apply_along_channel(lab, _lab_to_rgb_kernel, out_dtype=dtype, max_value=dtype_max(out_dtype), dtype=dtype) diff --git a/src/squidpy/experimental/im/_stain/_decomposition.py b/src/squidpy/experimental/im/_stain/_decomposition.py new file mode 100644 index 000000000..b1893c510 --- /dev/null +++ b/src/squidpy/experimental/im/_stain/_decomposition.py @@ -0,0 +1,293 @@ +"""Macenko and Vahadane stain decomposition (fit + apply). + +Pure DataArray/numpy layer: no ``sdata``, no public export. The stain-matrix +fits run on tissue pixels (a bounded reduction at the chosen scale); the apply +transform is a single per-pixel matmul and stays lazy. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, fields +from typing import Any + +import numpy as np +import xarray as xr + +from squidpy.experimental.im._stain._constants import RUIFROK_HE +from squidpy.experimental.im._stain._conversion import ( + _apply_along_channel, + _check_channel_dim, + _working_dtype, + rgb_to_sda, + sda_to_rgb, +) +from squidpy.experimental.im._stain._mask import as_spatial_mask, foreground_mask_from_sda +from squidpy.experimental.im._stain._reference import StainMethod, StainReference +from squidpy.experimental.im._stain._validation import ( + StainFittingError, + _unit_columns, + complement_third_column, + reorder_to_canonical, + validate_stain_matrix, +) + +_MAXC_PERCENTILE = 99.0 +_MAXC_FLOOR = 1e-6 + + +@dataclass(slots=True, frozen=True) +class MacenkoParams: + """Tuning knobs for Macenko stain-matrix fitting.""" + + alpha: float = 1.0 + """Angular percentile (deg) for the two stain directions; the extremes are taken at ``alpha`` / ``100 - alpha``.""" + + beta: float = 0.15 + """Mean-absorbance cutoff selecting tissue pixels (optical-density space).""" + + def __post_init__(self) -> None: + object.__setattr__(self, "alpha", float(self.alpha)) + object.__setattr__(self, "beta", float(self.beta)) + if not 0.0 < self.alpha < 50.0: + raise ValueError(f"`alpha` must be in (0, 50), got {self.alpha}.") + if self.beta < 0.0: + raise ValueError(f"`beta` must be >= 0, got {self.beta}.") + + +@dataclass(slots=True, frozen=True) +class VahadaneParams: + """Tuning knobs for Vahadane (sparse-NMF) stain-matrix fitting.""" + + beta: float = 0.15 + """Mean-absorbance cutoff selecting tissue pixels (optical-density space).""" + + lambda1: float = 0.1 + """L1 sparsity regularisation on the concentration factor of the NMF.""" + + n_iter: int = 200 + """Maximum NMF iterations.""" + + random_state: int | None = 0 + """Seed for NMF initialisation tie-breaking; fixed for reproducible fits.""" + + def __post_init__(self) -> None: + object.__setattr__(self, "beta", float(self.beta)) + object.__setattr__(self, "lambda1", float(self.lambda1)) + object.__setattr__(self, "n_iter", int(self.n_iter)) + if self.beta < 0.0: + raise ValueError(f"`beta` must be >= 0, got {self.beta}.") + if self.lambda1 < 0.0: + raise ValueError(f"`lambda1` must be >= 0, got {self.lambda1}.") + if self.n_iter < 1: + raise ValueError(f"`n_iter` must be >= 1, got {self.n_iter}.") + + +_MACENKO_DEFAULTS = MacenkoParams() +_VAHADANE_DEFAULTS = VahadaneParams() +_MACENKO_FIELDS = frozenset(f.name for f in fields(MacenkoParams)) +_VAHADANE_FIELDS = frozenset(f.name for f in fields(VahadaneParams)) + + +def _resolve_params(params: Any, cls: type, defaults: Any, valid_fields: frozenset[str]) -> Any: + if params is None: + return defaults + if isinstance(params, cls): + return params + if isinstance(params, Mapping): + unknown = set(params) - valid_fields + if unknown: + raise ValueError( + f"Unknown `method_params` field(s): {sorted(unknown)}; expected from {sorted(valid_fields)}." + ) + return cls(**params) + raise TypeError(f"`method_params` must be {cls.__name__}, Mapping, or None; got {type(params).__name__}.") + + +def _resolve_macenko_params(params: MacenkoParams | Mapping[str, Any] | None) -> MacenkoParams: + return _resolve_params(params, MacenkoParams, _MACENKO_DEFAULTS, _MACENKO_FIELDS) + + +def _resolve_vahadane_params(params: VahadaneParams | Mapping[str, Any] | None) -> VahadaneParams: + return _resolve_params(params, VahadaneParams, _VAHADANE_DEFAULTS, _VAHADANE_FIELDS) + + +def _tissue_od( + image_rgb: xr.DataArray, + white_point: np.ndarray, + beta: float, + *, + tissue_mask: np.ndarray | None = None, + image_key: str | None, +) -> np.ndarray: + """Flatten tissue pixels to an ``(N, 3)`` optical-density matrix. + + Reduces over the chosen scale only (bounded); the stain fits need the + tissue pixels resident for SVD/NMF, so this is the one materialising step. + When ``tissue_mask`` (a ``(y, x)`` boolean aligned to ``image_rgb``) is + given it selects the tissue pixels; otherwise the absorbance threshold + ``beta`` is used. + """ + sda = rgb_to_sda(image_rgb, white_point) + mask = as_spatial_mask(tissue_mask, sda) if tissue_mask is not None else foreground_mask_from_sda(sda, beta) + od = np.asarray(sda.where(mask).transpose("c", "y", "x").data.reshape(3, -1)).T + od = od[np.all(np.isfinite(od), axis=1)] + if od.shape[0] == 0: + raise StainFittingError("no tissue pixels for stain fitting; the mask is empty.", image_key=image_key) + # Keep signed OD: pixels brighter than the estimated background carry + # negative absorbance that Macenko's SVD legitimately uses. Only Vahadane's + # NMF requires non-negativity, and clips locally. + return od + + +def _macenko_stain_matrix(od: np.ndarray, alpha: float) -> np.ndarray: + """Recover a ``(3, 2)`` H/E matrix via Macenko's angular-extreme method.""" + # right singular vectors of OD = principal absorbance directions through 0 + _, _, vh = np.linalg.svd(od, full_matrices=False) + plane = vh[:2].T # (3, 2) + # SVD sign is arbitrary; orient the basis into the data so the projected + # angles cluster around 0 instead of straddling the atan2 branch cut at + # +-180 deg (which would collapse the angular percentiles). + signs = np.sign(od.mean(axis=0) @ plane) + signs[signs == 0] = 1.0 + plane = plane * signs + proj = od @ plane # (N, 2) + phi = np.arctan2(proj[:, 1], proj[:, 0]) + lo, hi = np.percentile(phi, [alpha, 100.0 - alpha]) + extremes = np.stack( + [plane @ np.array([np.cos(lo), np.sin(lo)]), plane @ np.array([np.cos(hi), np.sin(hi)])], + axis=1, + ) + return _unit_columns(extremes) + + +def _vahadane_stain_matrix(od: np.ndarray, params: VahadaneParams) -> np.ndarray: + """Recover a ``(3, 2)`` H/E matrix via sparse NMF (Vahadane).""" + from sklearn.decomposition import NMF + + nmf = NMF( + n_components=2, + init="nndsvda", + random_state=params.random_state, + alpha_W=params.lambda1, + l1_ratio=1.0, + max_iter=params.n_iter, + ) + nmf.fit(np.clip(od, 0.0, None)) # NMF requires non-negative absorbance + stains = nmf.components_.T # (3, 2) + if np.any(np.linalg.norm(stains, axis=0) < 1e-8): + raise StainFittingError("Vahadane NMF produced a zero-norm stain vector.") + return _unit_columns(stains) + + +def _stain_matrix( + od: np.ndarray, + method: StainMethod, + params: Any, + *, + image_key: str | None, + reference: dict[str, np.ndarray] = RUIFROK_HE, + max_angle_deg: float = 45.0, +) -> np.ndarray: + """Fit, canonicalise, complete and validate a ``(3, 3)`` stain matrix. + + ``reference`` (the canonical H/E vectors) drives both the column ordering and + the deviation gate; ``max_angle_deg`` is the gate tolerance. + """ + raw = _macenko_stain_matrix(od, params.alpha) if method == "macenko" else _vahadane_stain_matrix(od, params) + matrix = complement_third_column(reorder_to_canonical(raw, reference)) + validate_stain_matrix(matrix, reference=reference, max_angle_deg=max_angle_deg, image_key=image_key) + return matrix + + +def _concentrations(od: np.ndarray, stain_matrix: np.ndarray) -> np.ndarray: + """Per-pixel stain concentrations ``(N, 3)`` from optical density.""" + return od @ np.linalg.pinv(stain_matrix).T + + +def _max_concentrations(concentrations: np.ndarray) -> np.ndarray: + """Robust per-stain (H, E) maximum concentrations ``(2,)`` from an ``(N, 3)`` array.""" + return np.maximum(np.percentile(concentrations[:, :2], _MAXC_PERCENTILE, axis=0), _MAXC_FLOOR) + + +def fit_decomposition( + image_rgb: xr.DataArray, + method: StainMethod, + params: Any, + white_point: np.ndarray, + *, + tissue_mask: np.ndarray | None = None, + image_key: str | None = None, + reference: dict[str, np.ndarray] = RUIFROK_HE, + max_angle_deg: float = 45.0, +) -> StainReference: + """Fit a decomposition :class:`StainReference` (stain matrix + max concentrations).""" + od = _tissue_od(image_rgb, white_point, params.beta, tissue_mask=tissue_mask, image_key=image_key) + matrix = _stain_matrix(od, method, params, image_key=image_key, reference=reference, max_angle_deg=max_angle_deg) + return StainReference( + method=method, + stain_matrix=matrix, + white_point=np.asarray(white_point, dtype=np.float64), + max_concentrations=_max_concentrations(_concentrations(od, matrix)), + ) + + +def _matmul_kernel(x: np.ndarray, *, matrix: np.ndarray, dtype: np.dtype) -> np.ndarray: + return (x.astype(dtype, copy=False) @ matrix.T).astype(dtype, copy=False) + + +def apply_decomposition( + image_rgb: xr.DataArray, + reference: StainReference, + params: Any, + *, + fit_rgb: xr.DataArray | None = None, + tissue_mask: np.ndarray | None = None, + out_dtype: np.dtype | type = np.uint8, +) -> xr.DataArray: + """Normalize a source image to a decomposition reference. + + Fits the *source's* own stain matrix and concentration scale, then maps + source absorbance to reference absorbance via a single ``(3, 3)`` linear + operator so the per-pixel transform stays lazy. + + The source matrix is a colour property, so it is fit on ``fit_rgb`` (a + coarse level) when given, while ``image_rgb`` (which may be full + resolution) is only ever touched by the lazy operator - never + materialised to fit a matrix. + """ + _check_channel_dim(image_rgb) + if reference.max_concentrations is None: + raise ValueError("reference is missing max_concentrations; refit it with fit_stain_reference.") + bg = reference.white_point + + od_src = _tissue_od( + fit_rgb if fit_rgb is not None else image_rgb, bg, params.beta, tissue_mask=tissue_mask, image_key=None + ) + w_src = _stain_matrix(od_src, reference.method, params, image_key=None) + pinv_src = np.linalg.pinv(w_src) # reused for the source concentrations and the operator + maxc_src = _max_concentrations(od_src @ pinv_src.T) + + scale = np.ones(3) + scale[:2] = reference.max_concentrations / maxc_src + operator = reference.stain_matrix @ np.diag(scale) @ pinv_src + + sda = rgb_to_sda(image_rgb, bg) + dtype = _working_dtype(sda) + sda_out = _apply_along_channel(sda, _matmul_kernel, out_dtype=dtype, matrix=operator.astype(dtype), dtype=dtype) + return sda_to_rgb(sda_out, bg, out_dtype=out_dtype) + + +def decompose_to_concentrations( + image_rgb: xr.DataArray, stain_matrix: np.ndarray, white_point: np.ndarray +) -> xr.DataArray: + """Project an image onto a stain matrix, returning a 3-channel concentration image. + + Channels are ``(hematoxylin, eosin, residual)``; the residual is the + concentration along the complement vector and is a diagnostic, not a stain. + """ + _check_channel_dim(image_rgb) + sda = rgb_to_sda(image_rgb, white_point) + dtype = _working_dtype(sda) + pinv = np.linalg.pinv(stain_matrix) + return _apply_along_channel(sda, _matmul_kernel, out_dtype=dtype, matrix=pinv.astype(dtype), dtype=dtype) diff --git a/src/squidpy/experimental/im/_stain/_mask.py b/src/squidpy/experimental/im/_stain/_mask.py index 87b12e960..833e3da18 100644 --- a/src/squidpy/experimental/im/_stain/_mask.py +++ b/src/squidpy/experimental/im/_stain/_mask.py @@ -1,11 +1,10 @@ """Foreground (tissue) masking for stain fitting. Method-agnostic on purpose: Reinhard fits its channel statistics over tissue -pixels only, and the Macenko/Vahadane fits added later need the same kind of -mask. The luminosity variant lives here; the absorbance variant -(``absorbance_foreground_mask``) is added beside it when decomposition lands, -both returning the same ``(y, x)`` boolean contract so downstream statistics -code stays mask-source-agnostic. +pixels only, and the Macenko/Vahadane fits need the same kind of mask. Two +variants live here - luminosity (Reinhard, intensity space) and absorbance +(decomposition, optical-density space) - both returning the same ``(y, x)`` +boolean contract so downstream statistics code stays mask-source-agnostic. """ from __future__ import annotations @@ -20,6 +19,7 @@ from squidpy.experimental.im._stain._conversion import ( _check_channel_dim, rgb_to_lab_ruderman, + rgb_to_sda, ) @@ -74,3 +74,50 @@ def luminosity_foreground_mask(rgb: xr.DataArray, threshold: float) -> xr.DataAr """ _check_channel_dim(rgb) return foreground_mask_from_lab(rgb_to_lab_ruderman(rgb), threshold) + + +def as_spatial_mask(mask: np.ndarray, like: xr.DataArray) -> xr.DataArray: + """Wrap a ``(y, x)`` boolean array as a DataArray aligned to ``like``'s y/x. + + Copies ``like``'s ``y``/``x`` coords (when present) so ``like.where(...)`` + aligns by coordinate rather than silently broadcasting. ``mask`` must match + ``like`` in the spatial dims. + """ + coords = {d: like.coords[d] for d in ("y", "x") if d in like.coords} + return xr.DataArray(np.asarray(mask, dtype=bool), dims=("y", "x"), coords=coords) + + +def foreground_mask_from_sda(sda: xr.DataArray, beta: float = 0.15) -> xr.DataArray: + """Tissue mask from an already-computed optical-density (SDA) image. + + The absorbance-space sibling of :func:`foreground_mask_from_lab`: lets the + decomposition fit derive the mask from the same lazy ``sda`` graph it + already needs for the optical densities, so dask materialises the + RGB->SDA conversion once. ``True`` = tissue (mean absorbance ``> beta``). + """ + return sda.mean(dim="c") > beta + + +def absorbance_foreground_mask(rgb: xr.DataArray, white_point: np.ndarray, beta: float = 0.15) -> xr.DataArray: + """Boolean tissue mask in optical-density (absorbance) space. + + The convention the Macenko/Vahadane fits expect: a pixel is tissue if its + mean absorbance across channels exceeds ``beta``. Near-white background + has near-zero absorbance and is excluded. + + Parameters + ---------- + rgb + Image with a ``"c"`` dimension of length 3. Numpy- or dask-backed. + white_point + Per-channel white point ``I_0`` (shape ``(3,)``), as used by + :func:`~squidpy.experimental.im._stain._conversion.rgb_to_sda`. + beta + Mean-absorbance cutoff. Pixels with mean SDA ``> beta`` are tissue. + + Returns + ------- + Boolean ``(y, x)`` DataArray: ``True`` = tissue. Lazy if ``rgb`` was lazy. + """ + _check_channel_dim(rgb) + return foreground_mask_from_sda(rgb_to_sda(rgb, white_point), beta) diff --git a/src/squidpy/experimental/im/_stain/_normalize.py b/src/squidpy/experimental/im/_stain/_normalize.py index 3178985b6..b1f15950e 100644 --- a/src/squidpy/experimental/im/_stain/_normalize.py +++ b/src/squidpy/experimental/im/_stain/_normalize.py @@ -5,9 +5,9 @@ re-exported publicly. Everything it calls is a pure DataArray-layer primitive (:mod:`._reinhard`, :mod:`._mask`, :mod:`._conversion`). -Both entry points dispatch on the fitting ``method``. Only ``"reinhard"`` is -implemented here; ``"macenko"``/``"vahadane"`` raise ``NotImplementedError`` -and are filled in without changing these signatures. +Both entry points dispatch on the fitting ``method`` (``"reinhard"`` colour +transfer, or ``"macenko"``/``"vahadane"`` absorbance decomposition); a third +entry, :func:`decompose_stains`, projects an image onto its stain matrix. """ from __future__ import annotations @@ -15,13 +15,25 @@ from collections.abc import Mapping from typing import Any, Literal +import numpy as np import spatialdata as sd import xarray as xr +from numpy.typing import DTypeLike from spatialdata.models import Image2DModel from spatialdata.transformations import get_transformation from squidpy._utils import _get_scale_factors -from squidpy.experimental.im._stain._conversion import _check_channel_dim +from squidpy.experimental.im._stain._constants import RUIFROK_HE +from squidpy.experimental.im._stain._conversion import _check_channel_dim, cast_to_image_dtype +from squidpy.experimental.im._stain._decomposition import ( + MacenkoParams, + VahadaneParams, + _resolve_macenko_params, + _resolve_vahadane_params, + apply_decomposition, + decompose_to_concentrations, + fit_decomposition, +) from squidpy.experimental.im._stain._reference import StainMethod, StainReference from squidpy.experimental.im._stain._reinhard import ( ReinhardParams, @@ -29,9 +41,24 @@ apply_reinhard, fit_reinhard, ) -from squidpy.experimental.im._utils import get_element_data +from squidpy.experimental.im._stain._white_point import ( + default_white_point, + validate_rgb_range, + white_point_from_background, +) +from squidpy.experimental.im._utils import ( + _choose_label_scale_for_image, + get_element_data, + get_mask_materialized, + resolve_tissue_mask, +) -_DECOMPOSITION_NOT_IMPLEMENTED = "macenko/vahadane decomposition is not yet implemented" +_VALID_METHODS = ("reinhard", "macenko", "vahadane") +_DECOMPOSITION_METHODS = ("macenko", "vahadane") +_CONCENTRATION_CHANNELS = ["hematoxylin", "eosin", "residual"] + +# Public union accepted by the method_params argument of the dispatchers. +MethodParams = ReinhardParams | MacenkoParams | VahadaneParams | Mapping[str, Any] | None def _resolve_image( @@ -49,13 +76,150 @@ def _resolve_image( return da +def _resolve_mask_key_and_scale( + sdata: sd.SpatialData, image_key: str, target_da: xr.DataArray, tissue_mask_key: str | None +) -> tuple[str, str, tuple[int, int]]: + """Resolve the (mandatory) tissue-mask key and the label scale closest to ``target_da``. + + Shared by the two mask consumers below. Consumes a + :func:`!detect_tissue` labels element - raises if + none exists. + """ + mask_key = resolve_tissue_mask(sdata, image_key, "auto", tissue_mask_key, auto_create=False) + target_hw = (int(target_da.sizes["y"]), int(target_da.sizes["x"])) + label_scale = _choose_label_scale_for_image(sdata.labels[mask_key], target_hw) + return mask_key, label_scale, target_hw + + +def _resolve_tissue_bool_mask( + sdata: sd.SpatialData, image_key: str, fit_da: xr.DataArray, tissue_mask_key: str | None +) -> np.ndarray: + """Return a materialised ``(y, x)`` boolean tissue mask aligned to ``fit_da``. + + For the (coarse) fit: nearest-resizes to ``fit_da``'s ``(y, x)`` when the + closest label scale differs. The fits run on a coarse level, so the mask + stays small. + """ + mask_key, label_scale, target_hw = _resolve_mask_key_and_scale(sdata, image_key, fit_da, tissue_mask_key) + mask = get_mask_materialized(sdata, mask_key, label_scale) > 0 + if mask.shape != target_hw: + from skimage.transform import resize + + mask = resize(mask, target_hw, order=0, preserve_range=True) > 0.5 + return mask + + +def _resolve_output_tissue_mask( + sdata: sd.SpatialData, image_key: str, target_da: xr.DataArray, tissue_mask_key: str | None +) -> xr.DataArray: + """Return a lazy ``(y, x)`` boolean tissue mask aligned to ``target_da``. + + Like :func:`_resolve_tissue_bool_mask` but kept lazy and at the (full-res) + output resolution, for compositing the original background back into the + normalized image without materialising the full frame. The label pyramid + shares the image's scale factors, so the matching level usually lines up + exactly; only a residual size mismatch forces a (small) eager resize. + """ + mask_key, label_scale, target_hw = _resolve_mask_key_and_scale(sdata, image_key, target_da, tissue_mask_key) + coords = {d: target_da.coords[d] for d in ("y", "x") if d in target_da.coords} + mask = get_element_data(sdata.labels[mask_key], label_scale, "label", mask_key).squeeze() > 0 + if (int(mask.sizes["y"]), int(mask.sizes["x"])) == target_hw: + return mask.assign_coords(coords) + from skimage.transform import resize + + resized = resize(np.asarray(mask.data) > 0, target_hw, order=0, preserve_range=True) > 0.5 + return xr.DataArray(resized, dims=("y", "x"), coords=coords) + + +def _resolve_method_params(method: str, method_params: MethodParams) -> Any: + """Pick the right Params dataclass for ``method`` and resolve a mapping/instance/None.""" + if method == "reinhard": + return _resolve_reinhard_params(method_params) + if method == "macenko": + return _resolve_macenko_params(method_params) + if method == "vahadane": + return _resolve_vahadane_params(method_params) + raise ValueError(f"Unknown method {method!r}; expected one of {list(_VALID_METHODS)}.") + + +def _write_image( + sdata: sd.SpatialData, + source_node: Any, + image_key_added: str, + data_array: xr.DataArray, + *, + c_coords: list[Any] | None = None, +) -> None: + """Write a derived image element, preserving the source's transforms/pyramid. + + Reconstructs the element from the bare array (a derived DataArray would + carry the source's ``transform`` attr and collide with the transformations + we pass) plus the dims/channel-coords/transforms to preserve. The same + idiom as detect_tissue. ``_get_scale_factors`` returns ``[]`` for a + single-scale source; parse needs ``None`` there (an empty list builds a + degenerate single-level pyramid). + """ + if image_key_added in sdata.images: + raise ValueError(f"image_key_added={image_key_added!r} already exists in sdata.images.") + if c_coords is None: + c_coords = data_array.coords["c"].values.tolist() if "c" in data_array.coords else None + sdata.images[image_key_added] = Image2DModel.parse( + data_array.data, + dims=data_array.dims, + c_coords=c_coords, + transformations=get_transformation(source_node, get_all=True), + scale_factors=_get_scale_factors(source_node) or None, + ) + + +def estimate_white_point( + sdata: sd.SpatialData, + image_key: str, + *, + tissue_mask_key: str | None = None, + scale: str | Literal["auto"] = "auto", +) -> np.ndarray: + """Estimate the white point ``I_0`` from a slide's background (non-tissue median). + + Opt-in alternative to the fixed dtype-aware default white point, for a slide + whose unstained background is genuinely not full white. Samples the + per-channel median over **non-tissue** pixels (background = the complement of + the :func:`!detect_tissue` mask). + + Parameters + ---------- + sdata, image_key + The SpatialData object and the RGB image key. + tissue_mask_key + Tissue-label element key (defaults to ``f"{image_key}_tissue"``); a + tissue mask is required, as for :func:`fit_stain_reference`. + scale + Scale level to sample on. ``"auto"`` (default) uses the coarsest level. + The sampled level is materialised to take the median, so keep this + coarse - do not pass a fine level on a whole-slide image. + + Returns + ------- + Shape-``(3,)`` white point; pass it as ``white_point`` to + :func:`fit_stain_reference` / :func:`decompose_stains`. + """ + da = _resolve_image(sdata, image_key, scale, prefer="coarsest") + validate_rgb_range(da) + tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, da, tissue_mask_key) + return white_point_from_background(da, ~tissue_mask) + + def fit_stain_reference( sdata: sd.SpatialData, image_key: str, *, - method: StainMethod = "reinhard", + method: StainMethod = "macenko", scale: str | Literal["auto"] = "auto", - method_params: ReinhardParams | Mapping[str, Any] | None = None, + method_params: MethodParams = None, + white_point: np.ndarray | None = None, + tissue_mask_key: str | None = None, + max_angle_deg: float = 45.0, + canonical_reference: Mapping[str, np.ndarray] | None = None, ) -> StainReference: """Fit a stain reference from an image in a :class:`~spatialdata.SpatialData` object. @@ -66,35 +230,80 @@ def fit_stain_reference( image_key Key of the RGB image in ``sdata.images`` to fit on. method - Fitting method. Only ``"reinhard"`` is implemented; ``"macenko"`` and - ``"vahadane"`` raise :class:`NotImplementedError`. + Fitting method: ``"macenko"`` (default) or ``"vahadane"`` (physical + stain-matrix decomposition, usable by both :func:`normalize_stains` and + :func:`decompose_stains`), or ``"reinhard"`` (faster statistical colour + transfer, no stain separation). Macenko is the default because its one + documented weakness - artifact pixels contaminating the fit - is removed + by the mandatory tissue mask. scale Scale level to fit on. ``"auto"`` (default) uses the coarsest level, which is cheap and sufficient for colour statistics. method_params - :class:`ReinhardParams` instance, a mapping of its fields, or ``None`` - for defaults. + A :class:`ReinhardParams`/:class:`MacenkoParams`/:class:`VahadaneParams` + instance, a mapping of its fields, or ``None`` for defaults. Must match + ``method``. + white_point + Per-channel white point ``I_0`` ``(3,)`` for the decomposition methods. + If ``None``, a fixed full-white ``[255, 255, 255]`` is used (the + HistomicsTK/Macenko convention), so unstained pixels round-trip to + white. Pass :func:`estimate_white_point` only for slides with a + known non-white background. Ignored by Reinhard. + tissue_mask_key + Key of a tissue-label element in ``sdata.labels`` (as produced by + :func:`!detect_tissue`) restricting the fit to + tissue pixels. If ``None``, ``f"{image_key}_tissue"`` is used. A tissue + mask is **required**: if neither exists, a :class:`KeyError` asks you to + run :func:`!detect_tissue` first. + max_angle_deg + Tolerance of the H/E sanity gate for the decomposition methods: the fit + raises :class:`!StainFittingError` if either recovered stain vector + deviates more than this many degrees from its canonical reference. + Default ``45``. Ignored by Reinhard. + canonical_reference + Canonical H/E reference for the decomposition methods, a mapping with + ``"hematoxylin"`` and ``"eosin"`` keys to ``(3,)`` RGB optical-density + unit vectors. Drives both the H/E column ordering and the deviation + gate. If ``None``, the Ruifrok H&E vectors are used. Ignored by Reinhard. Returns ------- The fitted :class:`StainReference`. Nothing is written to ``sdata``. """ + if method not in _VALID_METHODS: + raise ValueError(f"Unknown method {method!r}; expected one of {list(_VALID_METHODS)}.") da = _resolve_image(sdata, image_key, scale, prefer="coarsest") + validate_rgb_range(da) + params = _resolve_method_params(method, method_params) + tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, da, tissue_mask_key) if method == "reinhard": - return fit_reinhard(da, _resolve_reinhard_params(method_params)) - if method in {"macenko", "vahadane"}: - raise NotImplementedError(_DECOMPOSITION_NOT_IMPLEMENTED) - raise ValueError(f"Unknown method {method!r}; expected one of ['macenko', 'reinhard', 'vahadane'].") + return fit_reinhard(da, params, tissue_mask=tissue_mask) + bg = default_white_point(da) if white_point is None else np.asarray(white_point, np.float64) + reference = RUIFROK_HE if canonical_reference is None else dict(canonical_reference) + return fit_decomposition( + da, + method, + params, + bg, + tissue_mask=tissue_mask, + image_key=image_key, + reference=reference, + max_angle_deg=max_angle_deg, + ) -def apply_stain_normalization( +def normalize_stains( sdata: sd.SpatialData, image_key: str, reference: StainReference, *, scale: str | Literal["auto"] = "auto", - method_params: ReinhardParams | Mapping[str, Any] | None = None, + method_params: MethodParams = None, image_key_added: str | None = None, + inplace: bool = True, + output_dtype: DTypeLike | None = None, + tissue_mask_key: str | None = None, + preserve_background: bool = True, ) -> xr.DataArray | None: """Normalize an image to a fitted stain reference. @@ -112,47 +321,167 @@ def apply_stain_normalization( so the result is not downsampled; source statistics are reduced lazily so memory stays bounded. method_params - :class:`ReinhardParams` instance, a mapping of its fields, or ``None`` - for defaults. + Params matching ``reference.method`` (instance, mapping, or ``None``). image_key_added - If ``None`` (default), return the lazy normalized DataArray and leave - ``sdata`` untouched. If given, write the result to - ``sdata.images[image_key_added]`` (rebuilding the pyramid for - multiscale sources, preserving transforms) and return ``None``. - Raises if the key already exists. + Key for the written image when ``inplace=True``. If ``None`` (default), + ``f"{image_key}_normalized"`` is used. Ignored when ``inplace=False``. + inplace + If ``True`` (default), write the normalized image to + ``sdata.images[image_key_added]`` (rebuilding the pyramid for multiscale + sources, preserving transforms) and return ``None``; raises if the key + already exists. If ``False``, leave ``sdata`` untouched and return the + lazy normalized :class:`~xarray.DataArray`. + output_dtype + Dtype of the result. If ``None`` (default), the source image's dtype is + used. The reconstruction is clipped to that dtype's valid range and + rounded (for integer dtypes) at the write boundary. + tissue_mask_key + Key of a tissue-label element in ``sdata.labels`` restricting the + *source* statistics to tissue pixels. As for + :func:`fit_stain_reference`, a tissue mask is required (defaults to + ``f"{image_key}_tissue"``; raises if missing). + preserve_background + If ``True`` (default), non-tissue (background) pixels are passed through + unchanged from the source image, so the normalization recolours only + tissue. The colour map is a global linear transform that would otherwise + tint background/white pixels. Set ``False`` for full-frame normalization. Returns ------- - The lazy normalized :class:`xarray.DataArray` if ``image_key_added`` is - ``None``, otherwise ``None``. + ``None`` if ``inplace=True`` (the image is written), otherwise the lazy + normalized :class:`xarray.DataArray`. """ da = _resolve_image(sdata, image_key, scale, prefer="finest") + target_key = image_key_added if image_key_added is not None else f"{image_key}_normalized" + if inplace and target_key in sdata.images: + raise ValueError(f"image_key_added={target_key!r} already exists in sdata.images.") + params = _resolve_method_params(reference.method, method_params) + # Source statistics (Reinhard mu/sigma or the decomposition source matrix) + # are reduced on a coarse level with a tissue mask; the lazy transform is + # then applied to the full-resolution `da`. + fit_rgb = _resolve_image(sdata, image_key, scale, prefer="coarsest") + validate_rgb_range(fit_rgb) # reject mis-typed source (e.g. 0-255 float) before the dtype-clipped reconstruction + tissue_mask = _resolve_tissue_bool_mask(sdata, image_key, fit_rgb, tissue_mask_key) + out_dtype = da.dtype if output_dtype is None else np.dtype(output_dtype) # clip range + final cast if reference.method == "reinhard": - normalized = apply_reinhard(da, reference, _resolve_reinhard_params(method_params)) - elif reference.method in {"macenko", "vahadane"}: - raise NotImplementedError(_DECOMPOSITION_NOT_IMPLEMENTED) - else: # pragma: no cover - StainReference validates method on construction - raise ValueError(f"Unknown reference method {reference.method!r}.") + normalized = apply_reinhard( + da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask, out_dtype=out_dtype + ) + else: + normalized = apply_decomposition( + da, reference, params, fit_rgb=fit_rgb, tissue_mask=tissue_mask, out_dtype=out_dtype + ) + + if preserve_background: + # Keep non-tissue pixels byte-identical to the source: the global colour + # map would otherwise recolour background/white pixels (HistomicsTK's + # `mask_out`). Stays lazy - the mask aligns to `da` without materialising. + keep = _resolve_output_tissue_mask(sdata, image_key, da, tissue_mask_key) + normalized = normalized.where(keep, da) + + # Deferred cast at the write boundary: the reconstruction was kept in float + # (clipped to `out_dtype`'s range); round + cast here so the stored image is + # the requested dtype and integer background stays byte-identical. + normalized = cast_to_image_dtype(normalized, out_dtype) - if image_key_added is None: + if not inplace: return normalized - if image_key_added in sdata.images: - raise ValueError(f"image_key_added={image_key_added!r} already exists in sdata.images.") + _write_image(sdata, sdata.images[image_key], target_key, normalized) + return None - node = sdata.images[image_key] - # Reconstruct the element explicitly from the underlying array: parse a - # DataArray would carry over the source's transform attr and collide with - # the transformations we pass, so we hand it the bare array plus the - # dims/channel coords/transforms we want to preserve (the same idiom as - # detect_tissue). `_get_scale_factors` returns [] for a single-scale - # source; parse needs None there (an empty list builds a degenerate - # single-level pyramid). - c_coords = normalized.coords["c"].values.tolist() if "c" in normalized.coords else None - sdata.images[image_key_added] = Image2DModel.parse( - normalized.data, - dims=normalized.dims, - c_coords=c_coords, - transformations=get_transformation(node, get_all=True), - scale_factors=_get_scale_factors(node) or None, - ) + +def decompose_stains( + sdata: sd.SpatialData, + image_key: str, + reference_or_method: StainReference | Literal["macenko", "vahadane"], + *, + scale: str | Literal["auto"] = "auto", + method_params: MethodParams = None, + white_point: np.ndarray | None = None, + image_key_added: str | None = None, + inplace: bool = True, + output_dtype: DTypeLike = np.float16, + tissue_mask_key: str | None = None, + include_residual: bool = True, +) -> dict[str, xr.DataArray] | None: + """Decompose an image into separate per-stain concentration maps. + + Parameters + ---------- + sdata, image_key + The SpatialData object and the RGB image key to decompose. + reference_or_method + Either a decomposition :class:`StainReference` (its stain matrix and + white point are used) or a method name (``"macenko"``/``"vahadane"``) + to fit on this image first. The reference is the provenance record of + how the maps were produced (method, stain matrix, white point). + scale, method_params, white_point, tissue_mask_key + As for :func:`fit_stain_reference` (only used when a method name is + given; a reference is projected as-is and needs no tissue mask). + image_key_added + Key *prefix* for the written images when ``inplace=True``. If ``None`` + (default), ``image_key`` is used, so each stain is written as its own + single-channel image ``sdata.images[f"{image_key}_{stain}"]`` (e.g. + ``f"{image_key}_hematoxylin"``). Ignored when ``inplace=False``. + inplace + If ``True`` (default), write each stain as a separate single-channel + image under the ``image_key_added`` prefix and return ``None``; the + write is atomic (all target keys are validated free before any is + written). If ``False``, leave ``sdata`` untouched and return the maps + as a dict. + output_dtype + Dtype of the concentration maps. Defaults to ``float16`` (half the + storage; ~3 significant figures, adequate for concentrations); pass + ``float32`` for strict quantification. + include_residual + If ``True`` (default), also produce the ``"residual"`` map. The residual + is the absorbance along the complement direction - a diagnostic of + decomposition quality (extra chromogen, artifacts, or a poor fit), not a + biological stain. Set ``False`` to keep only ``hematoxylin``/``eosin``. + + Returns + ------- + ``None`` if ``inplace=True`` (the maps are written as separate images), + otherwise a ``dict`` mapping each stain name to its ``(y, x)`` concentration + :class:`~xarray.DataArray` (``"hematoxylin"``, ``"eosin"``, and + ``"residual"`` unless dropped). + """ + da = _resolve_image(sdata, image_key, scale, prefer="finest") + if isinstance(reference_or_method, StainReference): + reference = reference_or_method + if reference.method not in _DECOMPOSITION_METHODS or reference.stain_matrix is None: + raise ValueError("decompose_stains requires a macenko/vahadane reference with a stain matrix.") + stain_matrix, bg = reference.stain_matrix, reference.white_point + else: + if reference_or_method not in _DECOMPOSITION_METHODS: + raise ValueError(f"method must be one of {list(_DECOMPOSITION_METHODS)}; got {reference_or_method!r}.") + reference = fit_stain_reference( + sdata, + image_key, + method=reference_or_method, + scale=scale, + method_params=method_params, + white_point=white_point, + tissue_mask_key=tissue_mask_key, + ) + stain_matrix, bg = reference.stain_matrix, reference.white_point + + names = ["hematoxylin", "eosin"] + (["residual"] if include_residual else []) + prefix = image_key_added if image_key_added is not None else image_key + target_keys = [f"{prefix}_{name}" for name in names] + if inplace: # validate all keys free up front, so a partial write can't leave a half-decomposed sdata + clashes = [k for k in target_keys if k in sdata.images] + if clashes: + raise ValueError(f"decompose_stains would overwrite existing image(s): {clashes}.") + + concentrations = decompose_to_concentrations(da, stain_matrix, bg).assign_coords(c=_CONCENTRATION_CHANNELS) + concentrations = concentrations.astype(np.dtype(output_dtype)) + + if not inplace: + return {name: concentrations.sel(c=name) for name in names} + + source = sdata.images[image_key] + for name, key in zip(names, target_keys, strict=True): + # keep the c dim (length 1) so Image2DModel.parse accepts it + _write_image(sdata, source, key, concentrations.sel(c=[name]), c_coords=[name]) return None diff --git a/src/squidpy/experimental/im/_stain/_reference.py b/src/squidpy/experimental/im/_stain/_reference.py index 4b1a7fa06..bf4e464b4 100644 --- a/src/squidpy/experimental/im/_stain/_reference.py +++ b/src/squidpy/experimental/im/_stain/_reference.py @@ -27,7 +27,7 @@ def _coerce_finite(arr: Any, *, shape: tuple[int, ...], name: str) -> np.ndarray return out -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class StainReference: """Container for a fitted stain reference. @@ -43,19 +43,45 @@ class StainReference: sigma Shape ``(3,)`` Ruderman Lab channel standard deviations. Reinhard only. - background_intensity + white_point Shape ``(3,)`` per-channel white-point estimate. Required for decomposition methods (apply consumes it). Forbidden for Reinhard because Reinhard's color transfer operates in Ruderman Lab and does not model absorbance. There is no universal default; pass an - estimate from your data (PR 3 ships the estimator). + estimate from your data (see ``estimate_white_point``). + max_concentrations + Shape ``(2,)`` reference per-stain (H, E) maximum concentrations. + Decomposition only. Stored because at apply time the reference image + is gone, so the target concentration scale must travel with the + reference. Optional (Reinhard references and externally-built + decomposition references without it remain valid); forbidden for + Reinhard. """ method: StainMethod stain_matrix: np.ndarray | None = None mu: np.ndarray | None = None sigma: np.ndarray | None = None - background_intensity: np.ndarray | None = None + white_point: np.ndarray | None = None + max_concentrations: np.ndarray | None = None + + def __eq__(self, other: object) -> bool: + # The numpy-array fields make the dataclass-generated __eq__ raise + # ("truth value of an array is ambiguous"), so compare explicitly: + # equal method plus element-wise-equal arrays. + if not isinstance(other, StainReference): + return NotImplemented + if self.method != other.method: + return False + return all( + np.array_equal(getattr(self, name), getattr(other, name)) + for name in ("stain_matrix", "mu", "sigma", "white_point", "max_concentrations") + ) + + # eq=False keeps the default identity-based __hash__ (the array fields are + # unhashable, so a value-based hash is impossible); references remain usable + # as set members / dict keys by identity. + __hash__ = object.__hash__ def __post_init__(self) -> None: if self.method not in _VALID_METHODS: @@ -66,27 +92,34 @@ def __post_init__(self) -> None: raise ValueError(f"method={self.method!r} requires stain_matrix.") if self.mu is not None or self.sigma is not None: raise ValueError(f"method={self.method!r} forbids mu/sigma; pass them only for Reinhard.") - if self.background_intensity is None: - raise ValueError(f"method={self.method!r} requires background_intensity.") + if self.white_point is None: + raise ValueError(f"method={self.method!r} requires white_point.") object.__setattr__( self, "stain_matrix", _coerce_finite(self.stain_matrix, shape=(3, 3), name="stain_matrix"), ) - bg = _coerce_finite(self.background_intensity, shape=(3,), name="background_intensity") + bg = _coerce_finite(self.white_point, shape=(3,), name="white_point") if np.any(bg <= 0): - raise ValueError("background_intensity must be strictly positive.") - object.__setattr__(self, "background_intensity", bg) + raise ValueError("white_point must be strictly positive.") + object.__setattr__(self, "white_point", bg) + if self.max_concentrations is not None: + maxc = _coerce_finite(self.max_concentrations, shape=(2,), name="max_concentrations") + if np.any(maxc <= 0): + raise ValueError("max_concentrations must be strictly positive.") + object.__setattr__(self, "max_concentrations", maxc) else: if self.mu is None or self.sigma is None: raise ValueError("method='reinhard' requires both mu and sigma.") if self.stain_matrix is not None: raise ValueError("method='reinhard' forbids stain_matrix.") - if self.background_intensity is not None: + if self.white_point is not None: raise ValueError( - "method='reinhard' forbids background_intensity; Reinhard's color " + "method='reinhard' forbids white_point; Reinhard's color " "transfer is in Ruderman Lab and does not use a white point." ) + if self.max_concentrations is not None: + raise ValueError("method='reinhard' forbids max_concentrations.") mu = _coerce_finite(self.mu, shape=(3,), name="mu") sigma = _coerce_finite(self.sigma, shape=(3,), name="sigma") if np.any(sigma <= 0): diff --git a/src/squidpy/experimental/im/_stain/_reinhard.py b/src/squidpy/experimental/im/_stain/_reinhard.py index a3f425c4a..5c0021c5d 100644 --- a/src/squidpy/experimental/im/_stain/_reinhard.py +++ b/src/squidpy/experimental/im/_stain/_reinhard.py @@ -22,7 +22,7 @@ lab_ruderman_to_rgb, rgb_to_lab_ruderman, ) -from squidpy.experimental.im._stain._mask import foreground_mask_from_lab +from squidpy.experimental.im._stain._mask import as_spatial_mask, foreground_mask_from_lab from squidpy.experimental.im._stain._reference import StainReference # Numerical safeguard against divide-by-zero on flat (constant-colour) @@ -112,34 +112,59 @@ def _transfer_kernel( return ((x - mu_src) / sigma_src * sigma_ref + mu_ref).astype(dtype, copy=False) -def fit_reinhard(image_rgb: xr.DataArray, params: ReinhardParams) -> StainReference: +def _reinhard_mask(lab: xr.DataArray, params: ReinhardParams, tissue_mask: np.ndarray | None) -> xr.DataArray | None: + """Resolve the tissue mask for the Reinhard stats: external mask wins, else + the param-driven luminosity mask (or ``None`` for vanilla Reinhard).""" + if tissue_mask is not None: + return as_spatial_mask(tissue_mask, lab) + if params.mask_background: + return foreground_mask_from_lab(lab, params.luminosity_threshold) + return None + + +def fit_reinhard( + image_rgb: xr.DataArray, params: ReinhardParams, *, tissue_mask: np.ndarray | None = None +) -> StainReference: """Fit Reinhard channel statistics on a reference image. Converts to Ruderman Lab, computes per-channel ``mu``/``sigma`` over - tissue pixels (or all pixels when ``mask_background=False``), and packs - them into a ``StainReference(method="reinhard")``. + tissue pixels, and packs them into a ``StainReference(method="reinhard")``. + ``tissue_mask`` (a ``(y, x)`` boolean aligned to ``image_rgb``) selects the + tissue pixels when given; otherwise the ``mask_background`` / + ``luminosity_threshold`` params drive the mask. """ _check_channel_dim(image_rgb) lab = rgb_to_lab_ruderman(image_rgb) - mask = foreground_mask_from_lab(lab, params.luminosity_threshold) if params.mask_background else None - mu, sigma = _masked_channel_stats(lab, mask) + mu, sigma = _masked_channel_stats(lab, _reinhard_mask(lab, params, tissue_mask)) return StainReference(method="reinhard", mu=mu, sigma=sigma) -def apply_reinhard(image_rgb: xr.DataArray, reference: StainReference, params: ReinhardParams) -> xr.DataArray: +def apply_reinhard( + image_rgb: xr.DataArray, + reference: StainReference, + params: ReinhardParams, + *, + fit_rgb: xr.DataArray | None = None, + tissue_mask: np.ndarray | None = None, + out_dtype: np.dtype | type = np.uint8, +) -> xr.DataArray: """Apply a Reinhard reference to a source image. Standardises by the source's own tissue statistics, rescales to the reference statistics, and converts back to RGB. The transform is applied - to every pixel (the map is global); only the statistics that define it - are tissue-only. Lazy if and only if the input is lazy. + to every pixel of ``image_rgb`` (the map is global); the defining + statistics are reduced on ``fit_rgb`` (a coarse level) when given, so the + full-resolution image is never materialised to compute them. + ``tissue_mask`` (aligned to ``fit_rgb``) selects the source tissue pixels. + Lazy if and only if ``image_rgb`` is lazy. """ _check_channel_dim(image_rgb) - lab = rgb_to_lab_ruderman(image_rgb) - mask = foreground_mask_from_lab(lab, params.luminosity_threshold) if params.mask_background else None - mu_src, sigma_src = _masked_channel_stats(lab, mask) + fit_lab = rgb_to_lab_ruderman(fit_rgb if fit_rgb is not None else image_rgb) + mu_src, sigma_src = _masked_channel_stats(fit_lab, _reinhard_mask(fit_lab, params, tissue_mask)) sigma_src = np.maximum(sigma_src, _SIGMA_FLOOR) + lab = rgb_to_lab_ruderman(image_rgb) + dtype = _working_dtype(lab) lab_out = _apply_along_channel( lab, @@ -151,4 +176,4 @@ def apply_reinhard(image_rgb: xr.DataArray, reference: StainReference, params: R sigma_ref=np.asarray(reference.sigma, dtype=dtype), dtype=dtype, ) - return lab_ruderman_to_rgb(lab_out) + return lab_ruderman_to_rgb(lab_out, out_dtype=out_dtype) diff --git a/src/squidpy/experimental/im/_stain/_validation.py b/src/squidpy/experimental/im/_stain/_validation.py new file mode 100644 index 000000000..3c2cd8554 --- /dev/null +++ b/src/squidpy/experimental/im/_stain/_validation.py @@ -0,0 +1,124 @@ +"""Stain-matrix validation and canonicalisation primitives. + +Pure numpy, no ``sdata``, no public export. Shared by the Macenko and +Vahadane fits so both produce a canonical ``(H, E, complement)`` matrix that +downstream apply/decompose code can treat method-agnostically. +""" + +from __future__ import annotations + +import numpy as np + +from squidpy.experimental.im._stain._constants import RUIFROK_HE + + +class StainFittingError(RuntimeError): + """A stain-matrix fit produced an invalid or degenerate result. + + Carries ``image_key`` so cohort fitting (a later PR) can attribute a + failure to a specific slide and skip or flag it by name. + """ + + def __init__(self, reason: str, *, image_key: str | None = None) -> None: + self.reason = reason + self.image_key = image_key + prefix = f"[{image_key}] " if image_key is not None else "" + super().__init__(f"{prefix}{reason}") + + +def _canonical_he(reference: dict[str, np.ndarray]) -> np.ndarray: + """Stack the reference H and E unit vectors as columns of a ``(3, 2)``.""" + return np.stack([reference["hematoxylin"], reference["eosin"]], axis=1) + + +def angle_between_deg(u: np.ndarray, v: np.ndarray) -> float: + """Unsigned angle in degrees between two vectors (sign-agnostic).""" + cos = abs(float(u @ v)) / (np.linalg.norm(u) * np.linalg.norm(v)) + return float(np.degrees(np.arccos(np.clip(cos, -1.0, 1.0)))) + + +def _unit_columns(matrix: np.ndarray) -> np.ndarray: + """Scale each column of ``matrix`` to unit L2 norm.""" + return matrix / np.linalg.norm(matrix, axis=0, keepdims=True) + + +def reorder_to_canonical(matrix: np.ndarray, reference: dict[str, np.ndarray] = RUIFROK_HE) -> np.ndarray: + """Order a ``(3, 2)`` stain matrix to ``(H, E)`` and fix column signs. + + Macenko's SVD and Vahadane's NMF recover the two stain directions in an + arbitrary order and sign. We assign each recovered column to whichever of + the canonical Ruifrok H/E vectors it is most colinear with, then flip its + sign so it points the same way as that reference (absorbance is positive). + """ + w = np.asarray(matrix, dtype=np.float64) + if w.shape != (3, 2): + raise ValueError(f"stain matrix to reorder must have shape (3, 2); got {w.shape}.") + canonical = _canonical_he(reference) # (3, 2): [H, E] + cols = _unit_columns(w) + + # cosine of each recovered column against each canonical vector + sim = cols.T @ canonical # (2 recovered, 2 canonical) + # assign recovered column 0/1 to H if it favours H more than column 1 does + h_idx = int(np.argmax(np.abs(sim[:, 0]))) + e_idx = 1 - h_idx + ordered = np.stack([w[:, h_idx], w[:, e_idx]], axis=1) + + # flip signs so each column points along its canonical reference + for j in range(2): + if ordered[:, j] @ canonical[:, j] < 0: + ordered[:, j] = -ordered[:, j] + return ordered + + +def complement_third_column(matrix: np.ndarray) -> np.ndarray: + """Extend a ``(3, 2)`` H/E matrix to ``(3, 3)`` with a complement column. + + The third column is the unit cross product of the H and E columns: the + residual direction orthogonal to both, used to capture absorbance not + explained by either stain. + """ + w = np.asarray(matrix, dtype=np.float64) + if w.shape != (3, 2): + raise ValueError(f"stain matrix to complement must have shape (3, 2); got {w.shape}.") + third = np.cross(w[:, 0], w[:, 1]) + norm = np.linalg.norm(third) + if norm < 1e-8: + raise StainFittingError("H and E stain vectors are colinear; cannot form a complement.") + third = third / norm + return np.column_stack([w, third]) + + +def validate_stain_matrix( + matrix: np.ndarray, + *, + reference: dict[str, np.ndarray] = RUIFROK_HE, + max_angle_deg: float = 45.0, + image_key: str | None = None, +) -> None: + """Raise :class:`StainFittingError` if a ``(3, 3)`` matrix is implausible. + + Guards against the failure modes of an unsupervised stain fit: a column + collapsed to zero, a rank-deficient (single-stain) matrix, or an H/E + direction rotated far from its Ruifrok canonical (a sign the fit latched + onto noise or a non-H&E chromogen). + """ + w = np.asarray(matrix, dtype=np.float64) + if w.shape != (3, 3): + raise StainFittingError(f"stain matrix must have shape (3, 3); got {w.shape}.", image_key=image_key) + if not np.all(np.isfinite(w)): + raise StainFittingError("stain matrix contains non-finite values.", image_key=image_key) + + norms = np.linalg.norm(w, axis=0) + if np.any(norms < 1e-8): + raise StainFittingError("stain matrix has a zero-norm column.", image_key=image_key) + if np.linalg.matrix_rank(w, tol=1e-6) < 3: + raise StainFittingError("stain matrix is rank-deficient (stains are not separable).", image_key=image_key) + + canonical = _canonical_he(reference) + for name, j in (("hematoxylin", 0), ("eosin", 1)): + angle = angle_between_deg(w[:, j], canonical[:, j]) + if angle > max_angle_deg: + raise StainFittingError( + f"{name} stain vector deviates {angle:.1f} deg from its canonical (max {max_angle_deg}).", + image_key=image_key, + ) diff --git a/src/squidpy/experimental/im/_stain/_white_point.py b/src/squidpy/experimental/im/_stain/_white_point.py new file mode 100644 index 000000000..5fbfd031f --- /dev/null +++ b/src/squidpy/experimental/im/_stain/_white_point.py @@ -0,0 +1,80 @@ +"""White-point (``I_0``) handling for the absorbance methods. + +The decomposition methods measure absorbance against a per-channel white point +``I_0`` (the intensity that counts as fully unstained). The default is a fixed +full-white reference at the image's bit depth (255 for uint8, 65535 for uint16, +1.0 for float), matching HistomicsTK (255/256) and the Macenko literature (240): +``I_0`` must be at least as bright as the slide background, otherwise unstained +pixels get a non-zero absorbance and cannot round-trip back to white. Use +:func:`estimate_white_point` only for a slide with a genuinely non-white +background you want to anchor to. +""" + +from __future__ import annotations + +import numpy as np +import xarray as xr + +from squidpy.experimental.im._stain._conversion import _check_channel_dim, dtype_max +from squidpy.experimental.im._stain._validation import StainFittingError + + +def default_white_point(rgb: xr.DataArray) -> np.ndarray: + """Dtype-aware default white point ``I_0`` (full white) - pure, no validation. + + Returns ``(3,)`` filled with the dtype's full-white value (255 / 65535 / 1.0). + Call :func:`validate_rgb_range` separately to reject mis-typed data. + """ + return np.full(3, dtype_max(rgb.dtype), dtype=np.float64) + + +def validate_rgb_range(rgb: xr.DataArray) -> None: + """Raise if the image's values clearly don't match its dtype's range. + + Guards every absorbance entry point (fit / apply / estimate) against silently + mis-scaling or clipping: 8-bit values in a wider integer container, or 0-255 + values stored as float. The escape is to pass an explicit ``white_point``. + """ + m = dtype_max(rgb.dtype) + data_max = float(np.asarray(rgb.max())) + if np.issubdtype(rgb.dtype, np.integer): + if m >= 256 and data_max <= 255: + raise ValueError( + f"{rgb.dtype} image but the maximum value is {data_max:.0f} (<= 255) - this looks like " + f"8-bit data stored in a {rgb.dtype} container. Convert to uint8, or pass `white_point`." + ) + elif data_max > 1.5: + raise ValueError( + f"float image but the maximum value is {data_max:.1f} (> 1) - this looks like 0-255 data " + "stored as float. Rescale to [0, 1] or store as uint8." + ) + + +def white_point_from_background(rgb: xr.DataArray, background_mask: np.ndarray) -> np.ndarray: + """Per-channel median intensity over background pixels -> ``(3,)`` white point. + + ``background_mask`` is a ``(y, x)`` boolean, ``True`` over non-tissue + (background) pixels. Sampling the *median* of true background (rather than a + whole-image percentile) anchors ``I_0`` to the actual unstained intensity, + matching HistomicsTK's ``background_intensity`` semantics. + + Raises + ------ + StainFittingError + If the mask selects no background pixels, or the median is non-positive + (e.g. a black background). + """ + _check_channel_dim(rgb) + flat = np.asarray(rgb.transpose("c", "y", "x").data, dtype=np.float64) # (3, y, x) + bg_pixels = flat[:, np.asarray(background_mask, dtype=bool)] # (3, N_background) + if bg_pixels.shape[1] == 0: + raise StainFittingError( + "no background pixels to estimate the white point; the tissue mask covers the whole image. " + "Pass an explicit `white_point`." + ) + wp = np.median(bg_pixels, axis=1) + if np.any(wp <= 0): + raise StainFittingError( + "estimated white point is non-positive; the background may be black. Pass an explicit `white_point`." + ) + return wp diff --git a/src/squidpy/experimental/im/_utils.py b/src/squidpy/experimental/im/_utils.py index b58dea1bf..a18770741 100644 --- a/src/squidpy/experimental/im/_utils.py +++ b/src/squidpy/experimental/im/_utils.py @@ -9,9 +9,11 @@ from shapely import box from spatialdata import SpatialData from spatialdata._logging import logger -from spatialdata.models import ShapesModel +from spatialdata.models import Labels2DModel, ShapesModel from spatialdata.transformations import get_transformation, set_transformation +from squidpy._utils import _yx_from_shape + class TileGrid: """Immutable tile grid definition with cached bounds and centroids.""" @@ -252,22 +254,45 @@ def get_mask_materialized(sdata: SpatialData, mask_key: str, scale: str) -> np.n return np.asarray(arr.compute()) +def _choose_label_scale_for_image(label_node: Labels2DModel, target_hw: tuple[int, int]) -> str: + """Pick the label scale closest to the target image height/width.""" + if not hasattr(label_node, "keys"): + return "scale0" # single-scale labels default to their only scale + target_h, target_w = target_hw + best = None + best_diff = float("inf") + for k in label_node.keys(): + y, x = _yx_from_shape(label_node[k].image.shape) + diff = abs(y - target_h) + abs(x - target_w) + if diff == 0: + return k + if diff < best_diff: + best_diff = diff + best = k + return best or "scale0" + + def resolve_tissue_mask( sdata: SpatialData, image_key: str, scale: str, tissue_mask_key: str | None = None, + *, + auto_create: bool = True, ) -> str: """Return the key of a tissue mask in ``sdata.labels``, creating one if needed. If *tissue_mask_key* is given and exists, it is returned as-is. - Otherwise falls back to ``f"{image_key}_tissue"``, running ``detect_tissue`` - to create it when missing. + Otherwise falls back to ``f"{image_key}_tissue"``. When that key is missing, + the behaviour depends on *auto_create*: if ``True`` (default) ``detect_tissue`` + is run to create it; if ``False`` a :class:`KeyError` is raised asking the + caller to run ``detect_tissue`` first. Raises ------ KeyError - If *tissue_mask_key* is given but not found in ``sdata.labels``. + If *tissue_mask_key* is given but not found in ``sdata.labels``, or if + *auto_create* is ``False`` and no tissue mask exists. Exception Any exception raised by ``detect_tissue`` if auto-creation fails. Callers needing graceful fallback should wrap this call in try/except. @@ -279,6 +304,12 @@ def resolve_tissue_mask( mask_key = f"{image_key}_tissue" if mask_key not in sdata.labels: + if not auto_create: + raise KeyError( + f"No tissue mask found in sdata.labels (looked for {mask_key!r}). Run " + f"`squidpy.experimental.im.detect_tissue(sdata, {image_key!r})` first, " + "or pass an explicit `tissue_mask_key`." + ) from squidpy.experimental.im._detect_tissue import detect_tissue detect_tissue(sdata=sdata, image_key=image_key, scale=scale, inplace=True, new_labels_key=mask_key) diff --git a/tests/experimental/test_stain_conversion.py b/tests/experimental/test_stain_conversion.py index bd8f68d63..99733c585 100644 --- a/tests/experimental/test_stain_conversion.py +++ b/tests/experimental/test_stain_conversion.py @@ -103,5 +103,5 @@ def test_missing_channel_dim_raises(self) -> None: def test_wrong_channel_length_raises(self) -> None: arr = xr.DataArray(np.zeros((4, 4, 4)), dims=("y", "x", "c")) - with pytest.raises(ValueError, match="length 3"): + with pytest.raises(ValueError, match="3-channel RGB"): rgb_to_sda(arr, _TEST_WHITE) diff --git a/tests/experimental/test_stain_decompose_public.py b/tests/experimental/test_stain_decompose_public.py new file mode 100644 index 000000000..abe43199a --- /dev/null +++ b/tests/experimental/test_stain_decompose_public.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +import numpy as np +import pytest +import spatialdata as sd +import xarray as xr +from spatialdata.models import Image2DModel, Labels2DModel +from spatialdata.transformations import get_transformation + +import squidpy as sq +from squidpy.experimental.im import ( + StainReference, + decompose_stains, + fit_stain_reference, + normalize_stains, +) +from squidpy.experimental.im._stain._constants import RUIFROK_HE +from squidpy.experimental.im._stain._conversion import sda_to_rgb +from squidpy.experimental.im._stain._validation import ( + StainFittingError, + complement_third_column, + reorder_to_canonical, +) + +_WHITE = np.array([255.0, 255.0, 255.0]) + + +def _synthetic_rgb(seed: int = 0, n_side: int = 48, white: np.ndarray = _WHITE) -> np.ndarray: + w = complement_third_column( + reorder_to_canonical(np.stack([RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]], axis=1)) + ) + rng = np.random.default_rng(seed) + n = n_side * n_side + conc = rng.uniform(0.0, 70.0, size=(n, 2)) + third = n // 3 + conc[:third, 1] = 0.0 + conc[third : 2 * third, 0] = 0.0 + od = (conc @ w[:, :2].T).T.reshape(3, n_side, n_side) + rgb = sda_to_rgb(xr.DataArray(od, dims=("c", "y", "x")), white) + return np.asarray(rgb.data).astype(np.uint8) + + +def _make_sdata(values: np.ndarray, *, with_tissue: bool = True) -> sd.SpatialData: + sdata = sd.SpatialData(images={"img": Image2DModel.parse(values, dims=("c", "y", "x"))}) + if with_tissue: + h, w = values.shape[-2], values.shape[-1] + sdata.labels["img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) + return sdata + + +@pytest.mark.parametrize("method", ["macenko", "vahadane"]) +class TestDecompositionThroughDispatchers: + def test_fit_and_apply_end_to_end(self, method: str) -> None: + sdata = _make_sdata(_synthetic_rgb(seed=1)) + ref = fit_stain_reference(sdata, "img", method=method, white_point=_WHITE) + assert ref.method == method + assert ref.stain_matrix.shape == (3, 3) + assert ref.max_concentrations.shape == (2,) + + out = normalize_stains(sdata, "img", ref, inplace=False) + assert isinstance(out, xr.DataArray) + assert out.sizes["c"] == 3 + + def test_apply_writes_back(self, method: str) -> None: + sdata = _make_sdata(_synthetic_rgb(seed=2)) + ref = fit_stain_reference(sdata, "img", method=method, white_point=_WHITE) + result = normalize_stains(sdata, "img", ref, image_key_added="norm") + assert result is None + assert get_transformation(sdata.images["norm"], get_all=True).keys() == ( + get_transformation(sdata.images["img"], get_all=True).keys() + ) + + +class TestDecomposeStains: + def test_returns_named_concentration_maps(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + conc = decompose_stains(sdata, "img", "macenko", white_point=_WHITE, inplace=False) + assert set(conc) == {"hematoxylin", "eosin", "residual"} + assert all(set(c.dims) == {"y", "x"} for c in conc.values()) # one (y, x) map per stain + assert all(c.dtype == np.float16 for c in conc.values()) # default output_dtype + + def test_drop_residual(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + conc = decompose_stains(sdata, "img", "macenko", white_point=_WHITE, include_residual=False, inplace=False) + assert set(conc) == {"hematoxylin", "eosin"} + + def test_output_dtype_override(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + conc = decompose_stains(sdata, "img", "macenko", white_point=_WHITE, output_dtype=np.float32, inplace=False) + assert all(c.dtype == np.float32 for c in conc.values()) + + def test_inplace_default_writes_derived_keys(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + ref = fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE) + out = decompose_stains(sdata, "img", ref) # inplace=True, prefix defaults to image_key + assert out is None + for stain in ("hematoxylin", "eosin", "residual"): + assert f"img_{stain}" in sdata.images + + def test_with_reference_writes_separate_images(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + ref = fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE) + out = decompose_stains(sdata, "img", ref, image_key_added="conc") + assert out is None + for stain in ("hematoxylin", "eosin", "residual"): + assert f"conc_{stain}" in sdata.images + assert list(sdata.images[f"conc_{stain}"].coords["c"].values) == [stain] + + def test_atomic_write_aborts_on_any_existing_key(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + ref = fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE) + # pre-occupy only the *eosin* target; the whole write must abort, leaving + # no half-written hematoxylin/residual behind. + sdata.images["conc_eosin"] = sdata.images["img"] + with pytest.raises(ValueError, match="would overwrite"): + decompose_stains(sdata, "img", ref, image_key_added="conc") + assert "conc_hematoxylin" not in sdata.images + assert "conc_residual" not in sdata.images + + def test_reinhard_reference_rejected(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + reinhard_ref = fit_stain_reference(sdata, "img", method="reinhard") + with pytest.raises(ValueError, match="macenko/vahadane reference"): + decompose_stains(sdata, "img", reinhard_ref) + + def test_bad_method_rejected(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + with pytest.raises(ValueError, match="method must be"): + decompose_stains(sdata, "img", "reinhard") + + +class TestBackgroundDefault: + def test_fit_defaults_to_white_when_absent(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + ref = fit_stain_reference(sdata, "img", method="macenko") + # default I_0 is a fixed full-white point, not an image-derived estimate + np.testing.assert_array_equal(ref.white_point, [255.0, 255.0, 255.0]) + + def test_explicit_background_is_used(self) -> None: + I0 = np.array([240.0, 245.0, 250.0]) + # build the synthetic image against this white point so the fit is consistent + sdata = _make_sdata(_synthetic_rgb(white=I0)) + ref = fit_stain_reference(sdata, "img", method="vahadane", white_point=I0) + np.testing.assert_array_equal(ref.white_point, I0) + + +class TestUnknownMethod: + def test_fit_unknown_method_raises(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + with pytest.raises(ValueError, match="Unknown method"): + fit_stain_reference(sdata, "img", method="bogus") + + +class TestDefaultMethodAndGate: + def test_default_method_is_macenko(self) -> None: + sdata = _make_sdata(_synthetic_rgb()) + ref = fit_stain_reference(sdata, "img") # no method -> default + assert ref.method == "macenko" + + def test_max_angle_deg_gate_too_strict_raises(self) -> None: + # an impossibly tight tolerance trips the H/E sanity gate + sdata = _make_sdata(_synthetic_rgb()) + with pytest.raises(StainFittingError, match="deviates"): + fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE, max_angle_deg=0.01) + + def test_canonical_reference_passthrough(self) -> None: + # passing the Ruifrok canonical explicitly reproduces the default fit + sdata = _make_sdata(_synthetic_rgb()) + default = fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE) + custom = fit_stain_reference(sdata, "img", method="macenko", white_point=_WHITE, canonical_reference=RUIFROK_HE) + np.testing.assert_allclose(default.stain_matrix, custom.stain_matrix) + + +class TestDecompositionOnHnE: + # Correctness is gated by the synthetic-recovery tests above (per the arc + # decision); these are real-data smoke checks that the pipeline fits a + # valid matrix, applies lazily, and decomposes - fast because the source + # matrix is fit on the coarse level, not the full-resolution image. + # Both methods are exercised: once the fit consumes a real detect_tissue + # mask (dropping the fiducial ring + dim background), Macenko fits this + # low-contrast Visium H&E cleanly and agrees with Vahadane. + @pytest.mark.parametrize("method", ["macenko", "vahadane"]) + def test_fit_apply_decompose_smoke(self, sdata_hne, method: str) -> None: + image_key = next(iter(sdata_hne.images)) + sq.experimental.im.detect_tissue(sdata_hne, image_key) + ref = sq.experimental.im.fit_stain_reference(sdata_hne, image_key, method=method) + assert isinstance(ref, StainReference) + assert ref.stain_matrix.shape == (3, 3) + normalized = sq.experimental.im.normalize_stains(sdata_hne, image_key, ref, inplace=False) + assert normalized.sizes["c"] == 3 + conc = sq.experimental.im.decompose_stains(sdata_hne, image_key, ref, inplace=False) + assert set(conc) == {"hematoxylin", "eosin", "residual"} diff --git a/tests/experimental/test_stain_decomposition.py b/tests/experimental/test_stain_decomposition.py new file mode 100644 index 000000000..04f060b51 --- /dev/null +++ b/tests/experimental/test_stain_decomposition.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import dask.array as da +import numpy as np +import pytest +import xarray as xr + +from squidpy.experimental.im._stain._constants import RUIFROK_HE +from squidpy.experimental.im._stain._conversion import sda_to_rgb +from squidpy.experimental.im._stain._decomposition import ( + MacenkoParams, + VahadaneParams, + _resolve_macenko_params, + _resolve_vahadane_params, + apply_decomposition, + fit_decomposition, +) +from squidpy.experimental.im._stain._validation import ( + StainFittingError, + angle_between_deg, + complement_third_column, + reorder_to_canonical, +) + +_WHITE = np.array([255.0, 255.0, 255.0]) + + +def _canonical(h: np.ndarray, e: np.ndarray) -> np.ndarray: + return complement_third_column(reorder_to_canonical(np.stack([h, e], axis=1))) + + +def _synthetic_he(stain_matrix: np.ndarray, *, n_side: int = 48, seed: int = 0, chunked: bool = False) -> xr.DataArray: + """Build an RGB image from known H/E concentrations and a stain matrix.""" + rng = np.random.default_rng(seed) + n = n_side * n_side + conc = rng.uniform(0.0, 70.0, size=(n, 2)) + # dense pure-H / pure-E populations so the angular extremes are well sampled + # (real H&E slides have many near-pure pixels; a uniform mix under-samples them) + third = n // 3 + conc[:third, 1] = 0.0 # pure-H pixels (one angular extreme) + conc[third : 2 * third, 0] = 0.0 # pure-E pixels (the other extreme) + od = (conc @ stain_matrix[:, :2].T).T.reshape(3, n_side, n_side) + data = da.from_array(od, chunks=(3, 16, 16)) if chunked else od + return sda_to_rgb(xr.DataArray(data, dims=("c", "y", "x")), _WHITE) + + +class TestMacenko: + @pytest.mark.parametrize("chunked", [False, True]) + def test_recovers_planted_matrix(self, chunked: bool) -> None: + truth = _canonical(RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]) + img = _synthetic_he(truth, chunked=chunked) + ref = fit_decomposition(img, "macenko", MacenkoParams(), _WHITE) + assert angle_between_deg(ref.stain_matrix[:, 0], truth[:, 0]) < 12.0 + assert angle_between_deg(ref.stain_matrix[:, 1], truth[:, 1]) < 12.0 + assert ref.max_concentrations.shape == (2,) + assert np.all(ref.max_concentrations > 0) + + +class TestVahadane: + def test_recovers_planted_matrix(self) -> None: + truth = _canonical(RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]) + img = _synthetic_he(truth) + ref = fit_decomposition(img, "vahadane", VahadaneParams(), _WHITE) + assert angle_between_deg(ref.stain_matrix[:, 0], truth[:, 0]) < 20.0 + assert angle_between_deg(ref.stain_matrix[:, 1], truth[:, 1]) < 20.0 + + +class TestApplyDecomposition: + def test_transfer_matches_reference_matrix(self) -> None: + truth_a = _canonical(RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]) + # a slightly rotated source staining + e_shift = RUIFROK_HE["eosin"] + 0.15 * RUIFROK_HE["hematoxylin"] + truth_b = _canonical(RUIFROK_HE["hematoxylin"], e_shift / np.linalg.norm(e_shift)) + + img_a = _synthetic_he(truth_a, seed=1) + img_b = _synthetic_he(truth_b, seed=2) + ref_a = fit_decomposition(img_a, "macenko", MacenkoParams(), _WHITE) + + normalized = apply_decomposition(img_b, ref_a, MacenkoParams()) + refit = fit_decomposition(normalized, "macenko", MacenkoParams(), _WHITE) + assert angle_between_deg(refit.stain_matrix[:, 0], ref_a.stain_matrix[:, 0]) < 12.0 + assert angle_between_deg(refit.stain_matrix[:, 1], ref_a.stain_matrix[:, 1]) < 12.0 + + def test_lazy_in_lazy_out(self) -> None: + truth = _canonical(RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]) + ref = fit_decomposition(_synthetic_he(truth), "macenko", MacenkoParams(), _WHITE) + out = apply_decomposition(_synthetic_he(truth, chunked=True), ref, MacenkoParams()) + assert isinstance(out.data, da.Array) + + def test_missing_max_concentrations_raises(self) -> None: + from squidpy.experimental.im._stain._reference import StainReference + + ref = StainReference( + method="macenko", + stain_matrix=_canonical(RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]), + white_point=_WHITE, + ) + img = _synthetic_he(ref.stain_matrix) + with pytest.raises(ValueError, match="max_concentrations"): + apply_decomposition(img, ref, MacenkoParams()) + + +class TestDegenerate: + def test_empty_tissue_raises(self) -> None: + white = xr.DataArray(np.full((3, 16, 16), 255.0), dims=("c", "y", "x")) + with pytest.raises(StainFittingError, match="mask is empty"): + fit_decomposition(white, "macenko", MacenkoParams(), _WHITE) + + +class TestResolvers: + def test_macenko_mapping_and_unknown(self) -> None: + assert _resolve_macenko_params({"alpha": 2.0}).alpha == 2.0 + with pytest.raises(ValueError, match="Unknown"): + _resolve_macenko_params({"nope": 1}) + + def test_vahadane_instance_and_badtype(self) -> None: + p = VahadaneParams(lambda1=0.2) + assert _resolve_vahadane_params(p) is p + with pytest.raises(TypeError, match="VahadaneParams"): + _resolve_vahadane_params(5) + + @pytest.mark.parametrize("bad", [0.0, 50.0, -1.0]) + def test_macenko_alpha_bounds(self, bad: float) -> None: + with pytest.raises(ValueError, match="alpha"): + MacenkoParams(alpha=bad) diff --git a/tests/experimental/test_stain_mask.py b/tests/experimental/test_stain_mask.py index 3f87116f0..c0e70f23b 100644 --- a/tests/experimental/test_stain_mask.py +++ b/tests/experimental/test_stain_mask.py @@ -5,7 +5,12 @@ import pytest import xarray as xr -from squidpy.experimental.im._stain._mask import luminosity_foreground_mask +from squidpy.experimental.im._stain._mask import ( + absorbance_foreground_mask, + luminosity_foreground_mask, +) + +_WHITE = np.array([255.0, 255.0, 255.0]) def _rgb_dataarray(values: np.ndarray, *, chunked: bool) -> xr.DataArray: @@ -41,5 +46,25 @@ def test_lazy_in_lazy_out(self) -> None: def test_non_three_channel_raises(self) -> None: values = np.zeros((2, 8, 8)) - with pytest.raises(ValueError, match="length 3"): + with pytest.raises(ValueError, match="3-channel RGB"): luminosity_foreground_mask(xr.DataArray(values, dims=("c", "y", "x")), 0.8) + + +class TestAbsorbanceForegroundMask: + def test_white_is_background_dark_is_tissue(self) -> None: + values = np.full((3, 8, 16), 255.0) + values[:, :, 8:] = 30.0 # dark right half = high absorbance = tissue + mask = absorbance_foreground_mask(_rgb_dataarray(values, chunked=False), _WHITE) + assert mask.dims == ("y", "x") + assert not bool(mask.values[:, :8].any()) + assert bool(mask.values[:, 8:].all()) + + def test_lazy_in_lazy_out(self) -> None: + values = np.full((3, 16, 16), 50.0) + mask = absorbance_foreground_mask(_rgb_dataarray(values, chunked=True), _WHITE) + assert isinstance(mask.data, da.Array) + + def test_non_three_channel_raises(self) -> None: + values = np.zeros((2, 8, 8)) + with pytest.raises(ValueError, match="3-channel RGB"): + absorbance_foreground_mask(xr.DataArray(values, dims=("c", "y", "x")), _WHITE) diff --git a/tests/experimental/test_stain_normalize.py b/tests/experimental/test_stain_normalize.py index e8aaa6e55..478459782 100644 --- a/tests/experimental/test_stain_normalize.py +++ b/tests/experimental/test_stain_normalize.py @@ -7,15 +7,15 @@ import spatialdata as sd import spatialdata_plot as sdp import xarray as xr -from spatialdata.models import Image2DModel +from spatialdata.models import Image2DModel, Labels2DModel from spatialdata.transformations import Scale, get_transformation, set_transformation import squidpy as sq from squidpy.experimental.im import ( ReinhardParams, StainReference, - apply_stain_normalization, fit_stain_reference, + normalize_stains, ) from squidpy.experimental.im._utils import get_element_data from tests.conftest import PlotTester, PlotTesterMeta @@ -23,21 +23,27 @@ _ = sdp # registers the `.pl` spatialdata accessor -def _make_sdata(values: np.ndarray, *, scale_factors: list[int] | None = None) -> sd.SpatialData: +def _make_sdata( + values: np.ndarray, *, scale_factors: list[int] | None = None, with_tissue: bool = True +) -> sd.SpatialData: img = Image2DModel.parse(values, dims=("c", "y", "x"), scale_factors=scale_factors) - return sd.SpatialData(images={"img": img}) + sdata = sd.SpatialData(images={"img": img}) + if with_tissue: + h, w = values.shape[-2], values.shape[-1] + sdata.labels["img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) + return sdata @pytest.fixture def rgb_values() -> np.ndarray: rng = np.random.default_rng(3) - return rng.uniform(40.0, 200.0, size=(3, 64, 64)).astype(np.float32) + return rng.uniform(40.0, 200.0, size=(3, 64, 64)).astype(np.uint8) class TestFitStainReference: def test_end_to_end(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") assert isinstance(ref, StainReference) assert ref.method == "reinhard" @@ -46,43 +52,62 @@ def test_missing_image_key_raises(self, rgb_values: np.ndarray) -> None: with pytest.raises(ValueError, match="not found, valid keys"): fit_stain_reference(sdata, "nope") - def test_macenko_not_implemented(self, rgb_values: np.ndarray) -> None: - sdata = _make_sdata(rgb_values) - with pytest.raises(NotImplementedError, match="decomposition is not yet implemented"): - fit_stain_reference(sdata, "img", method="macenko") - def test_unknown_method_raises(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) with pytest.raises(ValueError, match="Unknown method"): fit_stain_reference(sdata, "img", method="bogus") + def test_rgba_image_rejected(self) -> None: + rng = np.random.default_rng(0) + rgba = rng.integers(0, 256, size=(4, 16, 16), dtype=np.uint8) # 4-channel, not RGB + sdata = sd.SpatialData(images={"img": Image2DModel.parse(rgba, dims=("c", "y", "x"))}) + with pytest.raises(ValueError, match="3-channel RGB"): + fit_stain_reference(sdata, "img", method="reinhard") + class TestApplyStainNormalization: def test_returns_lazy_and_leaves_sdata_untouched(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img") - out = apply_stain_normalization(sdata, "img", ref) + ref = fit_stain_reference(sdata, "img", method="reinhard") + out = normalize_stains(sdata, "img", ref, inplace=False) assert isinstance(out, xr.DataArray) assert isinstance(out.data, da.Array) assert list(sdata.images.keys()) == ["img"] + def test_inplace_default_writes_derived_key(self, rgb_values: np.ndarray) -> None: + sdata = _make_sdata(rgb_values) + ref = fit_stain_reference(sdata, "img", method="reinhard") + result = normalize_stains(sdata, "img", ref) # inplace=True, image_key_added defaults to f"{key}_normalized" + assert result is None + assert "img_normalized" in sdata.images + out = sdata.images["img_normalized"] + assert out.dtype == rgb_values.dtype # cast back to the source dtype at the write boundary + assert out.shape == rgb_values.shape + + def test_output_dtype_override(self, rgb_values: np.ndarray) -> None: + sdata = _make_sdata(rgb_values) + ref = fit_stain_reference(sdata, "img", method="reinhard") + out = normalize_stains(sdata, "img", ref, inplace=False, output_dtype=np.uint16) + assert out.dtype == np.uint16 + def test_writes_and_preserves_transform_and_dims(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img") - result = apply_stain_normalization(sdata, "img", ref, image_key_added="norm") + ref = fit_stain_reference(sdata, "img", method="reinhard") + result = normalize_stains(sdata, "img", ref, image_key_added="norm") assert result is None assert "norm" in sdata.images out = sdata.images["norm"] assert out.dims == ("c", "y", "x") assert out.shape == rgb_values.shape + assert out.dtype == rgb_values.dtype assert ( get_transformation(out, get_all=True).keys() == get_transformation(sdata.images["img"], get_all=True).keys() ) def test_multiscale_rebuilds_pyramid(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values, scale_factors=[2]) - ref = fit_stain_reference(sdata, "img") - apply_stain_normalization(sdata, "img", ref, image_key_added="norm") + ref = fit_stain_reference(sdata, "img", method="reinhard") + normalize_stains(sdata, "img", ref, image_key_added="norm") src, out = sdata.images["img"], sdata.images["norm"] assert hasattr(out, "keys") src_shapes = [src[k].image.shape for k in src] @@ -93,41 +118,114 @@ def test_preserves_channel_coords_and_nonidentity_transform(self, rgb_values: np img = Image2DModel.parse(rgb_values, dims=("c", "y", "x"), c_coords=["r", "g", "b"]) set_transformation(img, Scale([2.0, 2.0], axes=("y", "x")), to_coordinate_system="global") sdata = sd.SpatialData(images={"img": img}) - ref = fit_stain_reference(sdata, "img") - apply_stain_normalization(sdata, "img", ref, image_key_added="norm") + h, w = rgb_values.shape[-2], rgb_values.shape[-1] + sdata.labels["img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) + ref = fit_stain_reference(sdata, "img", method="reinhard") + normalize_stains(sdata, "img", ref, image_key_added="norm") out = sdata.images["norm"] assert list(out.coords["c"].values) == ["r", "g", "b"] assert get_transformation(out, get_all=True) == get_transformation(img, get_all=True) def test_existing_key_raises(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img") + ref = fit_stain_reference(sdata, "img", method="reinhard") with pytest.raises(ValueError, match="already exists"): - apply_stain_normalization(sdata, "img", ref, image_key_added="img") + normalize_stains(sdata, "img", ref, image_key_added="img") - def test_decomposition_reference_not_implemented(self, rgb_values: np.ndarray) -> None: + def test_decomposition_reference_without_max_concentrations_raises(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) ref = StainReference( method="macenko", stain_matrix=np.eye(3), - background_intensity=np.array([255.0, 255.0, 255.0]), + white_point=np.array([255.0, 255.0, 255.0]), ) - with pytest.raises(NotImplementedError, match="decomposition is not yet implemented"): - apply_stain_normalization(sdata, "img", ref) + with pytest.raises(ValueError, match="max_concentrations"): + normalize_stains(sdata, "img", ref) def test_method_params_mapping(self, rgb_values: np.ndarray) -> None: sdata = _make_sdata(rgb_values) - ref = fit_stain_reference(sdata, "img", method_params={"mask_background": False}) - out = apply_stain_normalization(sdata, "img", ref, method_params=ReinhardParams(mask_background=False)) + ref = fit_stain_reference(sdata, "img", method="reinhard", method_params={"mask_background": False}) + out = normalize_stains(sdata, "img", ref, method_params=ReinhardParams(mask_background=False), inplace=False) assert isinstance(out, xr.DataArray) +class TestTissueMaskMandate: + def test_fit_requires_tissue_mask(self, rgb_values: np.ndarray) -> None: + sdata = _make_sdata(rgb_values, with_tissue=False) + with pytest.raises(KeyError, match="detect_tissue"): + fit_stain_reference(sdata, "img", method="reinhard") + + def test_apply_requires_tissue_mask(self, rgb_values: np.ndarray) -> None: + sdata = _make_sdata(rgb_values) # has a mask -> fit works + ref = fit_stain_reference(sdata, "img", method="reinhard") + del sdata.labels["img_tissue"] # ... but now the source has none + with pytest.raises(KeyError, match="detect_tissue"): + normalize_stains(sdata, "img", ref) + + def test_explicit_missing_key_raises(self, rgb_values: np.ndarray) -> None: + sdata = _make_sdata(rgb_values) + with pytest.raises(KeyError, match="not found in sdata.labels"): + fit_stain_reference(sdata, "img", tissue_mask_key="nope") + + def test_float_0_255_source_rejected_on_apply(self, rgb_values: np.ndarray) -> None: + # A float image holding 0-255 values would otherwise clip to [0, 1] in the + # reconstruction (dtype_max(float)=1.0); apply must reject it, not silently destroy it. + sdata = _make_sdata(rgb_values) # uint8 + ref = fit_stain_reference(sdata, "img", method="reinhard") + floaty = rgb_values.astype(np.float32) + sdata.images["floaty"] = Image2DModel.parse(floaty, dims=("c", "y", "x")) + sdata.labels["floaty_tissue"] = Labels2DModel.parse( + np.ones(floaty.shape[-2:], dtype=np.uint32), dims=("y", "x") + ) + with pytest.raises(ValueError, match="stored as float"): + normalize_stains(sdata, "floaty", ref) + + def test_mask_is_used_in_the_fit(self, rgb_values: np.ndarray) -> None: + # A different tissue region yields different channel statistics, proving + # the mask actually drives the fit (not silently ignored). + ref_full = fit_stain_reference(_make_sdata(rgb_values), "img", method="reinhard") + + sdata_part = _make_sdata(rgb_values, with_tissue=False) + h, w = rgb_values.shape[-2], rgb_values.shape[-1] + partial = np.zeros((h, w), dtype=np.uint32) + partial[: h // 2] = 1 # only the top half is tissue + sdata_part.labels["img_tissue"] = Labels2DModel.parse(partial, dims=("y", "x")) + ref_part = fit_stain_reference(sdata_part, "img", method="reinhard") + + assert not np.allclose(ref_full.mu, ref_part.mu) + + +class TestPreserveBackground: + def test_background_passthrough_vs_full_frame(self, rgb_values: np.ndarray) -> None: + # tissue = top half only; bottom half is background + h, w = rgb_values.shape[-2], rgb_values.shape[-1] + sdata = _make_sdata(rgb_values, with_tissue=False) + partial = np.zeros((h, w), dtype=np.uint32) + partial[: h // 2] = 1 + sdata.labels["img_tissue"] = Labels2DModel.parse(partial, dims=("y", "x")) + + # a differently-coloured reference so the transform is non-trivial + shifted = np.clip(rgb_values * np.array([1.3, 0.8, 1.1])[:, None, None], 0, 255).astype(np.uint8) + sdata.images["ref_img"] = Image2DModel.parse(shifted, dims=("c", "y", "x")) + sdata.labels["ref_img_tissue"] = Labels2DModel.parse(np.ones((h, w), dtype=np.uint32), dims=("y", "x")) + ref = fit_stain_reference(sdata, "ref_img", method="reinhard") + + original = get_element_data(sdata.images["img"], "auto", "image", "img").values + kept = normalize_stains(sdata, "img", ref, inplace=False).values # preserve_background=True (default) + full = normalize_stains(sdata, "img", ref, preserve_background=False, inplace=False).values + + bg = slice(h // 2, None) + np.testing.assert_allclose(kept[:, bg], original[:, bg]) # background untouched + assert not np.allclose(full[:, bg], original[:, bg]) # full-frame recolours it + + class TestStainNormalizationOnHnE: def test_fit_apply_smoke(self, sdata_hne) -> None: image_key = next(iter(sdata_hne.images)) - ref = sq.experimental.im.fit_stain_reference(sdata_hne, image_key) + sq.experimental.im.detect_tissue(sdata_hne, image_key) + ref = sq.experimental.im.fit_stain_reference(sdata_hne, image_key, method="reinhard") assert ref.method == "reinhard" - out = sq.experimental.im.apply_stain_normalization(sdata_hne, image_key, ref) + out = sq.experimental.im.normalize_stains(sdata_hne, image_key, ref, inplace=False) assert "c" in out.dims assert out.sizes["c"] == 3 @@ -136,16 +234,20 @@ class TestStainNormalizationVisual(PlotTester, metaclass=PlotTesterMeta): def test_plot_reinhard_before_after(self, sdata_hne) -> None: """Visual: a re-stained source (left) normalized back to the H&E reference (right).""" image_key = next(iter(sdata_hne.images)) - reference = fit_stain_reference(sdata_hne, image_key) + sq.experimental.im.detect_tissue(sdata_hne, image_key) + reference = fit_stain_reference(sdata_hne, image_key, method="reinhard") # Deterministically warm/cool the channels to simulate a different # staining batch, so the before/after panels are visibly distinct. da_rgb = get_element_data(sdata_hne.images[image_key], "auto", "image", image_key).astype("float32") weights = xr.DataArray([1.4, 1.0, 0.6], dims="c", coords={"c": da_rgb.coords["c"]}) - shifted = (da_rgb * weights).clip(0, 255) + shifted = (da_rgb * weights).clip(0, 255).astype("uint8") sdata_hne.images["hne_shifted"] = Image2DModel.parse(shifted.data, dims=shifted.dims) - apply_stain_normalization(sdata_hne, "hne_shifted", reference, image_key_added="hne_normalized") + # `hne_shifted` shares geometry with `image_key`; reuse its tissue mask. + normalize_stains( + sdata_hne, "hne_shifted", reference, image_key_added="hne_normalized", tissue_mask_key=f"{image_key}_tissue" + ) _, axes = plt.subplots(1, 2, figsize=(8, 4)) sdata_hne.pl.render_images("hne_shifted").pl.show(ax=axes[0], title="before") diff --git a/tests/experimental/test_stain_reference.py b/tests/experimental/test_stain_reference.py index 5bd5cfd0e..93783c276 100644 --- a/tests/experimental/test_stain_reference.py +++ b/tests/experimental/test_stain_reference.py @@ -21,19 +21,19 @@ def test_macenko_basic() -> None: ref = StainReference( method="macenko", stain_matrix=_ruifrok_matrix(), - background_intensity=_TEST_BACKGROUND, + white_point=_TEST_BACKGROUND, ) assert ref.method == "macenko" assert ref.stain_matrix.shape == (3, 3) assert ref.mu is None and ref.sigma is None - np.testing.assert_array_equal(ref.background_intensity, _TEST_BACKGROUND) + np.testing.assert_array_equal(ref.white_point, _TEST_BACKGROUND) def test_reinhard_basic() -> None: ref = StainReference(method="reinhard", mu=np.array([1.0, 0.5, -0.2]), sigma=np.array([0.1, 0.1, 0.1])) assert ref.method == "reinhard" assert ref.stain_matrix is None - assert ref.background_intensity is None + assert ref.white_point is None def test_unknown_method_raises() -> None: @@ -43,11 +43,11 @@ def test_unknown_method_raises() -> None: def test_decomposition_requires_stain_matrix() -> None: with pytest.raises(ValueError, match="requires stain_matrix"): - StainReference(method="macenko", background_intensity=_TEST_BACKGROUND) + StainReference(method="macenko", white_point=_TEST_BACKGROUND) -def test_decomposition_requires_background_intensity() -> None: - with pytest.raises(ValueError, match="requires background_intensity"): +def test_decomposition_requires_white_point() -> None: + with pytest.raises(ValueError, match="requires white_point"): StainReference(method="macenko", stain_matrix=_ruifrok_matrix()) @@ -56,7 +56,7 @@ def test_decomposition_forbids_mu_sigma() -> None: StainReference( method="macenko", stain_matrix=_ruifrok_matrix(), - background_intensity=_TEST_BACKGROUND, + white_point=_TEST_BACKGROUND, mu=np.zeros(3), sigma=np.ones(3), ) @@ -82,22 +82,22 @@ def test_reinhard_forbids_stain_matrix() -> None: ) -def test_reinhard_forbids_background_intensity() -> None: - with pytest.raises(ValueError, match="forbids background_intensity"): +def test_reinhard_forbids_white_point() -> None: + with pytest.raises(ValueError, match="forbids white_point"): StainReference( method="reinhard", mu=np.zeros(3), sigma=np.ones(3), - background_intensity=_TEST_BACKGROUND, + white_point=_TEST_BACKGROUND, ) -def test_bad_background_intensity() -> None: - with pytest.raises(ValueError, match="background_intensity"): +def test_bad_white_point() -> None: + with pytest.raises(ValueError, match="white_point"): StainReference( method="macenko", stain_matrix=_ruifrok_matrix(), - background_intensity=np.array([255.0, -1.0, 255.0]), + white_point=np.array([255.0, -1.0, 255.0]), ) @@ -106,7 +106,7 @@ def test_rejects_bad_shape() -> None: StainReference( method="macenko", stain_matrix=np.zeros((2, 3)), - background_intensity=_TEST_BACKGROUND, + white_point=_TEST_BACKGROUND, ) @@ -117,3 +117,15 @@ def test_rejects_non_finite() -> None: mu=np.array([np.nan, 0.0, 0.0]), sigma=np.ones(3), ) + + +def test_equality_is_array_aware_and_hashable() -> None: + # distinct-but-equal references compare equal (array-aware __eq__), and + # references remain hashable (identity) despite the numpy-array fields. + a = StainReference(method="reinhard", mu=np.array([1.0, 2.0, 3.0]), sigma=np.ones(3)) + b = StainReference(method="reinhard", mu=np.array([1.0, 2.0, 3.0]), sigma=np.ones(3)) + c = StainReference(method="reinhard", mu=np.array([9.0, 2.0, 3.0]), sigma=np.ones(3)) + assert a == b + assert a != c + assert len({a, b, c}) == 3 # identity-hashed, no TypeError + assert a != "not a reference" diff --git a/tests/experimental/test_stain_validation.py b/tests/experimental/test_stain_validation.py new file mode 100644 index 000000000..cb548ab54 --- /dev/null +++ b/tests/experimental/test_stain_validation.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import numpy as np +import pytest + +from squidpy.experimental.im._stain._constants import RUIFROK_HE +from squidpy.experimental.im._stain._validation import ( + StainFittingError, + complement_third_column, + reorder_to_canonical, + validate_stain_matrix, +) + + +def _he_matrix() -> np.ndarray: + return np.stack([RUIFROK_HE["hematoxylin"], RUIFROK_HE["eosin"]], axis=1) + + +class TestReorderToCanonical: + def test_swapped_columns_restored(self) -> None: + he = _he_matrix() + swapped = he[:, ::-1] + out = reorder_to_canonical(swapped) + np.testing.assert_allclose(out, he, atol=1e-8) + + def test_sign_flips_corrected(self) -> None: + he = _he_matrix() + flipped = he * np.array([-1.0, 1.0]) + out = reorder_to_canonical(flipped) + np.testing.assert_allclose(out, he, atol=1e-8) + + def test_bad_shape_raises(self) -> None: + with pytest.raises(ValueError, match="shape"): + reorder_to_canonical(np.eye(3)) + + +class TestComplementThirdColumn: + def test_unit_orthogonal(self) -> None: + w = complement_third_column(_he_matrix()) + assert w.shape == (3, 3) + np.testing.assert_allclose(np.linalg.norm(w[:, 2]), 1.0, atol=1e-8) + assert abs(w[:, 2] @ w[:, 0]) < 1e-8 + assert abs(w[:, 2] @ w[:, 1]) < 1e-8 + + def test_colinear_raises(self) -> None: + v = RUIFROK_HE["hematoxylin"] + with pytest.raises(StainFittingError, match="colinear"): + complement_third_column(np.stack([v, v], axis=1)) + + +class TestValidateStainMatrix: + def test_canonical_passes(self) -> None: + validate_stain_matrix(complement_third_column(_he_matrix())) + + def test_rank_deficient_raises_with_image_key(self) -> None: + w = complement_third_column(_he_matrix()) + w[:, 1] = w[:, 0] # collapse E onto H + with pytest.raises(StainFittingError) as exc: + validate_stain_matrix(w, image_key="slideA") + assert exc.value.image_key == "slideA" + + def test_rotated_he_raises(self) -> None: + # rotate H far toward an unrelated direction + w = complement_third_column(_he_matrix()) + w[:, 0] = np.array([1.0, 0.0, 0.0]) + with pytest.raises(StainFittingError, match="deviates"): + validate_stain_matrix(w) + + def test_zero_column_raises(self) -> None: + w = complement_third_column(_he_matrix()) + w[:, 0] = 0.0 + with pytest.raises(StainFittingError, match="zero-norm"): + validate_stain_matrix(w) + + def test_bad_shape_raises(self) -> None: + with pytest.raises(StainFittingError, match="shape"): + validate_stain_matrix(np.eye(2)) diff --git a/tests/experimental/test_stain_white_point.py b/tests/experimental/test_stain_white_point.py new file mode 100644 index 000000000..da5fb3595 --- /dev/null +++ b/tests/experimental/test_stain_white_point.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import numpy as np +import pytest +import spatialdata as sd +import xarray as xr +from spatialdata.models import Image2DModel, Labels2DModel + +from squidpy.experimental.im import estimate_white_point +from squidpy.experimental.im._stain._conversion import dtype_max +from squidpy.experimental.im._stain._validation import StainFittingError +from squidpy.experimental.im._stain._white_point import default_white_point, validate_rgb_range + + +def _rgb(values: np.ndarray) -> xr.DataArray: + return xr.DataArray(values, dims=("c", "y", "x")) + + +class TestDtypeMax: + def test_known_dtypes(self) -> None: + assert dtype_max(np.uint8) == 255.0 + assert dtype_max(np.uint16) == 65535.0 + assert dtype_max(np.float32) == 1.0 + + +class TestDefaultWhitePoint: + def test_uint8(self) -> None: + rgb = _rgb(np.full((3, 8, 8), 200, dtype=np.uint8)) + np.testing.assert_array_equal(default_white_point(rgb), [255.0, 255.0, 255.0]) + + def test_uint16(self) -> None: + rgb = _rgb(np.full((3, 8, 8), 5000, dtype=np.uint16)) + np.testing.assert_array_equal(default_white_point(rgb), [65535.0] * 3) + + def test_float_unit_range(self) -> None: + rgb = _rgb(np.full((3, 8, 8), 0.8, dtype=np.float32)) + np.testing.assert_array_equal(default_white_point(rgb), [1.0, 1.0, 1.0]) + + +class TestValidateRgbRange: + def test_passes_on_uint8(self) -> None: + validate_rgb_range(_rgb(np.full((3, 8, 8), 200, dtype=np.uint8))) # no raise + + def test_passes_on_float_unit_range(self) -> None: + validate_rgb_range(_rgb(np.full((3, 8, 8), 0.8, dtype=np.float32))) # no raise + + def test_raises_on_8bit_in_uint16(self) -> None: + with pytest.raises(ValueError, match="8-bit data stored in"): + validate_rgb_range(_rgb(np.full((3, 8, 8), 200, dtype=np.uint16))) + + def test_raises_on_0_255_float(self) -> None: + with pytest.raises(ValueError, match="stored as float"): + validate_rgb_range(_rgb(np.full((3, 8, 8), 200.0, dtype=np.float32))) + + +class TestEstimateWhitePoint: + def _sdata(self, *, all_tissue: bool = False) -> sd.SpatialData: + rng = np.random.default_rng(0) + values = np.empty((3, 32, 32), dtype=np.uint8) + values[0], values[1], values[2] = 240, 245, 250 # background + values[:, :8, :8] = rng.integers(20, 60, size=(3, 8, 8)) # a darker tissue blob + mask = np.zeros((32, 32), dtype=np.uint32) + mask[:8, :8] = 1 # tissue = the blob + if all_tissue: + mask[:] = 1 + sdata = sd.SpatialData(images={"img": Image2DModel.parse(values, dims=("c", "y", "x"))}) + sdata.labels["img_tissue"] = Labels2DModel.parse(mask, dims=("y", "x")) + return sdata + + def test_recovers_background_median(self) -> None: + wp = estimate_white_point(self._sdata(), "img") + assert wp.shape == (3,) + np.testing.assert_allclose(wp, [240.0, 245.0, 250.0], atol=1.0) + + def test_raises_when_tissue_covers_all(self) -> None: + with pytest.raises(StainFittingError, match="covers the whole image"): + estimate_white_point(self._sdata(all_tissue=True), "img") + + def test_requires_a_tissue_mask(self) -> None: + values = np.full((3, 16, 16), 240, dtype=np.uint8) + sdata = sd.SpatialData(images={"img": Image2DModel.parse(values, dims=("c", "y", "x"))}) + with pytest.raises(KeyError, match="detect_tissue"): + estimate_white_point(sdata, "img")