Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ Request ────► Prefill (max_tokens=1) ──► P Node
This project draws inspiration from the following open-source projects:

- **[LMDeploy](https://github.com/InternLM/lmdeploy)** — The proxy implementation in `lmdeploy/serve/proxy/proxy.py` provided valuable reference for the routing architecture and PD disaggregation support.
- **[vLLM](https://github.com/vllm-project/vllm)** — The implementation of load balancing policies such as cache_aware in VLLM routers provides us with many references.
- **[vLLM Router](https://github.com/vllm-project/router)** — The implementation of load balancing policies such as cache_aware in VLLM routers provides us with many references.

We extend our sincere thanks to the developers and contributors of these projects for their excellent work in the LLM inference ecosystem.

Expand Down
2 changes: 1 addition & 1 deletion dlrouter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,4 +392,4 @@ def main():


if __name__ == '__main__':
main()
main()
4 changes: 2 additions & 2 deletions dlrouter/core/proxy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,10 @@ def _extract_prompt_for_prefix_cache(
for msg in messages:
if isinstance(msg, dict):
content = msg.get('content', '')
if content:
if content is not None:
parts.append(str(content))
elif hasattr(msg, 'content'):
if msg.content:
if msg.content is not None:
parts.append(str(msg.content))
return '\n'.join(parts) if parts else None

Expand Down
42 changes: 20 additions & 22 deletions dlrouter/routing/prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from typing import Optional

from dlrouter.logger import get_logger
from dlrouter.models.node import NodeStatus
from dlrouter.routing.base import BaseRoutingStrategy


logger = get_logger('dlrouter.prefix_cache')


Expand Down Expand Up @@ -67,10 +67,9 @@ def find_best_node(self, prompt: str, candidate_nodes: list) -> Optional[str]:
break
current = current.children[char]
for node_url in list(current.nodes.keys()):
if node_url in candidate_set:
if depth + 1 > best_depth:
best_depth = depth + 1
best_node = node_url
if node_url in candidate_set and depth + 1 > best_depth:
best_depth = depth + 1
best_node = node_url
if best_node:
current.nodes[best_node] = time.time()
return best_node
Expand Down Expand Up @@ -169,47 +168,46 @@ def __init__(self, max_prefix_depth: int = 100) -> None:

def _select_least_loaded_node(self, candidates: dict) -> str:
"""Select the node with minimum unfinished requests.

Uses round-robin as tie-breaker when all nodes have same load.

Args:
candidates: Dict of node_url -> NodeStatus

Returns:
Selected node URL with minimum load.
"""
if not candidates:
raise ValueError("No candidates available")
raise ValueError('No candidates available')

candidate_list = list(candidates.items())

# Find minimum unfinished count
min_unfinished = min(status.unfinished for _, status in candidate_list)

# Get all nodes with minimum unfinished count
min_load_nodes = [
(url, status) for url, status in candidate_list
if status.unfinished == min_unfinished
]

min_load_nodes = [(url, status) for url, status in candidate_list if status.unfinished == min_unfinished]

# If multiple nodes have same minimum load, use round-robin
if len(min_load_nodes) > 1:
self._rr_counter += 1
selected_idx = self._rr_counter % len(min_load_nodes)
selected_url, selected_status = min_load_nodes[selected_idx]
logger.info(f'Load balancing: {len(min_load_nodes)} nodes with same load ({min_unfinished}), '
f'using round-robin index {selected_idx}')
logger.info(
f'Load balancing: {len(min_load_nodes)} nodes with same load ({min_unfinished}), '
f'using round-robin index {selected_idx}'
)
else:
selected_url, selected_status = min_load_nodes[0]

# Log all nodes' load info
load_info = [
f"{url}(unfinished={status.unfinished}, latency={sum(status.latency)/len(status.latency) if status.latency else 0:.3f}s)"
f'{url}(unfinished={status.unfinished}, latency={sum(status.latency) / len(status.latency) if status.latency else 0:.3f}s)'
for url, status in candidate_list
]
logger.info(f'Load balancing candidates: {load_info}')
logger.info(f'Selected: {selected_url} (unfinished={selected_status.unfinished})')

return selected_url

def update_cache(self, prompt: str, node_url: str) -> None:
Expand Down
211 changes: 211 additions & 0 deletions tests/core/test_proxy_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
"""Tests for ProxyEngine."""

import pytest
from unittest.mock import MagicMock

from dlrouter.core.proxy_engine import ProxyEngine
from dlrouter.models.protocol import (
ChatCompletionRequest,
CompletionRequest,
)


class TestExtractPromptForPrefixCache:
"""Tests for _extract_prompt_for_prefix_cache method."""

@pytest.fixture
def proxy_engine(self):
"""Create a ProxyEngine instance with mocked dependencies."""
mock_node_manager = MagicMock()
engine = ProxyEngine(node_manager=mock_node_manager)
return engine

# ------------------------------------------------------------------
# Test None input
# ------------------------------------------------------------------

def test_none_body_returns_none(self, proxy_engine):
"""Test that None body returns None."""
result = proxy_engine._extract_prompt_for_prefix_cache(None)
assert result is None

# ------------------------------------------------------------------
# Tests for ChatCompletionRequest with string messages
# ------------------------------------------------------------------

def test_chat_request_with_string_messages(self, proxy_engine):
"""Test ChatCompletionRequest with string messages field."""
request = ChatCompletionRequest(
model="test-model",
messages="Hello, world!"
)
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result == "Hello, world!"

# ------------------------------------------------------------------
# Tests for ChatCompletionRequest with dict list messages
# ------------------------------------------------------------------

def test_chat_request_with_dict_messages(self, proxy_engine):
"""Test ChatCompletionRequest with list of dict messages."""
request = ChatCompletionRequest(
model="test-model",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
]
)
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result == "You are a helpful assistant.\nHello!"

def test_chat_request_with_empty_dict_content(self, proxy_engine):
"""Test that empty string content is preserved."""
request = ChatCompletionRequest(
model="test-model",
messages=[
{"role": "user", "content": ""},
{"role": "assistant", "content": "Response"}
]
)
result = proxy_engine._extract_prompt_for_prefix_cache(request)
# Empty string should be preserved (converted to str)
assert result == "\nResponse"

def test_chat_request_with_none_content_in_dict(self, proxy_engine):
"""Test that None content in dict is filtered out."""
request = ChatCompletionRequest(
model="test-model",
messages=[
{"role": "system", "content": None},
{"role": "user", "content": "Hello"}
]
)
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result == "Hello"

def test_chat_request_with_missing_content_key(self, proxy_engine):
"""Test dict message without 'content' key defaults to empty string."""
request = ChatCompletionRequest(
model="test-model",
messages=[
{"role": "system"}, # No content key, defaults to ''
{"role": "user", "content": "Hello"}
]
)
result = proxy_engine._extract_prompt_for_prefix_cache(request)
# Missing content key defaults to '', which is now preserved
assert result == "\nHello"

def test_chat_request_with_empty_messages_list(self, proxy_engine):
"""Test empty messages list returns None."""
request = ChatCompletionRequest(
model="test-model",
messages=[]
)
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result is None

# ------------------------------------------------------------------
# Tests for ChatCompletionRequest with ChatMessage-like objects
# ------------------------------------------------------------------
# NOTE: ChatCompletionRequest.messages is defined as
# Union[str, list[dict[str, Any]]], not list[ChatMessage].
# The code has a branch for objects with .content attribute,
# which would handle ChatMessage if it were passed directly
# (e.g., from manual construction without validation).

def test_chat_request_with_object_having_content_attr(self, proxy_engine):
"""Test handling of objects with content attribute (simulated)."""
# Simulate what would happen if ChatMessage objects were passed
class FakeMessage:
def __init__(self, content):
self.content = content

request = ChatCompletionRequest(
model="test-model",
messages=[
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "User message"}
]
)
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result == "System prompt\nUser message"

# ------------------------------------------------------------------
# Tests for CompletionRequest
# ------------------------------------------------------------------

def test_completion_request_with_string_prompt(self, proxy_engine):
"""Test CompletionRequest with string prompt."""
request = CompletionRequest(
model="test-model",
prompt="Complete this sentence"
)
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result == "Complete this sentence"

def test_completion_request_with_string_list_prompt(self, proxy_engine):
"""Test CompletionRequest with list of string prompts."""
request = CompletionRequest(
model="test-model",
prompt=["First prompt", "Second prompt"]
)
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result == "First prompt\nSecond prompt"

def test_completion_request_with_empty_list_prompt(self, proxy_engine):
"""Test CompletionRequest with empty list returns None."""
request = CompletionRequest(
model="test-model",
prompt=[]
)
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result is None

def test_completion_request_with_mixed_list_prompt(self, proxy_engine):
"""Test CompletionRequest with list containing non-string items."""
request = CompletionRequest(
model="test-model",
prompt=[123, "string", 45.6]
)
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result == "123\nstring\n45.6"

# ------------------------------------------------------------------
# Tests for objects without messages or prompt
# ------------------------------------------------------------------

def test_object_without_messages_or_prompt(self, proxy_engine):
"""Test object without messages or prompt attributes returns None."""
class DummyRequest:
pass

request = DummyRequest()
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result is None

def test_object_with_none_messages(self, proxy_engine):
"""Test object with None messages returns None.

Note: Pydantic validates messages as non-None, so we test with
a mock object instead.
"""
class MockRequest:
messages = None

request = MockRequest()
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result is None

def test_object_with_none_prompt(self, proxy_engine):
"""Test object with None prompt returns None.

Note: Pydantic validates prompt as non-None, so we test with
a mock object instead.
"""
class MockRequest:
prompt = None

request = MockRequest()
result = proxy_engine._extract_prompt_for_prefix_cache(request)
assert result is None
Loading