-
Notifications
You must be signed in to change notification settings - Fork 1
Feature/breast cancer virchow #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0246465
5b5e286
b97d7ef
8b57b68
bf28420
064373f
f3375e1
5102177
71df259
9beddce
d4604cc
cc4092f
378f605
3ea57f1
a1ee087
d35bd5b
6b5b6c0
9a88352
b384791
3f89133
ba2eb95
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| - name: breast-cancer-virchow2 | ||
| import_path: models.breast_cancer_virchow2:app | ||
| route_prefix: /breast-cancer-virchow2 | ||
| runtime_env: | ||
| working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip | ||
| env_vars: | ||
| MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow | ||
| HF_HOME: /mnt/huggingface_cache | ||
| pip: | ||
| - timm | ||
| - pillow | ||
|
kahood23 marked this conversation as resolved.
|
||
| deployments: | ||
| - name: BreastCancerVirchow2 | ||
| max_ongoing_requests: 512 | ||
| max_queued_requests: 4096 | ||
| autoscaling_config: | ||
| min_replicas: 0 | ||
| max_replicas: 2 | ||
| target_ongoing_requests: 64 | ||
| ray_actor_options: | ||
| num_cpus: 2 | ||
| num_gpus: 1 | ||
| memory: 12884901888 | ||
| user_config: | ||
| tile_size: 224 | ||
| output_tile_size: 1 | ||
| n_channels: 1 | ||
| mpp: 0.46 | ||
| max_batch_size: 128 | ||
| batch_wait_timeout_s: 0.05 | ||
| foundation_model_id: virchow2 | ||
| model: | ||
| _target_: providers.model_provider:mlflow | ||
| artifact_uri: mlflow-artifacts:/2/bc79482fb539496ea9a2a43479150956/artifacts/model | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,3 +8,4 @@ applications: | |
| - prostate-classifier-1 | ||
| - prov-gigapath | ||
| - virchow2 | ||
| - breast-cancer-virchow2 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| import asyncio | ||
| import importlib | ||
| from pathlib import Path | ||
| from typing import Any, TypedDict | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from fastapi import FastAPI, Request | ||
| from numpy.typing import NDArray | ||
| from PIL import Image | ||
| from ray import serve | ||
|
|
||
|
|
||
| class Config(TypedDict): | ||
| tile_size: int | ||
| output_tile_size: int | ||
| n_channels: int | ||
| mpp: float | ||
| max_batch_size: int | ||
| batch_wait_timeout_s: float | ||
| foundation_model_id: str | ||
| model: dict[str, Any] | ||
|
|
||
|
|
||
| fastapi = FastAPI() | ||
|
|
||
|
|
||
| @serve.deployment(num_replicas="auto") | ||
| @serve.ingress(fastapi) | ||
| class BreastCancerVirchow2: | ||
| def __init__(self) -> None: | ||
| import lz4.frame | ||
|
|
||
| self.lz4 = lz4.frame | ||
|
|
||
| def reconfigure(self, config: Config) -> None: | ||
| import onnxruntime as ort | ||
| import timm | ||
| from timm.data.config import resolve_data_config | ||
| from timm.data.transforms_factory import create_transform | ||
| from timm.layers.mlp import SwiGLUPacked | ||
|
|
||
| self.tile_size = config["tile_size"] | ||
| self.output_tile_size = config["output_tile_size"] | ||
| self.n_channels = config["n_channels"] | ||
| self.mpp = config["mpp"] | ||
|
|
||
| self.foundation_model = serve.get_app_handle(config["foundation_model_id"]) | ||
|
|
||
| # Only used to construct the correct Virchow2 transform. | ||
| # The real embeddings are produced by the deployed Virchow2 service. | ||
| virchow2 = timm.create_model( | ||
| "hf-hub:paige-ai/Virchow2", | ||
| pretrained=False, | ||
| num_classes=0, | ||
| mlp_layer=SwiGLUPacked, | ||
| act_layer=torch.nn.SiLU, | ||
| ) | ||
|
Comment on lines
+52
to
+58
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you creating full virchow?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The internal Virchow2 |
||
|
|
||
| self.foundation_transform = create_transform( | ||
| **resolve_data_config(virchow2.pretrained_cfg, model=virchow2) | ||
| ) | ||
|
|
||
| model_config = dict(config["model"]) | ||
| module_path, attr_name = model_config.pop("_target_").split(":") | ||
| provider = getattr(importlib.import_module(module_path), attr_name) | ||
|
|
||
| downloaded_path = Path(provider(**model_config)) | ||
|
|
||
| candidates = list(downloaded_path.rglob("model.onnx")) | ||
|
|
||
| if not candidates: | ||
| raise FileNotFoundError( | ||
| "Downloaded MLflow artifact path is a directory, " | ||
| f"but no model.onnx was found under: {downloaded_path}" | ||
| ) | ||
|
|
||
| model_path = candidates[0] | ||
|
|
||
|
kahood23 marked this conversation as resolved.
|
||
| if not model_path.exists(): | ||
| raise FileNotFoundError(f"ONNX model file not found: {model_path}") | ||
|
|
||
| self.session = ort.InferenceSession( | ||
| str(model_path), | ||
| providers=["CUDAExecutionProvider"], | ||
| ) | ||
|
Comment on lines
+83
to
+86
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it intended to run it only on CPUs? |
||
|
|
||
| self.input_name = self.session.get_inputs()[0].name | ||
| self.output_name = self.session.get_outputs()[0].name | ||
|
|
||
| # Batching should happen only for the ONNX head, after Virchow2 embeddings | ||
| # have already been produced for individual tiles. | ||
| self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] | ||
| self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] | ||
|
|
||
| async def get_config(self) -> dict[str, Any]: | ||
| return { | ||
| "tile_size": self.tile_size, | ||
| "output_tile_size": self.output_tile_size, | ||
| "n_channels": self.n_channels, | ||
| "mpp": self.mpp, | ||
| } | ||
|
|
||
| def _prepare_tile_for_virchow2(self, tile_chw: NDArray[np.uint8]) -> torch.Tensor: | ||
| tile_hwc = tile_chw.transpose(1, 2, 0) | ||
| image = Image.fromarray(tile_hwc) | ||
|
|
||
| # Important: return [3, 224, 224], not [1, 3, 224, 224]. | ||
| return self.foundation_transform(image) | ||
|
kahood23 marked this conversation as resolved.
|
||
|
|
||
| async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: | ||
| tile_tensor = await asyncio.to_thread(self._prepare_tile_for_virchow2, tile) | ||
|
|
||
| # Intentionally send a single tile to the foundation model. | ||
| # Batching is handled inside the Virchow2 service, not here. | ||
| virchow2_output = await self.foundation_model.predict.remote(tile_tensor) | ||
|
|
||
| if isinstance(virchow2_output, np.ndarray): | ||
| virchow2_output = torch.from_numpy(virchow2_output) | ||
|
|
||
| # Virchow2 predict returns one tensor per tile, shape [tokens, dim]. | ||
| # Make it [1, tokens, dim] so the pooling code is batch-compatible. | ||
| if virchow2_output.ndim == 2: | ||
| virchow2_output = virchow2_output.unsqueeze(0) | ||
|
|
||
| class_token = virchow2_output[:, 0] | ||
| patch_tokens = virchow2_output[:, 5:] | ||
|
|
||
| embedding = torch.cat( | ||
| [class_token, patch_tokens.mean(dim=1)], | ||
| dim=-1, | ||
| ) | ||
|
|
||
| return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) | ||
|
|
||
| @serve.batch | ||
| async def _predict_head( | ||
| self, | ||
| embeddings: list[NDArray[np.float32]], | ||
| ) -> list[NDArray[np.float32]]: | ||
| batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) | ||
|
|
||
| probabilities = self.session.run( | ||
| [self.output_name], | ||
| {self.input_name: batch}, | ||
| )[0] | ||
|
|
||
| # Important for heatmap-builder: | ||
| # each tile prediction must be 2D or 3D. | ||
| # For one scalar probability per tile, return a 1x1 map. | ||
| return [ | ||
| np.asarray([[float(prob)]], dtype=np.float32) | ||
| for prob in probabilities.reshape(-1) | ||
| ] | ||
|
|
||
| async def predict( | ||
| self, | ||
| tile: NDArray[np.uint8], | ||
| ) -> NDArray[np.float32]: | ||
| embedding = await self._create_embedding(tile) | ||
| return await self._predict_head(embedding) | ||
|
|
||
| @fastapi.post("/") | ||
| async def root(self, request: Request) -> list[list[float]]: | ||
| data = await asyncio.to_thread(self.lz4.decompress, await request.body()) | ||
|
|
||
| tile = np.frombuffer(data, dtype=np.uint8).reshape( | ||
| self.tile_size, | ||
| self.tile_size, | ||
| 3, | ||
| ) | ||
|
|
||
| tile_chw = tile.transpose(2, 0, 1) | ||
|
|
||
| result = await self.predict(tile_chw) | ||
|
kahood23 marked this conversation as resolved.
|
||
|
|
||
| return result.tolist() | ||
|
|
||
|
|
||
| app = BreastCancerVirchow2.bind() # type: ignore[attr-defined] | ||
Uh oh!
There was an error while loading. Please reload this page.