diff --git a/.gitignore b/.gitignore index d572eac42..b1717ce67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__/ .pyc +.codex build dist *.egg-info diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index c41dbb6d9..6b1578ae7 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -25,6 +25,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify +from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner from .control_state import ControlState logger = init_logger(__name__) @@ -45,6 +46,9 @@ def __init__(self) -> None: self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla self.enable_dynamic_mtp = enable_dynamic_mtp_verify() + self.dynamic_mtp_planner = ( + DynamicMTPPlanner(max_mtp_step=get_env_start_args().mtp_step) if self.enable_dynamic_mtp else None + ) else: self.prefill = self.prefill_normal self.decode = self.decode_normal @@ -233,7 +237,15 @@ def decode_mtp( """ MTP解码的通用流程,整合eagle和vanilla的共同逻辑 """ - model_input, run_reqs = prepare_decode_inputs(decode_reqs) + mtp_plan = None + if self.enable_dynamic_mtp: + mtp_plan = self.dynamic_mtp_planner.build_plan(decode_reqs) + model_input, run_reqs = prepare_decode_inputs( + decode_reqs, + mtp_decode_indexes=mtp_plan.selected_mtp_indexes, + ) + else: + model_input, run_reqs = prepare_decode_inputs(decode_reqs) with torch.cuda.stream(g_infer_context.get_overlap_stream()): b_mtp_index_cpu = model_input.b_mtp_index @@ -260,32 +272,12 @@ def decode_mtp( verify_event = torch.cuda.Event() verify_event.record() - if self.enable_dynamic_mtp: - all_next_token_ids, additional_mem_indexes_cpu, draft_probs_list = self._draft_decode_func( - main_model_input=model_input, - main_model_output=model_output, - next_token_ids=next_token_ids, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) - else: - all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( - main_model_input=model_input, - main_model_output=model_output, - next_token_ids=next_token_ids, - b_req_mtp_start_loc=b_req_mtp_start_loc, - ) - - # dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size - if self.enable_dynamic_mtp: - draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0]) - dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor) - # 异步拷贝回 CPU Pin Memory - dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor( - key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu - ) - - dynamic_mtp_event = torch.cuda.Event() - dynamic_mtp_event.record() + all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func( + main_model_input=model_input, + main_model_output=model_output, + next_token_ids=next_token_ids, + b_req_mtp_start_loc=b_req_mtp_start_loc, + ) mtp_scatter_next_token_ids( req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids, @@ -320,11 +312,6 @@ def decode_mtp( verify_event.synchronize() accepted_index_cpu_numpy = accepted_index_cpu.numpy() verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1] - if self.enable_dynamic_mtp: - dynamic_mtp_event.synchronize() - self._update_dynamic_mtp_size_cpu_part( - run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu - ) update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False) # 第三阶段 @@ -337,6 +324,9 @@ def decode_mtp( need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0) self._update_mtp_accept_ratio(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu) + if self.enable_dynamic_mtp: + self.dynamic_mtp_planner.update(decode_reqs, mtp_accept_len_cpu) + select_mask = torch.tensor(accepted_index_cpu, dtype=torch.bool, device="cpu") self._post_handle( run_reqs=verify_ok_reqs, @@ -355,28 +345,6 @@ def decode_mtp( event_pack.notify_pre_post_handle() return - def _compute_dynamic_mtp_size_gpu_part( - self, - draft_probs_tensor: torch.Tensor, - ) -> torch.Tensor: - rand_vals = torch.rand_like(draft_probs_tensor) - accepted_mask = draft_probs_tensor > rand_vals - valid_steps = torch.cumprod(accepted_mask.to(torch.int32), dim=0) - dynamic_mtp_sizes = valid_steps.sum(dim=0) - return dynamic_mtp_sizes - - def _update_dynamic_mtp_size_cpu_part( - self, - run_reqs: List[InferReq], - dynamic_sizes_cpu: torch.Tensor, - accepted_index_cpu: torch.Tensor, - ): - assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0] - for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()): - if int(accepted) == 1: - req.current_mtp_step = int(new_size) - assert req.current_mtp_step <= req.mtp_step - def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor): # spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。 draft_model_input = model_input @@ -442,9 +410,6 @@ def _draft_decode_eagle( all_next_token_ids = [] all_next_token_ids.append(next_token_ids) - # 用于收集每个 step 的 probs - draft_probs_list = [] if self.enable_dynamic_mtp else None - # process the draft model output for _step in range(self.mtp_step): draft_model_input.input_ids = draft_next_token_ids @@ -453,12 +418,7 @@ def _draft_decode_eagle( draft_model_idx = _step % self.num_mtp_models draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input) - # 收集 probs(如果需要) - if self.enable_dynamic_mtp: - draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output) - draft_probs_list.append(draft_probs) - else: - draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) + draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output) draft_model_input.b_seq_len += 1 draft_model_input.max_kv_seq_len += 1 eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs] @@ -478,7 +438,4 @@ def _draft_decode_eagle( all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1] - if self.enable_dynamic_mtp: - return all_next_token_ids, eagle_mem_indexes_cpu, draft_probs_list - else: - return all_next_token_ids, eagle_mem_indexes_cpu + return all_next_token_ids, eagle_mem_indexes_cpu diff --git a/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py new file mode 100644 index 000000000..eb2a51162 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +import dataclasses +import math +import threading +from typing import List, Sequence, TYPE_CHECKING + +if TYPE_CHECKING: + from lightllm.server.router.model_infer.infer_batch import InferReq + + +# Development-time knobs. Keep these local while the dynamic MTP planner is being +# tuned; move the stable subset to StartArgs once the policy settles. +EMA_ALPHA = 0.2 +BUDGET_SCALE = 1.0 +MIN_STEP = 1 +MAX_STEP = None + + +@dataclasses.dataclass +class MTPPlan: + planned_steps: List[int] + selected_mtp_indexes: List[List[int]] + budget: int + estimated_step: int + b_req_mtp_start_loc: List[int] + + +class DynamicMTPPlanner: + """ + Plans a uniform dynamic MTP verification length from historical acceptance. + + The plan is intentionally based on already available history so decode + preprocessing does not have to wait for the current draft pass to finish. + """ + + def __init__( + self, + max_mtp_step: int, + ema_alpha: float = EMA_ALPHA, + budget_scale: float = BUDGET_SCALE, + min_step: int = MIN_STEP, + max_step: int = None, + ) -> None: + assert max_mtp_step >= 0 + assert 0.0 < ema_alpha <= 1.0 + assert budget_scale > 0.0 + self.max_mtp_step = max_mtp_step + self.ema_alpha = ema_alpha + self.budget_scale = budget_scale + self.min_step = max(0, min(min_step, max_mtp_step)) + if max_step is None: + max_step = max_mtp_step if MAX_STEP is None else MAX_STEP + self.max_step = max(self.min_step, min(max_step, max_mtp_step)) + self._lock = threading.Lock() + self._ema_max_accept_step = float(self.max_step) + + def build_plan(self, reqs: Sequence[InferReq]) -> MTPPlan: + req_num = len(reqs) + if req_num == 0: + return MTPPlan( + planned_steps=[], + selected_mtp_indexes=[], + budget=0, + estimated_step=0, + b_req_mtp_start_loc=[], + ) + + with self._lock: + slot_limit = int(math.ceil(self._ema_max_accept_step * self.budget_scale)) + + slot_limit = min(max(slot_limit, self.min_step), self.max_step) + planned_steps = [slot_limit for _ in reqs] + + selected_mtp_indexes = [list(range(1, step + 1)) for step in planned_steps] + + start_locs = [] + cur_loc = 0 + for selected_indexes in selected_mtp_indexes: + start_locs.append(cur_loc) + cur_loc += 1 + len(selected_indexes) + + for req, step in zip(reqs, planned_steps): + req.current_mtp_step = step + + return MTPPlan( + planned_steps=planned_steps, + selected_mtp_indexes=selected_mtp_indexes, + budget=sum(planned_steps), + estimated_step=slot_limit, + b_req_mtp_start_loc=start_locs, + ) + + def update( + self, + reqs: Sequence[InferReq], + mtp_accept_len_cpu, + ) -> None: + if not reqs: + return + + accept_len_np = mtp_accept_len_cpu.numpy() + max_accept_step = 0 + for req_index in range(len(reqs)): + accept_len = int(accept_len_np[req_index]) + accept_step = max(0, accept_len - 1) + max_accept_step = max(max_accept_step, min(accept_step, self.max_step)) + + with self._lock: + self._ema_max_accept_step = ( + self.ema_alpha * max_accept_step + (1.0 - self.ema_alpha) * self._ema_max_accept_step + ) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index 3d9d8815e..c4ae24faf 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -1,6 +1,6 @@ import torch import numpy as np -from typing import List, Tuple +from typing import List, Optional, Sequence, Tuple from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.common.basemodel.batch_objs import ModelInput @@ -94,7 +94,17 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) -> return model_input, run_reqs -def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[InferReq]]: +def prepare_decode_inputs( + req_objs: List[InferReq], + mtp_decode_steps: Optional[Sequence[int]] = None, + mtp_decode_indexes: Optional[Sequence[Sequence[int]]] = None, +) -> Tuple[ModelInput, List[InferReq]]: + if mtp_decode_steps is not None: + assert len(mtp_decode_steps) == len(req_objs) + if mtp_decode_indexes is not None: + assert mtp_decode_steps is None + assert len(mtp_decode_indexes) == len(req_objs) + run_reqs: List[InferReq] = [] total_token_num = 0 b_req_idx = [] @@ -102,7 +112,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_seq_len = [] b_q_seq_len = [] multimodal_params = [] - for req in req_objs: + for req_index, req in enumerate(req_objs): run_reqs.append(req) b_req_idx.append(req.req_idx) seq_len = req.get_cur_total_len() @@ -113,15 +123,25 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_mtp_index.append(0) multimodal_params.append(req.multimodal_params) # process the draft tokens. - # 动态 MTP 模式:使用动态 current_mtp_step 构建 batch - # 非动态 MTP 模式:current_mtp_step 为固定的 mtp_step - for step in range(req.current_mtp_step): + # 动态 MTP planner 可以提前给出本轮要填充进验证槽位的 draft index。 + # 当前 planner 使用连续 prefix index;后续非连续选择可在该接口后接 compact kernel。 + if mtp_decode_indexes is not None: + decode_indexes = [int(index) for index in mtp_decode_indexes[req_index]] + assert decode_indexes == list(range(1, len(decode_indexes) + 1)), ( + "Current MTP verify path requires contiguous prefix draft indexes. " + "Non-prefix indexes need a compact/remap kernel before decode." + ) + else: + decode_step = req.current_mtp_step if mtp_decode_steps is None else int(mtp_decode_steps[req_index]) + decode_indexes = range(1, decode_step + 1) + + for mtp_index in decode_indexes: run_reqs.append(req) b_req_idx.append(req.req_idx) - seq_len += 1 - b_seq_len.append(seq_len) - total_token_num += seq_len - b_mtp_index.append(step + 1) + mtp_seq_len = seq_len + int(mtp_index) + b_seq_len.append(mtp_seq_len) + total_token_num += mtp_seq_len + b_mtp_index.append(int(mtp_index)) multimodal_params.append(req.multimodal_params) b_q_seq_len.append(1) diff --git a/test/benchmark/service/benchmark_sharegpt.py b/test/benchmark/service/benchmark_sharegpt.py index b056e69bd..9337e6527 100644 --- a/test/benchmark/service/benchmark_sharegpt.py +++ b/test/benchmark/service/benchmark_sharegpt.py @@ -215,7 +215,7 @@ async def send_request( "top_k": 1, "top_p": 1.0, "temperature": 0, - "stream": True, + # "stream": True, "ignore_eos": True, "max_tokens": output_len, } @@ -224,20 +224,41 @@ async def send_request( async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post(url, headers=headers, json=data) as response: + response.raise_for_status() chunks = [] text = "" start_time = time.time() is_first = True + sse_buffer = "" async for chunk, _ in response.content.iter_chunks(): now_time = time.time() delta_time = now_time - start_time if is_first: is_first = False ttft = delta_time - text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") - if delta_time < 0.005: - receive_n += 1 chunks.append(delta_time) + # OpenAI-compatible stream is SSE; one TCP chunk may contain + # partial/multiple events. Parse by complete lines safely. + sse_buffer += chunk.decode("utf-8", errors="ignore") + while "\n" in sse_buffer: + line, sse_buffer = sse_buffer.split("\n", 1) + line = line.strip() + if not line or not line.startswith("data:"): + continue + payload = line[5:].strip() + if payload == "[DONE]": + break + if not payload: + continue + try: + event = json.loads(payload) + except json.JSONDecodeError: + # In rare cases malformed/partial payload slips in; + # skip and continue to keep benchmark running. + continue + text += event.get("choices", [{}])[0].get("delta", {}).get("content", "") + if delta_time < 0.005: + receive_n += 1 start_time = now_time # print("messages", messages) # print("text", text) diff --git a/test/speculative/bench_throughput.sh b/test/speculative/bench_throughput.sh index 8e14f8189..4cfa90bcb 100644 --- a/test/speculative/bench_throughput.sh +++ b/test/speculative/bench_throughput.sh @@ -2,7 +2,7 @@ # 默认值 PORT=8088 NUM_PROMPTS=1000 -TOKENIZER="/mtc/models/qwen3-8b" +TOKENIZER="/mtc/models/qwen3-32b" DATASET="/data/nvme0/chenjunyi/project/lightllm/datasets/gsm8k.json" HISTORY_TURNS=1 CONCURRENCY=128 diff --git a/test/speculative/qwen3-32b/dynamic_triton.sh b/test/speculative/qwen3-32b/dynamic_triton.sh index 39145e5f5..ca70bf15e 100644 --- a/test/speculative/qwen3-32b/dynamic_triton.sh +++ b/test/speculative/qwen3-32b/dynamic_triton.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --disable_dynamic_prompt_cache \ diff --git a/test/speculative/qwen3-32b/no_mtp_fa3.sh b/test/speculative/qwen3-32b/no_mtp_fa3.sh new file mode 100644 index 000000000..c17562721 --- /dev/null +++ b/test/speculative/qwen3-32b/no_mtp_fa3.sh @@ -0,0 +1,11 @@ +MODEL_DIR=/mtc/models/qwen3-32b +DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 + +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ +--model_dir ${MODEL_DIR} \ +--disable_dynamic_prompt_cache \ +--graph_grow_step_size 1 \ +--llm_decode_att_backend triton \ No newline at end of file diff --git a/test/speculative/qwen3-32b/static_fa3.sh b/test/speculative/qwen3-32b/static_fa3.sh index c9712116e..44c67e03b 100644 --- a/test/speculative/qwen3-32b/static_fa3.sh +++ b/test/speculative/qwen3-32b/static_fa3.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --mtp_draft_model_dir ${DRAFT_MODEL_DIR} \ diff --git a/test/speculative/qwen3-32b/static_triton.sh b/test/speculative/qwen3-32b/static_triton.sh index 453c5678e..71964c9af 100644 --- a/test/speculative/qwen3-32b/static_triton.sh +++ b/test/speculative/qwen3-32b/static_triton.sh @@ -16,8 +16,10 @@ done MODEL_DIR=/mtc/models/qwen3-32b DRAFT_MODEL_DIR=/mtc/models/qwen3-32b-eagle3 -LOADWORKER=18 python -m lightllm.server.api_server --port 8088 \ ---tp 4 --max_total_token_num 200000 \ +PATH=/data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin:$PATH + +LOADWORKER=18 /data/nvme0/chenjunyi/miniconda3/envs/lightllm/bin/python -m lightllm.server.api_server --port 8088 \ +--tp 2 \ --model_dir ${MODEL_DIR} \ --mtp_mode eagle3 \ --disable_dynamic_prompt_cache \ diff --git a/test/speculative/run_vllm_speculative_baseline.sh b/test/speculative/run_vllm_speculative_baseline.sh new file mode 100755 index 000000000..f2027f20c --- /dev/null +++ b/test/speculative/run_vllm_speculative_baseline.sh @@ -0,0 +1,298 @@ +#!/bin/bash + +# ============================================================================= +# vLLM Speculative Decoding Baseline Experiment Script +# Function: Run vLLM default draft-model speculative decoding baseline for +# different mtp steps (mapped to num_speculative_tokens), and collect +# throughput/latency metrics with the same benchmark script. +# ============================================================================= + +set -euo pipefail + +# Keep default GPU visibility aligned with existing LightLLM experiment scripts. +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-1,2,3,4,6}" +# Reduce allocator fragmentation risk during model warmup. +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# ============================================================================= +# Configurable Parameters +# ============================================================================= +PROJECT_DIR="/data/nvme0/chenjunyi/project/lightllm" +BENCH_PY_SCRIPT="${PROJECT_DIR}/test/benchmark/service/benchmark_sharegpt.py" +DATASET="${PROJECT_DIR}/datasets/gsm8k.json" + +# Keep defaults close to existing LightLLM qwen3-32b setup. +MODEL_DIR="/mtc/models/qwen3-32b" +DRAFT_MODEL_DIR="/mtc/models/qwen3-32b-eagle3" +TOKENIZER="/mtc/models/qwen3-32b" + +SAMPLES=1000 +CONCURRENCY=256 +PORT=8088 +TP=4 +MAX_MODEL_LEN=16384 +MAX_NUM_BATCHED_TOKENS=200000 +MAX_NUM_SEQS=256 +GPU_MEMORY_UTILIZATION=0.6 +MAX_CUDAGRAPH_CAPTURE_SIZE=256 +ATTENTION_BACKEND="FLASH_ATTN" +DISABLE_CUSTOM_ALL_REDUCE=1 +MTP_STEPS=(5) + +RESULTS_DIR="${PROJECT_DIR}/experiment_results" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +DATASET_NAME=$(basename "${DATASET}" .json) +EXPERIMENT_SUBDIR="${RESULTS_DIR}/${DATASET_NAME}_${TIMESTAMP}_vllm_spec_default" +RESULTS_FILE="${EXPERIMENT_SUBDIR}/results.csv" + +usage() { + echo "Usage: $0 [options]" + echo "" + echo "Options:" + echo " --model-dir PATH Main model path (default: ${MODEL_DIR})" + echo " --draft-model-dir PATH Draft model path (default: ${DRAFT_MODEL_DIR})" + echo " --dataset PATH Dataset path (default: ${DATASET})" + echo " --tokenizer PATH Tokenizer path (default: ${TOKENIZER})" + echo " --samples NUM Number of prompts (default: ${SAMPLES})" + echo " --concurrency NUM Concurrency (default: ${CONCURRENCY})" + echo " --port PORT Service port (default: ${PORT})" + echo " --tp NUM Tensor parallel size (default: ${TP})" + echo " --mtp-steps LIST Comma-separated mtp steps (default: 5)" + echo " --num-speculative-tokens NUM Backward-compatible alias, equals one mtp step" + echo " --max-model-len NUM vLLM max model len (default: ${MAX_MODEL_LEN})" + echo " --max-num-batched-tokens NUM vLLM max batched tokens (default: ${MAX_NUM_BATCHED_TOKENS})" + echo " --max-num-seqs NUM vLLM max number of concurrent seqs (default: ${MAX_NUM_SEQS})" + echo " --max-cudagraph-capture-size NUM vLLM max cudagraph capture size (default: ${MAX_CUDAGRAPH_CAPTURE_SIZE})" + echo " --gpu-memory-utilization F GPU memory utilization (default: ${GPU_MEMORY_UTILIZATION})" + echo " --attention-backend NAME vLLM attention backend (default: ${ATTENTION_BACKEND})" + echo " --enable-custom-all-reduce Enable custom all-reduce (default: disabled)" + echo " --results-dir DIR Results base dir (default: ${RESULTS_DIR})" + echo " --help Show this help" + exit 1 +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --model-dir) + MODEL_DIR="$2" + shift 2 + ;; + --draft-model-dir) + DRAFT_MODEL_DIR="$2" + shift 2 + ;; + --dataset) + DATASET="$2" + shift 2 + ;; + --tokenizer) + TOKENIZER="$2" + shift 2 + ;; + --samples) + SAMPLES="$2" + shift 2 + ;; + --concurrency) + CONCURRENCY="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --tp) + TP="$2" + shift 2 + ;; + --mtp-steps) + IFS=',' read -ra MTP_STEPS <<< "$2" + shift 2 + ;; + --num-speculative-tokens) + MTP_STEPS=("$2") + shift 2 + ;; + --max-model-len) + MAX_MODEL_LEN="$2" + shift 2 + ;; + --max-num-batched-tokens) + MAX_NUM_BATCHED_TOKENS="$2" + shift 2 + ;; + --max-num-seqs) + MAX_NUM_SEQS="$2" + shift 2 + ;; + --max-cudagraph-capture-size) + MAX_CUDAGRAPH_CAPTURE_SIZE="$2" + shift 2 + ;; + --gpu-memory-utilization) + GPU_MEMORY_UTILIZATION="$2" + shift 2 + ;; + --attention-backend) + ATTENTION_BACKEND="$2" + shift 2 + ;; + --enable-custom-all-reduce) + DISABLE_CUSTOM_ALL_REDUCE=0 + shift 1 + ;; + --results-dir) + RESULTS_DIR="$2" + shift 2 + ;; + --help) + usage + ;; + *) + echo "Unknown argument: $1" + usage + ;; + esac +done + +# Recompute result paths in case dataset/results-dir was overridden. +DATASET_NAME=$(basename "${DATASET}" .json) +EXPERIMENT_SUBDIR="${RESULTS_DIR}/${DATASET_NAME}_${TIMESTAMP}_vllm_spec_default" +RESULTS_FILE="${EXPERIMENT_SUBDIR}/results.csv" + +mkdir -p "${EXPERIMENT_SUBDIR}" + +echo "timestamp,engine,mode,mtp_step,dataset,samples,concurrency,throughput,avg_latency,avg_ttft,avg_inter_token_latency" > "${RESULTS_FILE}" + +wait_for_server() { + local max_attempts=600 + local attempt=0 + echo "Waiting for vLLM server to start..." + while [[ ${attempt} -lt ${max_attempts} ]]; do + if curl -s "http://localhost:${PORT}/health" > /dev/null 2>&1; then + echo "vLLM server started" + return 0 + fi + sleep 2 + attempt=$((attempt + 1)) + done + echo "vLLM server startup timeout" + return 1 +} + +extract_benchmark_metrics() { + local log_file="$1" + local throughput="" + local avg_latency="" + local avg_ttft="" + local avg_inter_token_latency="" + + throughput=$(grep -oP 'Throughput: \K[\d.]+' "$log_file" | tail -1) + avg_latency=$(grep -oP 'Average latency: \K[\d.]+' "$log_file" | tail -1) + avg_ttft=$(grep -oP 'Average time to first token: \K[\d.]+' "$log_file" | tail -1) + avg_inter_token_latency=$(grep -oP 'Average inter-token latency: \K[\d.]+' "$log_file" | tail -1) + + echo "${throughput:-NA},${avg_latency:-NA},${avg_ttft:-NA},${avg_inter_token_latency:-NA}" +} + +kill_vllm() { + echo "Stopping vLLM server..." + pkill -9 -f "vllm serve" 2>/dev/null || true + pkill -9 -f "vllm.entrypoints.openai.api_server" 2>/dev/null || true + sleep 1 + echo "vLLM server stopped" +} + +trap 'kill_vllm' EXIT + +echo "==============================================" +echo "vLLM Speculative Baseline Started" +echo "==============================================" +echo "Model: ${MODEL_DIR}" +echo "Draft model: ${DRAFT_MODEL_DIR}" +echo "Tokenizer: ${TOKENIZER}" +echo "Dataset: ${DATASET}" +echo "Samples: ${SAMPLES}" +echo "Concurrency: ${CONCURRENCY}" +echo "TP: ${TP}" +echo "Port: ${PORT}" +echo "Max model len: ${MAX_MODEL_LEN}" +echo "Max batched tokens: ${MAX_NUM_BATCHED_TOKENS}" +echo "Max num seqs: ${MAX_NUM_SEQS}" +echo "Max cudagraph capture size: ${MAX_CUDAGRAPH_CAPTURE_SIZE}" +echo "GPU memory utilization: ${GPU_MEMORY_UTILIZATION}" +echo "Attention backend: ${ATTENTION_BACKEND}" +echo "Disable custom all reduce: ${DISABLE_CUSTOM_ALL_REDUCE}" +echo "MTP steps: ${MTP_STEPS[*]}" +echo "Results directory: ${EXPERIMENT_SUBDIR}" +echo "==============================================" + +for MTP_STEP in "${MTP_STEPS[@]}"; do + echo "" + echo "--- Running mtp step: ${MTP_STEP} ---" + + LOG_FILE="${EXPERIMENT_SUBDIR}/log_vllm_spec_default_step${MTP_STEP}_${TIMESTAMP}.txt" + BENCH_LOG="${EXPERIMENT_SUBDIR}/bench_vllm_spec_default_step${MTP_STEP}_${TIMESTAMP}.txt" + + SPECULATIVE_CONFIG=$(printf '{"model": "%s", "num_speculative_tokens": %s, "method": "draft_model"}' \ + "${DRAFT_MODEL_DIR}" "${MTP_STEP}") + CUSTOM_ALL_REDUCE_FLAG="" + if [[ "${DISABLE_CUSTOM_ALL_REDUCE}" == "1" ]]; then + CUSTOM_ALL_REDUCE_FLAG="--disable-custom-all-reduce" + fi + + kill_vllm + + echo "Starting vLLM server with speculative_config=${SPECULATIVE_CONFIG}" + ( + vllm serve "${MODEL_DIR}" \ + --host 0.0.0.0 \ + --port "${PORT}" \ + --served-model-name DeepSeek-R1 \ + -tp "${TP}" \ + --max_model_len "${MAX_MODEL_LEN}" \ + --max_num_batched_tokens "${MAX_NUM_BATCHED_TOKENS}" \ + --max_num_seqs "${MAX_NUM_SEQS}" \ + --max-cudagraph-capture-size "${MAX_CUDAGRAPH_CAPTURE_SIZE}" \ + --attention-backend "${ATTENTION_BACKEND}" \ + ${CUSTOM_ALL_REDUCE_FLAG} \ + --speculative_config "${SPECULATIVE_CONFIG}" + ) > "${LOG_FILE}" 2>&1 & + + SERVER_PID=$! + echo "vLLM PID: ${SERVER_PID}" + + if ! wait_for_server; then + echo "vLLM server failed to start for mtp step ${MTP_STEP}. Check log: ${LOG_FILE}" + RESULT_LINE="${TIMESTAMP},vllm,speculative_draft_model_default,${MTP_STEP},${DATASET},${SAMPLES},${CONCURRENCY},NA,NA,NA,NA" + echo "${RESULT_LINE}" >> "${RESULTS_FILE}" + continue + fi + + sleep 5 + + echo "Running benchmark with benchmark_sharegpt.py (OpenAI API mode)..." + python "${BENCH_PY_SCRIPT}" \ + --use_openai_api \ + --port "${PORT}" \ + --num-prompts "${SAMPLES}" \ + --tokenizer "${TOKENIZER}" \ + --dataset "${DATASET}" \ + --history-turns 1 \ + --concurrency "${CONCURRENCY}" 2>&1 | tee "${BENCH_LOG}" + + cat "${BENCH_LOG}" >> "${LOG_FILE}" + + BENCH_METRICS=$(extract_benchmark_metrics "${LOG_FILE}") + RESULT_LINE="${TIMESTAMP},vllm,speculative_draft_model_default,${MTP_STEP},${DATASET},${SAMPLES},${CONCURRENCY},${BENCH_METRICS}" + echo "${RESULT_LINE}" >> "${RESULTS_FILE}" + + echo "Completed mtp step ${MTP_STEP}: ${RESULT_LINE}" +done + +echo "" +echo "==============================================" +echo "All Experiments Completed" +echo "==============================================" +echo "Results file: ${RESULTS_FILE}" +cat "${RESULTS_FILE}"