diff --git a/helm/rayservice/applications/breast-grading-virchow2.yaml b/helm/rayservice/applications/breast-grading-virchow2.yaml new file mode 100644 index 0000000..335a3a6 --- /dev/null +++ b/helm/rayservice/applications/breast-grading-virchow2.yaml @@ -0,0 +1,34 @@ +- name: breast-grading-virchow2 + import_path: models.breast_grading_virchow2:app + route_prefix: /breast-grading-virchow2 + runtime_env: + working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip + env_vars: + MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow + HF_HOME: /mnt/huggingface_cache + pip: + - timm + - pillow + deployments: + - name: BreastCancerGradingVirchow2 + max_ongoing_requests: 512 + max_queued_requests: 4096 + autoscaling_config: + min_replicas: 0 + max_replicas: 2 + target_ongoing_requests: 64 + ray_actor_options: + num_cpus: 2 + num_gpus: 1 + memory: 12884901888 + user_config: + tile_size: 224 + output_tile_size: 1 + n_channels: 4 + mpp: 0.46 + max_batch_size: 128 + batch_wait_timeout_s: 0.05 + foundation_model_id: virchow2 + model: + _target_: providers.model_provider:mlflow + artifact_uri: mlflow-artifacts:/37/484fdb26a5394af4bc76e387e7c93c89/artifacts/grading_head \ No newline at end of file diff --git a/helm/rayservice/values.yaml b/helm/rayservice/values.yaml index 6e62751..b4ac860 100644 --- a/helm/rayservice/values.yaml +++ b/helm/rayservice/values.yaml @@ -8,3 +8,4 @@ applications: - prostate-classifier-1 - prov-gigapath - virchow2 + - breast-grading-virchow2 diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py new file mode 100644 index 0000000..b0a3fb9 --- /dev/null +++ b/models/breast_grading_virchow2.py @@ -0,0 +1,188 @@ +import asyncio +import importlib +from pathlib import Path +from typing import Any, TypedDict + +import numpy as np +import torch +from fastapi import FastAPI, HTTPException, Request +from numpy.typing import NDArray +from PIL import Image +from ray import serve + + +class Config(TypedDict): + tile_size: int + output_tile_size: int + n_channels: int + mpp: float + max_batch_size: int + batch_wait_timeout_s: float + foundation_model_id: str + model: dict[str, Any] + + +fastapi = FastAPI() + + +@serve.deployment(num_replicas="auto") +@serve.ingress(fastapi) +class BreastCancerGradingVirchow2: + def __init__(self) -> None: + import lz4.frame + + self.lz4 = lz4.frame + + def reconfigure(self, config: Config) -> None: + import onnxruntime as ort + import timm + from timm.data.config import resolve_data_config + from timm.data.transforms_factory import create_transform + from timm.layers.mlp import SwiGLUPacked + + # Grid and slide resolution metadata needed by universal builders + self.tile_size = config["tile_size"] + self.output_tile_size = config["output_tile_size"] + self.n_channels = config["n_channels"] + self.mpp = config["mpp"] + + # Connect this deployment to the cluster's running Virchow2 service + self.foundation_model = serve.get_app_handle(config["foundation_model_id"]) + + # Instantiates an offline token skeleton to match the exact Virchow2 transform logic + virchow2 = timm.create_model( + "hf-hub:paige-ai/Virchow2", + pretrained=False, + num_classes=0, + mlp_layer=SwiGLUPacked, + act_layer=torch.nn.SiLU, + ) + + self.foundation_transform = create_transform( + **resolve_data_config(virchow2.pretrained_cfg, model=virchow2) + ) + + # Parse and fetch your trained 4-class linear head ONNX file from MLflow + model_config = dict(config["model"]) + module_path, attr_name = model_config.pop("_target_").split(":") + provider = getattr(importlib.import_module(module_path), attr_name) + + downloaded_path = Path(provider(**model_config)) + candidates = list(downloaded_path.rglob("model.onnx")) + + if not candidates: + raise FileNotFoundError( + "Downloaded MLflow artifact path is a directory, " + f"but no model.onnx was found under: {downloaded_path}" + ) + + model_path = candidates[0] + + self.session = ort.InferenceSession( + str(model_path), + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + + self.input_name = self.session.get_inputs()[0].name + self.output_name = self.session.get_outputs()[0].name + + self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] + self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] + + async def get_config(self) -> dict[str, Any]: + return { + "tile_size": self.tile_size, + "output_tile_size": self.output_tile_size, + "n_channels": self.n_channels, + "mpp": self.mpp, + } + + def _prepare_tile_for_virchow2(self, tile_chw: NDArray[np.uint8]) -> torch.Tensor: + # Flip layouts from Channel-Height-Width back to standard Image arrays + tile_hwc = tile_chw.transpose(1, 2, 0) + image = Image.fromarray(tile_hwc) + + # Returns the normalized [3, 224, 224] patch structure + return self.foundation_transform(image) + + async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: + tile_tensor = await asyncio.to_thread(self._prepare_tile_for_virchow2, tile) + + # Execute remote pipeline call to the shared Virchow2 microservice + virchow2_output = await self.foundation_model.predict.remote(tile_tensor) + + if isinstance(virchow2_output, np.ndarray): + virchow2_output = torch.from_numpy(virchow2_output) + + # Virchow2 predict returns one tensor per tile, shape [tokens, dim]. + # Make it [1, tokens, dim] so pooling is batch-compatible. + if virchow2_output.ndim == 2: + virchow2_output = virchow2_output.unsqueeze(0) + + class_token = virchow2_output[:, 0] + patch_tokens = virchow2_output[:, 5:] + embedding = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1) + + return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) + + @serve.batch + async def _predict_head( + self, + embeddings: list[NDArray[np.float32]], + ) -> list[NDArray[np.float32]]: + + batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) + + # Evaluate the [B, 4] output matrix out of your exported ONNX linear model + logits = self.session.run( + [self.output_name], + {self.input_name: batch}, + )[0] + + return [row.reshape(self.n_channels, 1, 1).astype(np.float32) for row in logits] + + # Entry point takes exactly ONE tile at a time from root + async def predict( + self, + tile: NDArray[np.uint8], + ) -> NDArray[np.float32]: + embedding = await self._create_embedding(tile) + return await self._predict_head(embedding) # returns 4 raw logits per tile + + @fastapi.post("/") + async def root(self, request: Request) -> list[list[list[float]]]: + # Capture raw incoming network request body bytes + body_bytes = await request.body() + + try: + # 1. Unzip raw compressed image tile bytes asynchronously in a thread worker + data = await asyncio.to_thread(self.lz4.decompress, body_bytes) + + # 2. Size check + expected_bytes = self.tile_size * self.tile_size * 3 + if len(data) != expected_bytes: + raise ValueError( + f"Decompressed payload byte length mismatch. " + f"Expected exactly {expected_bytes} bytes, but got {len(data)}." + ) + + # 3. Reconstruct the raw pixel array + tile = np.frombuffer(data, dtype=np.uint8).reshape( + self.tile_size, + self.tile_size, + 3, + ) + except (RuntimeError, ValueError) as err: + # 4. Gracefully map decompression or reshape shape errors to a clean HTTP 400 + raise HTTPException( + status_code=400, + detail=f"Malformed or invalid compressed tile image payload: {err!s}", + ) from err + + tile_chw = tile.transpose(2, 0, 1) + result = await self.predict(tile_chw) + + return result.tolist() + + +app = BreastCancerGradingVirchow2.bind() # type: ignore[attr-defined]