diff --git a/README.md b/README.md index 9dcba85..dc45faf 100644 --- a/README.md +++ b/README.md @@ -288,7 +288,7 @@ Request ────► Prefill (max_tokens=1) ──► P Node This project draws inspiration from the following open-source projects: - **[LMDeploy](https://github.com/InternLM/lmdeploy)** — The proxy implementation in `lmdeploy/serve/proxy/proxy.py` provided valuable reference for the routing architecture and PD disaggregation support. -- **[vLLM](https://github.com/vllm-project/vllm)** — The implementation of load balancing policies such as cache_aware in VLLM routers provides us with many references. +- **[vLLM Router](https://github.com/vllm-project/router)** — The implementation of load balancing policies such as cache_aware in VLLM routers provides us with many references. We extend our sincere thanks to the developers and contributors of these projects for their excellent work in the LLM inference ecosystem. diff --git a/dlrouter/__main__.py b/dlrouter/__main__.py index 48ff9e1..5f65335 100644 --- a/dlrouter/__main__.py +++ b/dlrouter/__main__.py @@ -392,4 +392,4 @@ def main(): if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/dlrouter/core/proxy_engine.py b/dlrouter/core/proxy_engine.py index 15de28f..e184e41 100644 --- a/dlrouter/core/proxy_engine.py +++ b/dlrouter/core/proxy_engine.py @@ -326,10 +326,10 @@ def _extract_prompt_for_prefix_cache( for msg in messages: if isinstance(msg, dict): content = msg.get('content', '') - if content: + if content is not None: parts.append(str(content)) elif hasattr(msg, 'content'): - if msg.content: + if msg.content is not None: parts.append(str(msg.content)) return '\n'.join(parts) if parts else None diff --git a/dlrouter/routing/prefix_cache.py b/dlrouter/routing/prefix_cache.py index 030c4e4..a8a3267 100644 --- a/dlrouter/routing/prefix_cache.py +++ b/dlrouter/routing/prefix_cache.py @@ -10,9 +10,9 @@ from typing import Optional from dlrouter.logger import get_logger -from dlrouter.models.node import NodeStatus from dlrouter.routing.base import BaseRoutingStrategy + logger = get_logger('dlrouter.prefix_cache') @@ -67,10 +67,9 @@ def find_best_node(self, prompt: str, candidate_nodes: list) -> Optional[str]: break current = current.children[char] for node_url in list(current.nodes.keys()): - if node_url in candidate_set: - if depth + 1 > best_depth: - best_depth = depth + 1 - best_node = node_url + if node_url in candidate_set and depth + 1 > best_depth: + best_depth = depth + 1 + best_node = node_url if best_node: current.nodes[best_node] = time.time() return best_node @@ -169,47 +168,46 @@ def __init__(self, max_prefix_depth: int = 100) -> None: def _select_least_loaded_node(self, candidates: dict) -> str: """Select the node with minimum unfinished requests. - + Uses round-robin as tie-breaker when all nodes have same load. - + Args: candidates: Dict of node_url -> NodeStatus - + Returns: Selected node URL with minimum load. """ if not candidates: - raise ValueError("No candidates available") - + raise ValueError('No candidates available') + candidate_list = list(candidates.items()) - + # Find minimum unfinished count min_unfinished = min(status.unfinished for _, status in candidate_list) - + # Get all nodes with minimum unfinished count - min_load_nodes = [ - (url, status) for url, status in candidate_list - if status.unfinished == min_unfinished - ] - + min_load_nodes = [(url, status) for url, status in candidate_list if status.unfinished == min_unfinished] + # If multiple nodes have same minimum load, use round-robin if len(min_load_nodes) > 1: self._rr_counter += 1 selected_idx = self._rr_counter % len(min_load_nodes) selected_url, selected_status = min_load_nodes[selected_idx] - logger.info(f'Load balancing: {len(min_load_nodes)} nodes with same load ({min_unfinished}), ' - f'using round-robin index {selected_idx}') + logger.info( + f'Load balancing: {len(min_load_nodes)} nodes with same load ({min_unfinished}), ' + f'using round-robin index {selected_idx}' + ) else: selected_url, selected_status = min_load_nodes[0] - + # Log all nodes' load info load_info = [ - f"{url}(unfinished={status.unfinished}, latency={sum(status.latency)/len(status.latency) if status.latency else 0:.3f}s)" + f'{url}(unfinished={status.unfinished}, latency={sum(status.latency) / len(status.latency) if status.latency else 0:.3f}s)' for url, status in candidate_list ] logger.info(f'Load balancing candidates: {load_info}') logger.info(f'Selected: {selected_url} (unfinished={selected_status.unfinished})') - + return selected_url def update_cache(self, prompt: str, node_url: str) -> None: diff --git a/tests/core/test_proxy_engine.py b/tests/core/test_proxy_engine.py new file mode 100644 index 0000000..c57adde --- /dev/null +++ b/tests/core/test_proxy_engine.py @@ -0,0 +1,211 @@ +"""Tests for ProxyEngine.""" + +import pytest +from unittest.mock import MagicMock + +from dlrouter.core.proxy_engine import ProxyEngine +from dlrouter.models.protocol import ( + ChatCompletionRequest, + CompletionRequest, +) + + +class TestExtractPromptForPrefixCache: + """Tests for _extract_prompt_for_prefix_cache method.""" + + @pytest.fixture + def proxy_engine(self): + """Create a ProxyEngine instance with mocked dependencies.""" + mock_node_manager = MagicMock() + engine = ProxyEngine(node_manager=mock_node_manager) + return engine + + # ------------------------------------------------------------------ + # Test None input + # ------------------------------------------------------------------ + + def test_none_body_returns_none(self, proxy_engine): + """Test that None body returns None.""" + result = proxy_engine._extract_prompt_for_prefix_cache(None) + assert result is None + + # ------------------------------------------------------------------ + # Tests for ChatCompletionRequest with string messages + # ------------------------------------------------------------------ + + def test_chat_request_with_string_messages(self, proxy_engine): + """Test ChatCompletionRequest with string messages field.""" + request = ChatCompletionRequest( + model="test-model", + messages="Hello, world!" + ) + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result == "Hello, world!" + + # ------------------------------------------------------------------ + # Tests for ChatCompletionRequest with dict list messages + # ------------------------------------------------------------------ + + def test_chat_request_with_dict_messages(self, proxy_engine): + """Test ChatCompletionRequest with list of dict messages.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ] + ) + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result == "You are a helpful assistant.\nHello!" + + def test_chat_request_with_empty_dict_content(self, proxy_engine): + """Test that empty string content is preserved.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ + {"role": "user", "content": ""}, + {"role": "assistant", "content": "Response"} + ] + ) + result = proxy_engine._extract_prompt_for_prefix_cache(request) + # Empty string should be preserved (converted to str) + assert result == "\nResponse" + + def test_chat_request_with_none_content_in_dict(self, proxy_engine): + """Test that None content in dict is filtered out.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ + {"role": "system", "content": None}, + {"role": "user", "content": "Hello"} + ] + ) + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result == "Hello" + + def test_chat_request_with_missing_content_key(self, proxy_engine): + """Test dict message without 'content' key defaults to empty string.""" + request = ChatCompletionRequest( + model="test-model", + messages=[ + {"role": "system"}, # No content key, defaults to '' + {"role": "user", "content": "Hello"} + ] + ) + result = proxy_engine._extract_prompt_for_prefix_cache(request) + # Missing content key defaults to '', which is now preserved + assert result == "\nHello" + + def test_chat_request_with_empty_messages_list(self, proxy_engine): + """Test empty messages list returns None.""" + request = ChatCompletionRequest( + model="test-model", + messages=[] + ) + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result is None + + # ------------------------------------------------------------------ + # Tests for ChatCompletionRequest with ChatMessage-like objects + # ------------------------------------------------------------------ + # NOTE: ChatCompletionRequest.messages is defined as + # Union[str, list[dict[str, Any]]], not list[ChatMessage]. + # The code has a branch for objects with .content attribute, + # which would handle ChatMessage if it were passed directly + # (e.g., from manual construction without validation). + + def test_chat_request_with_object_having_content_attr(self, proxy_engine): + """Test handling of objects with content attribute (simulated).""" + # Simulate what would happen if ChatMessage objects were passed + class FakeMessage: + def __init__(self, content): + self.content = content + + request = ChatCompletionRequest( + model="test-model", + messages=[ + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": "User message"} + ] + ) + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result == "System prompt\nUser message" + + # ------------------------------------------------------------------ + # Tests for CompletionRequest + # ------------------------------------------------------------------ + + def test_completion_request_with_string_prompt(self, proxy_engine): + """Test CompletionRequest with string prompt.""" + request = CompletionRequest( + model="test-model", + prompt="Complete this sentence" + ) + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result == "Complete this sentence" + + def test_completion_request_with_string_list_prompt(self, proxy_engine): + """Test CompletionRequest with list of string prompts.""" + request = CompletionRequest( + model="test-model", + prompt=["First prompt", "Second prompt"] + ) + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result == "First prompt\nSecond prompt" + + def test_completion_request_with_empty_list_prompt(self, proxy_engine): + """Test CompletionRequest with empty list returns None.""" + request = CompletionRequest( + model="test-model", + prompt=[] + ) + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result is None + + def test_completion_request_with_mixed_list_prompt(self, proxy_engine): + """Test CompletionRequest with list containing non-string items.""" + request = CompletionRequest( + model="test-model", + prompt=[123, "string", 45.6] + ) + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result == "123\nstring\n45.6" + + # ------------------------------------------------------------------ + # Tests for objects without messages or prompt + # ------------------------------------------------------------------ + + def test_object_without_messages_or_prompt(self, proxy_engine): + """Test object without messages or prompt attributes returns None.""" + class DummyRequest: + pass + + request = DummyRequest() + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result is None + + def test_object_with_none_messages(self, proxy_engine): + """Test object with None messages returns None. + + Note: Pydantic validates messages as non-None, so we test with + a mock object instead. + """ + class MockRequest: + messages = None + + request = MockRequest() + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result is None + + def test_object_with_none_prompt(self, proxy_engine): + """Test object with None prompt returns None. + + Note: Pydantic validates prompt as non-None, so we test with + a mock object instead. + """ + class MockRequest: + prompt = None + + request = MockRequest() + result = proxy_engine._extract_prompt_for_prefix_cache(request) + assert result is None diff --git a/tests/routing/test_prefix_cache.py b/tests/routing/test_prefix_cache.py new file mode 100644 index 0000000..c28d9fd --- /dev/null +++ b/tests/routing/test_prefix_cache.py @@ -0,0 +1,407 @@ +"""Tests for prefix cache routing strategy.""" + +from dlrouter.constants import RoutingStrategy +from dlrouter.models.node import NodeStatus +from dlrouter.routing.factory import create_routing_strategy +from dlrouter.routing.prefix_cache import PrefixCacheStrategy, PrefixCacheTrie + + +def _make_candidates(): + """Create test candidate nodes.""" + return { + 'http://node1:8000': NodeStatus( + models=['model-a', 'model-b'], + speed=10.0, + unfinished=2, + ), + 'http://node2:8000': NodeStatus( + models=['model-a'], + speed=20.0, + unfinished=1, + ), + 'http://node3:8000': NodeStatus( + models=['model-b'], + speed=5.0, + unfinished=0, + ), + } + + +class TestPrefixCacheTrie: + """Tests for PrefixCacheTrie data structure.""" + + def test_add_and_find_prefix(self): + """Test adding prefix and finding best node.""" + trie = PrefixCacheTrie() + prompt = 'Hello world this is a test' + node_url = 'http://node1:8000' + + trie.add_prefix(prompt, node_url) + + # Should find the node for the same prompt + found = trie.find_best_node(prompt, [node_url]) + assert found == node_url + + def test_find_best_node_with_shared_prefix(self): + """Test finding node with shared prefix.""" + trie = PrefixCacheTrie() + node1 = 'http://node1:8000' + node2 = 'http://node2:8000' + + # Add two prompts with shared prefix + trie.add_prefix('Hello world from node1', node1) + trie.add_prefix('Hello world from node2', node2) + + # Query with shared prefix should find the one with longest match + found = trie.find_best_node('Hello world new request', [node1, node2]) + # Both have same prefix length, either could be returned + assert found in [node1, node2] + + def test_find_best_node_longest_match(self): + """Test that longest prefix match is selected.""" + trie = PrefixCacheTrie() + node1 = 'http://node1:8000' + node2 = 'http://node2:8000' + + # node1 has longer shared prefix + trie.add_prefix('Hello world from node1', node1) + trie.add_prefix('Hello', node2) + + # Should prefer node1 for this prompt + found = trie.find_best_node('Hello world', [node1, node2]) + assert found == node1 + + def test_find_best_node_not_in_candidates(self): + """Test that nodes not in candidates are ignored.""" + trie = PrefixCacheTrie() + node1 = 'http://node1:8000' + node2 = 'http://node2:8000' + + trie.add_prefix('Hello world', node1) + + # node1 not in candidates, should return None + found = trie.find_best_node('Hello world', [node2]) + assert found is None + + def test_remove_node(self): + """Test removing a node from trie.""" + trie = PrefixCacheTrie() + node1 = 'http://node1:8000' + node2 = 'http://node2:8000' + + trie.add_prefix('Hello world', node1) + trie.add_prefix('Hello world', node2) + + # Remove node1 + trie.remove_node(node1) + + # Should only find node2 + found = trie.find_best_node('Hello world', [node1, node2]) + assert found == node2 + + def test_cleanup_expired(self): + """Test cleaning up expired entries.""" + trie = PrefixCacheTrie() + node1 = 'http://node1:8000' + + trie.add_prefix('Hello world', node1) + + # Immediately cleanup with 0 max age should remove everything + removed = trie.cleanup_expired(max_age_seconds=0) + assert removed >= 1 + + # Should not find the node anymore + found = trie.find_best_node('Hello world', [node1]) + assert found is None + + def test_max_depth_limit(self): + """Test that max_depth limits prefix storage.""" + trie = PrefixCacheTrie(max_depth=5) + node1 = 'http://node1:8000' + + long_prompt = 'a' * 1000 + trie.add_prefix(long_prompt, node1) + + # Should still find the node + found = trie.find_best_node(long_prompt, [node1]) + assert found == node1 + + def test_normalize_prompt(self): + """Test prompt normalization (whitespace handling).""" + trie = PrefixCacheTrie() + node1 = 'http://node1:8000' + + # These should be normalized to the same prompt + trie.add_prefix('Hello world', node1) + + found = trie.find_best_node('Hello world', [node1]) + assert found == node1 + + def test_get_stats(self): + """Test getting trie statistics.""" + trie = PrefixCacheTrie(max_depth=50) + node1 = 'http://node1:8000' + + trie.add_prefix('Hello', node1) + trie.add_prefix('World', node1) + + stats = trie.get_stats() + assert stats['max_depth'] == 50 + assert stats['total_nodes'] > 0 + + +class TestPrefixCacheStrategy: + """Tests for PrefixCacheStrategy routing.""" + + def test_cache_hit_same_prompt(self): + """Test that same prompt hits cache and routes to same node.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + cands = _make_candidates() + prompt = 'This is a test prompt for caching' + + # First request - cache miss + url1 = strategy.select_node('model-a', cands, prompt) + assert url1 is not None + + # Second request with same prompt - cache hit + url2 = strategy.select_node('model-a', cands, prompt) + assert url2 == url1 + + def test_cache_hit_shared_prefix(self): + """Test that prompts with shared prefix route to same node.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + cands = _make_candidates() + + # First request + prompt1 = 'Hello world this is a long prompt' + url1 = strategy.select_node('model-a', cands, prompt1) + + # Second request with shared prefix + prompt2 = 'Hello world this is another prompt' + url2 = strategy.select_node('model-a', cands, prompt2) + + # Should route to same node due to shared prefix + assert url2 == url1 + + def test_no_request_key_fallback(self): + """Test fallback when no request_key provided.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + cands = _make_candidates() + + # No request_key - should return first candidate + url = strategy.select_node('model-a', cands, None) + assert url in ['http://node1:8000', 'http://node2:8000'] + + def test_cache_miss_load_balancing(self): + """Test load balancing on cache miss.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + cands = { + 'http://node1:8000': NodeStatus( + models=['model-a'], + speed=10.0, + unfinished=5, # High load + ), + 'http://node2:8000': NodeStatus( + models=['model-a'], + speed=20.0, + unfinished=1, # Low load - should be selected + ), + } + + prompt = "Unique prompt that won't be cached" + url = strategy.select_node('model-a', cands, prompt) + + # Should select node with minimum unfinished + assert url == 'http://node2:8000' + + def test_no_matching_model(self): + """Test when no nodes serve the requested model.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + cands = _make_candidates() + + url = strategy.select_node('nonexistent-model', cands, 'test prompt') + assert url is None + + def test_remove_node_from_strategy(self): + """Test removing a node from strategy.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + cands = _make_candidates() + prompt = 'Test prompt for node removal' + + # First request + url1 = strategy.select_node('model-a', cands, prompt) + + # Remove that node + strategy.remove_node(url1) + + # Next request should not route to removed node + cands2 = {k: v for k, v in cands.items() if k != url1} + if cands2: + url2 = strategy.select_node('model-a', cands2, prompt) + assert url2 != url1 + + def test_update_cache_manually(self): + """Test manually updating cache.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + + prompt = 'Manual cache update test' + node_url = 'http://node1:8000' + + strategy.update_cache(prompt, node_url) + + # Should find the manually cached node + cands = {'http://node1:8000': NodeStatus(models=['model-a'])} + found = strategy.select_node('model-a', cands, prompt) + assert found == node_url + + def test_cleanup_expired_entries(self): + """Test cleanup of expired cache entries.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + cands = _make_candidates() + prompt = 'Test prompt for cleanup' + + # Add to cache + strategy.select_node('model-a', cands, prompt) + + # Cleanup with 0 age should remove entries + removed = strategy.cleanup(max_age_seconds=0) + assert removed >= 0 + + def test_get_stats(self): + """Test getting strategy statistics.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + cands = _make_candidates() + + # Add some entries + strategy.select_node('model-a', cands, 'Test prompt 1') + strategy.select_node('model-a', cands, 'Test prompt 2') + + stats = strategy.get_stats() + assert 'total_nodes' in stats + assert 'max_depth' in stats + + def test_load_balancing_tie_breaker(self): + """Test round-robin tie breaker when multiple nodes have same load on cache miss.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + + # First request - cache miss, should select one node (round-robin tie breaker) + cands1 = { + 'http://node1:8000': NodeStatus( + models=['model-a'], + speed=10.0, + unfinished=1, # Same load + ), + 'http://node2:8000': NodeStatus( + models=['model-a'], + speed=20.0, + unfinished=1, # Same load + ), + } + prompt1 = 'First unique prompt for load balancing' + url1 = strategy.select_node('model-a', cands1, prompt1) + + # Second request with different prompt - also cache miss + # Create new strategy to ensure clean cache state + strategy2 = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + cands2 = { + 'http://node1:8000': NodeStatus( + models=['model-a'], + speed=10.0, + unfinished=1, # Same load + ), + 'http://node2:8000': NodeStatus( + models=['model-a'], + speed=20.0, + unfinished=1, # Same load + ), + } + prompt2 = 'Second unique prompt for load balancing' + url2 = strategy2.select_node('model-a', cands2, prompt2) + + # Both nodes should be selectable on cache miss with round-robin + # Note: Due to round-robin counter, consecutive requests may hit different nodes + + # At least one of the two requests should demonstrate load balancing + # (both could be same due to counter state, but typically they differ) + assert url1 is not None + assert url2 is not None + assert url1 in ['http://node1:8000', 'http://node2:8000'] + assert url2 in ['http://node1:8000', 'http://node2:8000'] + + def test_zero_unfinished_priority(self): + """Test that nodes with zero unfinished are prioritized on cache miss.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + cands = { + 'http://node1:8000': NodeStatus( + models=['model-a'], + speed=10.0, + unfinished=5, + ), + 'http://node2:8000': NodeStatus( + models=['model-a'], + speed=20.0, + unfinished=0, # Should be selected + ), + } + + prompt = 'Test prompt for zero unfinished priority' + url = strategy.select_node('model-a', cands, prompt) + + assert url == 'http://node2:8000' + + +class TestPrefixCacheIntegration: + """Integration tests for prefix cache routing.""" + + def test_multiple_models_isolation(self): + """Test that different models don't interfere with each other.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + + cands_model_a = { + 'http://node1:8000': NodeStatus(models=['model-a']), + } + cands_model_b = { + 'http://node2:8000': NodeStatus(models=['model-b']), + } + + prompt = 'Shared prompt text' + + url_a = strategy.select_node('model-a', cands_model_a, prompt) + url_b = strategy.select_node('model-b', cands_model_b, prompt) + + # Should route to different nodes for different models + assert url_a == 'http://node1:8000' + assert url_b == 'http://node2:8000' + + def test_concurrent_prompts_cache_independence(self): + """Test that different prompts are cached independently.""" + strategy = create_routing_strategy(RoutingStrategy.PREFIX_CACHE) + cands = _make_candidates() + + prompt1 = 'First unique prompt for caching' + prompt2 = 'Second unique prompt for caching' + + url1 = strategy.select_node('model-a', cands, prompt1) + url2 = strategy.select_node('model-a', cands, prompt2) + + # Both should be cached and return their respective nodes + url1_again = strategy.select_node('model-a', cands, prompt1) + url2_again = strategy.select_node('model-a', cands, prompt2) + + assert url1_again == url1 + assert url2_again == url2 + + def test_prefix_depth_limit(self): + """Test that very long prompts respect max_depth.""" + strategy = PrefixCacheStrategy(max_prefix_depth=10) + cands = _make_candidates() + + # Very long prompt + long_prompt = 'a' * 1000 + + url = strategy.select_node('model-a', cands, long_prompt) + assert url is not None + + # Stats should show limited depth + stats = strategy.get_stats() + assert stats['max_depth'] == 10