From 0246465c41d46e128ddb36db15aab92396d18c62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Tue, 2 Jun 2026 08:46:40 +0000 Subject: [PATCH 01/21] feat: breast cancer virchow model added --- .../applications/breast-cancer-virchow2.yaml | 38 +++++ helm/rayservice/values.yaml | 1 + models/breast_cancer_virchow2.py | 145 ++++++++++++++++++ 3 files changed, 184 insertions(+) create mode 100644 helm/rayservice/applications/breast-cancer-virchow2.yaml create mode 100644 models/breast_cancer_virchow2.py diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml new file mode 100644 index 0000000..2678a84 --- /dev/null +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -0,0 +1,38 @@ +- name: breast-cancer-virchow2 + import_path: models.breast_cancer_virchow2:app + route_prefix: /breast-cancer-virchow2 + runtime_env: + working_dir: https://github.com/RationAI/model-service/archive/refs/heads/feature/breast-cancer-virchow.zip + deployments: + - name: BreastCancerVirchow2 + max_ongoing_requests: 512 + max_queued_requests: 4096 + autoscaling_config: + min_replicas: 0 + max_replicas: 4 + target_ongoing_requests: 128 + ray_actor_options: + num_cpus: 4 + num_gpus: 1 + memory: 12884901888 + runtime_env: + env_vars: + MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow + HF_HOME: /mnt/huggingface_cache + pip: + - onnxruntime-gpu + - mlflow<3.0 + - lz4 + - timm + - pillow + user_config: + tile_size: 224 + output_tile_size: 1 + n_channels: 1 + mpp: 0.46 + max_batch_size: 128 + batch_wait_timeout_s: 0.05 + foundation_model_id: virchow2 + model: + _target_: providers.model_provider:mlflow + artifact_uri: mlflow-artifacts:/2/bc79482fb539496ea9a2a43479150956/artifacts/model/model.onnx \ No newline at end of file diff --git a/helm/rayservice/values.yaml b/helm/rayservice/values.yaml index 6e62751..cca36ce 100644 --- a/helm/rayservice/values.yaml +++ b/helm/rayservice/values.yaml @@ -8,3 +8,4 @@ applications: - prostate-classifier-1 - prov-gigapath - virchow2 + - breast-cancer-virchow2.yaml diff --git a/models/breast_cancer_virchow2.py b/models/breast_cancer_virchow2.py new file mode 100644 index 0000000..b8fedc8 --- /dev/null +++ b/models/breast_cancer_virchow2.py @@ -0,0 +1,145 @@ +import asyncio +import importlib +from typing import Any, TypedDict + +import numpy as np +import onnxruntime as ort +import torch +from fastapi import FastAPI, Request +from numpy.typing import NDArray +from PIL import Image +from ray import serve + + +class Config(TypedDict): + tile_size: int + output_tile_size: int + n_channels: int + mpp: float + max_batch_size: int + batch_wait_timeout_s: float + foundation_model_id: str + model: dict[str, Any] + + +fastapi = FastAPI() + + +@serve.deployment(num_replicas="auto") +@serve.ingress(fastapi) +class BreastCancerVirchow2: + def __init__(self) -> None: + import lz4.frame + + self.lz4 = lz4.frame + + def reconfigure(self, config: Config) -> None: + import timm + from timm.data.config import resolve_data_config + from timm.data.transforms_factory import create_transform + from timm.layers import SwiGLUPacked + + self.tile_size = config["tile_size"] + self.output_tile_size = config["output_tile_size"] + self.n_channels = config["n_channels"] + self.mpp = config["mpp"] + + self.foundation_model = serve.get_app_handle(config["foundation_model_id"]) + + virchow2 = timm.create_model( + "hf-hub:paige-ai/Virchow2", + pretrained=False, + num_classes=0, + mlp_layer=SwiGLUPacked, + act_layer=torch.nn.SiLU, + ) + + self.foundation_transform = create_transform( + **resolve_data_config(virchow2.pretrained_cfg, model=virchow2) + ) + + model_config = dict(config["model"]) + module_path, attr_name = model_config.pop("_target_").split(":") + provider = getattr(importlib.import_module(module_path), attr_name) + + self.session = ort.InferenceSession( + provider(**model_config), + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) + + self.input_name = self.session.get_inputs()[0].name + self.output_name = self.session.get_outputs()[0].name + + self.predict.set_max_batch_size(config["max_batch_size"]) + self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) + + async def get_config(self) -> dict[str, Any]: + return { + "tile_size": self.tile_size, + "output_tile_size": self.output_tile_size, + "n_channels": self.n_channels, + "mpp": self.mpp, + } + + def _prepare_tile_for_virchow2(self, tile_chw: NDArray[np.uint8]) -> torch.Tensor: + tile_hwc = tile_chw.transpose(1, 2, 0) + image = Image.fromarray(tile_hwc) + tensor = self.foundation_transform(image) + + if tensor.ndim == 3: + tensor = tensor.unsqueeze(0) + + return tensor + + async def _create_embedding(self, tile_chw: NDArray[np.uint8]) -> np.ndarray: + tile_tensor = self._prepare_tile_for_virchow2(tile_chw) + + virchow2_output = await self.foundation_model.predict.remote(tile_tensor) + + if isinstance(virchow2_output, np.ndarray): + virchow2_output = torch.from_numpy(virchow2_output) + + if virchow2_output.ndim == 2: + virchow2_output = virchow2_output.unsqueeze(0) + + class_token = virchow2_output[:, 0] + patch_tokens = virchow2_output[:, 5:] + + embedding = torch.cat( + [class_token, patch_tokens.mean(dim=1)], + dim=-1, + ) + + return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) + + @serve.batch + async def predict(self, tiles: list[NDArray[np.uint8]]) -> list[float]: + embeddings = await asyncio.gather( + *(self._create_embedding(tile) for tile in tiles) + ) + + batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) + + probabilities = self.session.run( + [self.output_name], + {self.input_name: batch}, + )[0] + + return probabilities.reshape(-1).astype(float).tolist() + + @fastapi.post("/") + async def root(self, request: Request) -> float: + data = await asyncio.to_thread(self.lz4.decompress, await request.body()) + + tile = np.frombuffer(data, dtype=np.uint8).reshape( + self.tile_size, + self.tile_size, + 3, + ) + + tile_chw = tile.transpose(2, 0, 1) + + return await self.predict(tile_chw) + + +app = BreastCancerVirchow2.bind() \ No newline at end of file From 5b5e286e47a439e6bc2fe0d87f16483d4a3d0624 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Tue, 2 Jun 2026 16:02:53 +0000 Subject: [PATCH 02/21] feat: update bc model output from predict --- models/breast_cancer_virchow2.py | 34 +++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/models/breast_cancer_virchow2.py b/models/breast_cancer_virchow2.py index b8fedc8..5b06941 100644 --- a/models/breast_cancer_virchow2.py +++ b/models/breast_cancer_virchow2.py @@ -37,7 +37,7 @@ def reconfigure(self, config: Config) -> None: import timm from timm.data.config import resolve_data_config from timm.data.transforms_factory import create_transform - from timm.layers import SwiGLUPacked + from timm.layers.mlp import SwiGLUPacked self.tile_size = config["tile_size"] self.output_tile_size = config["output_tile_size"] @@ -70,8 +70,8 @@ def reconfigure(self, config: Config) -> None: self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name - self.predict.set_max_batch_size(config["max_batch_size"]) - self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) + self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] + self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] async def get_config(self) -> dict[str, Any]: return { @@ -84,12 +84,9 @@ async def get_config(self) -> dict[str, Any]: def _prepare_tile_for_virchow2(self, tile_chw: NDArray[np.uint8]) -> torch.Tensor: tile_hwc = tile_chw.transpose(1, 2, 0) image = Image.fromarray(tile_hwc) - tensor = self.foundation_transform(image) - if tensor.ndim == 3: - tensor = tensor.unsqueeze(0) - - return tensor + # Important: return [3, 224, 224], not [1, 3, 224, 224]. + return self.foundation_transform(image) async def _create_embedding(self, tile_chw: NDArray[np.uint8]) -> np.ndarray: tile_tensor = self._prepare_tile_for_virchow2(tile_chw) @@ -99,6 +96,8 @@ async def _create_embedding(self, tile_chw: NDArray[np.uint8]) -> np.ndarray: if isinstance(virchow2_output, np.ndarray): virchow2_output = torch.from_numpy(virchow2_output) + # Public Virchow2 predict returns one tensor per tile, shape [tokens, dim]. + # Make it [1, tokens, dim] so the pooling code is batch-compatible. if virchow2_output.ndim == 2: virchow2_output = virchow2_output.unsqueeze(0) @@ -113,7 +112,10 @@ async def _create_embedding(self, tile_chw: NDArray[np.uint8]) -> np.ndarray: return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) @serve.batch - async def predict(self, tiles: list[NDArray[np.uint8]]) -> list[float]: + async def predict( + self, + tiles: list[NDArray[np.uint8]], + ) -> list[NDArray[np.float32]]: embeddings = await asyncio.gather( *(self._create_embedding(tile) for tile in tiles) ) @@ -125,10 +127,16 @@ async def predict(self, tiles: list[NDArray[np.uint8]]) -> list[float]: {self.input_name: batch}, )[0] - return probabilities.reshape(-1).astype(float).tolist() + # Important for heatmap-builder: + # each tile prediction must be 2D or 3D. + # For one scalar probability per tile, return a 1x1 map. + return [ + np.asarray([[float(prob)]], dtype=np.float32) + for prob in probabilities.reshape(-1) + ] @fastapi.post("/") - async def root(self, request: Request) -> float: + async def root(self, request: Request) -> list[list[float]]: data = await asyncio.to_thread(self.lz4.decompress, await request.body()) tile = np.frombuffer(data, dtype=np.uint8).reshape( @@ -139,7 +147,9 @@ async def root(self, request: Request) -> float: tile_chw = tile.transpose(2, 0, 1) - return await self.predict(tile_chw) + result = await self.predict(tile_chw) + + return result.tolist() app = BreastCancerVirchow2.bind() \ No newline at end of file From b97d7ef29e0eba64a4e47f29d332b467b27fd899 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Tue, 2 Jun 2026 16:29:50 +0000 Subject: [PATCH 03/21] fix: values format --- helm/rayservice/values.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helm/rayservice/values.yaml b/helm/rayservice/values.yaml index cca36ce..190a4c8 100644 --- a/helm/rayservice/values.yaml +++ b/helm/rayservice/values.yaml @@ -8,4 +8,4 @@ applications: - prostate-classifier-1 - prov-gigapath - virchow2 - - breast-cancer-virchow2.yaml + - breast-cancer-virchow2 From 8b57b683f72c1a381d1aab914cf92c6824d151a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Tue, 2 Jun 2026 16:46:42 +0000 Subject: [PATCH 04/21] fix: pip dependencies --- .../applications/breast-cancer-virchow2.yaml | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index 2678a84..a07124f 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -3,6 +3,15 @@ route_prefix: /breast-cancer-virchow2 runtime_env: working_dir: https://github.com/RationAI/model-service/archive/refs/heads/feature/breast-cancer-virchow.zip + env_vars: + MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow + HF_HOME: /mnt/huggingface_cache + pip: + - onnxruntime-gpu + - mlflow<3.0 + - lz4 + - timm + - pillow deployments: - name: BreastCancerVirchow2 max_ongoing_requests: 512 @@ -15,16 +24,6 @@ num_cpus: 4 num_gpus: 1 memory: 12884901888 - runtime_env: - env_vars: - MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow - HF_HOME: /mnt/huggingface_cache - pip: - - onnxruntime-gpu - - mlflow<3.0 - - lz4 - - timm - - pillow user_config: tile_size: 224 output_tile_size: 1 From bf284201367c94cc86954d62d8fe7fcc73961fbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Tue, 2 Jun 2026 17:03:43 +0000 Subject: [PATCH 05/21] fix: load onnx as a whole folder --- .../applications/breast-cancer-virchow2.yaml | 2 +- models/breast_cancer_virchow2.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index a07124f..62078dc 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -34,4 +34,4 @@ foundation_model_id: virchow2 model: _target_: providers.model_provider:mlflow - artifact_uri: mlflow-artifacts:/2/bc79482fb539496ea9a2a43479150956/artifacts/model/model.onnx \ No newline at end of file + artifact_uri: mlflow-artifacts:/2/bc79482fb539496ea9a2a43479150956/artifacts/model \ No newline at end of file diff --git a/models/breast_cancer_virchow2.py b/models/breast_cancer_virchow2.py index 5b06941..9ee4077 100644 --- a/models/breast_cancer_virchow2.py +++ b/models/breast_cancer_virchow2.py @@ -62,9 +62,16 @@ def reconfigure(self, config: Config) -> None: module_path, attr_name = model_config.pop("_target_").split(":") provider = getattr(importlib.import_module(module_path), attr_name) + from pathlib import Path + + model_path = Path(provider(**model_config)) + + if model_path.is_dir(): + model_path = model_path / "model.onnx" + self.session = ort.InferenceSession( - provider(**model_config), - providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + str(model_path), + providers=["CPUExecutionProvider"], ) self.input_name = self.session.get_inputs()[0].name From 064373fb2d50b5e8ea16ef1edc3a96426b48677d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Tue, 2 Jun 2026 17:09:24 +0000 Subject: [PATCH 06/21] fix: imports placement --- helm/rayservice/applications/breast-cancer-virchow2.yaml | 2 +- models/breast_cancer_virchow2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index 62078dc..8cafa97 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -7,7 +7,7 @@ MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow HF_HOME: /mnt/huggingface_cache pip: - - onnxruntime-gpu + - onnxruntime - mlflow<3.0 - lz4 - timm diff --git a/models/breast_cancer_virchow2.py b/models/breast_cancer_virchow2.py index 9ee4077..0345d03 100644 --- a/models/breast_cancer_virchow2.py +++ b/models/breast_cancer_virchow2.py @@ -3,7 +3,6 @@ from typing import Any, TypedDict import numpy as np -import onnxruntime as ort import torch from fastapi import FastAPI, Request from numpy.typing import NDArray @@ -35,6 +34,7 @@ def __init__(self) -> None: def reconfigure(self, config: Config) -> None: import timm + import onnxruntime as ort from timm.data.config import resolve_data_config from timm.data.transforms_factory import create_transform from timm.layers.mlp import SwiGLUPacked From f3375e16fd27c8e860d431431a765f09fefcaba8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Tue, 2 Jun 2026 17:15:30 +0000 Subject: [PATCH 07/21] fix: model path --- models/breast_cancer_virchow2.py | 42 +++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/models/breast_cancer_virchow2.py b/models/breast_cancer_virchow2.py index 0345d03..3b0ff86 100644 --- a/models/breast_cancer_virchow2.py +++ b/models/breast_cancer_virchow2.py @@ -1,5 +1,6 @@ import asyncio import importlib +from pathlib import Path from typing import Any, TypedDict import numpy as np @@ -33,8 +34,8 @@ def __init__(self) -> None: self.lz4 = lz4.frame def reconfigure(self, config: Config) -> None: - import timm import onnxruntime as ort + import timm from timm.data.config import resolve_data_config from timm.data.transforms_factory import create_transform from timm.layers.mlp import SwiGLUPacked @@ -46,6 +47,8 @@ def reconfigure(self, config: Config) -> None: self.foundation_model = serve.get_app_handle(config["foundation_model_id"]) + # Only used to construct the correct Virchow2 transform. + # The real embeddings are produced by the deployed Virchow2 service. virchow2 = timm.create_model( "hf-hub:paige-ai/Virchow2", pretrained=False, @@ -62,12 +65,39 @@ def reconfigure(self, config: Config) -> None: module_path, attr_name = model_config.pop("_target_").split(":") provider = getattr(importlib.import_module(module_path), attr_name) - from pathlib import Path + downloaded_path = Path(provider(**model_config)) + + if downloaded_path.is_dir(): + candidates = list(downloaded_path.rglob("model.onnx")) + + if not candidates: + raise FileNotFoundError( + "Downloaded MLflow artifact path is a directory, " + f"but no model.onnx was found under: {downloaded_path}" + ) + + model_path = candidates[0] + + elif downloaded_path.name == "model.onnx": + model_path = downloaded_path + + else: + candidates = list(downloaded_path.parent.rglob("model.onnx")) + + if not candidates: + raise FileNotFoundError( + "Downloaded MLflow artifact path is not model.onnx and no " + f"model.onnx was found nearby. Downloaded path: {downloaded_path}" + ) + + model_path = candidates[0] - model_path = Path(provider(**model_config)) + if not model_path.exists(): + raise FileNotFoundError(f"ONNX model file not found: {model_path}") - if model_path.is_dir(): - model_path = model_path / "model.onnx" + print(f"Using ONNX model path: {model_path}") + print(f"ONNX model exists: {model_path.exists()}") + print(f"ONNX model is file: {model_path.is_file()}") self.session = ort.InferenceSession( str(model_path), @@ -103,7 +133,7 @@ async def _create_embedding(self, tile_chw: NDArray[np.uint8]) -> np.ndarray: if isinstance(virchow2_output, np.ndarray): virchow2_output = torch.from_numpy(virchow2_output) - # Public Virchow2 predict returns one tensor per tile, shape [tokens, dim]. + # Virchow2 predict returns one tensor per tile, shape [tokens, dim]. # Make it [1, tokens, dim] so the pooling code is batch-compatible. if virchow2_output.ndim == 2: virchow2_output = virchow2_output.unsqueeze(0) From 5102177ae353010954a2cc0809e87d9f6a2e1a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Tue, 2 Jun 2026 17:27:23 +0000 Subject: [PATCH 08/21] feat: full sha --- helm/rayservice/applications/breast-cancer-virchow2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index 8cafa97..48d3874 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -2,7 +2,7 @@ import_path: models.breast_cancer_virchow2:app route_prefix: /breast-cancer-virchow2 runtime_env: - working_dir: https://github.com/RationAI/model-service/archive/refs/heads/feature/breast-cancer-virchow.zip + working_dir: https://github.com/RationAI/model-service/archive/f3375e16fd27c8e860d431431a765f09fefcaba8.zip env_vars: MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow HF_HOME: /mnt/huggingface_cache From 71df25955413467231209067ed2882d1d97a3ce8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Tue, 2 Jun 2026 18:49:00 +0000 Subject: [PATCH 09/21] fix: formatting and typing --- models/breast_cancer_virchow2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/breast_cancer_virchow2.py b/models/breast_cancer_virchow2.py index 3b0ff86..9fc42d5 100644 --- a/models/breast_cancer_virchow2.py +++ b/models/breast_cancer_virchow2.py @@ -189,4 +189,4 @@ async def root(self, request: Request) -> list[list[float]]: return result.tolist() -app = BreastCancerVirchow2.bind() \ No newline at end of file +app = BreastCancerVirchow2.bind() # type: ignore[attr-defined] From 9beddced0d578e7939d8ad53a38703d704527e1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Wed, 3 Jun 2026 06:13:18 +0000 Subject: [PATCH 10/21] fix: num gpus set to 0 --- helm/rayservice/applications/breast-cancer-virchow2.yaml | 2 +- models/breast_cancer_virchow2.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index 48d3874..a7dc36c 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -22,7 +22,7 @@ target_ongoing_requests: 128 ray_actor_options: num_cpus: 4 - num_gpus: 1 + num_gpus: 0 memory: 12884901888 user_config: tile_size: 224 diff --git a/models/breast_cancer_virchow2.py b/models/breast_cancer_virchow2.py index 9fc42d5..6e448e4 100644 --- a/models/breast_cancer_virchow2.py +++ b/models/breast_cancer_virchow2.py @@ -125,8 +125,8 @@ def _prepare_tile_for_virchow2(self, tile_chw: NDArray[np.uint8]) -> torch.Tenso # Important: return [3, 224, 224], not [1, 3, 224, 224]. return self.foundation_transform(image) - async def _create_embedding(self, tile_chw: NDArray[np.uint8]) -> np.ndarray: - tile_tensor = self._prepare_tile_for_virchow2(tile_chw) + async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: + tile_tensor = await asyncio.to_thread(self._prepare_tile_for_virchow2, tile) virchow2_output = await self.foundation_model.predict.remote(tile_tensor) From d4604cc7789488e6c74c91afb4f1230a588091a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Wed, 3 Jun 2026 06:15:20 +0000 Subject: [PATCH 11/21] feat: updated sha --- helm/rayservice/applications/breast-cancer-virchow2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index a7dc36c..7ea68c6 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -2,7 +2,7 @@ import_path: models.breast_cancer_virchow2:app route_prefix: /breast-cancer-virchow2 runtime_env: - working_dir: https://github.com/RationAI/model-service/archive/f3375e16fd27c8e860d431431a765f09fefcaba8.zip + working_dir: https://github.com/RationAI/model-service/archive/9beddced0d578e7939d8ad53a38703d704527e1f.zip env_vars: MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow HF_HOME: /mnt/huggingface_cache From cc4092fe8b1ecc1af24abd4963329b5a3bd4e3b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Wed, 3 Jun 2026 06:27:10 +0000 Subject: [PATCH 12/21] feat: return num gpu to 1 --- helm/rayservice/applications/breast-cancer-virchow2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index 7ea68c6..a2272aa 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -22,7 +22,7 @@ target_ongoing_requests: 128 ray_actor_options: num_cpus: 4 - num_gpus: 0 + num_gpus: 1 memory: 12884901888 user_config: tile_size: 224 From 378f6057a2ffe8cb3c409c4561d83740ef526877 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Wed, 3 Jun 2026 06:27:48 +0000 Subject: [PATCH 13/21] fix: update sha --- helm/rayservice/applications/breast-cancer-virchow2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index a2272aa..c4bdacd 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -2,7 +2,7 @@ import_path: models.breast_cancer_virchow2:app route_prefix: /breast-cancer-virchow2 runtime_env: - working_dir: https://github.com/RationAI/model-service/archive/9beddced0d578e7939d8ad53a38703d704527e1f.zip + working_dir: https://github.com/RationAI/model-service/archive/cc4092fe8b1ecc1af24abd4963329b5a3bd4e3b6.zip env_vars: MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow HF_HOME: /mnt/huggingface_cache From 3ea57f1efca9be7eee8e1e7a61fd88a21dc324b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Wed, 3 Jun 2026 07:00:43 +0000 Subject: [PATCH 14/21] feat: try to fasten heatmap builder after faster single tile prediction --- helm/rayservice/applications/breast-cancer-virchow2.yaml | 8 ++++---- helm/rayservice/applications/heatmap-builder.yaml | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index c4bdacd..5096655 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -17,11 +17,11 @@ max_ongoing_requests: 512 max_queued_requests: 4096 autoscaling_config: - min_replicas: 0 - max_replicas: 4 - target_ongoing_requests: 128 + min_replicas: 1 + max_replicas: 2 + target_ongoing_requests: 64 ray_actor_options: - num_cpus: 4 + num_cpus: 2 num_gpus: 1 memory: 12884901888 user_config: diff --git a/helm/rayservice/applications/heatmap-builder.yaml b/helm/rayservice/applications/heatmap-builder.yaml index f39bf49..c06c39f 100644 --- a/helm/rayservice/applications/heatmap-builder.yaml +++ b/helm/rayservice/applications/heatmap-builder.yaml @@ -15,5 +15,5 @@ num_cpus: 8 memory: 12884901888 user_config: - num_threads: 8 - max_concurrent_tasks: 16 + num_threads: 16 + max_concurrent_tasks: 64 From a1ee087cdeaa129f0b11175dc25dbbe1948c25c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Wed, 3 Jun 2026 07:02:32 +0000 Subject: [PATCH 15/21] feat: update sha --- helm/rayservice/applications/breast-cancer-virchow2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index 5096655..5e31e65 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -2,7 +2,7 @@ import_path: models.breast_cancer_virchow2:app route_prefix: /breast-cancer-virchow2 runtime_env: - working_dir: https://github.com/RationAI/model-service/archive/cc4092fe8b1ecc1af24abd4963329b5a3bd4e3b6.zip + working_dir: https://github.com/RationAI/model-service/archive/3ea57f1efca9be7eee8e1e7a61fd88a21dc324b1.zip env_vars: MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow HF_HOME: /mnt/huggingface_cache From d35bd5baec133823bb30bdf84e0a4f1f7faed4d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Wed, 3 Jun 2026 08:02:07 +0000 Subject: [PATCH 16/21] feat: point heatmap and virchow to branch --- helm/rayservice/applications/heatmap-builder.yaml | 2 +- helm/rayservice/applications/virchow2.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/helm/rayservice/applications/heatmap-builder.yaml b/helm/rayservice/applications/heatmap-builder.yaml index c06c39f..e0b76ef 100644 --- a/helm/rayservice/applications/heatmap-builder.yaml +++ b/helm/rayservice/applications/heatmap-builder.yaml @@ -2,7 +2,7 @@ import_path: builders.heatmap_builder:app route_prefix: /heatmap-builder runtime_env: - working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip + working_dir: https://github.com/RationAI/model-service/archive/3ea57f1efca9be7eee8e1e7a61fd88a21dc324b1.zip deployments: - name: HeatmapBuilder max_ongoing_requests: 16 diff --git a/helm/rayservice/applications/virchow2.yaml b/helm/rayservice/applications/virchow2.yaml index eaac069..b1fc127 100644 --- a/helm/rayservice/applications/virchow2.yaml +++ b/helm/rayservice/applications/virchow2.yaml @@ -4,7 +4,7 @@ runtime_env: config: setup_timeout_seconds: 1800 - working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip + working_dir: https://github.com/RationAI/model-service/archive/3ea57f1efca9be7eee8e1e7a61fd88a21dc324b1.zip deployments: - name: Virchow2 max_ongoing_requests: 1024 From 6b5b6c0399e02a06eada9743b43322d49005c116 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Thu, 4 Jun 2026 11:33:45 +0000 Subject: [PATCH 17/21] feat: point working dir to main --- helm/rayservice/applications/breast-cancer-virchow2.yaml | 5 +---- helm/rayservice/applications/heatmap-builder.yaml | 2 +- helm/rayservice/applications/virchow2.yaml | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index 5e31e65..06dd9fa 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -2,14 +2,11 @@ import_path: models.breast_cancer_virchow2:app route_prefix: /breast-cancer-virchow2 runtime_env: - working_dir: https://github.com/RationAI/model-service/archive/3ea57f1efca9be7eee8e1e7a61fd88a21dc324b1.zip + working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip env_vars: MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow HF_HOME: /mnt/huggingface_cache pip: - - onnxruntime - - mlflow<3.0 - - lz4 - timm - pillow deployments: diff --git a/helm/rayservice/applications/heatmap-builder.yaml b/helm/rayservice/applications/heatmap-builder.yaml index e0b76ef..c06c39f 100644 --- a/helm/rayservice/applications/heatmap-builder.yaml +++ b/helm/rayservice/applications/heatmap-builder.yaml @@ -2,7 +2,7 @@ import_path: builders.heatmap_builder:app route_prefix: /heatmap-builder runtime_env: - working_dir: https://github.com/RationAI/model-service/archive/3ea57f1efca9be7eee8e1e7a61fd88a21dc324b1.zip + working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip deployments: - name: HeatmapBuilder max_ongoing_requests: 16 diff --git a/helm/rayservice/applications/virchow2.yaml b/helm/rayservice/applications/virchow2.yaml index b1fc127..eaac069 100644 --- a/helm/rayservice/applications/virchow2.yaml +++ b/helm/rayservice/applications/virchow2.yaml @@ -4,7 +4,7 @@ runtime_env: config: setup_timeout_seconds: 1800 - working_dir: https://github.com/RationAI/model-service/archive/3ea57f1efca9be7eee8e1e7a61fd88a21dc324b1.zip + working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip deployments: - name: Virchow2 max_ongoing_requests: 1024 From 9a8835273e34d6a19c6ad0d3114864cb52f2e807 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Thu, 4 Jun 2026 11:34:33 +0000 Subject: [PATCH 18/21] fix: revert changes in heatmap builder --- helm/rayservice/applications/heatmap-builder.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/helm/rayservice/applications/heatmap-builder.yaml b/helm/rayservice/applications/heatmap-builder.yaml index c06c39f..f39bf49 100644 --- a/helm/rayservice/applications/heatmap-builder.yaml +++ b/helm/rayservice/applications/heatmap-builder.yaml @@ -15,5 +15,5 @@ num_cpus: 8 memory: 12884901888 user_config: - num_threads: 16 - max_concurrent_tasks: 64 + num_threads: 8 + max_concurrent_tasks: 16 From b384791d9e734b2bff2bbce8434fda246f5874dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Thu, 4 Jun 2026 16:29:00 +0000 Subject: [PATCH 19/21] fix: review changes, move to gpu and remove min replicas --- helm/rayservice/applications/breast-cancer-virchow2.yaml | 2 +- models/breast_cancer_virchow2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/helm/rayservice/applications/breast-cancer-virchow2.yaml b/helm/rayservice/applications/breast-cancer-virchow2.yaml index 06dd9fa..ca61377 100644 --- a/helm/rayservice/applications/breast-cancer-virchow2.yaml +++ b/helm/rayservice/applications/breast-cancer-virchow2.yaml @@ -14,7 +14,7 @@ max_ongoing_requests: 512 max_queued_requests: 4096 autoscaling_config: - min_replicas: 1 + min_replicas: 0 max_replicas: 2 target_ongoing_requests: 64 ray_actor_options: diff --git a/models/breast_cancer_virchow2.py b/models/breast_cancer_virchow2.py index 6e448e4..ddc40e4 100644 --- a/models/breast_cancer_virchow2.py +++ b/models/breast_cancer_virchow2.py @@ -101,7 +101,7 @@ def reconfigure(self, config: Config) -> None: self.session = ort.InferenceSession( str(model_path), - providers=["CPUExecutionProvider"], + providers=["GPUExecutionProvider"], ) self.input_name = self.session.get_inputs()[0].name From 3f8913351d6618defbe3648a6dbca424bbcc6681 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Thu, 4 Jun 2026 17:12:06 +0000 Subject: [PATCH 20/21] fix: cuda instead of gpu naming inconsistency --- models/breast_cancer_virchow2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/breast_cancer_virchow2.py b/models/breast_cancer_virchow2.py index ddc40e4..46b3a28 100644 --- a/models/breast_cancer_virchow2.py +++ b/models/breast_cancer_virchow2.py @@ -101,7 +101,7 @@ def reconfigure(self, config: Config) -> None: self.session = ort.InferenceSession( str(model_path), - providers=["GPUExecutionProvider"], + providers=["CUDAExecutionProvider"], ) self.input_name = self.session.get_inputs()[0].name From ba2eb95445f539740826779121a2a4c7bdfd8f02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bc=2E=20Katar=C3=ADna=20Hudcovicov=C3=A1?= <514657@mail.muni.cz> Date: Fri, 5 Jun 2026 07:44:09 +0000 Subject: [PATCH 21/21] feat: batching update --- models/breast_cancer_virchow2.py | 56 +++++++++++++------------------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/models/breast_cancer_virchow2.py b/models/breast_cancer_virchow2.py index 46b3a28..081e460 100644 --- a/models/breast_cancer_virchow2.py +++ b/models/breast_cancer_virchow2.py @@ -67,38 +67,19 @@ def reconfigure(self, config: Config) -> None: downloaded_path = Path(provider(**model_config)) - if downloaded_path.is_dir(): - candidates = list(downloaded_path.rglob("model.onnx")) + candidates = list(downloaded_path.rglob("model.onnx")) - if not candidates: - raise FileNotFoundError( - "Downloaded MLflow artifact path is a directory, " - f"but no model.onnx was found under: {downloaded_path}" - ) + if not candidates: + raise FileNotFoundError( + "Downloaded MLflow artifact path is a directory, " + f"but no model.onnx was found under: {downloaded_path}" + ) - model_path = candidates[0] - - elif downloaded_path.name == "model.onnx": - model_path = downloaded_path - - else: - candidates = list(downloaded_path.parent.rglob("model.onnx")) - - if not candidates: - raise FileNotFoundError( - "Downloaded MLflow artifact path is not model.onnx and no " - f"model.onnx was found nearby. Downloaded path: {downloaded_path}" - ) - - model_path = candidates[0] + model_path = candidates[0] if not model_path.exists(): raise FileNotFoundError(f"ONNX model file not found: {model_path}") - print(f"Using ONNX model path: {model_path}") - print(f"ONNX model exists: {model_path.exists()}") - print(f"ONNX model is file: {model_path.is_file()}") - self.session = ort.InferenceSession( str(model_path), providers=["CUDAExecutionProvider"], @@ -107,8 +88,10 @@ def reconfigure(self, config: Config) -> None: self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name - self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] - self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] + # Batching should happen only for the ONNX head, after Virchow2 embeddings + # have already been produced for individual tiles. + self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] + self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] async def get_config(self) -> dict[str, Any]: return { @@ -128,6 +111,8 @@ def _prepare_tile_for_virchow2(self, tile_chw: NDArray[np.uint8]) -> torch.Tenso async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: tile_tensor = await asyncio.to_thread(self._prepare_tile_for_virchow2, tile) + # Intentionally send a single tile to the foundation model. + # Batching is handled inside the Virchow2 service, not here. virchow2_output = await self.foundation_model.predict.remote(tile_tensor) if isinstance(virchow2_output, np.ndarray): @@ -149,14 +134,10 @@ async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) @serve.batch - async def predict( + async def _predict_head( self, - tiles: list[NDArray[np.uint8]], + embeddings: list[NDArray[np.float32]], ) -> list[NDArray[np.float32]]: - embeddings = await asyncio.gather( - *(self._create_embedding(tile) for tile in tiles) - ) - batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) probabilities = self.session.run( @@ -172,6 +153,13 @@ async def predict( for prob in probabilities.reshape(-1) ] + async def predict( + self, + tile: NDArray[np.uint8], + ) -> NDArray[np.float32]: + embedding = await self._create_embedding(tile) + return await self._predict_head(embedding) + @fastapi.post("/") async def root(self, request: Request) -> list[list[float]]: data = await asyncio.to_thread(self.lz4.decompress, await request.body())