Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__pycache__/
.pyc
.codex
build
dist
*.egg-info
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.envs_utils import get_env_start_args, enable_dynamic_mtp_verify
from lightllm.server.router.model_infer.mode_backend.dynamic_mtp_planner import DynamicMTPPlanner
from .control_state import ControlState

logger = init_logger(__name__)
Expand All @@ -45,6 +46,9 @@ def __init__(self) -> None:
self.num_mtp_models = 1 if self.is_mtp_eagle else get_env_start_args().mtp_step
self._draft_decode_func = self._draft_decode_eagle if self.is_mtp_eagle else self._draft_decode_vanilla
self.enable_dynamic_mtp = enable_dynamic_mtp_verify()
self.dynamic_mtp_planner = (
DynamicMTPPlanner(max_mtp_step=get_env_start_args().mtp_step) if self.enable_dynamic_mtp else None
)
else:
self.prefill = self.prefill_normal
self.decode = self.decode_normal
Expand Down Expand Up @@ -233,7 +237,15 @@ def decode_mtp(
"""
MTP解码的通用流程,整合eagle和vanilla的共同逻辑
"""
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
mtp_plan = None
if self.enable_dynamic_mtp:
mtp_plan = self.dynamic_mtp_planner.build_plan(decode_reqs)
model_input, run_reqs = prepare_decode_inputs(
decode_reqs,
mtp_decode_indexes=mtp_plan.selected_mtp_indexes,
)
else:
model_input, run_reqs = prepare_decode_inputs(decode_reqs)

with torch.cuda.stream(g_infer_context.get_overlap_stream()):
b_mtp_index_cpu = model_input.b_mtp_index
Expand All @@ -260,32 +272,12 @@ def decode_mtp(
verify_event = torch.cuda.Event()
verify_event.record()

if self.enable_dynamic_mtp:
all_next_token_ids, additional_mem_indexes_cpu, draft_probs_list = self._draft_decode_func(
main_model_input=model_input,
main_model_output=model_output,
next_token_ids=next_token_ids,
b_req_mtp_start_loc=b_req_mtp_start_loc,
)
else:
all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func(
main_model_input=model_input,
main_model_output=model_output,
next_token_ids=next_token_ids,
b_req_mtp_start_loc=b_req_mtp_start_loc,
)

# dynamic_sizes_gpu 用于第二阶段更新 req 的 mtp_size
if self.enable_dynamic_mtp:
draft_probs_tensor = torch.cat(draft_probs_list, dim=-1).view(self.mtp_step, b_mtp_index_cpu.shape[0])
dynamic_sizes_gpu = self._compute_dynamic_mtp_size_gpu_part(draft_probs_tensor=draft_probs_tensor)
# 异步拷贝回 CPU Pin Memory
dynamic_sizes_cpu = g_pin_mem_manager.async_copy_from_gpu_tensor(
key="dynamic_mtp_sizes", gpu_tensor=dynamic_sizes_gpu
)

dynamic_mtp_event = torch.cuda.Event()
dynamic_mtp_event.record()
all_next_token_ids, additional_mem_indexes_cpu = self._draft_decode_func(
main_model_input=model_input,
main_model_output=model_output,
next_token_ids=next_token_ids,
b_req_mtp_start_loc=b_req_mtp_start_loc,
)

mtp_scatter_next_token_ids(
req_to_next_token_ids=self.model.req_manager.req_sampling_params_manager.req_to_next_token_ids,
Expand Down Expand Up @@ -320,11 +312,6 @@ def decode_mtp(
verify_event.synchronize()
accepted_index_cpu_numpy = accepted_index_cpu.numpy()
verify_ok_reqs = [run_reqs[i] for i in range(len(run_reqs)) if accepted_index_cpu_numpy[i] == 1]
if self.enable_dynamic_mtp:
dynamic_mtp_event.synchronize()
self._update_dynamic_mtp_size_cpu_part(
run_reqs=run_reqs, dynamic_sizes_cpu=dynamic_sizes_cpu, accepted_index_cpu=accepted_index_cpu
)
update_packs = self._pre_post_handle(verify_ok_reqs, is_chuncked_mode=False)

# 第三阶段
Expand All @@ -337,6 +324,9 @@ def decode_mtp(
need_free_mem_indexes = torch.cat([need_free_mem_indexes, additional_mem_indexes_cpu], dim=0)

self._update_mtp_accept_ratio(decode_reqs=decode_reqs, mtp_accept_len_cpu=mtp_accept_len_cpu)
if self.enable_dynamic_mtp:
self.dynamic_mtp_planner.update(decode_reqs, mtp_accept_len_cpu)

select_mask = torch.tensor(accepted_index_cpu, dtype=torch.bool, device="cpu")
self._post_handle(
run_reqs=verify_ok_reqs,
Expand All @@ -355,28 +345,6 @@ def decode_mtp(
event_pack.notify_pre_post_handle()
return

def _compute_dynamic_mtp_size_gpu_part(
self,
draft_probs_tensor: torch.Tensor,
) -> torch.Tensor:
rand_vals = torch.rand_like(draft_probs_tensor)
accepted_mask = draft_probs_tensor > rand_vals
valid_steps = torch.cumprod(accepted_mask.to(torch.int32), dim=0)
dynamic_mtp_sizes = valid_steps.sum(dim=0)
return dynamic_mtp_sizes

def _update_dynamic_mtp_size_cpu_part(
self,
run_reqs: List[InferReq],
dynamic_sizes_cpu: torch.Tensor,
accepted_index_cpu: torch.Tensor,
):
assert len(run_reqs) == dynamic_sizes_cpu.shape[0] == accepted_index_cpu.shape[0]
for req, new_size, accepted in zip(run_reqs, dynamic_sizes_cpu.numpy(), accepted_index_cpu.numpy()):
if int(accepted) == 1:
req.current_mtp_step = int(new_size)
assert req.current_mtp_step <= req.mtp_step

def _draft_prefill_forward(self, model_input: ModelInput, model_output: ModelOutput, next_token_ids: torch.Tensor):
# spec prefill: MTP, 这个地方只是为了填充draft model的 kv, 并不会使用生成的token_id。
draft_model_input = model_input
Expand Down Expand Up @@ -442,9 +410,6 @@ def _draft_decode_eagle(
all_next_token_ids = []
all_next_token_ids.append(next_token_ids)

# 用于收集每个 step 的 probs
draft_probs_list = [] if self.enable_dynamic_mtp else None

# process the draft model output
for _step in range(self.mtp_step):
draft_model_input.input_ids = draft_next_token_ids
Expand All @@ -453,12 +418,7 @@ def _draft_decode_eagle(
draft_model_idx = _step % self.num_mtp_models
draft_model_output: ModelOutput = self.draft_models[draft_model_idx].forward(draft_model_input)

# 收集 probs(如果需要)
if self.enable_dynamic_mtp:
draft_next_token_ids, draft_probs = self._gen_argmax_token_ids_and_prob(draft_model_output)
draft_probs_list.append(draft_probs)
else:
draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output)
draft_next_token_ids = self._gen_argmax_token_ids(draft_model_output)
draft_model_input.b_seq_len += 1
draft_model_input.max_kv_seq_len += 1
eagle_mem_indexes_i = eagle_mem_indexes[_step * num_reqs : (_step + 1) * num_reqs]
Expand All @@ -478,7 +438,4 @@ def _draft_decode_eagle(

all_next_token_ids = torch.stack(all_next_token_ids, dim=1) # [batch_size, mtp_step + 1]

if self.enable_dynamic_mtp:
return all_next_token_ids, eagle_mem_indexes_cpu, draft_probs_list
else:
return all_next_token_ids, eagle_mem_indexes_cpu
return all_next_token_ids, eagle_mem_indexes_cpu
112 changes: 112 additions & 0 deletions lightllm/server/router/model_infer/mode_backend/dynamic_mtp_planner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import annotations

import dataclasses
import math
import threading
from typing import List, Sequence, TYPE_CHECKING

if TYPE_CHECKING:
from lightllm.server.router.model_infer.infer_batch import InferReq


# Development-time knobs. Keep these local while the dynamic MTP planner is being
# tuned; move the stable subset to StartArgs once the policy settles.
EMA_ALPHA = 0.2
BUDGET_SCALE = 1.0
MIN_STEP = 1
MAX_STEP = None


@dataclasses.dataclass
class MTPPlan:
planned_steps: List[int]
selected_mtp_indexes: List[List[int]]
budget: int
estimated_step: int
b_req_mtp_start_loc: List[int]


class DynamicMTPPlanner:
"""
Plans a uniform dynamic MTP verification length from historical acceptance.

The plan is intentionally based on already available history so decode
preprocessing does not have to wait for the current draft pass to finish.
"""

def __init__(
self,
max_mtp_step: int,
ema_alpha: float = EMA_ALPHA,
budget_scale: float = BUDGET_SCALE,
min_step: int = MIN_STEP,
max_step: int = None,
) -> None:
assert max_mtp_step >= 0
assert 0.0 < ema_alpha <= 1.0
assert budget_scale > 0.0
self.max_mtp_step = max_mtp_step
self.ema_alpha = ema_alpha
self.budget_scale = budget_scale
self.min_step = max(0, min(min_step, max_mtp_step))
if max_step is None:
max_step = max_mtp_step if MAX_STEP is None else MAX_STEP
self.max_step = max(self.min_step, min(max_step, max_mtp_step))
self._lock = threading.Lock()
self._ema_max_accept_step = float(self.max_step)

def build_plan(self, reqs: Sequence[InferReq]) -> MTPPlan:
req_num = len(reqs)
if req_num == 0:
return MTPPlan(
planned_steps=[],
selected_mtp_indexes=[],
budget=0,
estimated_step=0,
b_req_mtp_start_loc=[],
)

with self._lock:
slot_limit = int(math.ceil(self._ema_max_accept_step * self.budget_scale))

slot_limit = min(max(slot_limit, self.min_step), self.max_step)
planned_steps = [slot_limit for _ in reqs]

selected_mtp_indexes = [list(range(1, step + 1)) for step in planned_steps]

start_locs = []
cur_loc = 0
for selected_indexes in selected_mtp_indexes:
start_locs.append(cur_loc)
cur_loc += 1 + len(selected_indexes)

for req, step in zip(reqs, planned_steps):
req.current_mtp_step = step

return MTPPlan(
planned_steps=planned_steps,
selected_mtp_indexes=selected_mtp_indexes,
budget=sum(planned_steps),
estimated_step=slot_limit,
b_req_mtp_start_loc=start_locs,
)

def update(
self,
reqs: Sequence[InferReq],
mtp_accept_len_cpu,
) -> None:
if not reqs:
return

accept_len_np = mtp_accept_len_cpu.numpy()
max_accept_step = 0
for req_index in range(len(reqs)):
accept_len = int(accept_len_np[req_index])
accept_step = max(0, accept_len - 1)
max_accept_step = max(max_accept_step, min(accept_step, self.max_step))

with self._lock:
self._ema_max_accept_step = (
self.ema_alpha * max_accept_step + (1.0 - self.ema_alpha) * self._ema_max_accept_step
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import numpy as np
from typing import List, Tuple
from typing import List, Optional, Sequence, Tuple
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
from lightllm.common.basemodel.batch_objs import ModelInput
Expand Down Expand Up @@ -94,15 +94,25 @@ def prepare_prefill_inputs(req_objs: List[InferReq], is_chuncked_mode: bool) ->
return model_input, run_reqs


def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[InferReq]]:
def prepare_decode_inputs(
req_objs: List[InferReq],
mtp_decode_steps: Optional[Sequence[int]] = None,
mtp_decode_indexes: Optional[Sequence[Sequence[int]]] = None,
) -> Tuple[ModelInput, List[InferReq]]:
if mtp_decode_steps is not None:
assert len(mtp_decode_steps) == len(req_objs)
if mtp_decode_indexes is not None:
assert mtp_decode_steps is None
assert len(mtp_decode_indexes) == len(req_objs)

run_reqs: List[InferReq] = []
total_token_num = 0
b_req_idx = []
b_mtp_index = []
b_seq_len = []
b_q_seq_len = []
multimodal_params = []
for req in req_objs:
for req_index, req in enumerate(req_objs):
run_reqs.append(req)
b_req_idx.append(req.req_idx)
seq_len = req.get_cur_total_len()
Expand All @@ -113,15 +123,25 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
b_mtp_index.append(0)
multimodal_params.append(req.multimodal_params)
# process the draft tokens.
# 动态 MTP 模式:使用动态 current_mtp_step 构建 batch
# 非动态 MTP 模式:current_mtp_step 为固定的 mtp_step
for step in range(req.current_mtp_step):
# 动态 MTP planner 可以提前给出本轮要填充进验证槽位的 draft index。
# 当前 planner 使用连续 prefix index;后续非连续选择可在该接口后接 compact kernel。
if mtp_decode_indexes is not None:
decode_indexes = [int(index) for index in mtp_decode_indexes[req_index]]
assert decode_indexes == list(range(1, len(decode_indexes) + 1)), (
"Current MTP verify path requires contiguous prefix draft indexes. "
"Non-prefix indexes need a compact/remap kernel before decode."
)
else:
decode_step = req.current_mtp_step if mtp_decode_steps is None else int(mtp_decode_steps[req_index])
decode_indexes = range(1, decode_step + 1)

for mtp_index in decode_indexes:
run_reqs.append(req)
b_req_idx.append(req.req_idx)
seq_len += 1
b_seq_len.append(seq_len)
total_token_num += seq_len
b_mtp_index.append(step + 1)
mtp_seq_len = seq_len + int(mtp_index)
b_seq_len.append(mtp_seq_len)
total_token_num += mtp_seq_len
b_mtp_index.append(int(mtp_index))
multimodal_params.append(req.multimodal_params)
b_q_seq_len.append(1)

Expand Down
Loading