Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions lightllm/common/basemodel/attention/triton/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ def prefill_att(
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
) -> torch.Tensor:
assert att_control.use_sliding_window is False and att_control.use_att_sink is False
if att_control.use_alibi:
assert att_control.use_sliding_window is False, "alibi + sliding_window not supported"
assert att_control.tp_alibi is not None
return self._alibi_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)
else:
return self._nomarl_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func)
return self._nomarl_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)

def _alibi_prefill_att(
self,
Expand Down Expand Up @@ -59,9 +59,21 @@ def _alibi_prefill_att(
)
return out

def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty):
def _nomarl_prefill_att(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
):
from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd

if att_control.use_sliding_window:
sliding_window = int(att_control.sliding_window[0])
else:
sliding_window = -1

out = alloc_func(q.shape, q.dtype)
context_attention_fwd(
q,
Expand All @@ -74,6 +86,7 @@ def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
self.infer_state.b_ready_cache_len,
self.infer_state.max_q_seq_len,
self.infer_state.req_manager.req_to_token_indexs,
sliding_window=sliding_window,
)
return out

Expand All @@ -94,8 +107,8 @@ def decode_att(
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
):
assert att_control.use_sliding_window is False and att_control.use_att_sink is False
if att_control.use_alibi:
assert att_control.use_sliding_window is False, "alibi + sliding_window not supported"
assert att_control.tp_alibi is not None
return self._alibi_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func)
else:
Expand All @@ -104,7 +117,9 @@ def decode_att(
if q_head_num == k_head_num:
return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func)
elif q_head_num > k_head_num:
return self._normal_decode_gqa_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func)
return self._normal_decode_gqa_flash_decoding_att(
q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func
)
else:
raise NotImplementedError("error")

Expand Down Expand Up @@ -163,12 +178,18 @@ def _normal_decode_gqa_flash_decoding_att(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
att_control: AttControl = AttControl(),
alloc_func=torch.empty,
):
from ...triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding import (
gqa_token_decode_attention_flash_decoding,
)

if att_control.use_sliding_window:
sliding_window = int(att_control.sliding_window[0])
else:
sliding_window = -1

out = alloc_func(q.shape, q.dtype)

gqa_token_decode_attention_flash_decoding(
Expand All @@ -178,6 +199,7 @@ def _normal_decode_gqa_flash_decoding_att(
cache_v=v,
out=out,
alloc_tensor_func=alloc_func,
sliding_window=sliding_window,
)

return out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,13 @@ def _context_attention_wrapper_run(
) -> torch.Tensor:
if torch.cuda.is_current_stream_capturing():
q = q.contiguous()
cache_kv = cache_kv.contiguous()
_q, _cache_kv = (
tensor_to_no_ref_tensor(q),
tensor_to_no_ref_tensor(cache_kv),
)
# cache_kv is None for layers that own no K/V slot (e.g. gemma4
# KV-shared layers, which read K/V from a prior layer's cache and
# ignore this arg in _context_attention_kernel). Skip the
# graph-input plumbing for it instead of crashing on None.
cache_kv = cache_kv.contiguous() if cache_kv is not None else None
_q = tensor_to_no_ref_tensor(q)
_cache_kv = tensor_to_no_ref_tensor(cache_kv) if cache_kv is not None else None
pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph()
pre_capture_graph.__exit__(None, None, None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ def __init__(
num_fused_shared_experts: int = 0,
layer_num: int = 0,
network_config: Dict[str, Any] = None,
per_expert_scale_name: str = "",
) -> None:
super().__init__(data_type=data_type)
self.w1_weight_name = gate_proj_name
self.w2_weight_name = down_proj_name
self.w3_weight_name = up_proj_name
self.e_score_correction_bias_name = e_score_correction_bias_name
self.per_expert_scale_name = per_expert_scale_name
self.weight_prefix = weight_prefix
self.layer_num_ = layer_num
self.global_rank_ = get_global_rank()
Expand Down Expand Up @@ -130,6 +132,8 @@ def experts(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
per_expert_scale: Optional[torch.Tensor] = None,
use_gelu: bool = False,
) -> torch.Tensor:
"""Backward compatible method that routes to platform-specific implementation."""
return self.fuse_moe_impl(
Expand All @@ -145,6 +149,8 @@ def experts(
topk_group=topk_group,
num_expert_group=num_expert_group,
is_prefill=is_prefill,
per_expert_scale=per_expert_scale,
use_gelu=use_gelu,
)

def low_latency_dispatch(
Expand Down Expand Up @@ -263,23 +269,36 @@ def load_hf_weights(self, weights):
# Load bias
if self.e_score_correction_bias_name in weights:
self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name])
self._load_per_expert_scale(weights)
self._load_weight(self.expert_idx_to_local_idx, weights)
if self.redundancy_expert_num > 0:
self._load_weight(self.redundancy_expert_idx_to_local_idx, weights)

def verify_load(self):
return all(all(_weight_pack.load_ok) for _weight_pack in self.w1_list + self.w2_list + self.w3_list)
weight_load_ok = all(all(_weight_pack.load_ok) for _weight_pack in self.w1_list + self.w2_list + self.w3_list)
per_expert_scale_load_ok = (
True if self.per_expert_scale is None else getattr(self.per_expert_scale, "load_ok", False)
)
return weight_load_ok and per_expert_scale_load_ok

def _create_weight(self):
intermediate_size = self.split_inter_size
self.e_score_correction_bias = None
self.per_expert_scale = None
# Create e_score_correction_bias
if self.e_score_correction_bias_name:
self.e_score_correction_bias = torch.empty(
(self.n_routed_experts,),
dtype=self.data_type_,
device=f"cuda:{self.device_id_}",
)
if self.per_expert_scale_name:
self.per_expert_scale = torch.empty(
(self.n_routed_experts,),
dtype=torch.float32,
device=f"cuda:{self.device_id_}",
)
self.per_expert_scale.load_ok = False

self.w13, w13_param_list = self.quant_method.create_moe_weight(
out_dims=[intermediate_size, intermediate_size],
Expand All @@ -299,6 +318,11 @@ def _create_weight(self):
self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1])
self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2)

def _load_per_expert_scale(self, weights: Dict[str, torch.Tensor]):
if self.per_expert_scale_name and self.per_expert_scale_name in weights:
self.per_expert_scale.copy_(weights[self.per_expert_scale_name].to(self.per_expert_scale.dtype))
self.per_expert_scale.load_ok = True

def _get_expert_weight_list(self, weight_pack: WeightPack):
weight_list = []
for idx in range(self.local_n_routed_experts):
Expand All @@ -307,7 +331,6 @@ def _get_expert_weight_list(self, weight_pack: WeightPack):
return weight_list

def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]):

# Load each expert with TP slicing
for expert_idx, local_expert_idx in expert_idx_to_local_idx.items():
with self.lock:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight import FusedMoeWeight


class Gemma4PackedFusedMoeWeight(FusedMoeWeight):
def load_hf_weights(self, weights):
gate_up_name = f"{self.weight_prefix}.gate_up_proj"
down_name = f"{self.weight_prefix}.down_proj"
if gate_up_name not in weights and down_name not in weights and self.per_expert_scale_name not in weights:
return super().load_hf_weights(weights)

assert self.quant_method.method_name == "none", "Gemma-4 packed MoE currently supports bf16/no-quant weights."
assert not self.enable_ep_moe, "Gemma-4 packed MoE currently supports TP mode only."

start = self.split_inter_size * self.tp_rank_
end = self.split_inter_size * (self.tp_rank_ + 1)
moe_intermediate_size = self.moe_intermediate_size

if gate_up_name in weights:
gate_up_weight = weights[gate_up_name]
for expert_idx, local_expert_idx in self.expert_idx_to_local_idx.items():
gate_weight = gate_up_weight[expert_idx, start:end, :].contiguous()
up_weight = gate_up_weight[
expert_idx, moe_intermediate_size + start : moe_intermediate_size + end, :
].contiguous()
Comment on lines +21 to +24
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .contiguous() on slices of gate_up_weight creates additional copies of the expert weights in memory during the loading process. If the model has a large number of experts or a high intermediate dimension, this could significantly increase the peak memory usage of the loader. If self.quant_method.load_weight can handle non-contiguous tensors, these calls should be removed.

self.quant_method.load_weight(gate_weight, self.w1_list[local_expert_idx])
self.quant_method.load_weight(up_weight, self.w3_list[local_expert_idx])

if down_name in weights:
down_weight = weights[down_name]
for expert_idx, local_expert_idx in self.expert_idx_to_local_idx.items():
down_weight_slice = down_weight[expert_idx, :, start:end].contiguous()
self.quant_method.load_weight(down_weight_slice, self.w2_list[local_expert_idx])

self._load_per_expert_scale(weights)
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,7 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
per_expert_scale: Optional[torch.Tensor] = None,
use_gelu: bool = False,
) -> torch.Tensor:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _select_experts(
topk_group: int,
num_expert_group: int,
scoring_func: str,
per_expert_scale: Optional[torch.Tensor] = None,
):
"""Select experts and return topk weights and ids."""
from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts
Expand All @@ -48,6 +49,8 @@ def _select_experts(
)
if self.routed_scaling_factor != 1.0:
topk_weights.mul_(self.routed_scaling_factor)
if per_expert_scale is not None:
topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype)
if self.redundancy_expert_num > 0:
redundancy_topk_ids_repair(
topk_ids=topk_ids,
Expand All @@ -68,8 +71,8 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
use_gelu: bool = False,
):

w13_weight, w13_scale = w13.weight, w13.weight_scale
w2_weight, w2_scale = w2.weight, w2.weight_scale
use_fp8_w8a8 = self.quant_method.method_name != "none"
Expand All @@ -88,6 +91,7 @@ def _fused_experts(
w1_scale=w13_scale,
w2_scale=w2_scale,
previous_event=None, # for overlap
use_gelu=use_gelu,
)
return output

Expand Down Expand Up @@ -210,11 +214,20 @@ def masked_group_gemm(
masked_m: torch.Tensor,
dtype: torch.dtype,
expected_m: int,
use_gelu: bool = False,
):
w13_weight, w13_scale = w13.weight, w13.weight_scale
w2_weight, w2_scale = w2.weight, w2.weight_scale
return masked_group_gemm(
recv_x, masked_m, dtype, w13_weight, w13_scale, w2_weight, w2_scale, expected_m=expected_m
recv_x,
masked_m,
dtype,
w13_weight,
w13_scale,
w2_weight,
w2_scale,
expected_m=expected_m,
use_gelu=use_gelu,
)

def prefilled_group_gemm(
Expand All @@ -226,6 +239,7 @@ def prefilled_group_gemm(
w13: WeightPack,
w2: WeightPack,
hidden_dtype=torch.bfloat16,
use_gelu: bool = False,
):
device = recv_x[0].device
w13_weight, w13_scale = w13.weight, w13.weight_scale
Expand Down Expand Up @@ -278,7 +292,7 @@ def prefilled_group_gemm(
# TODO fused kernel
silu_out = torch.empty((all_tokens, N // 2), device=device, dtype=hidden_dtype)

silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out)
silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out, use_gelu=use_gelu)
qsilu_out, qsilu_out_scale = per_token_group_quant_fp8(
silu_out, block_size, dtype=w13_weight.dtype, column_major_scales=True, scale_tma_aligned=True
)
Expand All @@ -298,7 +312,7 @@ def prefilled_group_gemm(
if Autotuner.is_autotune_warmup():
_gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype)
_silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype)
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out, use_gelu=use_gelu)
_gemm_out_a, _silu_out = None, None

return gather_out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = None,
use_gelu: bool = False,
):
assert not use_gelu, "FuseMoeMarlin does not support GELU expert activation."

w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point
w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _select_experts(
topk_group: int,
num_expert_group: int,
scoring_func: str,
per_expert_scale: Optional[torch.Tensor] = None,
):
"""Select experts and return topk weights and ids."""
from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts
Expand All @@ -59,6 +60,8 @@ def _select_experts(
)
if self.routed_scaling_factor != 1.0:
topk_weights.mul_(self.routed_scaling_factor)
if per_expert_scale is not None:
topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype)
if self.num_fused_shared_experts > 0:
pad_topk_ids = (
torch.arange(
Expand Down Expand Up @@ -91,6 +94,7 @@ def _fused_experts(
topk_ids: torch.Tensor,
router_logits: Optional[torch.Tensor] = None,
is_prefill: bool = False,
use_gelu: bool = False,
):
w13_weight, w13_scale = w13.weight, w13.weight_scale
w2_weight, w2_scale = w2.weight, w2.weight_scale
Expand All @@ -108,6 +112,7 @@ def _fused_experts(
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w13_scale,
w2_scale=w2_scale,
use_gelu=use_gelu,
)
return input_tensor

Expand All @@ -125,6 +130,8 @@ def __call__(
topk_group: int,
num_expert_group: int,
is_prefill: Optional[bool] = None,
per_expert_scale: Optional[torch.Tensor] = None,
use_gelu: bool = False,
):
topk_weights, topk_ids = self._select_experts(
input_tensor=input_tensor,
Expand All @@ -136,6 +143,7 @@ def __call__(
topk_group=topk_group,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
per_expert_scale=per_expert_scale,
)
output = self._fused_experts(
input_tensor=input_tensor,
Expand All @@ -145,5 +153,6 @@ def __call__(
topk_ids=topk_ids,
router_logits=router_logits,
is_prefill=is_prefill,
use_gelu=use_gelu,
)
return output
Loading
Loading