feat: 7-class tissue classifier on Virchow2 embeddings#13
feat: 7-class tissue classifier on Virchow2 embeddings#13vojtech-cifka wants to merge 2 commits into
Conversation
📝 WalkthroughWalkthroughAdds a TissueLinear Ray Serve application and Helm registration plus a FastAPI-backed Ray deployment that LZ4-decompresses tiles, obtains Virchow2 embeddings remotely, runs an ONNX linear head in batches, and returns per-class probability maps. ChangesTissueLinear Tissue Classification Service
Sequence DiagramsequenceDiagram
participant Client
participant FastAPI_root
participant TissueLinear_predict
participant ThreadPool
participant Virchow2Service
participant ONNXRuntime
Client->>FastAPI_root: POST /tissue-linear (LZ4 tile)
FastAPI_root->>FastAPI_root: decompress & reshape to CHW
FastAPI_root->>TissueLinear_predict: predict(tiles)
TissueLinear_predict->>ThreadPool: _prepare_tile_for_virchow2(tile)
ThreadPool-->>TissueLinear_predict: transformed tensor
TissueLinear_predict->>Virchow2Service: get_app_handle (token embeddings)
Virchow2Service-->>TissueLinear_predict: token embeddings
TissueLinear_predict->>ONNXRuntime: run(linear head on batch)
ONNXRuntime-->>TissueLinear_predict: logits
TissueLinear_predict-->>FastAPI_root: probability maps
FastAPI_root-->>Client: JSON response
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a new Ray Serve application, tissue-linear, which implements a 7-class tissue classifier using a linear head over Virchow2 embeddings. The feedback focuses on several key performance and resource optimizations: releasing the unused GPU resource in the Helm configuration, constructing the image transform directly instead of instantiating the heavy Virchow2 model to save memory, eliminating redundant array transpositions between CHW and HWC formats, and offloading the CPU-bound ONNX inference to a thread pool to avoid blocking the async event loop.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
26ae1d7 to
8aca5aa
Compare
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@helm/rayservice/applications/tissue-linear.yaml`:
- Around line 16-18: The ray_actor_options currently reserves a GPU (num_gpus:
1) even though this replica is CPU-only; remove or set num_gpus to 0 in the
ray_actor_options block to avoid claiming GPU MIG slots, and instead use a node
label/custom resource or nodeAffinity for placement if you need GPU-hosted
nodes; note that models/tissue_linear.py explicitly uses CPUExecutionProvider
(lines ~106-109) and embeddings are fetched from the remote virchow2 app, so no
local CUDA is required.
- Line 7: The working_dir currently points to a moving ref
(https://github.com/RationAI/model-service/archive/refs/heads/main.zip); update
the working_dir value in the tissue-linear.yaml to an immutable archive URL (a
release tag or commit SHA zip, e.g.
https://github.com/RationAI/model-service/archive/<COMMIT_SHA>.zip) so the chart
pulls a fixed revision; change the value for the working_dir key and regenerate
any chart lock or documentation that records the pinned revision.
In `@models/tissue_linear.py`:
- Around line 179-186: In root(Request) (models/tissue_linear.py) validate the
decompressed payload length before reshaping: compute expected_size =
self.tile_size * self.tile_size * 3, attempt decompression but constrain or
check output size from self.lz4.decompress, and if the decompressed length !=
expected_size return a 400 (e.g., raise HTTPException(status_code=400) or return
a 400 Response) instead of reshaping; only call np.frombuffer(...).reshape(...)
when the length matches exactly to avoid oversized allocations and malformed
payloads.
- Around line 59-61: Validate the advertised output contract in reconfigure():
check that config["output_tile_size"] == 1 and config["n_channels"] ==
self._num_classes (use the instance attributes output_tile_size and n_channels
and the model property self._num_classes) and raise a clear ValueError if either
check fails so callers (and HeatmapBuilder) fail fast; add the same validation
where the config is parsed/assigned (the other block around the second
assignment of tile_size/output_tile_size/n_channels) to ensure consistency
across reconfigure() and any alternate config path.
- Line 138: Wrap the direct RPC await to
self.foundation_model.predict.remote(tile_tensor) in an async timeout (e.g.,
asyncio.wait_for) with a configurable timeout value and handle
asyncio.TimeoutError: on timeout, cancel the Ray object ref
(ray.cancel(virchow2_output, force=True)), log the timeout including the model
and tile context, and surface a clear error or fallback so the caller (methods
in models/tissue_linear.py that call foundation_model.predict.remote) can fail
fast instead of hanging; make the timeout value configurable via existing
config/constants and ensure the cancel + logging occurs before re-raising or
returning an error response.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6580bdcd-1f07-4b23-9f23-971331a6326b
📒 Files selected for processing (3)
helm/rayservice/applications/tissue-linear.yamlhelm/rayservice/values.yamlmodels/tissue_linear.py
8aca5aa to
8be7cab
Compare
matejpekar
left a comment
There was a problem hiding this comment.
Check the comments on the other PR #12
8be7cab to
473042b
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/tissue_linear.py`:
- Around line 175-183: The current predict() call uses embeddings = await
asyncio.gather(*(self._create_embedding(tile) for tile in tiles)) which will
abort the whole batch if any _create_embedding(tile) raises; change to use
asyncio.gather(..., return_exceptions=True), iterate the results to detect
exceptions per index, log/record the failing tile index and exception, and
either skip that tile or substitute a default embedding (e.g., zeros of the same
shape) before creating batch = np.stack(...). Ensure you only pass the resulting
valid/substituted embeddings into session.run({self.input_name: batch}) and
preserve mapping from output logits to the original tile indices (so output_name
results align with non-failed tiles). Do not change _create_embedding()’s
no-timeout call to self.foundation_model.predict.remote(tile_tensor).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d8dbfa32-9ce9-457d-9186-77af54920e5d
📒 Files selected for processing (3)
helm/rayservice/applications/tissue-linear.yamlhelm/rayservice/values.yamlmodels/tissue_linear.py
✅ Files skipped from review due to trivial changes (1)
- helm/rayservice/values.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
- helm/rayservice/applications/tissue-linear.yaml
5de47b1 to
bacc3c1
Compare
matejpekar
left a comment
There was a problem hiding this comment.
You haven't tackled the biggest issue
bacc3c1 to
b3f3ab4
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@models/tissue_linear.py`:
- Around line 107-113: After creating the ONNX session in the TissueLinear
initializer (the block that sets self.session, self.input_name, self.output_name
and self._num_classes), validate the model's input dimension matches the
expected 2560 Virchow2 embedding width by reading
self.session.get_inputs()[0].shape[-1], converting to int, and raising a clear
exception (or failing fast) if it does not equal 2560; this prevents silent
initialization of a model that emits 7 classes but accepts the wrong input width
and avoids later runtime failures in session.run().
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9cddaba4-49fd-401f-ad08-3d18183753bd
📒 Files selected for processing (3)
helm/rayservice/applications/tissue-linear.yamlhelm/rayservice/values.yamlmodels/tissue_linear.py
🚧 Files skipped from review as they are similar to previous changes (2)
- helm/rayservice/values.yaml
- helm/rayservice/applications/tissue-linear.yaml
Deploys a 7-class tissue classifier as a linear head over the Virchow2 foundation model. Per tile: apply Virchow2's image transform, fetch the ViT token sequence from the deployed `virchow2` service via a Ray Serve handle, pool tokens (class token + mean of patch tokens) into a 2560-d embedding, run the ONNX linear head, and emit a 7-channel softmax probability map for HeatmapBuilder. The hard class map is recoverable via argmax over channels at full resolution. The ONNX linear head is exported from the Virchow2 + LBFGS final linear classifier (MLflow run 0e2230c722134ce0985e09a18ccadf75, artifacts/onnx/linear_head.onnx). Files: - models/tissue_linear.py: Serve deployment. torch/PIL/timm imports are lazy (the head node builds the app graph without them; the replica runs on GPU workers that carry them). ONNX runs on CPUExecutionProvider. - helm/rayservice/applications/tissue-linear.yaml: app definition (num_gpus: 1 to land on the mig20 GPU workers for torch/timm). - helm/rayservice/values.yaml: register tissue-linear. Validated on the full WSI 07 Leiomyosarkom.svs via HeatmapBuilder, producing a (52224, 36864, 7) BigTIFF; argmax over channels yields Other (neoplastic) and Connective-Tissue (stroma) dominant, consistent with a leiomyosarcoma. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
b3f3ab4 to
64725d2
Compare
- Remove unreachable .onnx rglob fallback; artifact_uri points directly at the file, so provider() returns it as-is - Drop n_channels / output_tile_size / embedding-dim validation guards per review - Shorten verbose comments Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
There was a problem hiding this comment.
♻️ Duplicate comments (1)
models/tissue_linear.py (1)
153-159:⚠️ Potential issue | 🟠 Major | ⚡ Quick winValidate decompressed payload size before reshaping.
The handler decompresses arbitrary request bodies and reshapes without verifying the expected size. A malformed payload triggers a 500 at reshape; a highly-compressible body can force a larger allocation than intended. Since the exact raw size is known (
tile_size * tile_size * 3), validate before proceeding:Suggested fix
`@fastapi.post`("/") async def root(self, request: Request) -> list[Any]: data = await asyncio.to_thread(self.lz4.decompress, await request.body()) + expected_size = self.tile_size * self.tile_size * 3 + if len(data) != expected_size: + from fastapi import HTTPException + raise HTTPException( + status_code=400, + detail=f"Expected {expected_size} bytes after decompression, got {len(data)}", + ) tile = np.frombuffer(data, dtype=np.uint8).reshape(🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/tissue_linear.py` around lines 153 - 159, The code decompresses request body data and immediately reshapes it without validating the payload size, which can cause 500 errors on malformed payloads or force unexpected memory allocations. After decompressing the data via self.lz4.decompress, calculate the expected byte size as self.tile_size * self.tile_size * 3 and validate that the decompressed data length matches this expected size before calling np.frombuffer and reshape. If the size validation fails, raise an appropriate error (like ValueError) with a descriptive message indicating the size mismatch.
🧹 Nitpick comments (1)
models/tissue_linear.py (1)
126-128: 💤 Low valueConsider documenting the Virchow2 token layout.
The magic number
5inpatch_tokens = virchow2_output[:, 5:]assumes Virchow2's specific token layout:[CLS, reg1, reg2, reg3, reg4, patches...]. A brief inline comment noting this (e.g., "skip class token + 4 register tokens") would help future maintainers understand the offset without consulting Virchow2 documentation.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@models/tissue_linear.py` around lines 126 - 128, Add an inline comment above or on the line containing patch_tokens = virchow2_output[:, 5:] to document the magic number 5, explaining that it skips Virchow2's class token (CLS) and 4 register tokens (reg1, reg2, reg3, reg4) to extract only the patch tokens. This will help future maintainers understand the token layout offset without needing to consult external Virchow2 documentation.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Duplicate comments:
In `@models/tissue_linear.py`:
- Around line 153-159: The code decompresses request body data and immediately
reshapes it without validating the payload size, which can cause 500 errors on
malformed payloads or force unexpected memory allocations. After decompressing
the data via self.lz4.decompress, calculate the expected byte size as
self.tile_size * self.tile_size * 3 and validate that the decompressed data
length matches this expected size before calling np.frombuffer and reshape. If
the size validation fails, raise an appropriate error (like ValueError) with a
descriptive message indicating the size mismatch.
---
Nitpick comments:
In `@models/tissue_linear.py`:
- Around line 126-128: Add an inline comment above or on the line containing
patch_tokens = virchow2_output[:, 5:] to document the magic number 5, explaining
that it skips Virchow2's class token (CLS) and 4 register tokens (reg1, reg2,
reg3, reg4) to extract only the patch tokens. This will help future maintainers
understand the token layout offset without needing to consult external Virchow2
documentation.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: db4c2582-4a43-4bfd-bd7a-ce3e648e45a9
📒 Files selected for processing (3)
helm/rayservice/applications/tissue-linear.yamlhelm/rayservice/values.yamlmodels/tissue_linear.py
✅ Files skipped from review due to trivial changes (1)
- helm/rayservice/values.yaml
🚧 Files skipped from review as they are similar to previous changes (1)
- helm/rayservice/applications/tissue-linear.yaml
Deploys a 7-class tissue classifier as a linear head over the Virchow2 foundation model. Per tile: apply Virchow2's image transform, fetch the ViT token sequence from the deployed
virchow2service via a Ray Serve handle, pool tokens (class token + mean of patch tokens) into a 2560-d embedding, run the ONNX linear head, and emit a 7-channel softmax probability map for HeatmapBuilder. The hard class map is recoverable via argmax over channels at full resolution.The ONNX linear head is exported from the Virchow2 + LBFGS final linear classifier (MLflow run 0e2230c722134ce0985e09a18ccadf75, artifacts/onnx/linear_head.onnx).
Files:
Validated on the full WSI 07 Leiomyosarkom.svs via HeatmapBuilder, producing a (52224, 36864, 7) BigTIFF; argmax over channels yields Other (neoplastic) and Connective-Tissue (stroma) dominant, consistent with a leiomyosarcoma. You can check the resulting segmentation map.
Summary by CodeRabbit
tissue-linearinference app exposed at /tissue-linear.