-
Notifications
You must be signed in to change notification settings - Fork 1
feat: 7-class tissue classifier on Virchow2 embeddings #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vojtech-cifka
wants to merge
2
commits into
main
Choose a base branch
from
feature/tissue-linear-virchow2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| - name: tissue-linear | ||
| import_path: models.tissue_linear:app | ||
| route_prefix: /tissue-linear | ||
| runtime_env: | ||
| config: | ||
| setup_timeout_seconds: 1800 | ||
| working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip | ||
| deployments: | ||
| - name: TissueLinear | ||
| max_ongoing_requests: 512 | ||
| max_queued_requests: 4096 | ||
| autoscaling_config: | ||
| min_replicas: 0 | ||
| max_replicas: 2 | ||
| target_ongoing_requests: 128 | ||
| ray_actor_options: | ||
| num_cpus: 2 | ||
| num_gpus: 1 | ||
|
vojtech-cifka marked this conversation as resolved.
|
||
| memory: 12884901888 | ||
|
vojtech-cifka marked this conversation as resolved.
|
||
| runtime_env: | ||
| env_vars: | ||
| MLFLOW_TRACKING_URI: http://mlflow.rationai-mlflow:5000 | ||
| HF_HOME: /mnt/huggingface_cache | ||
| user_config: | ||
| tile_size: 224 | ||
| output_tile_size: 1 | ||
| n_channels: 7 | ||
| mpp: 0.5 | ||
| max_batch_size: 64 | ||
|
vojtech-cifka marked this conversation as resolved.
|
||
| batch_wait_timeout_s: 0.1 | ||
| foundation_model_id: virchow2 | ||
| model: | ||
| _target_: providers.model_provider:mlflow | ||
| artifact_uri: mlflow-artifacts:/104/0e2230c722134ce0985e09a18ccadf75/artifacts/onnx/linear_head.onnx | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,4 +7,6 @@ applications: | |
| - heatmap-builder | ||
| - prostate-classifier-1 | ||
| - prov-gigapath | ||
| - tissue-linear | ||
| - virchow2 | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import importlib | ||
| from typing import TYPE_CHECKING, Any, TypedDict | ||
|
|
||
| import numpy as np | ||
| from fastapi import FastAPI, Request | ||
| from numpy.typing import NDArray | ||
| from ray import serve | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| import torch | ||
|
|
||
|
|
||
| class Config(TypedDict): | ||
| tile_size: int | ||
| output_tile_size: int | ||
| n_channels: int | ||
| mpp: float | ||
| max_batch_size: int | ||
| batch_wait_timeout_s: float | ||
| foundation_model_id: str | ||
| model: dict[str, Any] | ||
|
|
||
|
|
||
| fastapi = FastAPI() | ||
|
|
||
|
|
||
| @serve.deployment(num_replicas="auto") | ||
| @serve.ingress(fastapi) | ||
| class TissueLinear: | ||
| """7-class tissue classifier: linear head over Virchow2 embeddings. | ||
|
|
||
| Per tile: apply Virchow2's transform, fetch the ViT token sequence from | ||
| the deployed Virchow2 service, pool tokens (class token + mean of patch | ||
| tokens) into a 2560-d embedding, run the ONNX linear head, and return a | ||
| 7-channel softmax probability map of shape (n_classes, 1, 1). Softmax | ||
| (rather than a hard class index) is used so HeatmapBuilder's resize to | ||
| source resolution interpolates well-defined probabilities; the hard class | ||
| map is recoverable via argmax over channels at full resolution. | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| import lz4.frame | ||
|
|
||
| self.lz4 = lz4.frame | ||
|
|
||
| def reconfigure(self, config: Config) -> None: | ||
| import onnxruntime as ort | ||
| from timm.data.transforms_factory import create_transform | ||
|
|
||
| self.tile_size = config["tile_size"] | ||
| self.output_tile_size = config["output_tile_size"] | ||
| self.n_channels = config["n_channels"] | ||
|
vojtech-cifka marked this conversation as resolved.
|
||
| self.mpp = config["mpp"] | ||
|
|
||
| self.foundation_model = serve.get_app_handle(config["foundation_model_id"]) | ||
|
|
||
| # Virchow2's eval transform, built from its pretrained_cfg (ImageNet | ||
| # mean/std, bicubic, crop_pct 1.0) instead of loading the ~600M-param | ||
| # model just to read it: saves ~2.4 GB RAM/replica and avoids HF Hub | ||
| # access at init (gated repo). Embeddings come from the Virchow2 service. | ||
| self.foundation_transform = create_transform( | ||
| input_size=(3, self.tile_size, self.tile_size), | ||
| is_training=False, | ||
| mean=(0.485, 0.456, 0.406), | ||
| std=(0.229, 0.224, 0.225), | ||
| crop_pct=1.0, | ||
| crop_mode="center", | ||
| interpolation="bicubic", | ||
| ) | ||
|
|
||
| model_config = dict(config["model"]) | ||
| module_path, attr_name = model_config.pop("_target_").split(":") | ||
| provider = getattr(importlib.import_module(module_path), attr_name) | ||
|
|
||
| model_path = provider(**model_config) | ||
|
|
||
| # Head runs on CPU: a single 2560->n_classes matmul, so GPU launch + | ||
| # transfer overhead would exceed it, and embeddings arrive as CPU numpy. | ||
| # num_gpus: 1 only pins the actor to a worker with torch/timm. | ||
| self.session = ort.InferenceSession( | ||
| str(model_path), | ||
| providers=["CPUExecutionProvider"], | ||
| ) | ||
|
vojtech-cifka marked this conversation as resolved.
|
||
| self.input_name = self.session.get_inputs()[0].name | ||
| self.output_name = self.session.get_outputs()[0].name | ||
| self._num_classes = int(self.session.get_outputs()[0].shape[-1]) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] | ||
| self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] | ||
|
|
||
| async def get_config(self) -> dict[str, Any]: | ||
| return { | ||
| "tile_size": self.tile_size, | ||
| "output_tile_size": self.output_tile_size, | ||
| "n_channels": self.n_channels, | ||
| "mpp": self.mpp, | ||
| } | ||
|
|
||
| def _prepare_tile_for_virchow2(self, tile_chw: NDArray[np.uint8]) -> torch.Tensor: | ||
| from PIL import Image | ||
|
|
||
| tile_hwc = tile_chw.transpose(1, 2, 0) | ||
| image = Image.fromarray(tile_hwc) | ||
| # Return [3, 224, 224], not [1, 3, 224, 224]. | ||
| return self.foundation_transform(image) | ||
|
vojtech-cifka marked this conversation as resolved.
|
||
|
|
||
| async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: | ||
| import torch | ||
|
|
||
| tile_tensor = await asyncio.to_thread(self._prepare_tile_for_virchow2, tile) | ||
|
|
||
| virchow2_output = await self.foundation_model.predict.remote(tile_tensor) | ||
|
vojtech-cifka marked this conversation as resolved.
|
||
|
|
||
| if isinstance(virchow2_output, np.ndarray): | ||
| virchow2_output = torch.from_numpy(virchow2_output) | ||
|
|
||
| # Virchow2 predict returns one tensor per tile, shape [tokens, dim]. | ||
| # Make it [1, tokens, dim] so pooling is batch-compatible. | ||
| if virchow2_output.ndim == 2: | ||
| virchow2_output = virchow2_output.unsqueeze(0) | ||
|
|
||
| class_token = virchow2_output[:, 0] | ||
| patch_tokens = virchow2_output[:, 5:] | ||
| embedding = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1) | ||
|
|
||
| return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) | ||
|
|
||
| @serve.batch | ||
| async def predict( | ||
| self, | ||
| tiles: list[NDArray[np.uint8]], | ||
| ) -> list[NDArray[np.float32]]: | ||
| embeddings = await asyncio.gather( | ||
| *(self._create_embedding(tile) for tile in tiles) | ||
| ) | ||
| batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) | ||
|
|
||
| # ONNX graph ends in Softmax -> (batch, n_classes) probabilities. | ||
| # Reshape each row to (n_classes, 1, 1) for HeatmapBuilder. | ||
| probs = self.session.run( | ||
| [self.output_name], | ||
| {self.input_name: batch}, | ||
| )[0] | ||
|
vojtech-cifka marked this conversation as resolved.
|
||
|
|
||
| return [row.reshape(self._num_classes, 1, 1) for row in probs] | ||
|
|
||
| @fastapi.post("/") | ||
| async def root(self, request: Request) -> list[Any]: | ||
| data = await asyncio.to_thread(self.lz4.decompress, await request.body()) | ||
|
|
||
| tile = np.frombuffer(data, dtype=np.uint8).reshape( | ||
| self.tile_size, | ||
| self.tile_size, | ||
| 3, | ||
| ) | ||
| tile_chw = tile.transpose(2, 0, 1) | ||
|
|
||
| result = await self.predict(tile_chw) | ||
| return result.tolist() | ||
|
vojtech-cifka marked this conversation as resolved.
|
||
|
|
||
|
|
||
| app = TissueLinear.bind() # type: ignore[attr-defined] | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.