diff --git a/5-large-language-models/1-openai-compatible-endpoint/cerebrium.toml b/5-large-language-models/1-openai-compatible-endpoint/cerebrium.toml index 674dc42f..792bd28b 100644 --- a/5-large-language-models/1-openai-compatible-endpoint/cerebrium.toml +++ b/5-large-language-models/1-openai-compatible-endpoint/cerebrium.toml @@ -1,14 +1,18 @@ [cerebrium.deployment] name = "1-openai-compatible-endpoint" python_version = "3.11" -docker_base_image_url = "debian:bookworm-slim" +docker_base_image_url = "nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04" +disable_auth = true +deployment_initialization_timeout = 830 include = ["./*", "main.py", "cerebrium.toml"] exclude = [".*"] [cerebrium.hardware] -cpu = 2 -memory = 12.0 -compute = "AMPERE_A10" +cpu = 8 +memory = 60.0 +compute = "ADA_L40" +provider = "aws" +region = "us-east-1" [cerebrium.scaling] min_replicas = 0 @@ -18,3 +22,7 @@ cooldown = 30 [cerebrium.dependencies.pip] vllm = "latest" pydantic = "latest" +huggingface_hub = "latest" + +[cerebrium.experimental] +checkpointing = true diff --git a/5-large-language-models/1-openai-compatible-endpoint/main.py b/5-large-language-models/1-openai-compatible-endpoint/main.py index a199d2aa..f0fc7974 100644 --- a/5-large-language-models/1-openai-compatible-endpoint/main.py +++ b/5-large-language-models/1-openai-compatible-endpoint/main.py @@ -1,21 +1,46 @@ +import http.client import json import os import time - +import urllib.request from huggingface_hub import login from pydantic import BaseModel from vllm import SamplingParams, AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs +CHECKPOINT_URL = "http://169.254.169.253:8234/checkpoint" + + +def _trigger_snapshot() -> None: + print("[init] requesting GPU snapshot", flush=True) + try: + req = urllib.request.Request(CHECKPOINT_URL, method="POST") + urllib.request.urlopen(req, timeout=300) + print("[init] snapshot complete", flush=True) + except http.client.RemoteDisconnected: + # TCP connections disconnect on restore and throw remote + print("[init] snapshot complete (RemoteDisconnected)", flush=True) + except Exception as exc: + print(f"[init] snapshot failed: {type(exc).__name__}: {exc}", flush=True) + + +print("[init] starting", flush=True) login(token=os.environ.get("HF_TOKEN")) engine_args = AsyncEngineArgs( model="meta-llama/Meta-Llama-3.1-8B-Instruct", gpu_memory_utilization=0.9, max_model_len=8192, + async_scheduling=False, ) + +print("[init] building vLLM engine", flush=True) engine = AsyncLLMEngine.from_engine_args(engine_args) +print("[init] vLLM engine ready", flush=True) + +_trigger_snapshot() +print("[init] handler ready", flush=True) class Message(BaseModel): @@ -35,7 +60,9 @@ async def run( prompt = " ".join( [f"{Message(**msg).role}: {Message(**msg).content}" for msg in messages] ) - sampling_params = SamplingParams(temperature=temperature, top_p=top_p) + sampling_params = SamplingParams( + temperature=temperature, top_p=top_p, max_tokens=max_tokens + ) results_generator = engine.generate(prompt, sampling_params, run_id) previous_text = "" @@ -54,21 +81,17 @@ async def run( "choices": [{"index": 0, "delta": {}, "finish_reason": None}], } - # Include role in the first chunk if first_chunk: chunk["choices"][0]["delta"]["role"] = "assistant" first_chunk = False - # Add new text to delta if there is any if new_text: chunk["choices"][0]["delta"]["content"] = new_text - # Check for a finish_reason finish_reason = prompt_output[0].finish_reason if finish_reason and finish_reason != "none": chunk["choices"][0]["finish_reason"] = finish_reason yield f"data: {json.dumps(chunk)}\n\n" - # After all chunks, send [DONE] yield "data: [DONE]\n\n" diff --git a/5-large-language-models/4-llama-openai-chat-compatible-endpoint/cerebrium.toml b/5-large-language-models/4-llama-openai-chat-compatible-endpoint/cerebrium.toml index c12b9953..a2bd337a 100644 --- a/5-large-language-models/4-llama-openai-chat-compatible-endpoint/cerebrium.toml +++ b/5-large-language-models/4-llama-openai-chat-compatible-endpoint/cerebrium.toml @@ -1,20 +1,28 @@ [cerebrium.deployment] -name = "llm-run-3u" +name = "4-llama-openai-compatible" python_version = "3.11" -docker_base_image_url = "debian:bookworm-slim" +docker_base_image_url = "nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04" +disable_auth = true +deployment_initialization_timeout = 830 include = ["./*", "main.py", "cerebrium.toml"] exclude = [".*"] [cerebrium.hardware] -cpu = 2 -memory = 12.0 -compute = "AMPERE_A10" +cpu = 8 +memory = 60.0 +compute = "ADA_L40" +provider = "aws" +region = "us-east-1" [cerebrium.scaling] min_replicas = 0 max_replicas = 5 -cooldown = 60 +cooldown = 30 [cerebrium.dependencies.pip] vllm = "latest" pydantic = "latest" +huggingface_hub = "latest" + +[cerebrium.experimental] +checkpointing = true diff --git a/5-large-language-models/4-llama-openai-chat-compatible-endpoint/main.py b/5-large-language-models/4-llama-openai-chat-compatible-endpoint/main.py index a2fdae5b..55c7d494 100644 --- a/5-large-language-models/4-llama-openai-chat-compatible-endpoint/main.py +++ b/5-large-language-models/4-llama-openai-chat-compatible-endpoint/main.py @@ -1,20 +1,46 @@ +import http.client +import json import os import time -import json +import urllib.request from huggingface_hub import login from pydantic import BaseModel from vllm import SamplingParams, AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs +CHECKPOINT_URL = "http://169.254.169.253:8234/checkpoint" + + +def _trigger_snapshot() -> None: + print("[init] requesting GPU snapshot", flush=True) + try: + req = urllib.request.Request(CHECKPOINT_URL, method="POST") + urllib.request.urlopen(req, timeout=300) + print("[init] snapshot complete", flush=True) + except http.client.RemoteDisconnected: + # TCP connections disconnect on restore and throw remote + print("[init] snapshot complete (RemoteDisconnected)", flush=True) + except Exception as exc: + print(f"[init] snapshot failed: {type(exc).__name__}: {exc}", flush=True) + + +print("[init] starting", flush=True) login(token=os.environ.get("HF_TOKEN")) engine_args = AsyncEngineArgs( model="meta-llama/Meta-Llama-3.1-8B-Instruct", gpu_memory_utilization=0.9, max_model_len=8192, + async_scheduling=False, ) + +print("[init] building vLLM engine", flush=True) engine = AsyncLLMEngine.from_engine_args(engine_args) +print("[init] vLLM engine ready", flush=True) + +_trigger_snapshot() +print("[init] handler ready", flush=True) class Message(BaseModel): @@ -45,7 +71,6 @@ async def run( top_p: float = 0.95, max_tokens: int = 4096, ): - # Format your prompt for llama-friendly usage: prompt = format_chat_prompt(messages) sampling_params = SamplingParams( @@ -61,36 +86,25 @@ async def run( new_text = prompt_output[0].text[len(previous_text) :] previous_text = prompt_output[0].text - # Construct OpenAI-compatible chunk chunk = { "id": run_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model, - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": None, - } - ], + "choices": [{"index": 0, "delta": {}, "finish_reason": None}], } - # Include the role in the first chunk if first_chunk: chunk["choices"][0]["delta"]["role"] = "assistant" first_chunk = False - # Add new text to the delta if any if new_text: chunk["choices"][0]["delta"]["content"] = new_text - # Capture a finish reason if it's provided - finish_reason = prompt_output[0].finish_reason or None + finish_reason = prompt_output[0].finish_reason if finish_reason and finish_reason != "none": chunk["choices"][0]["finish_reason"] = finish_reason yield f"data: {json.dumps(chunk)}\n\n" - # Send the final [DONE] message yield "data: [DONE]\n\n" diff --git a/5-large-language-models/6-deepseek-sglang/cerebrium.toml b/5-large-language-models/6-deepseek-sglang/cerebrium.toml index fdb9b470..84710322 100644 --- a/5-large-language-models/6-deepseek-sglang/cerebrium.toml +++ b/5-large-language-models/6-deepseek-sglang/cerebrium.toml @@ -1,18 +1,21 @@ [cerebrium.deployment] name = "6-deepseek-sglang" python_version = "3.11" -docker_base_image_url = "debian:bookworm-slim" -disable_auth = false -include = ['./*', 'main.py', 'cerebrium.toml'] -exclude = ['.*'] +docker_base_image_url = "nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04" +disable_auth = true +deployment_initialization_timeout = 830 +include = ["./*", "main.py", "cerebrium.toml"] +exclude = [".*"] pre_build_commands = [ "pip install sglang[all]>=0.4.2.post2 --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer" ] [cerebrium.hardware] -cpu = 4 -memory = 12.0 -compute = "AMPERE_A10" +cpu = 8 +memory = 60.0 +compute = "ADA_L40" +provider = "aws" +region = "us-east-1" [cerebrium.scaling] min_replicas = 0 @@ -25,4 +28,10 @@ scaling_target = 100 [cerebrium.dependencies.pip] huggingface_hub = "latest" -pydantic = "latest" \ No newline at end of file +pydantic = "latest" + +[cerebrium.dependencies.apt] +libnuma-dev = "latest" + +[cerebrium.experimental] +checkpointing = true diff --git a/5-large-language-models/6-deepseek-sglang/main.py b/5-large-language-models/6-deepseek-sglang/main.py index 43e50f34..9a466f84 100644 --- a/5-large-language-models/6-deepseek-sglang/main.py +++ b/5-large-language-models/6-deepseek-sglang/main.py @@ -1,23 +1,47 @@ -# launch the offline engine -from huggingface_hub import login -from sglang import Runtime -from pydantic import BaseModel +import http.client import json -from typing import List, Dict, Any -import time import os +import time +import urllib.request + +from huggingface_hub import login +from sglang import Runtime + +CHECKPOINT_URL = "http://169.254.169.253:8234/checkpoint" os.environ["HF_TRANSFER"] = "1" os.environ["HF_HUB_VERBOSE"] = "1" os.environ["HF_HUB_ENABLE_PROGRESS_BARS"] = "1" + +def _trigger_snapshot() -> None: + print("[init] requesting GPU snapshot", flush=True) + try: + req = urllib.request.Request(CHECKPOINT_URL, method="POST") + urllib.request.urlopen(req, timeout=300) + print("[init] snapshot complete", flush=True) + except http.client.RemoteDisconnected: + # TCP connections disconnect on restore and throw remote + print("[init] snapshot complete (RemoteDisconnected)", flush=True) + except Exception as exc: + print(f"[init] snapshot failed: {type(exc).__name__}: {exc}", flush=True) + + +print("[init] starting", flush=True) login(token=os.environ.get("HF_TOKEN")) -# model_id = "deepseek-ai/DeepSeek-R1" ##uncomment for R1 +# model_id = "deepseek-ai/DeepSeek-R1" # uncomment for R1 model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" + +print("[init] building SGLang runtime", flush=True) runtime = Runtime( - model_path=model_id, tp_size=1 -) # change tp_size=8 if serving R1 on H200 + model_path=model_id, + tp_size=1, # change tp_size=8 if serving R1 on H200 +) +print("[init] SGLang runtime ready", flush=True) + +_trigger_snapshot() +print("[init] handler ready", flush=True) async def run( @@ -29,7 +53,6 @@ async def run( top_p: float = 0.95, max_tokens: int = 4096, ): - sampling_params = {"temperature": temperature, "top_p": top_p} tokenizer = runtime.get_tokenizer() @@ -38,24 +61,15 @@ async def run( ) stream = runtime.add_request(prompt, sampling_params) - full_text = "" first_chunk = True async for output in stream: - full_text += output - chunk = { "id": run_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model, - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": None, - } - ], + "choices": [{"index": 0, "delta": {}, "finish_reason": None}], } if first_chunk: diff --git a/5-large-language-models/7-vision-language-sglang/cerebrium.toml b/5-large-language-models/7-vision-language-sglang/cerebrium.toml index 9ed2903b..5ca4f60b 100644 --- a/5-large-language-models/7-vision-language-sglang/cerebrium.toml +++ b/5-large-language-models/7-vision-language-sglang/cerebrium.toml @@ -1,20 +1,26 @@ [cerebrium.deployment] name = "7-vision-language-sglang" python_version = "3.11" -docker_base_image_url = "nvidia/cuda:12.8.0-devel-ubuntu22.04" +docker_base_image_url = "nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04" +disable_auth = true deployment_initialization_timeout = 830 +include = ["./*", "main.py", "cerebrium.toml"] +exclude = [".*"] +pre_build_commands = [ + "pip install sglang[all]>=0.4.2.post2 --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer" +] [cerebrium.hardware] cpu = 10.0 memory = 60.0 compute = "ADA_L40" +provider = "aws" +region = "us-east-1" [cerebrium.scaling] min_replicas = 0 max_replicas = 2 - -[cerebrium.build] -use_uv = true +cooldown = 30 [cerebrium.dependencies.pip] transformers = "latest" @@ -22,14 +28,14 @@ huggingface_hub = "latest" pydantic = "latest" pillow = "latest" requests = "latest" -torch = "latest" -"sglang[all]" = "latest" -"sgl-kernel" = "latest" -"flashinfer-python" = "latest" [cerebrium.dependencies.apt] libnuma-dev = "latest" +[cerebrium.experimental] +checkpointing = true + [cerebrium.runtime.custom] port = 8000 -entrypoint = ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file +entrypoint = ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] +healthcheck_endpoint = "/health" diff --git a/5-large-language-models/7-vision-language-sglang/main.py b/5-large-language-models/7-vision-language-sglang/main.py index 035fec94..5d23d965 100644 --- a/5-large-language-models/7-vision-language-sglang/main.py +++ b/5-large-language-models/7-vision-language-sglang/main.py @@ -1,27 +1,60 @@ -import sglang as sgl -from sglang import function -from PIL import Image -from fastapi import FastAPI, HTTPException -from transformers import AutoProcessor -from pydantic import BaseModel +import asyncio import base64 +import http.client import io import json +import os +import urllib.request +from contextlib import asynccontextmanager + +import sglang as sgl +from fastapi import FastAPI, HTTPException +from PIL import Image +from pydantic import BaseModel +from sglang import function + +os.environ.setdefault("HF_HOME", "/persistent-storage/hf") +os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") -app = FastAPI(title="Vision Language SGLang API") model_path = "Qwen/Qwen3-VL-30B-A3B-Instruct-FP8" -processor = AutoProcessor.from_pretrained(model_path) +MODEL_DIR = f"/persistent-storage/models/{model_path}" +CHECKPOINT_URL = "http://169.254.169.253:8234/checkpoint" -class AnalyzeRequest(BaseModel): - image_base64: str - ad_description: str - dimensions: list -@app.on_event("startup") -def _startup_warmup(): - # Initialize engine on main thread during app startup +def _ensure_model_downloaded() -> str: + from pathlib import Path + from huggingface_hub import login, snapshot_download + + model_dir = Path(MODEL_DIR) + if model_dir.exists() and any(model_dir.iterdir()): + print(f"[init] using cached weights at {model_dir}", flush=True) + return str(model_dir) + + print(f"[init] downloading {model_path} to {model_dir}", flush=True) + login(token=os.environ.get("HF_TOKEN")) + snapshot_download(model_path, local_dir=str(model_dir)) + return str(model_dir) + + +def _trigger_snapshot() -> None: + print("[init] requesting GPU snapshot", flush=True) + try: + req = urllib.request.Request(CHECKPOINT_URL, method="POST") + urllib.request.urlopen(req, timeout=300) + print("[init] snapshot complete", flush=True) + except http.client.RemoteDisconnected: + # TCP connections disconnect on restore and throw remote + print("[init] snapshot complete (RemoteDisconnected)", flush=True) + except Exception as exc: + print(f"[init] snapshot failed: {type(exc).__name__}: {exc}", flush=True) + + +def _initialize_runtime() -> None: + print("[init] starting", flush=True) + print("[init] building SGLang runtime", flush=True) + local_model_path = _ensure_model_downloaded() runtime = sgl.Runtime( - model_path=model_path, + model_path=local_model_path, enable_multimodal=True, mem_fraction_static=0.8, tp_size=1, @@ -31,6 +64,24 @@ def _startup_warmup(): "qwen2-vl" ) sgl.set_default_backend(runtime) + print("[init] SGLang runtime ready", flush=True) + _trigger_snapshot() + print("[init] handler ready", flush=True) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await asyncio.to_thread(_initialize_runtime) + yield + + +app = FastAPI(title="Vision Language SGLang API", lifespan=lifespan) + + +class AnalyzeRequest(BaseModel): + image_base64: str + ad_description: str + dimensions: list @app.get("/health") @@ -39,6 +90,7 @@ def health(): "status": "healthy", } + def process_image(image_base64: str) -> Image.Image: image_data = base64.b64decode(image_base64) return Image.open(io.BytesIO(image_data)) @@ -96,4 +148,3 @@ def analyze_advertisement(req: AnalyzeRequest): } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) - diff --git a/5-large-language-models/8-faster-inference-with-triton-tensorrt/Dockerfile b/5-large-language-models/8-faster-inference-with-triton-tensorrt/Dockerfile index 9365a425..552ade03 100644 --- a/5-large-language-models/8-faster-inference-with-triton-tensorrt/Dockerfile +++ b/5-large-language-models/8-faster-inference-with-triton-tensorrt/Dockerfile @@ -1,6 +1,7 @@ FROM nvcr.io/nvidia/tritonserver:25.10-trtllm-python-py3 # Environment variables +ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility ENV PYTHONPATH=/usr/local/lib/python3.12/dist-packages:$PYTHONPATH ENV PYTHONDONTWRITEBYTECODE=1 ENV DEBIAN_FRONTEND=noninteractive diff --git a/5-large-language-models/8-faster-inference-with-triton-tensorrt/cerebrium.toml b/5-large-language-models/8-faster-inference-with-triton-tensorrt/cerebrium.toml index 58155a8e..9da9e880 100644 --- a/5-large-language-models/8-faster-inference-with-triton-tensorrt/cerebrium.toml +++ b/5-large-language-models/8-faster-inference-with-triton-tensorrt/cerebrium.toml @@ -2,23 +2,23 @@ name = "tensorrt-triton-demo" python_version = "3.12" disable_auth = true -include = ['./*', 'cerebrium.toml'] -exclude = ['.*'] -deployment_initialization_timeout = 830 +deployment_initialization_timeout = 830 +include = ["./*", "cerebrium.toml"] +exclude = [".*"] [cerebrium.hardware] -cpu = 4.0 -memory = 40.0 -compute = "AMPERE_A10" +cpu = 8.0 +memory = 60.0 +compute = "ADA_L40" gpu_count = 1 provider = "aws" region = "us-east-1" [cerebrium.scaling] -min_replicas = 1 -max_replicas = 5 -cooldown = 300 -replica_concurrency = 128 +min_replicas = 0 +max_replicas = 5 +cooldown = 30 +replica_concurrency = 128 scaling_metric = "concurrency_utilization" [cerebrium.dependencies.pip] @@ -29,4 +29,7 @@ transformers = "latest" port = 8000 healthcheck_endpoint = "/v2/health/live" readycheck_endpoint = "/v2/health/ready" -dockerfile_path = "./Dockerfile" \ No newline at end of file +dockerfile_path = "./Dockerfile" + +[cerebrium.experimental] +checkpointing = true diff --git a/5-large-language-models/8-faster-inference-with-triton-tensorrt/download_model.py b/5-large-language-models/8-faster-inference-with-triton-tensorrt/download_model.py index 2aded4a4..0f70c1d3 100644 --- a/5-large-language-models/8-faster-inference-with-triton-tensorrt/download_model.py +++ b/5-large-language-models/8-faster-inference-with-triton-tensorrt/download_model.py @@ -14,7 +14,7 @@ def download_model(): """Download model from HuggingFace if not already present.""" - hf_token = os.environ.get("HF_AUTH_TOKEN") + hf_token = os.environ.get("HF_AUTH_TOKEN") or os.environ.get("HF_TOKEN") if not hf_token: print("WARNING: HF_AUTH_TOKEN not set, model download may fail") diff --git a/5-large-language-models/8-faster-inference-with-triton-tensorrt/model.py b/5-large-language-models/8-faster-inference-with-triton-tensorrt/model.py index baba21c9..1fd56017 100644 --- a/5-large-language-models/8-faster-inference-with-triton-tensorrt/model.py +++ b/5-large-language-models/8-faster-inference-with-triton-tensorrt/model.py @@ -5,6 +5,10 @@ TensorRT-LLM's PyTorch backend for optimized LLM inference. """ +import http.client +import os +import urllib.request + import numpy as np import triton_python_backend_utils as pb_utils import torch @@ -16,24 +20,49 @@ # Model configuration MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct" MODEL_DIR = f"/persistent-storage/models/{MODEL_ID}" +CHECKPOINT_URL = "http://169.254.169.253:8234/checkpoint" + + +def _trigger_snapshot() -> None: + print("[init] requesting GPU snapshot", flush=True) + try: + req = urllib.request.Request(CHECKPOINT_URL, method="POST") + urllib.request.urlopen(req, timeout=300) + print("[init] snapshot complete", flush=True) + except http.client.RemoteDisconnected: + # TCP connections disconnect on restore and throw remote + print("[init] snapshot complete (RemoteDisconnected)", flush=True) + except Exception as exc: + print(f"[init] snapshot failed: {type(exc).__name__}: {exc}", flush=True) + + +def _has_model_weights(model_path: Path) -> bool: + return any(model_path.glob("*.safetensors")) or any(model_path.glob("*.bin")) def ensure_model_downloaded(): """Check if model exists, download if not available.""" model_path = Path(MODEL_DIR) - - # Check if model directory exists and has content - if not model_path.exists() or not any(model_path.iterdir()): - print("Model not found, downloading...") + + if model_path.exists() and _has_model_weights(model_path): try: - # Import download function from download_model - from download_model import download_model - download_model() - except Exception as e: - print(f"Error downloading model: {e}") - raise - else: - print("✓ Model already exists") + AutoTokenizer.from_pretrained(MODEL_DIR) + print("✓ Model already exists") + return + except Exception as exc: + print(f"Model cache corrupt or incomplete, re-downloading: {exc}") + import shutil + + shutil.rmtree(model_path) + + print("Model not found, downloading...") + try: + from download_model import download_model + + download_model() + except Exception as e: + print(f"Error downloading model: {e}") + raise class TritonPythonModel: @@ -49,14 +78,14 @@ def initialize(self, args): Loads tokenizer and initializes TensorRT-LLM with PyTorch backend. """ + print("[init] starting", flush=True) + # Ensure model is downloaded before loading ensure_model_downloaded() - print("Loading tokenizer...") + print("[init] building TensorRT-LLM engine", flush=True) self.tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) - print("Initializing TensorRT-LLM...") - plugin_config = PluginConfig.from_dict({ "paged_kv_cache": True, # Efficient memory usage for KV cache }) @@ -72,7 +101,15 @@ def initialize(self, args): build_config=build_config, tensor_parallel_size=torch.cuda.device_count(), ) - print("✓ Model ready") + print("[init] TensorRT-LLM engine ready", flush=True) + + # Warm CUDA so the snapshot includes a populated context. + self.llm.generate( + ["Hello"], + SamplingParams(max_tokens=1), + ) + _trigger_snapshot() + print("[init] handler ready", flush=True) def execute(self, requests): """