Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions helm/rayservice/applications/breast-grading-virchow2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
- name: breast-grading-virchow2
import_path: models.breast_grading_virchow2:app
route_prefix: /breast-grading-virchow2
runtime_env:
working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip
env_vars:
MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow
HF_HOME: /mnt/huggingface_cache
pip:
- timm

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

timm lib is already included in Dockerfile.gpu

- pillow
deployments:
- name: BreastCancerGradingVirchow2
max_ongoing_requests: 512
max_queued_requests: 4096
autoscaling_config:
min_replicas: 0
max_replicas: 2
target_ongoing_requests: 64
ray_actor_options:
num_cpus: 2
num_gpus: 1
memory: 12884901888
user_config:
tile_size: 224
output_tile_size: 1
n_channels: 4
mpp: 0.46
max_batch_size: 128
batch_wait_timeout_s: 0.05
foundation_model_id: virchow2
model:
_target_: providers.model_provider:mlflow
artifact_uri: mlflow-artifacts:/37/484fdb26a5394af4bc76e387e7c93c89/artifacts/grading_head
1 change: 1 addition & 0 deletions helm/rayservice/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ applications:
- prostate-classifier-1
- prov-gigapath
- virchow2
- breast-grading-virchow2
188 changes: 188 additions & 0 deletions models/breast_grading_virchow2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import asyncio
import importlib
from pathlib import Path
from typing import Any, TypedDict

import numpy as np
import torch
from fastapi import FastAPI, HTTPException, Request
from numpy.typing import NDArray
from PIL import Image
from ray import serve


class Config(TypedDict):
tile_size: int
output_tile_size: int
n_channels: int
mpp: float
max_batch_size: int
batch_wait_timeout_s: float
foundation_model_id: str
model: dict[str, Any]


fastapi = FastAPI()


@serve.deployment(num_replicas="auto")
@serve.ingress(fastapi)
class BreastCancerGradingVirchow2:
def __init__(self) -> None:
import lz4.frame

self.lz4 = lz4.frame

def reconfigure(self, config: Config) -> None:
import onnxruntime as ort
import timm
from timm.data.config import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers.mlp import SwiGLUPacked

# Grid and slide resolution metadata needed by universal builders
self.tile_size = config["tile_size"]
self.output_tile_size = config["output_tile_size"]
self.n_channels = config["n_channels"]
self.mpp = config["mpp"]

# Connect this deployment to the cluster's running Virchow2 service
self.foundation_model = serve.get_app_handle(config["foundation_model_id"])

# Instantiates an offline token skeleton to match the exact Virchow2 transform logic
virchow2 = timm.create_model(
"hf-hub:paige-ai/Virchow2",
pretrained=False,
num_classes=0,
mlp_layer=SwiGLUPacked,
act_layer=torch.nn.SiLU,
)

self.foundation_transform = create_transform(
**resolve_data_config(virchow2.pretrained_cfg, model=virchow2)
)

# Parse and fetch your trained 4-class linear head ONNX file from MLflow
model_config = dict(config["model"])
module_path, attr_name = model_config.pop("_target_").split(":")
provider = getattr(importlib.import_module(module_path), attr_name)

downloaded_path = Path(provider(**model_config))
candidates = list(downloaded_path.rglob("model.onnx"))

if not candidates:
raise FileNotFoundError(
"Downloaded MLflow artifact path is a directory, "
f"but no model.onnx was found under: {downloaded_path}"
)

model_path = candidates[0]

self.session = ort.InferenceSession(
str(model_path),
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
Comment thread
melovskak marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

self.input_name = self.session.get_inputs()[0].name
self.output_name = self.session.get_outputs()[0].name

self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined]
self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined]

async def get_config(self) -> dict[str, Any]:
return {
"tile_size": self.tile_size,
"output_tile_size": self.output_tile_size,
"n_channels": self.n_channels,
"mpp": self.mpp,
}

def _prepare_tile_for_virchow2(self, tile_chw: NDArray[np.uint8]) -> torch.Tensor:
# Flip layouts from Channel-Height-Width back to standard Image arrays
tile_hwc = tile_chw.transpose(1, 2, 0)
image = Image.fromarray(tile_hwc)

# Returns the normalized [3, 224, 224] patch structure
return self.foundation_transform(image)

async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray:
tile_tensor = await asyncio.to_thread(self._prepare_tile_for_virchow2, tile)

# Execute remote pipeline call to the shared Virchow2 microservice
virchow2_output = await self.foundation_model.predict.remote(tile_tensor)

if isinstance(virchow2_output, np.ndarray):
virchow2_output = torch.from_numpy(virchow2_output)

# Virchow2 predict returns one tensor per tile, shape [tokens, dim].
# Make it [1, tokens, dim] so pooling is batch-compatible.
if virchow2_output.ndim == 2:
virchow2_output = virchow2_output.unsqueeze(0)

class_token = virchow2_output[:, 0]
patch_tokens = virchow2_output[:, 5:]
embedding = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1)

return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False)

@serve.batch
async def _predict_head(
self,
embeddings: list[NDArray[np.float32]],
) -> list[NDArray[np.float32]]:

batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False)

# Evaluate the [B, 4] output matrix out of your exported ONNX linear model
logits = self.session.run(
[self.output_name],
{self.input_name: batch},
)[0]

return [row.reshape(self.n_channels, 1, 1).astype(np.float32) for row in logits]

# Entry point takes exactly ONE tile at a time from root
async def predict(
self,
tile: NDArray[np.uint8],
) -> NDArray[np.float32]:
embedding = await self._create_embedding(tile)
return await self._predict_head(embedding) # returns 4 raw logits per tile

@fastapi.post("/")
async def root(self, request: Request) -> list[list[list[float]]]:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

async def root(self, request: Request) -> list[list[list[float]]]:

Suggested change
async def root(self, request: Request) -> list[list[list[float]]]:
async def root(self, request: Request) -> Response:

# Capture raw incoming network request body bytes
body_bytes = await request.body()

try:
# 1. Unzip raw compressed image tile bytes asynchronously in a thread worker
data = await asyncio.to_thread(self.lz4.decompress, body_bytes)

# 2. Size check
expected_bytes = self.tile_size * self.tile_size * 3
if len(data) != expected_bytes:
raise ValueError(
f"Decompressed payload byte length mismatch. "
f"Expected exactly {expected_bytes} bytes, but got {len(data)}."
)

# 3. Reconstruct the raw pixel array
tile = np.frombuffer(data, dtype=np.uint8).reshape(
self.tile_size,
self.tile_size,
3,
)
except (RuntimeError, ValueError) as err:
# 4. Gracefully map decompression or reshape shape errors to a clean HTTP 400
raise HTTPException(
status_code=400,
detail=f"Malformed or invalid compressed tile image payload: {err!s}",
) from err
Comment on lines +154 to +180

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove try-except block

Suggested change
# Capture raw incoming network request body bytes
body_bytes = await request.body()
try:
# 1. Unzip raw compressed image tile bytes asynchronously in a thread worker
data = await asyncio.to_thread(self.lz4.decompress, body_bytes)
# 2. Size check
expected_bytes = self.tile_size * self.tile_size * 3
if len(data) != expected_bytes:
raise ValueError(
f"Decompressed payload byte length mismatch. "
f"Expected exactly {expected_bytes} bytes, but got {len(data)}."
)
# 3. Reconstruct the raw pixel array
tile = np.frombuffer(data, dtype=np.uint8).reshape(
self.tile_size,
self.tile_size,
3,
)
except (RuntimeError, ValueError) as err:
# 4. Gracefully map decompression or reshape shape errors to a clean HTTP 400
raise HTTPException(
status_code=400,
detail=f"Malformed or invalid compressed tile image payload: {err!s}",
) from err
data = await asyncio.to_thread(lz4.frame.decompress, await request.body())
image = np.frombuffer(data, dtype=np.uint8).reshape(
self.tile_size, self.tile_size, 3
)


tile_chw = tile.transpose(2, 0, 1)
result = await self.predict(tile_chw)

return result.tolist()


app = BreastCancerGradingVirchow2.bind() # type: ignore[attr-defined]