From 64725d215dcc4be3668afb698a405540361a3ceb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Thu, 4 Jun 2026 18:54:33 +0200 Subject: [PATCH 1/2] Add tissue-linear: 7-class tissue classifier on Virchow2 embeddings Deploys a 7-class tissue classifier as a linear head over the Virchow2 foundation model. Per tile: apply Virchow2's image transform, fetch the ViT token sequence from the deployed `virchow2` service via a Ray Serve handle, pool tokens (class token + mean of patch tokens) into a 2560-d embedding, run the ONNX linear head, and emit a 7-channel softmax probability map for HeatmapBuilder. The hard class map is recoverable via argmax over channels at full resolution. The ONNX linear head is exported from the Virchow2 + LBFGS final linear classifier (MLflow run 0e2230c722134ce0985e09a18ccadf75, artifacts/onnx/linear_head.onnx). Files: - models/tissue_linear.py: Serve deployment. torch/PIL/timm imports are lazy (the head node builds the app graph without them; the replica runs on GPU workers that carry them). ONNX runs on CPUExecutionProvider. - helm/rayservice/applications/tissue-linear.yaml: app definition (num_gpus: 1 to land on the mig20 GPU workers for torch/timm). - helm/rayservice/values.yaml: register tissue-linear. Validated on the full WSI 07 Leiomyosarkom.svs via HeatmapBuilder, producing a (52224, 36864, 7) BigTIFF; argmax over channels yields Other (neoplastic) and Connective-Tissue (stroma) dominant, consistent with a leiomyosarcoma. Co-Authored-By: Claude Opus 4.7 --- .../applications/tissue-linear.yaml | 34 +++ helm/rayservice/values.yaml | 2 + models/tissue_linear.py | 218 ++++++++++++++++++ 3 files changed, 254 insertions(+) create mode 100644 helm/rayservice/applications/tissue-linear.yaml create mode 100644 models/tissue_linear.py diff --git a/helm/rayservice/applications/tissue-linear.yaml b/helm/rayservice/applications/tissue-linear.yaml new file mode 100644 index 0000000..475302b --- /dev/null +++ b/helm/rayservice/applications/tissue-linear.yaml @@ -0,0 +1,34 @@ +- name: tissue-linear + import_path: models.tissue_linear:app + route_prefix: /tissue-linear + runtime_env: + config: + setup_timeout_seconds: 1800 + working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip + deployments: + - name: TissueLinear + max_ongoing_requests: 512 + max_queued_requests: 4096 + autoscaling_config: + min_replicas: 0 + max_replicas: 2 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 2 + num_gpus: 1 + memory: 12884901888 + runtime_env: + env_vars: + MLFLOW_TRACKING_URI: http://mlflow.rationai-mlflow:5000 + HF_HOME: /mnt/huggingface_cache + user_config: + tile_size: 224 + output_tile_size: 1 + n_channels: 7 + mpp: 0.5 + max_batch_size: 64 + batch_wait_timeout_s: 0.1 + foundation_model_id: virchow2 + model: + _target_: providers.model_provider:mlflow + artifact_uri: mlflow-artifacts:/104/0e2230c722134ce0985e09a18ccadf75/artifacts/onnx/linear_head.onnx diff --git a/helm/rayservice/values.yaml b/helm/rayservice/values.yaml index 6e62751..e436e62 100644 --- a/helm/rayservice/values.yaml +++ b/helm/rayservice/values.yaml @@ -7,4 +7,6 @@ applications: - heatmap-builder - prostate-classifier-1 - prov-gigapath + - tissue-linear - virchow2 + diff --git a/models/tissue_linear.py b/models/tissue_linear.py new file mode 100644 index 0000000..1a580a6 --- /dev/null +++ b/models/tissue_linear.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +import asyncio +import importlib +from pathlib import Path +from typing import TYPE_CHECKING, Any, TypedDict + +import numpy as np +from fastapi import FastAPI, Request +from numpy.typing import NDArray +from ray import serve + + +if TYPE_CHECKING: + import torch + + +class Config(TypedDict): + tile_size: int + output_tile_size: int + n_channels: int + mpp: float + max_batch_size: int + batch_wait_timeout_s: float + foundation_model_id: str + model: dict[str, Any] + + +fastapi = FastAPI() + + +@serve.deployment(num_replicas="auto") +@serve.ingress(fastapi) +class TissueLinear: + """7-class tissue classifier: linear head over Virchow2 embeddings. + + Per tile: apply Virchow2's transform, fetch the ViT token sequence from + the deployed Virchow2 service, pool tokens (class token + mean of patch + tokens) into a 2560-d embedding, run the ONNX linear head, and return a + 7-channel softmax probability map of shape (n_classes, 1, 1). Softmax + (rather than a hard class index) is used so HeatmapBuilder's resize to + source resolution interpolates well-defined probabilities; the hard class + map is recoverable via argmax over channels at full resolution. + """ + + def __init__(self) -> None: + import lz4.frame + + self.lz4 = lz4.frame + + def reconfigure(self, config: Config) -> None: + import onnxruntime as ort + from timm.data.transforms_factory import create_transform + + self.tile_size = config["tile_size"] + self.output_tile_size = config["output_tile_size"] + self.n_channels = config["n_channels"] + self.mpp = config["mpp"] + + self.foundation_model = serve.get_app_handle(config["foundation_model_id"]) + + # Build Virchow2's eval transform directly from its known pretrained_cfg + # (verified against the model's config.json: ImageNet mean/std, bicubic, + # crop_pct 1.0). This avoids instantiating the full ~600M-param model + # just to read its transform config, saving ~2.4 GB RAM per replica and + # removing any Hugging Face Hub access at init (the repo is gated). The + # embeddings themselves are produced by the deployed Virchow2 service. + self.foundation_transform = create_transform( + input_size=(3, self.tile_size, self.tile_size), + is_training=False, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + crop_pct=1.0, + crop_mode="center", + interpolation="bicubic", + ) + + model_config = dict(config["model"]) + module_path, attr_name = model_config.pop("_target_").split(":") + provider = getattr(importlib.import_module(module_path), attr_name) + + # Resolve the .onnx file from the MLflow download. The provider may + # return the file directly, a directory containing it, or a sibling + # path. Resolved inline (no module-level helper) because Ray's + # by-value deployment serialization does not reliably carry module + # globals into the worker. + downloaded_path = Path(provider(**model_config)) + if downloaded_path.is_file() and downloaded_path.suffix == ".onnx": + model_path = downloaded_path + else: + search_root = ( + downloaded_path if downloaded_path.is_dir() else downloaded_path.parent + ) + candidates = list(search_root.rglob("*.onnx")) + if not candidates: + raise FileNotFoundError( + f"No .onnx file found at or near downloaded path: {downloaded_path}" + ) + model_path = candidates[0] + print(f"Using ONNX model path: {model_path}") + + # Run the head on CPU. It is a single 2560->n_classes linear, so the + # GPU kernel-launch and host<->device transfer overhead would exceed + # the matmul itself, and the embeddings already arrive as CPU numpy. + # The num_gpus: 1 reservation is only to land the actor on a worker + # image that carries torch/timm, not for ONNX compute. + self.session = ort.InferenceSession( + str(model_path), + providers=["CPUExecutionProvider"], + ) + self.input_name = self.session.get_inputs()[0].name + self.output_name = self.session.get_outputs()[0].name + self._num_classes = int(self.session.get_outputs()[0].shape[-1]) + + # Fail fast on config that contradicts the model's output contract: + # one softmax probability per class per tile, shape (n_classes, 1, 1). + # A mismatch (e.g. a stale n_channels) would silently corrupt the + # HeatmapBuilder output instead of erroring. + if self.n_channels != self._num_classes: + raise ValueError( + f"n_channels ({self.n_channels}) must equal the ONNX head's " + f"number of classes ({self._num_classes})" + ) + if self.output_tile_size != 1: + raise ValueError( + f"output_tile_size must be 1 for per-tile classification, " + f"got {self.output_tile_size}" + ) + # Virchow2's pooled embedding is 2560-d (class token + patch-token mean, + # 1280 each). Guard against an artifact_uri pointing to a head trained + # for a different foundation model, which would otherwise fail with a + # cryptic shape error on the first session.run mid-slide. + expected_embedding_dim = 2560 + onnx_input_dim = int(self.session.get_inputs()[0].shape[-1]) + if onnx_input_dim != expected_embedding_dim: + raise ValueError( + f"ONNX head expects input width {onnx_input_dim}, but the " + f"Virchow2 embedding is {expected_embedding_dim}-d; the " + f"artifact_uri likely points to a head for a different " + f"foundation model" + ) + + self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] + self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] + + async def get_config(self) -> dict[str, Any]: + return { + "tile_size": self.tile_size, + "output_tile_size": self.output_tile_size, + "n_channels": self.n_channels, + "mpp": self.mpp, + } + + def _prepare_tile_for_virchow2(self, tile_chw: NDArray[np.uint8]) -> torch.Tensor: + from PIL import Image + + tile_hwc = tile_chw.transpose(1, 2, 0) + image = Image.fromarray(tile_hwc) + # Return [3, 224, 224], not [1, 3, 224, 224]. + return self.foundation_transform(image) + + async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: + import torch + + tile_tensor = await asyncio.to_thread(self._prepare_tile_for_virchow2, tile) + + virchow2_output = await self.foundation_model.predict.remote(tile_tensor) + + if isinstance(virchow2_output, np.ndarray): + virchow2_output = torch.from_numpy(virchow2_output) + + # Virchow2 predict returns one tensor per tile, shape [tokens, dim]. + # Make it [1, tokens, dim] so pooling is batch-compatible. + if virchow2_output.ndim == 2: + virchow2_output = virchow2_output.unsqueeze(0) + + class_token = virchow2_output[:, 0] + patch_tokens = virchow2_output[:, 5:] + embedding = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1) + + return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) + + @serve.batch + async def predict( + self, + tiles: list[NDArray[np.uint8]], + ) -> list[NDArray[np.float32]]: + embeddings = await asyncio.gather( + *(self._create_embedding(tile) for tile in tiles) + ) + batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) + + # The ONNX graph ends in a Softmax, so this already returns per-class + # probabilities of shape (batch, n_classes). Reshape each row to a + # (n_classes, 1, 1) map for HeatmapBuilder. + probs = self.session.run( + [self.output_name], + {self.input_name: batch}, + )[0] + + return [row.reshape(self._num_classes, 1, 1) for row in probs] + + @fastapi.post("/") + async def root(self, request: Request) -> list[Any]: + data = await asyncio.to_thread(self.lz4.decompress, await request.body()) + + tile = np.frombuffer(data, dtype=np.uint8).reshape( + self.tile_size, + self.tile_size, + 3, + ) + tile_chw = tile.transpose(2, 0, 1) + + result = await self.predict(tile_chw) + return result.tolist() + + +app = TissueLinear.bind() # type: ignore[attr-defined] From 67e5e2243b7679366bd816fe1b7c3a99c0b51ad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20C=C3=ADfka?= <550433@mail.muni.cz> Date: Sun, 14 Jun 2026 23:01:27 +0200 Subject: [PATCH 2/2] Simplify tissue_linear: drop dead path-resolution and config guards - Remove unreachable .onnx rglob fallback; artifact_uri points directly at the file, so provider() returns it as-is - Drop n_channels / output_tile_size / embedding-dim validation guards per review - Shorten verbose comments Co-Authored-By: Claude Opus 4.8 --- models/tissue_linear.py | 74 ++++++----------------------------------- 1 file changed, 11 insertions(+), 63 deletions(-) diff --git a/models/tissue_linear.py b/models/tissue_linear.py index 1a580a6..0178355 100644 --- a/models/tissue_linear.py +++ b/models/tissue_linear.py @@ -2,7 +2,6 @@ import asyncio import importlib -from pathlib import Path from typing import TYPE_CHECKING, Any, TypedDict import numpy as np @@ -59,12 +58,10 @@ def reconfigure(self, config: Config) -> None: self.foundation_model = serve.get_app_handle(config["foundation_model_id"]) - # Build Virchow2's eval transform directly from its known pretrained_cfg - # (verified against the model's config.json: ImageNet mean/std, bicubic, - # crop_pct 1.0). This avoids instantiating the full ~600M-param model - # just to read its transform config, saving ~2.4 GB RAM per replica and - # removing any Hugging Face Hub access at init (the repo is gated). The - # embeddings themselves are produced by the deployed Virchow2 service. + # Virchow2's eval transform, built from its pretrained_cfg (ImageNet + # mean/std, bicubic, crop_pct 1.0) instead of loading the ~600M-param + # model just to read it: saves ~2.4 GB RAM/replica and avoids HF Hub + # access at init (gated repo). Embeddings come from the Virchow2 service. self.foundation_transform = create_transform( input_size=(3, self.tile_size, self.tile_size), is_training=False, @@ -79,31 +76,11 @@ def reconfigure(self, config: Config) -> None: module_path, attr_name = model_config.pop("_target_").split(":") provider = getattr(importlib.import_module(module_path), attr_name) - # Resolve the .onnx file from the MLflow download. The provider may - # return the file directly, a directory containing it, or a sibling - # path. Resolved inline (no module-level helper) because Ray's - # by-value deployment serialization does not reliably carry module - # globals into the worker. - downloaded_path = Path(provider(**model_config)) - if downloaded_path.is_file() and downloaded_path.suffix == ".onnx": - model_path = downloaded_path - else: - search_root = ( - downloaded_path if downloaded_path.is_dir() else downloaded_path.parent - ) - candidates = list(search_root.rglob("*.onnx")) - if not candidates: - raise FileNotFoundError( - f"No .onnx file found at or near downloaded path: {downloaded_path}" - ) - model_path = candidates[0] - print(f"Using ONNX model path: {model_path}") - - # Run the head on CPU. It is a single 2560->n_classes linear, so the - # GPU kernel-launch and host<->device transfer overhead would exceed - # the matmul itself, and the embeddings already arrive as CPU numpy. - # The num_gpus: 1 reservation is only to land the actor on a worker - # image that carries torch/timm, not for ONNX compute. + model_path = provider(**model_config) + + # Head runs on CPU: a single 2560->n_classes matmul, so GPU launch + + # transfer overhead would exceed it, and embeddings arrive as CPU numpy. + # num_gpus: 1 only pins the actor to a worker with torch/timm. self.session = ort.InferenceSession( str(model_path), providers=["CPUExecutionProvider"], @@ -112,34 +89,6 @@ def reconfigure(self, config: Config) -> None: self.output_name = self.session.get_outputs()[0].name self._num_classes = int(self.session.get_outputs()[0].shape[-1]) - # Fail fast on config that contradicts the model's output contract: - # one softmax probability per class per tile, shape (n_classes, 1, 1). - # A mismatch (e.g. a stale n_channels) would silently corrupt the - # HeatmapBuilder output instead of erroring. - if self.n_channels != self._num_classes: - raise ValueError( - f"n_channels ({self.n_channels}) must equal the ONNX head's " - f"number of classes ({self._num_classes})" - ) - if self.output_tile_size != 1: - raise ValueError( - f"output_tile_size must be 1 for per-tile classification, " - f"got {self.output_tile_size}" - ) - # Virchow2's pooled embedding is 2560-d (class token + patch-token mean, - # 1280 each). Guard against an artifact_uri pointing to a head trained - # for a different foundation model, which would otherwise fail with a - # cryptic shape error on the first session.run mid-slide. - expected_embedding_dim = 2560 - onnx_input_dim = int(self.session.get_inputs()[0].shape[-1]) - if onnx_input_dim != expected_embedding_dim: - raise ValueError( - f"ONNX head expects input width {onnx_input_dim}, but the " - f"Virchow2 embedding is {expected_embedding_dim}-d; the " - f"artifact_uri likely points to a head for a different " - f"foundation model" - ) - self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] @@ -190,9 +139,8 @@ async def predict( ) batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) - # The ONNX graph ends in a Softmax, so this already returns per-class - # probabilities of shape (batch, n_classes). Reshape each row to a - # (n_classes, 1, 1) map for HeatmapBuilder. + # ONNX graph ends in Softmax -> (batch, n_classes) probabilities. + # Reshape each row to (n_classes, 1, 1) for HeatmapBuilder. probs = self.session.run( [self.output_name], {self.input_name: batch},