From 2cbf684566a56449596a73e9adccc47b4ebe80be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Fri, 5 Jun 2026 15:36:17 +0000 Subject: [PATCH 01/21] feat: initial breast grading --- .../applications/breast-grading-virchow2.yaml | 34 ++++ models/breast_grading_virchow2.py | 180 ++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 helm/rayservice/applications/breast-grading-virchow2.yaml create mode 100644 models/breast_grading_virchow2.py diff --git a/helm/rayservice/applications/breast-grading-virchow2.yaml b/helm/rayservice/applications/breast-grading-virchow2.yaml new file mode 100644 index 0000000..a6c8eff --- /dev/null +++ b/helm/rayservice/applications/breast-grading-virchow2.yaml @@ -0,0 +1,34 @@ +- name: breast-grading-virchow2 + import_path: models.breast_grading_virchow2:app + route_prefix: /breast-grading-virchow2 + runtime_env: + working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip + env_vars: + MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow + HF_HOME: /mnt/huggingface_cache + pip: + - timm + - pillow + deployments: + - name: BreastCancerGradingVirchow2 + max_ongoing_requests: 512 + max_queued_requests: 4096 + autoscaling_config: + min_replicas: 0 + max_replicas: 2 + target_ongoing_requests: 64 + ray_actor_options: + num_cpus: 2 + num_gpus: 1 + memory: 12884901888 + user_config: + tile_size: 224 + output_tile_size: 1 + n_channels: 4 + mpp: 0.46 + max_batch_size: 128 + batch_wait_timeout_s: 0.05 + foundation_model_id: virchow2 + model: + _target_: providers.model_provider:mlflow + artifact_uri: mlflow-artifacts:/2/# TODO /artifacts/model \ No newline at end of file diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py new file mode 100644 index 0000000..196071a --- /dev/null +++ b/models/breast_grading_virchow2.py @@ -0,0 +1,180 @@ +import asyncio +import importlib +from pathlib import Path +from typing import Any, TypedDict + +import numpy as np +import torch +from fastapi import FastAPI, Request +from numpy.typing import NDArray +from PIL import Image +from ray import serve + + +class Config(TypedDict): + tile_size: int + output_tile_size: int + n_channels: int + mpp: float + max_batch_size: int + batch_wait_timeout_s: float + foundation_model_id: str + model: dict[str, Any] + + +fastapi = FastAPI() + + +@serve.deployment(num_replicas="auto") +@serve.ingress(fastapi) +class BreastCancerGradingVirchow2: + def __init__(self) -> None: + import lz4.frame + self.lz4 = lz4.frame + + def reconfigure(self, config: Config) -> None: + import onnxruntime as ort + import timm + from timm.data.config import resolve_data_config + from timm.data.transforms_factory import create_transform + from timm.layers.mlp import SwiGLUPacked + + # Grid and slide resolution metadata needed by universal builders + self.tile_size = config["tile_size"] + self.output_tile_size = config["output_tile_size"] + self.n_channels = config["n_channels"] + self.mpp = config["mpp"] + + # Connect this deployment to the cluster's running Virchow2 service + self.foundation_model = serve.get_app_handle(config["foundation_model_id"]) + + # Instantiates an offline token skeleton to match the exact Virchow2 transform logic + virchow2 = timm.create_model( + "hf-hub:paige-ai/Virchow2", + pretrained=False, + num_classes=0, + mlp_layer=SwiGLUPacked, + act_layer=torch.nn.SiLU, + ) + + self.foundation_transform = create_transform( + **resolve_data_config(virchow2.pretrained_cfg, model=virchow2) + ) + + # Parse and fetch your trained 4-class linear head ONNX file from MLflow + model_config = dict(config["model"]) + module_path, attr_name = model_config.pop("_target_").split(":") + provider = getattr(importlib.import_module(module_path), attr_name) + + downloaded_path = Path(provider(**model_config)) + candidates = list(downloaded_path.rglob("model.onnx")) + + if not candidates: + raise FileNotFoundError( + "Downloaded MLflow artifact path is a directory, " + f"but no model.onnx was found under: {downloaded_path}" + ) + + model_path = candidates[0] + + # Spin up your linear head ONNX session + self.session = ort.InferenceSession( + str(model_path), + providers=["CUDAExecutionProvider"], + ) + + self.input_name = self.session.get_inputs()[0].name + self.output_name = self.session.get_outputs()[0].name + + # Enforce micro-batching limits for your 4-class ONNX head evaluation pass + self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] + self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] + + async def get_config(self) -> dict[str, Any]: + return { + "tile_size": self.tile_size, + "output_tile_size": self.output_tile_size, + "n_channels": self.n_channels, + "mpp": self.mpp, + } + + def _prepare_tile_for_virchow2(self, tile_chw: NDArray[np.uint8]) -> torch.Tensor: + # Flip layouts from Channel-Height-Width back to standard Image arrays + tile_hwc = tile_chw.transpose(1, 2, 0) + image = Image.fromarray(tile_hwc) + + # Returns the normalized [3, 224, 224] patch structure + return self.foundation_transform(image) + + async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: + tile_tensor = await asyncio.to_thread(self._prepare_tile_for_virchow2, tile) + + # Execute remote pipeline call to the shared Virchow2 microservice + virchow2_output = await self.foundation_model.predict.remote(tile_tensor) + + if isinstance(virchow2_output, np.ndarray): + virchow2_output = torch.from_numpy(virchow2_output) + + if virchow2_output.ndim == 2: + virchow2_output = virchow2_output.unsqueeze(0) + + # Pool patch tokens matching the baseline foundation extraction layout + class_token = virchow2_output[:, 0] + patch_tokens = virchow2_output[:, 5:] + + embedding = torch.cat( + [class_token, patch_tokens.mean(dim=1)], + dim=-1, + ) + + return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) + + @serve.batch + async def _predict_head( + self, + embeddings: list[NDArray[np.float32]], + ) -> list[NDArray[np.float32]]: + batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) + + # Evaluates the batched tensors through your 4-class linear network layer + probabilities = self.session.run( + [self.output_name], + {self.input_name: batch}, + )[0] + + # Modified to match 4-class heatmap dimensions: + # Reshapes predictions to [1, 1, 4] so the universal system-level + # HeatmapBuilder maps tissue grades over 4 channels instead of a binary scalar. + return [ + prob.reshape(1, 1, 4).astype(np.float32) + for prob in probabilities + ] + + async def predict( + self, + tile: NDArray[np.uint8], + ) -> NDArray[np.float32]: + embedding = await self._create_embedding(tile) + return await self._predict_head(embedding) + + @fastapi.post("/") + async def root(self, request: Request) -> list[list[list[list[float]]]]: + # 1. Unzip raw compressed image tile bytes coming from network traffic + data = await asyncio.to_thread(self.lz4.decompress, await request.body()) + + # 2. Reconstruct the raw pixel array + tile = np.frombuffer(data, dtype=np.uint8).reshape( + self.tile_size, + self.tile_size, + 3, + ) + + tile_chw = tile.transpose(2, 0, 1) + + # 3. Fire pipeline (Raw tile -> Virchow2 embedding -> Your 4-class Head) + result = await self.predict(tile_chw) + + return result.tolist() + + +app = BreastCancerGradingVirchow2.bind() # type: ignore[attr-defined] \ No newline at end of file From 0f4957c78b7a81ef0d63e89d73cecab20f2e13a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Fri, 5 Jun 2026 16:57:23 +0000 Subject: [PATCH 02/21] fix: add model uri --- .../applications/breast-grading-virchow2.yaml | 6 +++--- helm/rayservice/values.yaml | 1 + models/breast_grading_virchow2.py | 10 ++++------ 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/helm/rayservice/applications/breast-grading-virchow2.yaml b/helm/rayservice/applications/breast-grading-virchow2.yaml index a6c8eff..226a529 100644 --- a/helm/rayservice/applications/breast-grading-virchow2.yaml +++ b/helm/rayservice/applications/breast-grading-virchow2.yaml @@ -2,7 +2,7 @@ import_path: models.breast_grading_virchow2:app route_prefix: /breast-grading-virchow2 runtime_env: - working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip + working_dir: https://github.com/RationAI/model-service/archive/refs/heads/feature/breast-cancer-grading.zip #after debug set to this: https://github.com/RationAI/model-service/archive/refs/heads/main.zip env_vars: MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow HF_HOME: /mnt/huggingface_cache @@ -24,11 +24,11 @@ user_config: tile_size: 224 output_tile_size: 1 - n_channels: 4 + n_channels: 1 mpp: 0.46 max_batch_size: 128 batch_wait_timeout_s: 0.05 foundation_model_id: virchow2 model: _target_: providers.model_provider:mlflow - artifact_uri: mlflow-artifacts:/2/# TODO /artifacts/model \ No newline at end of file + artifact_uri: mlflow-artifacts:/2/ad34158eedb945c2895cca84ddee7d2a/artifacts/model \ No newline at end of file diff --git a/helm/rayservice/values.yaml b/helm/rayservice/values.yaml index 6e62751..b4ac860 100644 --- a/helm/rayservice/values.yaml +++ b/helm/rayservice/values.yaml @@ -8,3 +8,4 @@ applications: - prostate-classifier-1 - prov-gigapath - virchow2 + - breast-grading-virchow2 diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index 196071a..f9ecd0c 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -30,6 +30,7 @@ class Config(TypedDict): class BreastCancerGradingVirchow2: def __init__(self) -> None: import lz4.frame + self.lz4 = lz4.frame def reconfigure(self, config: Config) -> None: @@ -143,12 +144,9 @@ async def _predict_head( )[0] # Modified to match 4-class heatmap dimensions: - # Reshapes predictions to [1, 1, 4] so the universal system-level + # Reshapes predictions to [1, 1, 4] so the universal system-level # HeatmapBuilder maps tissue grades over 4 channels instead of a binary scalar. - return [ - prob.reshape(1, 1, 4).astype(np.float32) - for prob in probabilities - ] + return [prob.reshape(1, 1, 4).astype(np.float32) for prob in probabilities] async def predict( self, @@ -177,4 +175,4 @@ async def root(self, request: Request) -> list[list[list[list[float]]]]: return result.tolist() -app = BreastCancerGradingVirchow2.bind() # type: ignore[attr-defined] \ No newline at end of file +app = BreastCancerGradingVirchow2.bind() # type: ignore[attr-defined] From 8ab5b08a53362288a55c0bc56b8844ad8edb38b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= Date: Fri, 5 Jun 2026 19:42:55 +0200 Subject: [PATCH 03/21] Update models/breast_grading_virchow2.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- models/breast_grading_virchow2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index f9ecd0c..181d1c0 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -156,7 +156,7 @@ async def predict( return await self._predict_head(embedding) @fastapi.post("/") - async def root(self, request: Request) -> list[list[list[list[float]]]]: + async def root(self, request: Request) -> list[list[list[float]]]: # 1. Unzip raw compressed image tile bytes coming from network traffic data = await asyncio.to_thread(self.lz4.decompress, await request.body()) From eb394f13682db03a39c0f428ec6ea7c0bcd595da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= Date: Fri, 5 Jun 2026 19:43:12 +0200 Subject: [PATCH 04/21] Update models/breast_grading_virchow2.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- models/breast_grading_virchow2.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index 181d1c0..e10f2de 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -138,10 +138,12 @@ async def _predict_head( batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) # Evaluates the batched tensors through your 4-class linear network layer - probabilities = self.session.run( + probabilities = await asyncio.to_thread( + self.session.run, [self.output_name], {self.input_name: batch}, - )[0] + ) + probabilities = probabilities[0] # Modified to match 4-class heatmap dimensions: # Reshapes predictions to [1, 1, 4] so the universal system-level From cb8fcd7fad339fba979f0f6ad830d4ddd0359ed9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= Date: Fri, 5 Jun 2026 19:43:25 +0200 Subject: [PATCH 05/21] Update models/breast_grading_virchow2.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- models/breast_grading_virchow2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index e10f2de..1ccf033 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -81,7 +81,7 @@ def reconfigure(self, config: Config) -> None: # Spin up your linear head ONNX session self.session = ort.InferenceSession( str(model_path), - providers=["CUDAExecutionProvider"], + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) self.input_name = self.session.get_inputs()[0].name From 5dd26d99872b9cf0e9e2f2c81c42439c13b8df4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= Date: Fri, 5 Jun 2026 19:43:56 +0200 Subject: [PATCH 06/21] Update models/breast_grading_virchow2.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- models/breast_grading_virchow2.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index 1ccf033..03ffa67 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -113,22 +113,22 @@ async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: # Execute remote pipeline call to the shared Virchow2 microservice virchow2_output = await self.foundation_model.predict.remote(tile_tensor) - if isinstance(virchow2_output, np.ndarray): - virchow2_output = torch.from_numpy(virchow2_output) + if isinstance(virchow2_output, torch.Tensor): + virchow2_output = virchow2_output.cpu().numpy() if virchow2_output.ndim == 2: - virchow2_output = virchow2_output.unsqueeze(0) + virchow2_output = np.expand_dims(virchow2_output, axis=0) # Pool patch tokens matching the baseline foundation extraction layout class_token = virchow2_output[:, 0] patch_tokens = virchow2_output[:, 5:] - embedding = torch.cat( - [class_token, patch_tokens.mean(dim=1)], - dim=-1, + embedding = np.concatenate( + [class_token, patch_tokens.mean(axis=1)], + axis=-1, ) - return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) + return np.squeeze(embedding, axis=0).astype(np.float32, copy=False) @serve.batch async def _predict_head( From 92a4e078101be2e5b368c93b8bf3075f31c7a2e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Fri, 5 Jun 2026 20:36:42 +0000 Subject: [PATCH 07/21] fix: n_channels set to 4 --- helm/rayservice/applications/breast-grading-virchow2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helm/rayservice/applications/breast-grading-virchow2.yaml b/helm/rayservice/applications/breast-grading-virchow2.yaml index 226a529..5671171 100644 --- a/helm/rayservice/applications/breast-grading-virchow2.yaml +++ b/helm/rayservice/applications/breast-grading-virchow2.yaml @@ -24,7 +24,7 @@ user_config: tile_size: 224 output_tile_size: 1 - n_channels: 1 + n_channels: 4 mpp: 0.46 max_batch_size: 128 batch_wait_timeout_s: 0.05 From dac55ab88b91634e5edd4d1787fcd4b21af1d39d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Fri, 5 Jun 2026 22:25:41 +0000 Subject: [PATCH 08/21] fix: correct uri --- helm/rayservice/applications/breast-grading-virchow2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helm/rayservice/applications/breast-grading-virchow2.yaml b/helm/rayservice/applications/breast-grading-virchow2.yaml index 5671171..7053e8e 100644 --- a/helm/rayservice/applications/breast-grading-virchow2.yaml +++ b/helm/rayservice/applications/breast-grading-virchow2.yaml @@ -31,4 +31,4 @@ foundation_model_id: virchow2 model: _target_: providers.model_provider:mlflow - artifact_uri: mlflow-artifacts:/2/ad34158eedb945c2895cca84ddee7d2a/artifacts/model \ No newline at end of file + artifact_uri: mlflow-artifacts:/37/ad34158eedb945c2895cca84ddee7d2a/artifacts/model \ No newline at end of file From 61efdf2024c18ea7dd8a3768bfc8f0df29c6941b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sat, 6 Jun 2026 00:41:42 +0000 Subject: [PATCH 09/21] fix: new onnx and updated model script --- .../applications/breast-grading-virchow2.yaml | 2 +- models/breast_grading_virchow2.py | 59 ++++++++----------- 2 files changed, 26 insertions(+), 35 deletions(-) diff --git a/helm/rayservice/applications/breast-grading-virchow2.yaml b/helm/rayservice/applications/breast-grading-virchow2.yaml index 7053e8e..0c646dc 100644 --- a/helm/rayservice/applications/breast-grading-virchow2.yaml +++ b/helm/rayservice/applications/breast-grading-virchow2.yaml @@ -31,4 +31,4 @@ foundation_model_id: virchow2 model: _target_: providers.model_provider:mlflow - artifact_uri: mlflow-artifacts:/37/ad34158eedb945c2895cca84ddee7d2a/artifacts/model \ No newline at end of file + artifact_uri: mlflow-artifacts:/37/484fdb26a5394af4bc76e387e7c93c89/artifacts/grading_head \ No newline at end of file diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index 03ffa67..f67700b 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -81,15 +81,15 @@ def reconfigure(self, config: Config) -> None: # Spin up your linear head ONNX session self.session = ort.InferenceSession( str(model_path), - providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + providers=["CPUExecutionProvider"], ) self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name - # Enforce micro-batching limits for your 4-class ONNX head evaluation pass - self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] - self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] + # Enforce micro-batching configurations on the collective predict entry-point instead of _predict_head + self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] + self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] async def get_config(self) -> dict[str, Any]: return { @@ -113,49 +113,40 @@ async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: # Execute remote pipeline call to the shared Virchow2 microservice virchow2_output = await self.foundation_model.predict.remote(tile_tensor) - if isinstance(virchow2_output, torch.Tensor): - virchow2_output = virchow2_output.cpu().numpy() + if isinstance(virchow2_output, np.ndarray): + virchow2_output = torch.from_numpy(virchow2_output) + # Virchow2 predict returns one tensor per tile, shape [tokens, dim]. + # Make it [1, tokens, dim] so pooling is batch-compatible. if virchow2_output.ndim == 2: - virchow2_output = np.expand_dims(virchow2_output, axis=0) + virchow2_output = virchow2_output.unsqueeze(0) - # Pool patch tokens matching the baseline foundation extraction layout class_token = virchow2_output[:, 0] patch_tokens = virchow2_output[:, 5:] + embedding = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1) - embedding = np.concatenate( - [class_token, patch_tokens.mean(axis=1)], - axis=-1, - ) - - return np.squeeze(embedding, axis=0).astype(np.float32, copy=False) + # 🟢 Safe squeeze that leaves multi-tile production batch axes untouched! + return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) @serve.batch - async def _predict_head( + async def predict( self, - embeddings: list[NDArray[np.float32]], + tiles: list[NDArray[np.uint8]], ) -> list[NDArray[np.float32]]: + + embeddings = await asyncio.gather( + *(self._create_embedding(tile) for tile in tiles) + ) batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) - # Evaluates the batched tensors through your 4-class linear network layer - probabilities = await asyncio.to_thread( - self.session.run, + # Evaluate raw un-softmaxed logit outputs via ONNX Runtime session + logits = self.session.run( [self.output_name], {self.input_name: batch}, - ) - probabilities = probabilities[0] - - # Modified to match 4-class heatmap dimensions: - # Reshapes predictions to [1, 1, 4] so the universal system-level - # HeatmapBuilder maps tissue grades over 4 channels instead of a binary scalar. - return [prob.reshape(1, 1, 4).astype(np.float32) for prob in probabilities] + )[0] - async def predict( - self, - tile: NDArray[np.uint8], - ) -> NDArray[np.float32]: - embedding = await self._create_embedding(tile) - return await self._predict_head(embedding) + # Reshape output elements into the channel structure expected by HeatmapBuilder + return [row.reshape(self.n_channels, 1, 1).astype(np.float32) for row in logits] @fastapi.post("/") async def root(self, request: Request) -> list[list[list[float]]]: @@ -172,9 +163,9 @@ async def root(self, request: Request) -> list[list[list[float]]]: tile_chw = tile.transpose(2, 0, 1) # 3. Fire pipeline (Raw tile -> Virchow2 embedding -> Your 4-class Head) - result = await self.predict(tile_chw) + result = await self.predict([tile_chw]) - return result.tolist() + return result[0].tolist() app = BreastCancerGradingVirchow2.bind() # type: ignore[attr-defined] From bfa25bdf99e2df814cfa806c201f4d0f0e65340b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sat, 6 Jun 2026 01:56:05 +0000 Subject: [PATCH 10/21] fix: lint --- models/breast_grading_virchow2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index f67700b..811b0ab 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -163,7 +163,7 @@ async def root(self, request: Request) -> list[list[list[float]]]: tile_chw = tile.transpose(2, 0, 1) # 3. Fire pipeline (Raw tile -> Virchow2 embedding -> Your 4-class Head) - result = await self.predict([tile_chw]) + result = await self.predict(tile_chw) return result[0].tolist() From 1f3224d2d7addec5219626b06e6061ce128202d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sat, 6 Jun 2026 22:00:55 +0000 Subject: [PATCH 11/21] test: correct dimensions --- models/breast_grading_virchow2.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index 811b0ab..05521c4 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -81,15 +81,14 @@ def reconfigure(self, config: Config) -> None: # Spin up your linear head ONNX session self.session = ort.InferenceSession( str(model_path), - providers=["CPUExecutionProvider"], + providers=["CUDAExecutionProvider"], ) self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name - # Enforce micro-batching configurations on the collective predict entry-point instead of _predict_head - self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] - self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] + self._predict_head.set_max_batch_size(config["max_batch_size"]) + self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) async def get_config(self) -> dict[str, Any]: return { @@ -129,25 +128,29 @@ async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) @serve.batch - async def predict( + async def _predict_head( self, - tiles: list[NDArray[np.uint8]], + embeddings: list[NDArray[np.float32]], ) -> list[NDArray[np.float32]]: - embeddings = await asyncio.gather( - *(self._create_embedding(tile) for tile in tiles) - ) batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) - # Evaluate raw un-softmaxed logit outputs via ONNX Runtime session + # Evaluate the [B, 4] output matrix out of your exported ONNX linear model logits = self.session.run( [self.output_name], {self.input_name: batch}, )[0] - # Reshape output elements into the channel structure expected by HeatmapBuilder return [row.reshape(self.n_channels, 1, 1).astype(np.float32) for row in logits] + # Entry point takes exactly ONE tile at a time from root + async def predict( + self, + tile: NDArray[np.uint8], + ) -> NDArray[np.float32]: + embedding = await self._create_embedding(tile) + return await self._predict_head(embedding) + @fastapi.post("/") async def root(self, request: Request) -> list[list[list[float]]]: # 1. Unzip raw compressed image tile bytes coming from network traffic @@ -165,7 +168,7 @@ async def root(self, request: Request) -> list[list[list[float]]]: # 3. Fire pipeline (Raw tile -> Virchow2 embedding -> Your 4-class Head) result = await self.predict(tile_chw) - return result[0].tolist() + return result.tolist() app = BreastCancerGradingVirchow2.bind() # type: ignore[attr-defined] From 2ce6034be1977272db8a0f9f4edb4771a8273d61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sat, 6 Jun 2026 22:06:24 +0000 Subject: [PATCH 12/21] fix: lint --- models/breast_grading_virchow2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index 05521c4..34df8ac 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -87,8 +87,8 @@ def reconfigure(self, config: Config) -> None: self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name - self._predict_head.set_max_batch_size(config["max_batch_size"]) - self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) + self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] + self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] async def get_config(self) -> dict[str, Any]: return { @@ -171,4 +171,4 @@ async def root(self, request: Request) -> list[list[list[float]]]: return result.tolist() -app = BreastCancerGradingVirchow2.bind() # type: ignore[attr-defined] +app = BreastCancerGradingVirchow2.bind() # type: ignore[attr-defined \ No newline at end of file From 75a89cf6aca69291929ccee6e1d3ceeffdebc3c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sat, 6 Jun 2026 22:06:50 +0000 Subject: [PATCH 13/21] fix: lint --- models/breast_grading_virchow2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index 34df8ac..f9380b7 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -171,4 +171,4 @@ async def root(self, request: Request) -> list[list[list[float]]]: return result.tolist() -app = BreastCancerGradingVirchow2.bind() # type: ignore[attr-defined \ No newline at end of file +app = BreastCancerGradingVirchow2.bind() # type: ignore[attr-defined] From 46df233f4f6799f73df6948859936476efa1126c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sun, 7 Jun 2026 10:08:12 +0000 Subject: [PATCH 14/21] test: add guard and add channels for heatmap builder --- .../applications/breast-grading-virchow2.yaml | 2 +- models/breast_grading_virchow2.py | 71 ++++++++++++++----- 2 files changed, 53 insertions(+), 20 deletions(-) diff --git a/helm/rayservice/applications/breast-grading-virchow2.yaml b/helm/rayservice/applications/breast-grading-virchow2.yaml index 0c646dc..8ee44b7 100644 --- a/helm/rayservice/applications/breast-grading-virchow2.yaml +++ b/helm/rayservice/applications/breast-grading-virchow2.yaml @@ -24,7 +24,7 @@ user_config: tile_size: 224 output_tile_size: 1 - n_channels: 4 + n_channels: 8 mpp: 0.46 max_batch_size: 128 batch_wait_timeout_s: 0.05 diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index f9380b7..086bef0 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -5,7 +5,7 @@ import numpy as np import torch -from fastapi import FastAPI, Request +from fastapi import FastAPI, HTTPException, Request from numpy.typing import NDArray from PIL import Image from ray import serve @@ -73,19 +73,30 @@ def reconfigure(self, config: Config) -> None: if not candidates: raise FileNotFoundError( "Downloaded MLflow artifact path is a directory, " - f"but no model.onnx was found under: {downloaded_path}" + "but no model.onnx was found under: " + f"{downloaded_path}" ) model_path = candidates[0] - # Spin up your linear head ONNX session + # Spin up your linear head ONNX session using CPU Execution to prevent host<->device lag self.session = ort.InferenceSession( str(model_path), - providers=["CUDAExecutionProvider"], + providers=["CPUExecutionProvider", "CUDAExecutionProvider"], ) self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name + self._num_classes = int(self.session.get_outputs()[0].shape[-1]) + + # Fail-fast validation guard: Ensure config allows the 8-channel dual representation + # (4 raw logit channels + 4 normalized softmax channels) + if self.n_channels != 8: + raise ValueError( + f"n_channels config is set to {self.n_channels}, but must be exactly 8 " + f"to support dual representation (4 raw logits + 4 softmax probabilities) " + f"for the underlying {self._num_classes}-class model." + ) self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] @@ -124,7 +135,7 @@ async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: patch_tokens = virchow2_output[:, 5:] embedding = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1) - # 🟢 Safe squeeze that leaves multi-tile production batch axes untouched! + # Safe squeeze that leaves multi-tile production batch axes untouched return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) @serve.batch @@ -132,16 +143,25 @@ async def _predict_head( self, embeddings: list[NDArray[np.float32]], ) -> list[NDArray[np.float32]]: - batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) - # Evaluate the [B, 4] output matrix out of your exported ONNX linear model - logits = self.session.run( + # 1. Evaluate the [B, 4] raw logit matrix from your exported ONNX linear model + logits_np = self.session.run( [self.output_name], {self.input_name: batch}, )[0] - return [row.reshape(self.n_channels, 1, 1).astype(np.float32) for row in logits] + # 2. Compute Softmax dynamically using PyTorch over the final dimension + logits_tensor = torch.from_numpy(logits_np) + softmax_tensor = torch.nn.functional.softmax(logits_tensor, dim=-1) + softmax_np = softmax_tensor.cpu().numpy().astype(np.float32, copy=False) + + # 3. Concatenate along channel axis: shape transitions from (B, 4) + (B, 4) to (B, 8) + combined_outputs = np.concatenate([logits_np, softmax_np], axis=-1) + + # 4. Pack row into a 3D block (8, 1, 1) to satisfy HeatmapBuilder layout + # Channels 0-3: Raw Logits | Channels 4-7: Softmax Probabilities + return [row.reshape(8, 1, 1).astype(np.float32) for row in combined_outputs] # Entry point takes exactly ONE tile at a time from root async def predict( @@ -153,19 +173,32 @@ async def predict( @fastapi.post("/") async def root(self, request: Request) -> list[list[list[float]]]: - # 1. Unzip raw compressed image tile bytes coming from network traffic - data = await asyncio.to_thread(self.lz4.decompress, await request.body()) - - # 2. Reconstruct the raw pixel array - tile = np.frombuffer(data, dtype=np.uint8).reshape( - self.tile_size, - self.tile_size, - 3, - ) + body_bytes = await request.body() + + try: + data = await asyncio.to_thread(self.lz4.decompress, body_bytes) + + expected_bytes = self.tile_size * self.tile_size * 3 + if len(data) != expected_bytes: + raise ValueError( + f"Decompressed payload byte length mismatch. " + f"Expected exactly {expected_bytes} bytes, but got {len(data)}." + ) + + # Reconstruct the raw pixel array + tile = np.frombuffer(data, dtype=np.uint8).reshape( + self.tile_size, + self.tile_size, + 3, + ) + except (RuntimeError, ValueError) as err: + raise HTTPException( + status_code=400, + detail=f"Malformed or invalid compressed tile image payload: {err!s}", + ) from err tile_chw = tile.transpose(2, 0, 1) - # 3. Fire pipeline (Raw tile -> Virchow2 embedding -> Your 4-class Head) result = await self.predict(tile_chw) return result.tolist() From 8496c32036996eb9289978e3802e87694e33c0a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sun, 7 Jun 2026 10:36:31 +0000 Subject: [PATCH 15/21] test: switch softmax and raw logits in output --- helm/rayservice/applications/breast-grading-virchow2.yaml | 2 +- models/breast_grading_virchow2.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/helm/rayservice/applications/breast-grading-virchow2.yaml b/helm/rayservice/applications/breast-grading-virchow2.yaml index 8ee44b7..0c646dc 100644 --- a/helm/rayservice/applications/breast-grading-virchow2.yaml +++ b/helm/rayservice/applications/breast-grading-virchow2.yaml @@ -24,7 +24,7 @@ user_config: tile_size: 224 output_tile_size: 1 - n_channels: 8 + n_channels: 4 mpp: 0.46 max_batch_size: 128 batch_wait_timeout_s: 0.05 diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index 086bef0..bf26aaa 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -157,10 +157,11 @@ async def _predict_head( softmax_np = softmax_tensor.cpu().numpy().astype(np.float32, copy=False) # 3. Concatenate along channel axis: shape transitions from (B, 4) + (B, 4) to (B, 8) - combined_outputs = np.concatenate([logits_np, softmax_np], axis=-1) + # Put Softmax FIRST so HeatmapBuilder reads it when config is sliced to 4 + combined_outputs = np.concatenate([softmax_np, logits_np], axis=-1) - # 4. Pack row into a 3D block (8, 1, 1) to satisfy HeatmapBuilder layout - # Channels 0-3: Raw Logits | Channels 4-7: Softmax Probabilities + # Channels 0-3: Softmax Probabilities (Safe for pyvips) + # Channels 4-7: Raw Logits (For your aggregation scripts) return [row.reshape(8, 1, 1).astype(np.float32) for row in combined_outputs] # Entry point takes exactly ONE tile at a time from root From 4f8ef11581be84ecba3da00e444be1c7b5cf8f23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sun, 7 Jun 2026 10:46:12 +0000 Subject: [PATCH 16/21] test: fix raise error --- models/breast_grading_virchow2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index bf26aaa..d1e5f24 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -91,11 +91,11 @@ def reconfigure(self, config: Config) -> None: # Fail-fast validation guard: Ensure config allows the 8-channel dual representation # (4 raw logit channels + 4 normalized softmax channels) - if self.n_channels != 8: + if self.n_channels != self._num_classes: raise ValueError( - f"n_channels config is set to {self.n_channels}, but must be exactly 8 " - f"to support dual representation (4 raw logits + 4 softmax probabilities) " - f"for the underlying {self._num_classes}-class model." + f"n_channels ({self.n_channels}) must equal the ONNX model's " + f"native class count ({self._num_classes}) so that HeatmapBuilder " + f"slices the correct number of layers." ) self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] From 50909c85e5621c15af30ae070ca64490983a071f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sun, 7 Jun 2026 11:18:10 +0000 Subject: [PATCH 17/21] test: num channels 4 or 8 --- models/breast_grading_virchow2.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index d1e5f24..dfe947b 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -89,13 +89,10 @@ def reconfigure(self, config: Config) -> None: self.output_name = self.session.get_outputs()[0].name self._num_classes = int(self.session.get_outputs()[0].shape[-1]) - # Fail-fast validation guard: Ensure config allows the 8-channel dual representation - # (4 raw logit channels + 4 normalized softmax channels) - if self.n_channels != self._num_classes: + if self.n_channels not in (4, 8): raise ValueError( - f"n_channels ({self.n_channels}) must equal the ONNX model's " - f"native class count ({self._num_classes}) so that HeatmapBuilder " - f"slices the correct number of layers." + f"n_channels config is set to {self.n_channels}, but must be " + f"either 4 (Softmax only) or 8 (Softmax + Logits)." ) self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] @@ -160,9 +157,13 @@ async def _predict_head( # Put Softmax FIRST so HeatmapBuilder reads it when config is sliced to 4 combined_outputs = np.concatenate([softmax_np, logits_np], axis=-1) - # Channels 0-3: Softmax Probabilities (Safe for pyvips) - # Channels 4-7: Raw Logits (For your aggregation scripts) - return [row.reshape(8, 1, 1).astype(np.float32) for row in combined_outputs] + # 4. DYNAMIC SLICE: Slice the array to match exactly what the platform requested + # If config is 4, rows become shape (4, 1, 1) -> Safe for HeatmapBuilder + # If config is 8, rows become shape (8, 1, 1) -> Full data, both softmax and logits + return [ + row[: self.n_channels].reshape(self.n_channels, 1, 1).astype(np.float32) + for row in combined_outputs + ] # Entry point takes exactly ONE tile at a time from root async def predict( From f45d29d35863e5ba435426c11dc9cb8340c07b1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sun, 7 Jun 2026 11:51:14 +0000 Subject: [PATCH 18/21] test: num channels 4 or 8 --- models/breast_grading_virchow2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index dfe947b..c52596b 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -171,7 +171,8 @@ async def predict( tile: NDArray[np.uint8], ) -> NDArray[np.float32]: embedding = await self._create_embedding(tile) - return await self._predict_head(embedding) + results = await self._predict_head(embedding) + return results[0] @fastapi.post("/") async def root(self, request: Request) -> list[list[list[float]]]: From c9dd0a254cd29a3a085760d4313482e39e738c81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sun, 7 Jun 2026 11:56:00 +0000 Subject: [PATCH 19/21] test: num channels 4 or 8 --- models/breast_grading_virchow2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index c52596b..dfe947b 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -171,8 +171,7 @@ async def predict( tile: NDArray[np.uint8], ) -> NDArray[np.float32]: embedding = await self._create_embedding(tile) - results = await self._predict_head(embedding) - return results[0] + return await self._predict_head(embedding) @fastapi.post("/") async def root(self, request: Request) -> list[list[list[float]]]: From dfe47bbe6279ae150ba8352de4b95ba569eada20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sun, 7 Jun 2026 12:25:20 +0000 Subject: [PATCH 20/21] fix: revert to 4 logits output --- models/breast_grading_virchow2.py | 45 +++++++++---------------------- 1 file changed, 12 insertions(+), 33 deletions(-) diff --git a/models/breast_grading_virchow2.py b/models/breast_grading_virchow2.py index dfe947b..b0a3fb9 100644 --- a/models/breast_grading_virchow2.py +++ b/models/breast_grading_virchow2.py @@ -73,27 +73,18 @@ def reconfigure(self, config: Config) -> None: if not candidates: raise FileNotFoundError( "Downloaded MLflow artifact path is a directory, " - "but no model.onnx was found under: " - f"{downloaded_path}" + f"but no model.onnx was found under: {downloaded_path}" ) model_path = candidates[0] - # Spin up your linear head ONNX session using CPU Execution to prevent host<->device lag self.session = ort.InferenceSession( str(model_path), - providers=["CPUExecutionProvider", "CUDAExecutionProvider"], + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], ) self.input_name = self.session.get_inputs()[0].name self.output_name = self.session.get_outputs()[0].name - self._num_classes = int(self.session.get_outputs()[0].shape[-1]) - - if self.n_channels not in (4, 8): - raise ValueError( - f"n_channels config is set to {self.n_channels}, but must be " - f"either 4 (Softmax only) or 8 (Softmax + Logits)." - ) self._predict_head.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] self._predict_head.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] @@ -132,7 +123,6 @@ async def _create_embedding(self, tile: NDArray[np.uint8]) -> np.ndarray: patch_tokens = virchow2_output[:, 5:] embedding = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1) - # Safe squeeze that leaves multi-tile production batch axes untouched return embedding.squeeze(0).cpu().numpy().astype(np.float32, copy=False) @serve.batch @@ -140,30 +130,16 @@ async def _predict_head( self, embeddings: list[NDArray[np.float32]], ) -> list[NDArray[np.float32]]: + batch = np.stack(embeddings, axis=0).astype(np.float32, copy=False) - # 1. Evaluate the [B, 4] raw logit matrix from your exported ONNX linear model - logits_np = self.session.run( + # Evaluate the [B, 4] output matrix out of your exported ONNX linear model + logits = self.session.run( [self.output_name], {self.input_name: batch}, )[0] - # 2. Compute Softmax dynamically using PyTorch over the final dimension - logits_tensor = torch.from_numpy(logits_np) - softmax_tensor = torch.nn.functional.softmax(logits_tensor, dim=-1) - softmax_np = softmax_tensor.cpu().numpy().astype(np.float32, copy=False) - - # 3. Concatenate along channel axis: shape transitions from (B, 4) + (B, 4) to (B, 8) - # Put Softmax FIRST so HeatmapBuilder reads it when config is sliced to 4 - combined_outputs = np.concatenate([softmax_np, logits_np], axis=-1) - - # 4. DYNAMIC SLICE: Slice the array to match exactly what the platform requested - # If config is 4, rows become shape (4, 1, 1) -> Safe for HeatmapBuilder - # If config is 8, rows become shape (8, 1, 1) -> Full data, both softmax and logits - return [ - row[: self.n_channels].reshape(self.n_channels, 1, 1).astype(np.float32) - for row in combined_outputs - ] + return [row.reshape(self.n_channels, 1, 1).astype(np.float32) for row in logits] # Entry point takes exactly ONE tile at a time from root async def predict( @@ -171,15 +147,18 @@ async def predict( tile: NDArray[np.uint8], ) -> NDArray[np.float32]: embedding = await self._create_embedding(tile) - return await self._predict_head(embedding) + return await self._predict_head(embedding) # returns 4 raw logits per tile @fastapi.post("/") async def root(self, request: Request) -> list[list[list[float]]]: + # Capture raw incoming network request body bytes body_bytes = await request.body() try: + # 1. Unzip raw compressed image tile bytes asynchronously in a thread worker data = await asyncio.to_thread(self.lz4.decompress, body_bytes) + # 2. Size check expected_bytes = self.tile_size * self.tile_size * 3 if len(data) != expected_bytes: raise ValueError( @@ -187,20 +166,20 @@ async def root(self, request: Request) -> list[list[list[float]]]: f"Expected exactly {expected_bytes} bytes, but got {len(data)}." ) - # Reconstruct the raw pixel array + # 3. Reconstruct the raw pixel array tile = np.frombuffer(data, dtype=np.uint8).reshape( self.tile_size, self.tile_size, 3, ) except (RuntimeError, ValueError) as err: + # 4. Gracefully map decompression or reshape shape errors to a clean HTTP 400 raise HTTPException( status_code=400, detail=f"Malformed or invalid compressed tile image payload: {err!s}", ) from err tile_chw = tile.transpose(2, 0, 1) - result = await self.predict(tile_chw) return result.tolist() From 2f0cf7917ab0a414dc00f16566f6b03399b9aa76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Karol=C3=ADna=20Melovsk=C3=A1?= <569385@mail.muni.cz> Date: Sun, 7 Jun 2026 12:32:39 +0000 Subject: [PATCH 21/21] fix: working dir set to master in config --- helm/rayservice/applications/breast-grading-virchow2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helm/rayservice/applications/breast-grading-virchow2.yaml b/helm/rayservice/applications/breast-grading-virchow2.yaml index 0c646dc..335a3a6 100644 --- a/helm/rayservice/applications/breast-grading-virchow2.yaml +++ b/helm/rayservice/applications/breast-grading-virchow2.yaml @@ -2,7 +2,7 @@ import_path: models.breast_grading_virchow2:app route_prefix: /breast-grading-virchow2 runtime_env: - working_dir: https://github.com/RationAI/model-service/archive/refs/heads/feature/breast-cancer-grading.zip #after debug set to this: https://github.com/RationAI/model-service/archive/refs/heads/main.zip + working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip env_vars: MLFLOW_TRACKING_URI: http://mlflow-s3.rationai-mlflow HF_HOME: /mnt/huggingface_cache