Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions helm/rayservice/applications/breast-cancer-virchow2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
- name: breast-cancer-virchow2
import_path: models.breast_cancer_virchow2:app
route_prefix: /breast-cancer-virchow2
runtime_env:
working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip
Comment thread
kahood23 marked this conversation as resolved.
env_vars:
MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow
HF_HOME: /mnt/huggingface_cache
pip:
- timm
- pillow
Comment thread
kahood23 marked this conversation as resolved.
deployments:
- name: BreastCancerVirchow2
max_ongoing_requests: 512
max_queued_requests: 4096
autoscaling_config:
min_replicas: 0
max_replicas: 2
target_ongoing_requests: 64
ray_actor_options:
num_cpus: 2
num_gpus: 1
memory: 12884901888
user_config:
tile_size: 224
output_tile_size: 1
n_channels: 1
mpp: 0.46
max_batch_size: 128
batch_wait_timeout_s: 0.05
foundation_model_id: virchow2
model:
_target_: providers.model_provider:mlflow
artifact_uri: mlflow-artifacts:/2/bc79482fb539496ea9a2a43479150956/artifacts/model
1 change: 1 addition & 0 deletions helm/rayservice/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ applications:
- prostate-classifier-1
- prov-gigapath
- virchow2
- breast-cancer-virchow2
180 changes: 180 additions & 0 deletions models/breast_cancer_virchow2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import asyncio
import importlib
from pathlib import Path
from typing import Any, TypedDict

import numpy as np
import torch
from fastapi import FastAPI, Request
from numpy.typing import NDArray
from PIL import Image
from ray import serve


class Config(TypedDict):
tile_size: int
output_tile_size: int
n_channels: int
mpp: float
max_batch_size: int
batch_wait_timeout_s: float
foundation_model_id: str
model: dict[str, Any]


fastapi = FastAPI()


@serve.deployment(num_replicas="auto")
@serve.ingress(fastapi)
class BreastCancerVirchow2:
def __init__(self) -> None:
import lz4.frame

self.lz4 = lz4.frame

def reconfigure(self, config: Config) -> None:
import onnxruntime as ort
import timm
from timm.data.config import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers.mlp import SwiGLUPacked

self.tile_size = config["tile_size"]
self.output_tile_size = config["output_tile_size"]
self.n_channels = config["n_channels"]
self.mpp = config["mpp"]

self.foundation_model = serve.get_app_handle(config["foundation_model_id"])

# Only used to construct the correct Virchow2 transform.
# The real embeddings are produced by the deployed Virchow2 service.
virchow2 = timm.create_model(
"hf-hub:paige-ai/Virchow2",
pretrained=False,
num_classes=0,
mlp_layer=SwiGLUPacked,
act_layer=torch.nn.SiLU,
)
Comment on lines +52 to +58

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you creating full virchow?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The internal Virchow2 predict method expects a preprocessed torch.Tensor, because the transform is only applied in the HTTP root endpoint. Since breast-cancer-virchow2 calls foundation_model.predict.remote(...) directly, it must apply the same transform before calling Virchow2, so I use this to get the exact same transformation


self.foundation_transform = create_transform(
**resolve_data_config(virchow2.pretrained_cfg, model=virchow2)
)

model_config = dict(config["model"])
module_path, attr_name = model_config.pop("_target_").split(":")
provider = getattr(importlib.import_module(module_path), attr_name)

downloaded_path = Path(provider(**model_config))

candidates = list(downloaded_path.rglob("model.onnx"))

if not candidates:
raise FileNotFoundError(
"Downloaded MLflow artifact path is a directory, "
f"but no model.onnx was found under: {downloaded_path}"
)

model_path = candidates[0]

Comment thread
kahood23 marked this conversation as resolved.
if not model_path.exists():
raise FileNotFoundError(f"ONNX model file not found: {model_path}")

self.session = ort.InferenceSession(
str(model_path),
providers=["CUDAExecutionProvider"],
)
Comment on lines +83 to +86

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intended to run it only on CPUs?


self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name

# Batching should happen only for the ONNX head, after Virchow2 embeddings
# have already been produced for individual tiles.
self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined]
self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined]

async def get_config(self) -> dict[str, Any]:
return {
"tile_size": self.tile_size,
"output_tile_size": self.output_tile_size,
"n_channels": self.n_channels,
"mpp": self.mpp,
}

def _prepare_tile_for_virchow2(self, tile_chw: NDArray[np.uint8]) -> torch.Tensor:
tile_hwc = tile_chw.transpose(1, 2, 0)
image = Image.fromarray(tile_hwc)

# Important: return [3, 224, 224], not [1, 3, 224, 224].
return self.foundation_transform(image)
Comment thread
kahood23 marked this conversation as resolved.

async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray:
tile_tensor = await asyncio.to_thread(self._prepare_tile_for_virchow2, tile)

# Intentionally send a single tile to the foundation model.
# Batching is handled inside the Virchow2 service, not here.
virchow2_output = await self.foundation_model.predict.remote(tile_tensor)

if isinstance(virchow2_output, np.ndarray):
virchow2_output = torch.from_numpy(virchow2_output)

# Virchow2 predict returns one tensor per tile, shape [tokens, dim].
# Make it [1, tokens, dim] so the pooling code is batch-compatible.
if virchow2_output.ndim == 2:
virchow2_output = virchow2_output.unsqueeze(0)

class_token = virchow2_output[:, 0]
patch_tokens = virchow2_output[:, 5:]

embedding = torch.cat(
[class_token, patch_tokens.mean(dim=1)],
dim=-1,
)

return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False)

@serve.batch
async def _predict_head(
self,
embeddings: list[NDArray[np.float32]],
) -> list[NDArray[np.float32]]:
batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False)

probabilities = self.session.run(
[self.output_name],
{self.input_name: batch},
)[0]

# Important for heatmap-builder:
# each tile prediction must be 2D or 3D.
# For one scalar probability per tile, return a 1x1 map.
return [
np.asarray([[float(prob)]], dtype=np.float32)
for prob in probabilities.reshape(-1)
]

async def predict(
self,
tile: NDArray[np.uint8],
) -> NDArray[np.float32]:
embedding = await self._create_embedding(tile)
return await self._predict_head(embedding)

@fastapi.post("/")
async def root(self, request: Request) -> list[list[float]]:
data = await asyncio.to_thread(self.lz4.decompress, await request.body())

tile = np.frombuffer(data, dtype=np.uint8).reshape(
self.tile_size,
self.tile_size,
3,
)

tile_chw = tile.transpose(2, 0, 1)

result = await self.predict(tile_chw)
Comment thread
kahood23 marked this conversation as resolved.

return result.tolist()


app = BreastCancerVirchow2.bind() # type: ignore[attr-defined]