diff --git a/docs/api.md b/docs/api.md index 1bf4fdf92..c701c4058 100644 --- a/docs/api.md +++ b/docs/api.md @@ -149,6 +149,8 @@ See the {doc}`extensibility guide ` for how to implement a custo experimental.tl.calculate_tiling_qc experimental.tl.TilingQCParams + experimental.tl.assign_stitch_groups + experimental.tl.StitchParams experimental.pl.tiling_qc experimental.im.fit_stain_reference experimental.im.apply_stain_normalization diff --git a/src/squidpy/experimental/im/_tiling.py b/src/squidpy/experimental/im/_tiling.py index 4f6424ee3..5969f083e 100644 --- a/src/squidpy/experimental/im/_tiling.py +++ b/src/squidpy/experimental/im/_tiling.py @@ -53,9 +53,7 @@ class TileSpec: owned_ids: frozenset[int] -# --------------------------------------------------------------------------- # Centroid computation -# --------------------------------------------------------------------------- def compute_cell_info(labels: np.ndarray) -> dict[int, CellInfo]: @@ -194,9 +192,7 @@ def compute_cell_info_tiled( return result -# --------------------------------------------------------------------------- # Tile spec building -# --------------------------------------------------------------------------- def _auto_margin(cell_info: dict[int, CellInfo]) -> int: @@ -281,9 +277,7 @@ def build_tile_specs( return specs -# --------------------------------------------------------------------------- # Tile extraction -# --------------------------------------------------------------------------- def extract_tile( @@ -405,9 +399,7 @@ def _zero_non_owned(tile_labels: np.ndarray, owned_ids: frozenset[int]) -> None: tile_labels[~np.isin(tile_labels, owned_arr)] = 0 -# --------------------------------------------------------------------------- # Coverage verification -# --------------------------------------------------------------------------- def verify_coverage( diff --git a/src/squidpy/experimental/tl/__init__.py b/src/squidpy/experimental/tl/__init__.py index 1c2f97ece..7122bd3cd 100644 --- a/src/squidpy/experimental/tl/__init__.py +++ b/src/squidpy/experimental/tl/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations from ._tiling_qc import TilingQCParams, calculate_tiling_qc +from ._tiling_stitch import StitchParams, assign_stitch_groups -__all__ = ["TilingQCParams", "calculate_tiling_qc"] +__all__ = ["StitchParams", "TilingQCParams", "assign_stitch_groups", "calculate_tiling_qc"] diff --git a/src/squidpy/experimental/tl/_tiling_qc.py b/src/squidpy/experimental/tl/_tiling_qc.py index db757fd0e..287610a45 100644 --- a/src/squidpy/experimental/tl/_tiling_qc.py +++ b/src/squidpy/experimental/tl/_tiling_qc.py @@ -52,6 +52,7 @@ compute_cell_info_tiled, extract_labels_tile_lazy, ) +from squidpy.experimental.tl._tiling_stitch import _STITCH_COLUMNS, _STITCH_PARAM_KEYS, StitchParams from squidpy.experimental.utils._labels import resolve_labels_array __all__ = ["TilingQCParams", "calculate_tiling_qc"] @@ -137,9 +138,7 @@ def _has_distributed_client() -> bool: return True -# --------------------------------------------------------------------------- # Core geometry -# --------------------------------------------------------------------------- @njit(cache=True, nogil=True) @@ -355,9 +354,7 @@ def _straight_edge_metrics( return float(straight_ratio), float(cardinal), float(cut_score) -# --------------------------------------------------------------------------- # Per-tile scoring -# --------------------------------------------------------------------------- def _score_tile( @@ -430,9 +427,7 @@ def _score_tile( return pd.DataFrame.from_dict(rows, orient="index") -# --------------------------------------------------------------------------- # Centroid computation (shared logic with _feature.py) -# --------------------------------------------------------------------------- def _compute_centroids_for_labels( @@ -457,9 +452,7 @@ def _compute_centroids_for_labels( return compute_cell_info_tiled(labels_da) -# --------------------------------------------------------------------------- # Public API -# --------------------------------------------------------------------------- _METHOD_KEY = "tiling_qc" @@ -733,6 +726,40 @@ def _process_one(spec): if inplace: table_key = table_key_added if table_key_added is not None else f"{labels_key}_qc" + _warn_if_dropping_stitch_columns(sdata, table_key, labels_key) sdata.tables[table_key] = TableModel.parse(adata) return None return adata + + +def _warn_if_dropping_stitch_columns(sdata: sd.SpatialData, table_key: str, labels_key: str) -> None: + """Warn if re-running QC would drop downstream stitch results. + + ``calculate_tiling_qc`` replaces the QC table wholesale, so any columns + added by :func:`~squidpy.experimental.tl.assign_stitch_groups` to a previous + version of this table are about to disappear. We emit an actionable warning + listing the previous stitch parameters (from ``.uns["tiling_stitch"]``) and a + copy-pasteable invocation to restore them. + """ + if table_key not in sdata.tables: + return + existing = sdata.tables[table_key] + present = [c for c in _STITCH_COLUMNS if c in existing.obs.columns] + if not present: + return + + prev_params = existing.uns.get("tiling_stitch", {}) if hasattr(existing, "uns") else {} + parts = [f"labels_key={labels_key!r}"] + parts.extend(f"{k}={v!r}" for k, v in prev_params.items() if k in _STITCH_PARAM_KEYS) + nested = prev_params.get("stitch_params") + if isinstance(nested, dict) and nested: + defaults = asdict(StitchParams()) + diff = {k: v for k, v in nested.items() if k in defaults and defaults[k] != v} + if diff: + parts.append(f"stitch_params={diff!r}") + rerun = f"sq.experimental.tl.assign_stitch_groups(sdata, {', '.join(parts)})" + logg.warning( + f"Re-running calculate_tiling_qc dropped previous stitch columns " + f"({', '.join(present)}) from sdata.tables[{table_key!r}]. " + f"To restore them, run: {rerun}" + ) diff --git a/src/squidpy/experimental/tl/_tiling_stitch.py b/src/squidpy/experimental/tl/_tiling_stitch.py new file mode 100644 index 000000000..0c8b69cb1 --- /dev/null +++ b/src/squidpy/experimental/tl/_tiling_stitch.py @@ -0,0 +1,918 @@ +"""Stitching of tile-cut cells flagged by :func:`~squidpy.experimental.tl.calculate_tiling_qc`. + +When segmentation is run tile-by-tile (Cellpose, Stardist, Mesmer, ...) cells +that straddle tile boundaries get cut into 2-4 pieces with characteristic +straight, axis-aligned cut edges. :func:`~squidpy.experimental.tl.calculate_tiling_qc` flags these +as ``is_outlier=True``. This module pairs facing cut edges across boundaries +and assigns each candidate pair a heuristic geometric score in [0, 1]. + +The score is the flat (unweighted) mean of five dataset-independent geometric +features -- ``iou``, ``endpoint_match``, ``merge_compactness``, +``merge_solidity`` and ``gap_proximity`` -- computed from the cut-edge geometry +and the union mask after closing the seam gap. No model is fitted or shipped; +the features are recorded in ``.uns["tiling_stitch"]``. Users should tune +``min_confidence`` for their data; ``0.7`` is a reasonable starting point, not +a calibrated probability. + +The labels element is **never** modified here -- only ``.obs`` columns are +written. Materialising a stitched labels element is opt-in via +:func:`!make_stitched_labels`. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any + +import numpy as np +import spatialdata as sd +import xarray as xr +from scipy.ndimage import binary_closing +from scipy.sparse import csr_matrix +from scipy.sparse.csgraph import connected_components +from skimage.measure import label as cc_label +from skimage.measure import regionprops +from skimage.morphology import disk as morph_disk +from spatialdata._logging import logger as logg + +from squidpy.experimental.utils._geometry import equivalent_diameter, largest_contour +from squidpy.experimental.utils._labels import iter_chunked_regionprops, resolve_labels_array +from squidpy.experimental.utils._params import resolve_params + +if TYPE_CHECKING: + from collections.abc import Iterable + + import anndata as ad + +__all__ = ["StitchParams", "assign_stitch_groups"] + +# The geometric features whose flat mean is the stitch score. +_SCORE_FEATURES: tuple[str, ...] = ("iou", "endpoint_match", "merge_compactness", "merge_solidity", "gap_proximity") +# The subset computed by the expensive merge-union step; the rest are cheap +# geometry features known before it, which drives the scoring early-prune. +_SHAPE_FEATURES: tuple[str, ...] = ("merge_compactness", "merge_solidity") + + +@dataclass(slots=True) +class StitchParams: + """Advanced tuning knobs for :func:`~squidpy.experimental.tl.assign_stitch_groups`. + + Defaults work for typical 2D segmentation tiles produced by + cellpose-like pipelines. Pass an instance (or a ``Mapping`` of + field names to values) as ``stitch_params`` to override. These are + advanced knobs -- the defaults rarely need changing. + """ + + distance_tol: float = 0.75 + """Sub-pixel tolerance for "lies on a bbox edge".""" + + min_edge_length: float = 5.0 + """Absolute floor on cut-edge length (pixels).""" + + min_edge_length_ratio: float = 0.4 + """Minimum cut-edge length relative to the cell's equivalent diameter.""" + + min_edge_coverage: float = 0.5 + """Minimum fraction of parallel-axis positions covered by near-edge contour points.""" + + candidate_min_iou: float = 0.2 + """Loose 1-D IoU floor at candidate enumeration.""" + + close_radius: int = 3 + """Morphological closing disk radius for the union mask. Also the + length scale for ``gap_proximity`` (normalised by ``2 * close_radius``).""" + + def __post_init__(self) -> None: + # Coerce numeric types (accept numpy scalars cleanly) and bounds-check. + self.distance_tol = float(self.distance_tol) + self.min_edge_length = float(self.min_edge_length) + self.min_edge_length_ratio = float(self.min_edge_length_ratio) + self.min_edge_coverage = float(self.min_edge_coverage) + self.candidate_min_iou = float(self.candidate_min_iou) + self.close_radius = int(self.close_radius) + if self.distance_tol < 0: + raise ValueError(f"distance_tol must be >= 0, got {self.distance_tol}.") + if self.min_edge_length < 0: + raise ValueError(f"min_edge_length must be >= 0, got {self.min_edge_length}.") + if not 0.0 <= self.min_edge_length_ratio <= 1.0: + raise ValueError(f"min_edge_length_ratio must be in [0, 1], got {self.min_edge_length_ratio}.") + if not 0.0 <= self.min_edge_coverage <= 1.0: + raise ValueError(f"min_edge_coverage must be in [0, 1], got {self.min_edge_coverage}.") + if not 0.0 <= self.candidate_min_iou <= 1.0: + raise ValueError(f"candidate_min_iou must be in [0, 1], got {self.candidate_min_iou}.") + if self.close_radius < 0: + raise ValueError(f"close_radius must be >= 0, got {self.close_radius}.") + + +def _resolve_stitch_params(stitch_params: StitchParams | Mapping[str, Any] | None) -> StitchParams: + """Normalise the ``stitch_params`` argument to a :class:`StitchParams` instance.""" + return resolve_params(stitch_params, StitchParams, label="stitch_params") + + +_METHOD_KEY = "tiling_stitch" +_STITCH_DEFAULTS = StitchParams() + +# Contract between calculate_tiling_qc and assign_stitch_groups. _STITCH_COLUMNS +# is the obs columns stitch writes back into the QC table; _STITCH_PARAM_KEYS +# is the subset of top-level kwargs valid for re-running assign_stitch_groups +# (the advanced tuning lives in a nested ``stitch_params`` dict). +_STITCH_COLUMNS = ("stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence") +_STITCH_PARAM_KEYS = frozenset({"min_confidence", "max_gap", "max_group_size"}) + + +# Dataclasses + + +@dataclass(frozen=True) +class _CutEdge: + """A candidate cut edge on a single cell's bbox. + + Attributes + ---------- + cell_id + Label ID of the piece carrying this edge. + axis + ``"h"`` (horizontal cut: edge is a horizontal line, cell sits above + or below it) or ``"v"`` (vertical cut). + coord + Position of the cut line: y-coord for ``"h"``, x-coord for ``"v"``. + extent + ``(min, max)`` along the parallel axis -- the chord at the cut line. + normal_dir + ``+1`` if the cell's centroid sits at greater coord than the cut + line, ``-1`` otherwise. Used to enforce facing pairs. + length + Euclidean length of the run (``extent[1] - extent[0]``). + """ + + cell_id: int + axis: str + coord: float + extent: tuple[float, float] + normal_dir: int + length: float + + +@dataclass(frozen=True) +class _StitchPair: + """A scored candidate pairing of two cut edges across a tile boundary. + + ``confidence`` is the flat mean of the geometric features (see + :data:`_SCORE_FEATURES`); the individual feature components are kept for + diagnostics and for the ``min``-based group-confidence aggregation. + """ + + cell_a: int + cell_b: int + axis: str + confidence: float + iou: float + endpoint_match: float + gap_proximity: float + merge_solidity: float + merge_compactness: float + edge_a: _CutEdge | None = field(default=None, repr=False) + edge_b: _CutEdge | None = field(default=None, repr=False) + + +# Cut-edge extraction + + +def _read_bbox_slice(labels_da: xr.DataArray | np.ndarray, y0: int, y1: int, x0: int, x1: int) -> np.ndarray: + """Read a 2-D bbox slice from numpy or xarray, squeezing singleton dims.""" + if isinstance(labels_da, np.ndarray): + return labels_da[y0:y1, x0:x1] + arr = labels_da.isel(y=slice(y0, y1), x=slice(x0, x1)).values + while arr.ndim > 2: + arr = arr.squeeze(0) + return arr + + +def _compute_outlier_bboxes( + labels_da: xr.DataArray | np.ndarray, + outlier_ids: Iterable[int], + chunk_size: int = 4096, +) -> dict[int, tuple[int, int, int, int]]: + """Compute global bboxes for the outlier subset in a single chunked pass. + + Returns mapping ``label_id -> (min_row, min_col, max_row, max_col)``. + Works on numpy or dask-backed xarray; for xarray the array is read in + ``chunk_size`` x ``chunk_size`` tiles so memory is bounded. + """ + outlier_set = {int(x) for x in outlier_ids} + bboxes: dict[int, tuple[int, int, int, int]] = {} + # Single chunked pass (shared with the QC reader); only outlier labels are + # accumulated, merging bboxes across chunk boundaries for cells that span them. + # TODO: faster path -- pre-mask each chunk with np.where(np.isin(chunk, + # outlier_set), chunk, 0) before regionprops, so non-outlier cells are + # skipped instead of scanned. Worth doing if outlier fraction is < ~5%. + for lid, region, y0, x0 in iter_chunked_regionprops(labels_da, chunk_size=chunk_size, label_subset=outlier_set): + r0, c0, r1, c1 = region.bbox + r0 += y0 + c0 += x0 + r1 += y0 + c1 += x0 + prev = bboxes.get(lid) + if prev is None: + bboxes[lid] = (r0, c0, r1, c1) + else: + bboxes[lid] = (min(prev[0], r0), min(prev[1], c0), max(prev[2], r1), max(prev[3], c1)) + return bboxes + + +def _bbox_edge_run( + contour: np.ndarray, + perp_axis: int, + target: float, + distance_tol: float = _STITCH_DEFAULTS.distance_tol, + min_coverage: float = _STITCH_DEFAULTS.min_edge_coverage, +) -> tuple[float, float, float] | None: + """Find the extent of contour points lying near a single bbox edge. + + A genuine cut edge has many contour points clustered at the bbox boundary, + spanning a long parallel-axis range with high integer-position coverage. + A naturally curved cell only touches its bbox at a single point, which + fails either the count, length, or coverage check. + + Returns ``(ext_lo, ext_hi, length)`` if a substantial run is found. + """ + parallel_axis = 1 - perp_axis + near = np.abs(contour[:, perp_axis] - target) <= distance_tol + if near.sum() < 3: + return None + parallel_vals = contour[near, parallel_axis] + ext_lo = float(parallel_vals.min()) + ext_hi = float(parallel_vals.max()) + length = ext_hi - ext_lo + if length <= 0: + return None + width = max(int(np.ceil(length)), 1) + bins = np.zeros(width + 1, dtype=bool) + bins[np.clip((parallel_vals - ext_lo).astype(int), 0, width)] = True + coverage = float(bins.sum()) / (width + 1) + if coverage < min_coverage: + return None + return ext_lo, ext_hi, length + + +def _extract_cut_edges( + labels_da: xr.DataArray | np.ndarray, + outlier_ids: Iterable[int], + bboxes: dict[int, tuple[int, int, int, int]] | None = None, + distance_tol: float = _STITCH_DEFAULTS.distance_tol, + min_edge_length: float = _STITCH_DEFAULTS.min_edge_length, + min_edge_length_ratio: float = _STITCH_DEFAULTS.min_edge_length_ratio, + min_edge_coverage: float = _STITCH_DEFAULTS.min_edge_coverage, +) -> tuple[list[_CutEdge], dict[int, np.ndarray]]: + """Extract cardinal-aligned bbox-edge runs (cut-edge candidates) per outlier. + + For each outlier cell: + 1. Crop labels to its bbox + 1 px pad, build a binary mask. + 2. Trace its contour with :func:`skimage.measure.find_contours`. + 3. Check each of the 4 bbox-edge lines for a substantial straight run. + + A piece cut at a tile boundary always has its cut on a bbox edge -- the + piece terminates exactly at the cut. Curved cells only touch the bbox + at a single contour point, which the density check rejects. + + Cells at a 4-tile corner produce 2 perpendicular edges; mid-stripe pieces + can produce 2 parallel edges. + + Returns + ------- + The list of cut edges and, as a by-product of the per-cell crop already + read here, a ``{label_id -> boolean bbox mask}`` dict that lets the scoring + pass reconstruct merge unions in memory without re-reading the labels array. + """ + outlier_list = [int(x) for x in outlier_ids] + if bboxes is None: + bboxes = _compute_outlier_bboxes(labels_da, outlier_list) + + edges: list[_CutEdge] = [] + outlier_crops: dict[int, np.ndarray] = {} + for lid in outlier_list: + bbox = bboxes.get(lid) + if bbox is None: + continue + min_r, min_c, max_r, max_c = bbox + + crop_arr = _read_bbox_slice(labels_da, min_r, max_r, min_c, max_c) + cell_mask = crop_arr == lid # boolean bbox mask; reused by the scoring pass + if not cell_mask.any(): + continue + outlier_crops[lid] = cell_mask + mask = np.pad(cell_mask.astype(np.float32), 1, mode="constant", constant_values=0) + contour = largest_contour(mask) + if contour is None: + continue + contour_global = contour.copy() + contour_global[:, 0] += min_r - 1 + contour_global[:, 1] += min_c - 1 + + # Local centroid from the mask (avoids a second regionprops call). + ys, xs = np.where(mask) + cy = float(ys.mean()) + min_r - 1 + cx = float(xs.mean()) + min_c - 1 + area = float(mask.sum()) + eq_diameter = equivalent_diameter(area) + min_len = max(min_edge_length, min_edge_length_ratio * eq_diameter) + + # find_contours places level set 0.5 outside the integer pixel boundary. + bbox_targets = [ + ("h", float(min_r) - 0.5), + ("h", float(max_r) - 0.5), + ("v", float(min_c) - 0.5), + ("v", float(max_c) - 0.5), + ] + for axis, target in bbox_targets: + perp_axis = 0 if axis == "h" else 1 + run = _bbox_edge_run(contour_global, perp_axis, target, distance_tol, min_edge_coverage) + if run is None: + continue + ext_lo, ext_hi, length = run + if length < min_len: + continue + cell_coord = cy if axis == "h" else cx + normal = 1 if cell_coord > target else -1 + edges.append( + _CutEdge( + cell_id=lid, + axis=axis, + coord=target, + extent=(ext_lo, ext_hi), + normal_dir=normal, + length=float(length), + ) + ) + + return edges, outlier_crops + + +# Pair candidate enumeration + features + + +def _extent_overlap(a: tuple[float, float], b: tuple[float, float]) -> float: + return max(0.0, min(a[1], b[1]) - max(a[0], b[0])) + + +def _merge_shape_features( + cell_a: int, + cell_b: int, + bboxes: dict[int, tuple[int, int, int, int]], + outlier_crops: dict[int, np.ndarray], + close_radius: int = _STITCH_DEFAULTS.close_radius, + *, + H: int, + W: int, +) -> dict[str, float]: + """Reconstruct the union of two pieces, close the gap, and return shape stats. + + Solidity (area / convex_hull_area) and compactness (4*pi*A / P^2) drop + sharply when two unrelated cells are joined -- the union is concave at the + join. ``merge_compactness`` is typically the strongest single + discriminator between true cuts and false merges. + + The union mask is assembled in memory from the per-cell boolean crops + already collected by :func:`_extract_cut_edges`, so this never re-reads the + (possibly dask-backed) labels array -- which was the hot-loop cost, as the + old version fetched a crop once per candidate pair. + """ + zero = {"merge_solidity": 0.0, "merge_compactness": 0.0} + if cell_a not in bboxes or cell_b not in bboxes: + return zero + if cell_a not in outlier_crops or cell_b not in outlier_crops: + return zero + + r0a, c0a, r1a, c1a = bboxes[cell_a] + r0b, c0b, r1b, c1b = bboxes[cell_b] + # Padded + border-clamped union bbox. Identical bounds to the old single + # `np.isin` crop, so the reconstructed mask matches it pixel-for-pixel. + pad = close_radius + 2 + r0 = max(min(r0a, r0b) - pad, 0) + c0 = max(min(c0a, c0b) - pad, 0) + r1 = min(max(r1a, r1b) + pad, H) + c1 = min(max(c1a, c1b) + pad, W) + + mask = np.zeros((r1 - r0, c1 - c0), dtype=bool) + # Place each cell's pre-fetched bbox mask at its offset within the union. + mask[r0a - r0 : r1a - r0, c0a - c0 : c1a - c0] |= outlier_crops[cell_a] + mask[r0b - r0 : r1b - r0, c0b - c0 : c1b - c0] |= outlier_crops[cell_b] + if not mask.any(): + return zero + + closed = binary_closing(mask, structure=morph_disk(close_radius)) + cc = cc_label(closed, connectivity=2) + if cc.max() == 0: + return zero + sizes = np.bincount(cc.ravel()) + sizes[0] = 0 + biggest = int(sizes.argmax()) + region = regionprops((cc == biggest).astype(np.uint8))[0] + perimeter = max(region.perimeter, 1.0) + compactness = float(min(4 * np.pi * region.area / (perimeter * perimeter), 1.0)) + # Clamp solidity to 1.0: skimage can return area/convex_area slightly >1 for + # thin/degenerate rasterised regions, which would push the score out of [0, 1]. + solidity = float(min(region.solidity, 1.0)) + return {"merge_solidity": solidity, "merge_compactness": compactness} + + +def _pair_geometry_features( + e: _CutEdge, + c: _CutEdge, + max_gap: float, + candidate_min_iou: float = _STITCH_DEFAULTS.candidate_min_iou, +) -> dict[str, float] | None: + """Compute geometry-only features for a candidate pair, returning ``None`` + if the pair fails the basic facing/overlap/IoU filters. + """ + if c.normal_dir == e.normal_dir: + return None + # Facing: cell with +1 normal must sit at greater coord than cell with -1. + if (e.coord - c.coord) * e.normal_dir < -1e-6: + return None + overlap = _extent_overlap(e.extent, c.extent) + if overlap <= 0: + return None + union = e.length + c.length - overlap + iou = overlap / union if union > 0 else 0.0 + if iou < candidate_min_iou: + return None + gap = abs(e.coord - c.coord) + if gap > max_gap: + return None + endpoint_dist = abs(e.extent[0] - c.extent[0]) + abs(e.extent[1] - c.extent[1]) + max_len = max(e.length, c.length) + endpoint_match = max(0.0, 1.0 - endpoint_dist / max_len) if max_len > 0 else 0.0 + # Return the raw perpendicular gap; gap_proximity is derived later against + # the closing reach (2*close_radius), NOT against max_gap (a search radius). + return { + "iou": float(iou), + "endpoint_match": float(endpoint_match), + "gap": float(gap), + } + + +def _enumerate_pair_candidates( + edges: list[_CutEdge], + max_gap: float, + candidate_min_iou: float = _STITCH_DEFAULTS.candidate_min_iou, +) -> list[tuple[_CutEdge, _CutEdge, dict[str, float]]]: + """Find all (e, c) pairs of facing cut edges with their geometry features. + + Returns one entry per surviving candidate. No selection / scoring yet. + """ + out: list[tuple[_CutEdge, _CutEdge, dict[str, float]]] = [] + by_axis: dict[str, list[_CutEdge]] = {"h": [], "v": []} + for e in edges: + by_axis[e.axis].append(e) + + for axis_edges in by_axis.values(): + axis_edges.sort(key=lambda e: e.coord) + coords = np.array([e.coord for e in axis_edges]) + for i, e in enumerate(axis_edges): + lo = int(np.searchsorted(coords, e.coord - max_gap, side="left")) + hi = int(np.searchsorted(coords, e.coord + max_gap, side="right")) + for j in range(lo, hi): + if j <= i: + continue # symmetry: emit each unordered pair once + c = axis_edges[j] + if c.cell_id == e.cell_id: + continue + feats = _pair_geometry_features(e, c, max_gap, candidate_min_iou=candidate_min_iou) + if feats is None: + continue + out.append((e, c, feats)) + return out + + +# Scoring + + +def _gap_proximity(gap: float, close_radius: int) -> float: + """Map the raw perpendicular gap to [0, 1] against the closing reach. + + Normalised by ``2 * close_radius`` -- the scale at which morphological + closing could actually bridge the seam -- so the feature is independent of + the ``max_gap`` search radius and only reaches 0 when the gap genuinely + exceeds what closing can join. When closing is disabled (``close_radius=0``) + the feature is inactive and returns ``1.0`` rather than collapsing the score. + """ + reach = 2 * close_radius + # gap<=0 (touching/overlapping) or reach<=0 (closing disabled, close_radius=0) + # -> the feature is inactive (neutral 1.0), never a silent score cliff. + if gap <= 0 or reach <= 0: + return 1.0 + return max(0.0, 1.0 - gap / reach) + + +def _score_pair_features(features: dict[str, float]) -> float: + """Return the heuristic stitch score in [0, 1]. + + Flat (unweighted) mean of the five features in :data:`_SCORE_FEATURES`. + The score is dataset-independent and not a calibrated probability -- users + pick ``min_confidence`` based on their false-merge tolerance. + """ + return float(sum(features[name] for name in _SCORE_FEATURES) / len(_SCORE_FEATURES)) + + +def _max_achievable_score(known_features: dict[str, float]) -> float: + """Upper bound on the stitch score from the cheap geometry features alone. + + The deferred shape features (:data:`_SHAPE_FEATURES`) are each in ``[0, 1]``, + so assume their best case. Built on :func:`_score_pair_features` so the bound + can never drift from the real score if the feature set or weighting changes. + """ + return _score_pair_features({**known_features, **dict.fromkeys(_SHAPE_FEATURES, 1.0)}) + + +def _score_pairs( + candidates: list[tuple[_CutEdge, _CutEdge, dict[str, float]]], + bboxes: dict[int, tuple[int, int, int, int]], + outlier_crops: dict[int, np.ndarray], + min_confidence: float, + close_radius: int = _STITCH_DEFAULTS.close_radius, + *, + H: int, + W: int, +) -> list[_StitchPair]: + """Compute shape features per candidate, score, and keep pairs >= min_confidence. + + One entry per ``(cell_a, cell_b, axis)`` (keeping max confidence on duplicates). + """ + scored: list[_StitchPair] = [] + for e, c, geom in candidates: + known = {**geom, "gap_proximity": _gap_proximity(geom["gap"], close_radius)} + # Skip the costly union reconstruction when even the best case for the + # deferred shape features can't reach min_confidence. + if _max_achievable_score(known) < min_confidence: + continue + shape = _merge_shape_features(e.cell_id, c.cell_id, bboxes, outlier_crops, close_radius=close_radius, H=H, W=W) + feats = {**known, **shape} + confidence = _score_pair_features(feats) + if confidence < min_confidence: + continue + # Canonicalise so cell_a < cell_b for deterministic union-find. + if e.cell_id < c.cell_id: + ea, eb = e, c + else: + ea, eb = c, e + scored.append( + _StitchPair( + cell_a=ea.cell_id, + cell_b=eb.cell_id, + axis=e.axis, + confidence=confidence, + iou=feats["iou"], + endpoint_match=feats["endpoint_match"], + gap_proximity=feats["gap_proximity"], + merge_solidity=feats["merge_solidity"], + merge_compactness=feats["merge_compactness"], + edge_a=ea, + edge_b=eb, + ) + ) + + # Deduplicate to one entry per (cell_a, cell_b, axis), keeping max confidence. + by_pair: dict[tuple[int, int, str], _StitchPair] = {} + for p in scored: + k = (p.cell_a, p.cell_b, p.axis) + if k not in by_pair or by_pair[k].confidence < p.confidence: + by_pair[k] = p + return sorted(by_pair.values(), key=lambda p: (-p.confidence, p.cell_a, p.cell_b)) + + +# Group assembly (union-find + validation) + + +def _validate_group_geometry( + pairs_in_group: list[_StitchPair], + size: int, + max_gap: float, +) -> bool: + """Geometric sanity check for groups of size >= 3. + + Two cases: + + - **Corner group** (size 4, both axes present): the cut edges' endpoints + must converge near a single junction point (one ``h`` cut crossing one + ``v`` cut defines the junction). If the spread of edge extents from + the junction is greater than ``max_gap``, the group is implausible. + + - **Chain group** (size 3 or 4, all pairs share one axis): legitimate + same-axis chains (e.g., a cell split by 3 horizontal seams into 4 + vertically-stacked pieces) have pairs at N-1 *distinct* seam + coordinates. Multiple pairs at the same seam coord would imply + geometrically impossible "two cuts at the same seam" pairings -- a + signature of a false-positive cluster -- so we reject. + """ + h_pairs = [p for p in pairs_in_group if p.axis == "h"] + v_pairs = [p for p in pairs_in_group if p.axis == "v"] + + # Chain case: only one axis present and size >= 3. + if not h_pairs or not v_pairs: + if size < 3: + return True # 2-piece groups are trivially valid on one axis + # Each pair's seam coord is roughly midway between its two edges. + seam_coords = [round((p.edge_a.coord + p.edge_b.coord) / 2.0, 1) for p in pairs_in_group] + # Allow a max_gap-sized tolerance for "distinct" seams. + sorted_coords = sorted(seam_coords) + for prev, cur in zip(sorted_coords, sorted_coords[1:], strict=False): + if cur - prev <= max_gap: + return False + return True + + # Mixed-axis case: only validate the 4-piece corner pattern. 3-piece + # L-shapes (one h pair + one v pair sharing a corner cell) are + # geometrically valid and don't have a junction to converge on. + if size != 4: + return True + + # Corner case: both axes present, size 4. Junction y/x is the mean of edge coords. + h_edges = [p.edge_a for p in h_pairs] + [p.edge_b for p in h_pairs] + v_edges = [p.edge_a for p in v_pairs] + [p.edge_b for p in v_pairs] + junction_y = float(np.mean([e.coord for e in h_edges])) + junction_x = float(np.mean([e.coord for e in v_edges])) + for e in h_edges: + if min(abs(e.extent[0] - junction_x), abs(e.extent[1] - junction_x)) > max_gap: + return False + for e in v_edges: + if min(abs(e.extent[0] - junction_y), abs(e.extent[1] - junction_y)) > max_gap: + return False + return True + + +def _assemble_groups( + pairs: list[_StitchPair], + candidate_ids: Iterable[int], + max_group_size: int, + max_gap: float, +) -> tuple[dict[int, int], dict[int, float]]: + """Build stitch groups via union-find with size + corner validation. + + Returns + ------- + groups + ``cell_id -> group_id`` (group_id == own cell_id for unstitched). + confidences + ``cell_id -> stitch_confidence`` -- min over pairwise confidences in + the cell's group; ``1.0`` for confirmed-solo (no surviving pair). + """ + # Build undirected connected components via scipy. Cells map to a + # contiguous [0, n) index space; pairs become symmetric edges in a CSR + # adjacency matrix. We then re-key components by the smallest cell_id + # they contain so the group root is deterministic. + candidate_list = sorted({int(c) for c in candidate_ids}) + if not candidate_list: + return {}, {} + id_to_idx = {cid: i for i, cid in enumerate(candidate_list)} + n = len(candidate_list) + + valid_pairs = [p for p in pairs if p.cell_a in id_to_idx and p.cell_b in id_to_idx] + if valid_pairs: + rows = [id_to_idx[p.cell_a] for p in valid_pairs] + cols = [id_to_idx[p.cell_b] for p in valid_pairs] + adj = csr_matrix((np.ones(len(rows), dtype=np.int8), (rows, cols)), shape=(n, n)) + _, comp_labels = connected_components(adj, directed=False) + else: + comp_labels = np.arange(n) + + cells_by_comp: dict[int, list[int]] = {} + for i, comp in enumerate(comp_labels): + cells_by_comp.setdefault(int(comp), []).append(candidate_list[i]) + + members: dict[int, list[int]] = {} + root_of_cell: dict[int, int] = {} + for comp_members in cells_by_comp.values(): + comp_members.sort() + root = comp_members[0] + members[root] = comp_members + for cid in comp_members: + root_of_cell[cid] = root + + pairs_by_group: dict[int, list[_StitchPair]] = {} + for p in valid_pairs: + pairs_by_group.setdefault(root_of_cell[p.cell_a], []).append(p) + + groups: dict[int, int] = {} + confidences: dict[int, float] = {} + + for root, mem in members.items(): + size = len(mem) + group_pairs = pairs_by_group.get(root, []) + + # Size cap: collapse oversized groups back to singletons. + if size > max_group_size: + for m in mem: + groups[m] = m + confidences[m] = 1.0 + continue + + # Geometric validation for 3+ piece groups: corner-junction for + # mixed-axis 4-groups, chain (distinct seam coords) for same-axis 3+. + if size >= 3 and not _validate_group_geometry(group_pairs, size, max_gap): + for m in mem: + groups[m] = m + confidences[m] = 1.0 + continue + + if size == 1: + groups[mem[0]] = mem[0] + confidences[mem[0]] = 1.0 + continue + + # Group confidence = min over pairwise confidences (weakest link). + group_conf = float(min(p.confidence for p in group_pairs)) + for m in mem: + groups[m] = root + confidences[m] = group_conf + + return groups, confidences + + +# Public entry point + + +def assign_stitch_groups( + sdata: sd.SpatialData, + labels_key: str, + qc_table_key: str | None = None, + min_confidence: float = 0.7, + max_gap: float = 3.0, + max_group_size: int = 4, + stitch_params: StitchParams | Mapping[str, Any] | None = None, + inplace: bool = True, +) -> ad.AnnData | None: + """Assign tile-cut cell pieces to stitch groups. + + Reads ``is_outlier=True`` cells flagged by + :func:`~squidpy.experimental.tl.calculate_tiling_qc`, pairs facing cut + edges across tile boundaries, scores each pair via a transparent geometric + composite, and assembles high-confidence pairs into stitch groups via + union-find. This only *annotates* which pieces belong together -- it does + **not** modify the labels element. Materialising a stitched labels element + is opt-in via :func:`!make_stitched_labels`. + + The score per pair is the flat (unweighted) mean of five geometric features + in [0, 1]: ``iou`` (1-D extent overlap), ``endpoint_match`` (chord endpoints + coincide), ``merge_compactness`` (``4*pi*A / P^2`` of the closed union mask), + ``merge_solidity`` (union area / convex hull area), and ``gap_proximity`` + (seam gap relative to the morphological closing reach). No coefficients are + fitted or shipped; the features are recorded in ``.uns["tiling_stitch"]``. + + Parameters + ---------- + sdata + :class:`~spatialdata.SpatialData` with a labels element and a QC + table from :func:`~squidpy.experimental.tl.calculate_tiling_qc`. + labels_key + Key in ``sdata.labels``. + qc_table_key + Key of the QC table. Defaults to ``"{labels_key}_qc"``. + min_confidence + Threshold on ``stitch_confidence``. ``0.7`` (default) is a starting + point; raise it for stricter precision, lower for recall. Tune for + your data -- the score is heuristic, not a calibrated probability. + max_gap + Maximum perpendicular distance (px) between facing cut edges for a pair + to be *considered* a candidate. This is a search radius only; it does + not scale the score. + max_group_size + Cap on group size; oversized groups (likely false merges) collapse + to singletons. + stitch_params + Advanced tuning knobs as a :class:`StitchParams` instance or a + ``Mapping`` of its field names to values. See :class:`StitchParams` + for each field's meaning and default. ``None`` (default) uses + all defaults. + inplace + If ``True``, write back into ``sdata.tables[qc_table_key]``. + Otherwise return the modified AnnData. + + Returns + ------- + The QC :class:`~anndata.AnnData` with four new ``.obs`` columns when + ``inplace=False``, otherwise ``None``. + """ + if labels_key not in sdata.labels: + raise ValueError(f"Labels key '{labels_key}' not found in sdata.labels.") + if min_confidence < 0 or min_confidence > 1: + raise ValueError(f"min_confidence must be in [0, 1], got {min_confidence}.") + if max_gap < 0: + raise ValueError(f"max_gap must be non-negative, got {max_gap}.") + if max_group_size < 1: + raise ValueError(f"max_group_size must be >= 1, got {max_group_size}.") + params = _resolve_stitch_params(stitch_params) + + table_key = qc_table_key if qc_table_key is not None else f"{labels_key}_qc" + if table_key not in sdata.tables: + raise ValueError(f"QC table '{table_key}' not found. Run calculate_tiling_qc first.") + adata = sdata.tables[table_key].copy() + + if "is_outlier" not in adata.obs.columns: + raise ValueError(f"QC table '{table_key}' is missing 'is_outlier'; re-run calculate_tiling_qc.") + if "label_id" not in adata.obs.columns: + raise ValueError(f"QC table '{table_key}' is missing 'label_id'.") + + existing = [c for c in _STITCH_COLUMNS if c in adata.obs.columns] + if existing: + logg.warning(f"Overwriting existing stitch columns: {existing}.") + adata.obs.drop(columns=existing, inplace=True) + + # Resolve which labels DataArray was used at QC time (multi-scale aware). + qc_params = adata.uns.get("tiling_qc", {}) + scale = qc_params.get("scale") + labels_da = resolve_labels_array(sdata, labels_key, scale) + + label_ids = adata.obs["label_id"].astype(int).to_numpy() + is_outlier = adata.obs["is_outlier"].to_numpy(dtype=bool) + outlier_ids = label_ids[is_outlier].tolist() + + n_outliers = len(outlier_ids) + logg.info(f"Stitching {n_outliers} outlier cells (out of {len(label_ids)} total).") + + if n_outliers == 0: + logg.warning("No outliers flagged; nothing to stitch.") + groups: dict[int, int] = {} + confidences: dict[int, float] = {} + edges: list[_CutEdge] = [] + pairs: list[_StitchPair] = [] + else: + bboxes = _compute_outlier_bboxes(labels_da, outlier_ids) + missing = [lid for lid in outlier_ids if lid not in bboxes] + if missing: + logg.warning( + f"{len(missing)} outlier label_id(s) flagged in the QC table do not appear " + f"in '{labels_key}' (e.g. {missing[:5]}); they will not be stitched." + ) + edges, outlier_crops = _extract_cut_edges( + labels_da, + outlier_ids, + bboxes=bboxes, + distance_tol=params.distance_tol, + min_edge_length=params.min_edge_length, + min_edge_length_ratio=params.min_edge_length_ratio, + min_edge_coverage=params.min_edge_coverage, + ) + H, W = labels_da.shape[-2], labels_da.shape[-1] + candidates = _enumerate_pair_candidates(edges, max_gap=max_gap, candidate_min_iou=params.candidate_min_iou) + pairs = _score_pairs( + candidates, bboxes, outlier_crops, min_confidence, close_radius=params.close_radius, H=H, W=W + ) + groups, confidences = _assemble_groups(pairs, outlier_ids, max_group_size=max_group_size, max_gap=max_gap) + + # Write .obs columns with three states distinguished by stitch_confidence: + # - non-outlier cell -> own label_id, False, 1, NaN (not evaluated) + # - outlier solo -> own label_id, False, 1, 1.0 (checked, no partner) + # - outlier stitched -> shared root, True, n, composite score + n = len(label_ids) + stitch_group_id = label_ids.copy() + is_stitched = np.zeros(n, dtype=bool) + n_pieces = np.ones(n, dtype=np.int32) + stitch_confidence = np.full(n, np.nan, dtype=np.float64) + + group_sizes: dict[int, int] = {} + if outlier_ids: + for root in groups.values(): + group_sizes[root] = group_sizes.get(root, 0) + 1 + + id_to_idx = {int(lid): i for i, lid in enumerate(label_ids)} + for cid, root in groups.items(): + i = id_to_idx[int(cid)] + stitch_group_id[i] = int(root) + size = group_sizes[root] + n_pieces[i] = size + is_stitched[i] = size > 1 + stitch_confidence[i] = float(confidences.get(cid, 1.0)) + + adata.obs["stitch_group_id"] = stitch_group_id + adata.obs["is_stitched"] = is_stitched + adata.obs["n_pieces"] = n_pieces + adata.obs["stitch_confidence"] = stitch_confidence + + n_groups = sum(1 for s in group_sizes.values() if s > 1) + n_stitched = int(is_stitched.sum()) + # Use string keys so the dict round-trips through zarr-backed .uns cleanly. + pieces_dist: dict[str, int] = {} + for s in group_sizes.values(): + if s > 1: + key = str(int(s)) + pieces_dist[key] = pieces_dist.get(key, 0) + 1 + + adata.uns[_METHOD_KEY] = { + "min_confidence": float(min_confidence), + "max_gap": float(max_gap), + "max_group_size": int(max_group_size), + "stitch_params": asdict(params), + "n_outliers": int(n_outliers), + "n_candidate_pairs": int(len(pairs)), + "n_stitched_groups": int(n_groups), + "n_stitched_cells": int(n_stitched), + "n_pieces_distribution": pieces_dist, + "score_features": list(_SCORE_FEATURES), + } + + if not inplace: + return adata + sdata.tables[table_key] = adata + return None diff --git a/src/squidpy/experimental/utils/_geometry.py b/src/squidpy/experimental/utils/_geometry.py new file mode 100644 index 000000000..4f6c152c2 --- /dev/null +++ b/src/squidpy/experimental/utils/_geometry.py @@ -0,0 +1,30 @@ +"""Shared internal geometry helpers for mask/contour analysis. + +Not part of the public API - symbols here are private and may change +without notice. +""" + +from __future__ import annotations + +import numpy as np +from skimage.measure import find_contours + + +def equivalent_diameter(area: float) -> float: + """Diameter of the circle with the given area: ``sqrt(4 * area / pi)``.""" + return float(np.sqrt(4 * area / np.pi)) + + +def largest_contour(padded_mask: np.ndarray, level: float = 0.5) -> np.ndarray | None: + """Return the longest :func:`skimage.measure.find_contours` contour, or ``None``. + + The mask must be **already 1px zero-padded** by the caller so that cells + touching the crop edge (e.g. filling their bbox) are traced closed. Padding + is left to the caller because its placement relative to other steps (e.g. + downsampling) is order-sensitive and differs between call sites. Returned + coordinates are in the padded mask's frame. + """ + contours = find_contours(padded_mask, level) + if not contours: + return None + return max(contours, key=len) diff --git a/src/squidpy/experimental/utils/_labels.py b/src/squidpy/experimental/utils/_labels.py index 5d18e8370..616f64358 100644 --- a/src/squidpy/experimental/utils/_labels.py +++ b/src/squidpy/experimental/utils/_labels.py @@ -6,11 +6,56 @@ from __future__ import annotations +from collections.abc import Iterable, Iterator +from typing import Any + +import numpy as np import spatialdata as sd import xarray as xr +from skimage.measure import regionprops from spatialdata._logging import logger as logg +def iter_chunked_regionprops( + labels: xr.DataArray | np.ndarray, + chunk_size: int = 4096, + label_subset: Iterable[int] | None = None, +) -> Iterator[tuple[int, Any, int, int]]: + """Yield ``(label_id, region, y0, x0)`` over chunked ``regionprops`` of a labels array. + + Works on a plain :class:`numpy.ndarray` (a single chunk) or a possibly + dask-backed 2-D :class:`xarray.DataArray`, reading at most ``chunk_size`` x + ``chunk_size`` at a time so memory stays bounded for very large images. + + ``region`` is a :class:`skimage.measure.RegionProperties` whose coordinates + are LOCAL to the chunk; add ``y0`` / ``x0`` for global coordinates. When + ``label_subset`` is given, only regions with those label ids are yielded. + Background (label 0) is never yielded (``regionprops`` skips it). + """ + subset = None if label_subset is None else {int(x) for x in label_subset} + + if isinstance(labels, np.ndarray): + for region in regionprops(labels): + lid = int(region.label) + if subset is None or lid in subset: + yield lid, region, 0, 0 + return + + h = int(labels.sizes.get("y", labels.shape[-2])) + w = int(labels.sizes.get("x", labels.shape[-1])) + for y0 in range(0, h, chunk_size): + y1 = min(y0 + chunk_size, h) + for x0 in range(0, w, chunk_size): + x1 = min(x0 + chunk_size, w) + chunk = labels.isel(y=slice(y0, y1), x=slice(x0, x1)).values + while chunk.ndim > 2: + chunk = chunk.squeeze(0) + for region in regionprops(chunk): + lid = int(region.label) + if subset is None or lid in subset: + yield lid, region, y0, x0 + + def resolve_labels_array(sdata: sd.SpatialData, labels_key: str, scale: str | None) -> xr.DataArray: """Resolve a labels element to its 2-D ``xarray.DataArray``. diff --git a/src/squidpy/experimental/utils/_params.py b/src/squidpy/experimental/utils/_params.py new file mode 100644 index 000000000..3b7d40f8c --- /dev/null +++ b/src/squidpy/experimental/utils/_params.py @@ -0,0 +1,40 @@ +"""Shared internal helper for resolving params-dataclass arguments. + +Not part of the public API - symbols here are private and may change +without notice. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import fields +from typing import Any, TypeVar + +_T = TypeVar("_T") + + +def resolve_params(value: _T | Mapping[str, Any] | None, cls: type[_T], *, label: str) -> _T: + """Normalise a params argument (``None`` / instance / ``Mapping``) to a ``cls`` instance. + + Parameters + ---------- + value + ``None`` (use defaults), an instance of ``cls`` (passed through by + identity), or a ``Mapping`` of field names to values. + cls + The params dataclass to construct. + label + The user-facing argument name used verbatim in error messages. Include + backticks if the caller's convention uses them (e.g. ``"`tiling_qc_params`"``). + """ + if value is None: + return cls() + if isinstance(value, cls): + return value + if isinstance(value, Mapping): + valid = {f.name for f in fields(cls)} + unknown = set(value) - valid + if unknown: + raise ValueError(f"Unknown {label} field(s): {sorted(unknown)}; expected from {sorted(valid)}.") + return cls(**value) + raise TypeError(f"{label} must be {cls.__name__}, Mapping, or None; got {type(value).__name__}.") diff --git a/tests/_images/StitchVisual_seam_group_recolor.png b/tests/_images/StitchVisual_seam_group_recolor.png new file mode 100644 index 000000000..0a8786288 Binary files /dev/null and b/tests/_images/StitchVisual_seam_group_recolor.png differ diff --git a/tests/experimental/conftest.py b/tests/experimental/conftest.py index 3b1e0c5c1..edc41a851 100644 --- a/tests/experimental/conftest.py +++ b/tests/experimental/conftest.py @@ -16,9 +16,7 @@ from spatialdata import SpatialData from spatialdata.models import Image2DModel, Labels2DModel -# --------------------------------------------------------------------------- # Tile-boundary QC fixture -# --------------------------------------------------------------------------- _IMAGE_SIZE = 600 _TILE_BORDERS = (200, 400) # 3x3 grid on 600 px - borders at 200, 400 diff --git a/tests/experimental/test_tiling.py b/tests/experimental/test_tiling.py index 71884d7bd..c387bf1e3 100644 --- a/tests/experimental/test_tiling.py +++ b/tests/experimental/test_tiling.py @@ -26,9 +26,7 @@ ) from tests.conftest import PlotTester, PlotTesterMeta -# --------------------------------------------------------------------------- # Brick-pattern fixture -# --------------------------------------------------------------------------- _IMAGE_SIZE = 500 _CELL_H = 20 @@ -121,9 +119,7 @@ def _make_ci(label: int, cy: float, cx: float, h: int = 4, w: int = 4) -> CellIn return CellInfo(label=label, centroid_y=cy, centroid_x=cx, bbox_h=h, bbox_w=w) -# --------------------------------------------------------------------------- # Fixtures -# --------------------------------------------------------------------------- @pytest.fixture(params=[10, 0], ids=["gap=10", "gap=0"]) @@ -137,9 +133,7 @@ def brick_image(): return _make_image() -# --------------------------------------------------------------------------- # build_tile_specs - deterministic checks -# --------------------------------------------------------------------------- class TestBuildTileSpecs: @@ -264,9 +258,7 @@ def test_tile_size_larger_than_image(self): assert len(specs) == 1 -# --------------------------------------------------------------------------- # extract_tile -# --------------------------------------------------------------------------- class TestExtractTile: @@ -314,9 +306,7 @@ def test_image_crop_shape(self, brick_labels, brick_image): assert tile_lbl.shape == (cy1 - cy0, cx1 - cx0) -# --------------------------------------------------------------------------- # End-to-end roundtrip -# --------------------------------------------------------------------------- class TestEndToEnd: @@ -341,9 +331,7 @@ def test_roundtrip_no_cells_lost(self, brick_labels, brick_image): # test_roundtrip_no_cells_lost via the brick_labels fixture's parametrisation. -# --------------------------------------------------------------------------- # Visual test - tile assignment plot -# --------------------------------------------------------------------------- # Tile colors: one distinct color per tile quadrant _TILE_COLORS = [ @@ -388,9 +376,7 @@ def _plot_tile_assignment(labels, specs, title=""): ax.set_ylabel("y") -# --------------------------------------------------------------------------- # Lazy / multiscale helpers -# --------------------------------------------------------------------------- def _make_multiscale_tree(labels: np.ndarray, n_scales: int = 3) -> xr.DataTree: diff --git a/tests/experimental/test_tiling_qc.py b/tests/experimental/test_tiling_qc.py index 5413bb2ac..f33a96e59 100644 --- a/tests/experimental/test_tiling_qc.py +++ b/tests/experimental/test_tiling_qc.py @@ -11,9 +11,7 @@ from squidpy.experimental.im._tiling import compute_cell_info, compute_cell_info_tiled from tests.conftest import PlotTester, PlotTesterMeta -# --------------------------------------------------------------------------- # Core behavioural tests -# --------------------------------------------------------------------------- class TestCalculateTilingQC: @@ -222,9 +220,7 @@ def test_smoothed_only_gate(self, sdata_tile_boundary): assert adata.obs["is_outlier"].dtype == bool -# --------------------------------------------------------------------------- # Params resolution -# --------------------------------------------------------------------------- class TestTilingQCParamsResolution: @@ -280,9 +276,7 @@ def test_wrong_type_raises_type_error(self): _resolve_qc_params(42) -# --------------------------------------------------------------------------- # resolve_labels_array helper -# --------------------------------------------------------------------------- class TestResolveLabelsArray: @@ -310,9 +304,7 @@ def test_multi_scale_without_scale_raises(self): resolve_labels_array(sdata, "labels", scale=None) -# --------------------------------------------------------------------------- # Tiled centroid backend -# --------------------------------------------------------------------------- class TestComputeCellInfoTiled: @@ -363,9 +355,7 @@ def test_matches_reference_single_chunk(self, sdata_clean): assert ci_tiled.bbox_w == ci_ref.bbox_w -# --------------------------------------------------------------------------- # Visual regression tests (PlotTester) -# --------------------------------------------------------------------------- @pytest.fixture() diff --git a/tests/experimental/test_tiling_stitch.py b/tests/experimental/test_tiling_stitch.py new file mode 100644 index 000000000..a22c40efc --- /dev/null +++ b/tests/experimental/test_tiling_stitch.py @@ -0,0 +1,194 @@ +"""Tests for tile-cut cell stitching.""" + +from __future__ import annotations + +import dask.array as da +import matplotlib.pyplot as plt +import numpy as np +import pytest +import xarray as xr +from spatialdata import SpatialData +from spatialdata.models import Labels2DModel + +import squidpy as sq +from tests.conftest import DPI, PlotTester, PlotTesterMeta + + +def _run_qc_and_stitch(sdata, **stitch_kwargs): + sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="labels", tile_size=200, nmads_cut=1.0, nmads_smoothed=1.5) + sq.experimental.tl.assign_stitch_groups(sdata, labels_key="labels", **stitch_kwargs) + return sdata.tables["labels_qc"] + + +class TestAssignStitchGroups: + """Tests for sq.experimental.tl.assign_stitch_groups using the tile-boundary fixture.""" + + def test_columns_present(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata) + for col in ("stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence"): + assert col in adata.obs.columns + + def test_confidence_convention(self, sdata_tile_boundary): + # NaN = not evaluated (non-outlier), 1.0 = solo outlier, composite = stitched. + sdata, _ = sdata_tile_boundary + obs = _run_qc_and_stitch(sdata, min_confidence=0.5).obs + + non_outliers = ~obs["is_outlier"].astype(bool) + assert non_outliers.sum() > 0 + assert obs.loc[non_outliers, "stitch_confidence"].isna().all() + assert (obs.loc[non_outliers, "stitch_group_id"] == obs.loc[non_outliers, "label_id"]).all() + assert (obs.loc[non_outliers, "n_pieces"] == 1).all() + + solo = obs["is_outlier"].astype(bool) & ~obs["is_stitched"].astype(bool) + if solo.sum() > 0: + assert (obs.loc[solo, "stitch_confidence"] == 1.0).all() + + stitched = obs["is_stitched"].astype(bool) + if stitched.sum() > 0: + confs = obs.loc[stitched, "stitch_confidence"] + assert ((confs >= 0.5) & (confs <= 1.0)).all() + assert obs.loc[stitched, "n_pieces"].between(2, 4).all() + + def test_group_id_shared_within_group(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata, min_confidence=0.5) + stitched = adata.obs[adata.obs["is_stitched"].astype(bool)] + for _gid, members in stitched.groupby("stitch_group_id"): + assert len(members) == members["n_pieces"].iloc[0] + + def test_stitched_group_is_made_of_cut_pieces(self, sdata_tile_boundary): + sdata, gt = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata, min_confidence=0.5) + stitched = adata.obs[adata.obs["is_stitched"].astype(bool)] + found = any( + len(set(m["label_id"].astype(int))) >= 2 and set(m["label_id"].astype(int)) <= set(gt.cut_cell_ids) + for _gid, m in stitched.groupby("stitch_group_id") + ) + assert found + + def test_no_intact_cells_stitched_at_high_threshold(self, sdata_tile_boundary): + sdata, gt = sdata_tile_boundary + adata = _run_qc_and_stitch(sdata, min_confidence=0.9) + intact = adata.obs["label_id"].isin(gt.intact_cell_ids) + n_false = int((intact & adata.obs["is_stitched"].astype(bool)).sum()) + assert n_false <= 5 + + def test_uns_records_params_and_features(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + meta = _run_qc_and_stitch(sdata, min_confidence=0.7, max_gap=4.0).uns["tiling_stitch"] + assert meta["min_confidence"] == 0.7 + assert meta["max_gap"] == 4.0 + assert isinstance(meta["stitch_params"], dict) + assert "model_coefficients" not in meta and "model_intercept" not in meta + assert set(meta["score_features"]) == { + "iou", + "endpoint_match", + "merge_compactness", + "merge_solidity", + "gap_proximity", + } + + @pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"labels_key": "labels"}, "QC table"), + ({"labels_key": "bogus"}, "not found in sdata.labels"), + ({"labels_key": "labels", "min_confidence": 1.5}, "min_confidence"), + ], + ids=["missing_qc_table", "missing_labels_key", "invalid_min_confidence"], + ) + def test_invalid_input_raises(self, sdata_tile_boundary, kwargs, match): + sdata, _ = sdata_tile_boundary + with pytest.raises(ValueError, match=match): + sq.experimental.tl.assign_stitch_groups(sdata, **kwargs) + + def test_rerun_overwrites_without_growing_columns(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _run_qc_and_stitch(sdata) + n_before = len(sdata.tables["labels_qc"].obs.columns) + sq.experimental.tl.assign_stitch_groups(sdata, labels_key="labels") + assert len(sdata.tables["labels_qc"].obs.columns) == n_before + + def test_inplace_false_returns_without_writing(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="labels", tile_size=200) + n_before = len(sdata.tables["labels_qc"].obs.columns) + result = sq.experimental.tl.assign_stitch_groups(sdata, labels_key="labels", inplace=False) + assert result is not None and "stitch_group_id" in result.obs.columns + assert len(sdata.tables["labels_qc"].obs.columns) == n_before + + def test_qc_rerun_removes_stitch_columns(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + _run_qc_and_stitch(sdata) + sq.experimental.tl.calculate_tiling_qc(sdata, labels_key="labels", tile_size=200) + for col in ("stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence"): + assert col not in sdata.tables["labels_qc"].obs.columns + + def test_runs_on_multiscale(self): + from tests.experimental.conftest import make_tile_boundary_sdata + + base, _ = make_tile_boundary_sdata() + arr = np.asarray(base.labels["labels"].values) + ms = Labels2DModel.parse( + xr.DataArray(da.from_array(arr, chunks=(200, 200)), dims=("y", "x")), scale_factors=[2] + ) + sdata = SpatialData(images={"image": base.images["image"]}, labels={"labels": ms}) + sq.experimental.tl.calculate_tiling_qc( + sdata, labels_key="labels", scale="scale0", tile_size=200, nmads_cut=1.0, nmads_smoothed=1.5 + ) + sq.experimental.tl.assign_stitch_groups(sdata, labels_key="labels") + for col in ("stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence"): + assert col in sdata.tables["labels_qc"].obs.columns + + def test_obs_and_uns_survive_zarr_roundtrip(self, sdata_tile_boundary, tmp_path): + from spatialdata import read_zarr + + sdata, _ = sdata_tile_boundary + _run_qc_and_stitch(sdata, min_confidence=0.5) + sdata.write(tmp_path / "roundtrip.zarr") + a2 = read_zarr(tmp_path / "roundtrip.zarr").tables["labels_qc"] + for col in ("stitch_group_id", "is_stitched", "n_pieces", "stitch_confidence"): + assert col in a2.obs.columns + assert "tiling_stitch" in a2.uns + + +class TestStitchVisual(PlotTester, metaclass=PlotTesterMeta): + _ZOOM = (150, 250, 250, 350) + _SEAM_Y = 200 + + def test_plot_seam_group_recolor(self, sdata_tile_boundary): + sdata, _ = sdata_tile_boundary + sq.experimental.tl.calculate_tiling_qc( + sdata, labels_key="labels", tile_size=200, nmads_cut=1.0, nmads_smoothed=1.5 + ) + sq.experimental.tl.assign_stitch_groups(sdata, labels_key="labels", min_confidence=0.5) + adata = sdata.tables["labels_qc"] + + labels = np.asarray(sdata.labels["labels"].values) + lut = np.arange(int(labels.max()) + 1) + lut[adata.obs["label_id"].astype(int).to_numpy()] = adata.obs["stitch_group_id"].astype(int).to_numpy() + regrouped = lut[labels] + + rng = np.random.default_rng(0) + colors = rng.random((int(labels.max()) + 1, 3)) + colors[0] = 0.0 + + y0, y1, x0, x1 = self._ZOOM + before = colors[labels][y0:y1, x0:x1] # coloured by label_id (cut pieces differ) + after = colors[regrouped][y0:y1, x0:x1] # coloured by stitch_group_id (pieces share a colour) + seam = self._SEAM_Y - y0 + for panel in (before, after): + panel[seam, ::4] = 1.0 # dashed seam marker, drawn into the array (no mpl line AA) + sep = np.ones((before.shape[0], 4, 3)) # white column between the two panels + combined = np.concatenate([before, sep, after], axis=1) + + # Render 1:1 (figsize * DPI == array shape) on a full-figure axis. No + # upscaling -> no nearest-neighbour resampling, no text, no line AA, so the + # PNG is pixel-identical across platforms/matplotlib versions (the earlier + # tight_layout + upscaled imshow drifted by RMS ~53/28 between Linux/macOS). + h, w = combined.shape[:2] + fig = plt.figure(figsize=(w / DPI, h / DPI)) + ax = fig.add_axes((0, 0, 1, 1)) + ax.imshow(combined, interpolation="nearest") + ax.set_axis_off() diff --git a/tests/test_validators.py b/tests/test_validators.py index c0296072e..8a9f0fbc6 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -20,9 +20,7 @@ ) -# --------------------------------------------------------------------------- # assert_positive -# --------------------------------------------------------------------------- class TestAssertPositive: def test_positive_value(self): assert_positive(1.0, name="x") @@ -37,9 +35,7 @@ def test_negative_raises(self): assert_positive(-1, name="x") -# --------------------------------------------------------------------------- # assert_non_negative -# --------------------------------------------------------------------------- class TestAssertNonNegative: def test_non_negative_value(self): assert_non_negative(0, name="x") @@ -50,9 +46,7 @@ def test_negative_raises(self): assert_non_negative(-0.1, name="x") -# --------------------------------------------------------------------------- # assert_in_range -# --------------------------------------------------------------------------- class TestAssertInRange: def test_in_range(self): assert_in_range(0.5, 0, 1, name="x") @@ -66,9 +60,7 @@ def test_out_of_range(self): assert_in_range(-0.1, 0, 1, name="x") -# --------------------------------------------------------------------------- # assert_non_empty_sequence -# --------------------------------------------------------------------------- class TestAssertNonEmptySequence: def test_list(self): assert assert_non_empty_sequence(["a", "b"], name="items") == ["a", "b"] @@ -85,9 +77,7 @@ def test_empty_raises(self): assert_non_empty_sequence([], name="items") -# --------------------------------------------------------------------------- # get_valid_values -# --------------------------------------------------------------------------- class TestGetValidValues: def test_valid(self): assert get_valid_values(["a", "b"], ["a", "b", "c"]) == ["a", "b"] @@ -100,9 +90,7 @@ def test_none_valid(self): get_valid_values(["z"], ["a", "b"]) -# --------------------------------------------------------------------------- # check_tuple_needles -# --------------------------------------------------------------------------- class TestCheckTupleNeedles: def test_valid_needles(self): result = check_tuple_needles([("a", "b")], ["a", "b", "c"], "Value `{}` not found.") @@ -125,9 +113,7 @@ def test_not_sequence(self): check_tuple_needles([42], ["a"], "msg {}") -# --------------------------------------------------------------------------- # assert_isinstance -# --------------------------------------------------------------------------- class TestAssertIsinstance: def test_correct_type(self): assert_isinstance("hello", str, name="x") @@ -146,9 +132,7 @@ def test_wrong_type_tuple(self): assert_isinstance(3.14, (str, int), name="x") -# --------------------------------------------------------------------------- # assert_one_of -# --------------------------------------------------------------------------- class TestAssertOneOf: def test_valid(self): assert_one_of("a", ["a", "b", "c"], name="x") @@ -158,9 +142,7 @@ def test_invalid(self): assert_one_of("z", ["a", "b"], name="x") -# --------------------------------------------------------------------------- # assert_key_in_adata -# --------------------------------------------------------------------------- class TestAssertKeyInAdata: def test_key_present(self): adata = MagicMock() @@ -193,9 +175,7 @@ def test_container_without_keys_method(self): assert_key_in_adata(adata, "X_spatial", attr="obsm") -# --------------------------------------------------------------------------- # assert_key_in_sdata -# --------------------------------------------------------------------------- class TestAssertKeyInSdata: def test_key_present(self): sdata = MagicMock() @@ -221,9 +201,7 @@ def test_lists_available_keys(self): assert_key_in_sdata(sdata, "missing", attr="images") -# --------------------------------------------------------------------------- # assert_isinstance edge cases -# --------------------------------------------------------------------------- class TestAssertIsinstanceEdgeCases: def test_bool_is_subclass_of_int(self): """bool is a subclass of int — assert_isinstance(True, int) passes."""