From 21c3eebecf47d071be75fe1c0646617ba23fd403 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 30 Apr 2026 02:47:40 +0000 Subject: [PATCH 01/20] support 31B --- .../gqa_flash_decoding_stage1.py | 2 +- .../gqa_flash_decoding_stage2.py | 2 +- .../context_flashattention_nopad.py | 27 +- lightllm/models/__init__.py | 1 + lightllm/models/gemma4/__init__.py | 0 lightllm/models/gemma4/infer_struct.py | 33 ++ .../models/gemma4/layer_infer/__init__.py | 0 .../gemma4/layer_infer/post_layer_infer.py | 22 ++ .../gemma4/layer_infer/pre_layer_infer.py | 24 ++ .../layer_infer/transformer_layer_infer.py | 309 ++++++++++++++++++ .../models/gemma4/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 36 ++ .../layer_weights/transformer_layer_weight.py | 182 +++++++++++ lightllm/models/gemma4/model.py | 235 +++++++++++++ lightllm/server/tokenizer.py | 3 + 15 files changed, 871 insertions(+), 5 deletions(-) create mode 100644 lightllm/models/gemma4/__init__.py create mode 100644 lightllm/models/gemma4/infer_struct.py create mode 100644 lightllm/models/gemma4/layer_infer/__init__.py create mode 100644 lightllm/models/gemma4/layer_infer/post_layer_infer.py create mode 100644 lightllm/models/gemma4/layer_infer/pre_layer_infer.py create mode 100644 lightllm/models/gemma4/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/gemma4/layer_weights/__init__.py create mode 100644 lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/gemma4/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/gemma4/model.py diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py index 339088e753..eab25f9757 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py @@ -185,7 +185,7 @@ def flash_decode_stage1( # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256, 512} sm_scale = 1.0 / (Lk ** 0.5) batch, kv_head_num = B_req_idx.shape[0], k.shape[1] block_num = mid_out.shape[2] diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py index 4eff53c3ac..3abc7dc93b 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py @@ -56,7 +56,7 @@ def _fwd_kernel_flash_decode_stage2( @torch.no_grad() def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): Lk = mid_out.shape[-1] - assert Lk in {16, 32, 64, 128} + assert Lk in {16, 32, 64, 128, 256, 512} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) block_num = mid_out.shape[2] diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index 5ba6d0beb6..dab01ddf18 100644 --- a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -127,7 +127,14 @@ def context_attention_fwd( # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 512} + # Larger head_dim needs smaller tiles to fit in SM shared memory. + # H100/H200 has ~228KB shared memory per SM; a 128x512 bf16 tile already + # consumes 128KB, leaving no room for K/V/scores buffers. + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) # 计算scale系数, 并乘以 1/log(2) = 1.4426950408889634, # 算子内部使用 tl.math.exp2 来使计算与标准attention等价。 @@ -291,7 +298,14 @@ def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, ma # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 512} + # Larger head_dim needs smaller tiles to fit in SM shared memory. + # H100/H200 has ~228KB shared memory per SM; a 128x512 bf16 tile already + # consumes 128KB, leaving no room for K/V/scores buffers. + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) # 计算scale系数, 并乘以 1/log(2) = 1.4426950408889634, # 算子内部使用 tl.math.exp2 来使计算与标准attention等价。 @@ -463,7 +477,14 @@ def context_attention_fwd_contiguous_kv( # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 512} + # Larger head_dim needs smaller tiles to fit in SM shared memory. + # H100/H200 has ~228KB shared memory per SM; a 128x512 bf16 tile already + # consumes 128KB, leaving no room for K/V/scores buffers. + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) # 计算scale系数, 并乘以 1/log(2) = 1.4426950408889634, # 算子内部使用 tl.math.exp2 来使计算与标准attention等价。 diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 2caee91709..f619b1d88f 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -33,6 +33,7 @@ from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel from lightllm.models.gemma3.model import Gemma3TpPartModel +from lightllm.models.gemma4.model import Gemma4TpPartModel from lightllm.models.tarsier2.model import ( Tarsier2Qwen2TpPartModel, Tarsier2Qwen2VLTpPartModel, diff --git a/lightllm/models/gemma4/__init__.py b/lightllm/models/gemma4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/gemma4/infer_struct.py b/lightllm/models/gemma4/infer_struct.py new file mode 100644 index 0000000000..686118346d --- /dev/null +++ b/lightllm/models/gemma4/infer_struct.py @@ -0,0 +1,33 @@ +import torch +from lightllm.common.basemodel import InferStateInfo + + +class Gemma4InferStateInfo(InferStateInfo): + def __init__(self): + super().__init__() + # Gemma-4 uses two RoPE frequency tables (one per layer type): + # * sliding_attention layers: theta=10000, full rotation over head_dim=256 + # * full_attention layers: theta=1_000_000, partial rotation (first 25% of head_dim=512) + self.position_cos_sliding = None + self.position_sin_sliding = None + self.position_cos_full = None + self.position_sin_full = None + + def init_some_extra_state(self, model): + super().init_some_extra_state(model) + position_ids = self.position_ids + self.position_cos_sliding = torch.index_select(model._cos_cached_sliding, 0, position_ids).view( + position_ids.shape[0], -1 + ) + self.position_sin_sliding = torch.index_select(model._sin_cached_sliding, 0, position_ids).view( + position_ids.shape[0], -1 + ) + self.position_cos_full = torch.index_select(model._cos_cached_full, 0, position_ids).view( + position_ids.shape[0], -1 + ) + self.position_sin_full = torch.index_select(model._sin_cached_full, 0, position_ids).view( + position_ids.shape[0], -1 + ) + if self.is_prefill: + self.max_seq_len = self.max_kv_seq_len + return diff --git a/lightllm/models/gemma4/layer_infer/__init__.py b/lightllm/models/gemma4/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/gemma4/layer_infer/post_layer_infer.py b/lightllm/models/gemma4/layer_infer/post_layer_infer.py new file mode 100644 index 0000000000..3b3423d645 --- /dev/null +++ b/lightllm/models/gemma4/layer_infer/post_layer_infer.py @@ -0,0 +1,22 @@ +import torch +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer + + +class Gemma4PostLayerInfer(LlamaPostLayerInfer): + """ + Same final RMSNorm + tied lm_head path as Llama, with an extra tanh-based + logit softcap at the end: logits = softcap * tanh(logits / softcap). + """ + + def __init__(self, network_config): + super().__init__(network_config) + self.eps_ = 1e-6 + self.final_logit_softcapping = network_config.get("final_logit_softcapping", None) + + def token_forward(self, input_embdings, infer_state, layer_weight): + logits = super().token_forward(input_embdings, infer_state, layer_weight) + if self.final_logit_softcapping is not None and self.final_logit_softcapping > 0: + cap = float(self.final_logit_softcapping) + # logits are fp32 already (LlamaPostLayerInfer allocates the output in fp32) + logits = torch.tanh(logits / cap) * cap + return logits diff --git a/lightllm/models/gemma4/layer_infer/pre_layer_infer.py b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py new file mode 100644 index 0000000000..4771e4b1e1 --- /dev/null +++ b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py @@ -0,0 +1,24 @@ +import torch +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer + + +class Gemma4PreLayerInfer(LlamaPreLayerInfer): + """ + Text-only pre-layer for Gemma-4 (Phase A). Applies the Gemma embedding + scale (sqrt(hidden_size)) to the token embeddings. Multimodal embed-scatter + handling will be added alongside the vision tower port. + """ + + def __init__(self, network_config): + super().__init__(network_config) + self.embed_scale = float(network_config["hidden_size"]) ** 0.5 + + def context_forward(self, input_ids, infer_state, layer_weight): + input_embdings = super().context_forward(input_ids, infer_state, layer_weight) + input_dtype = input_embdings.dtype + return (input_embdings.float() * self.embed_scale).to(input_dtype) + + def token_forward(self, input_ids, infer_state, layer_weight): + input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + input_dtype = input_embdings.dtype + return (input_embdings.float() * self.embed_scale).to(input_dtype) diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..679e3e5a5d --- /dev/null +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -0,0 +1,309 @@ +import math +import torch +import torch.nn as nn + +from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.models.gemma4.layer_weights.transformer_layer_weight import Gemma4TransformerLayerWeight +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd + + +class Gemma4TransformerLayerInfer(LlamaTransformerLayerInfer): + """ + Gemma-4 decoder block. Per-layer heterogeneity (sliding vs full attention) + is handled by switching shape / RoPE table / sliding-window flag at init + time. The KV cache layout is uniform (sliding shape: num_kv_heads=16, + head_dim=256); full-attention layers pack their (4, 512) tensor into the + first 8 heads of the 16-head slot at cache-write time, then reshape on + read. See Gemma4TpPartModel._init_mem_manager for context. + """ + + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + self.eps_ = 1e-6 + self.embed_dim_ = network_config["hidden_size"] + + layer_type = network_config["layer_types"][layer_num] + self.is_sliding = layer_type == "sliding_attention" + + if self.is_sliding: + self.layer_head_dim_ = network_config["head_dim"] + total_kv_heads = network_config["num_key_value_heads"] + self.k_eq_v = False + else: + self.layer_head_dim_ = network_config["global_head_dim"] + total_kv_heads = network_config["num_global_key_value_heads"] + self.k_eq_v = network_config.get("attention_k_eq_v", True) + + # TP shard counts for this layer + self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ + self.tp_k_head_num_ = max(total_kv_heads // self.tp_world_size_, 1) + self.tp_v_head_num_ = self.tp_k_head_num_ + self.tp_o_head_num_ = self.tp_q_head_num_ + + # Uniform mem-manager layout (sliding shape per rank) + self.mm_head_dim_ = network_config["head_dim"] + self.mm_kv_head_num_ = network_config["num_key_value_heads"] // self.tp_world_size_ + + # Sliding window (None on full-attn layers) + if self.is_sliding: + sw = network_config.get("sliding_window", 0) + self.sliding_window_ = int(sw) if sw else 0 + else: + self.sliding_window_ = 0 + + # Partial rotary factor for the RoPE kernel. The sliding table is sized + # (seq, head_dim/2) so full rotation over head_dim is the default. + # The full table is sized (seq, global_head_dim/2) with zero-padded + # frequencies (proportional RoPE) — we still pass partial_rotary_factor=1 + # to the kernel so it walks every pair, applying identity for the zeroed + # frequencies. + self.rotary_partial_factor_ = 1.0 + + def _bind_func(self): + # Skip LlamaTransformerLayerInfer._bind_norm (it rebinds to Llama _att_norm / _ffn_norm); + # we want our own gemma-style norm implementations below. + return + + # ----- norms --------------------------------------------------------- + + def _att_norm( + self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight + ) -> torch.Tensor: + return layer_weight.att_norm_weight_( + input=input, eps=self.eps_, alloc_func=self.alloc_tensor + ) + + def _ffn_norm( + self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight + ) -> torch.Tensor: + # NOTE: gemma packs post_attention_layernorm under `ffn_norm_weight_` + return layer_weight.ffn_norm_weight_( + input=input, eps=self.eps_, alloc_func=self.alloc_tensor + ) + + # ----- QKV + attention --------------------------------------------- + + def _get_qkv( + self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight + ) -> torch.Tensor: + input = self._tpsp_allgather(input=input, infer_state=infer_state) + + head_dim = self.layer_head_dim_ + q_heads = self.tp_q_head_num_ + kv_heads = self.tp_k_head_num_ + + q = layer_weight.q_proj.mm(input).view(-1, q_heads, head_dim) + k = layer_weight.k_proj.mm(input).view(-1, kv_heads, head_dim) + if self.k_eq_v: + # Full-attn layers share K weights for V. + v = k.clone() + else: + v = layer_weight.v_proj.mm(input).view(-1, kv_heads, head_dim) + + # QK RMSNorm (learnable weight, Gemma-style `(1+w)` applied in fp32). + # Reshape to 2D (N*heads, head_dim) so NoTpGEMMANormWeight accepts it. + q_flat = q.reshape(-1, head_dim).float() + k_flat = k.reshape(-1, head_dim).float() + q_flat = layer_weight.q_norm_weight_(input=q_flat, eps=self.eps_, alloc_func=self.alloc_tensor) + k_flat = layer_weight.k_norm_weight_(input=k_flat, eps=self.eps_, alloc_func=self.alloc_tensor) + q = q_flat.view(-1, q_heads, head_dim).to(input.dtype) + k = k_flat.view(-1, kv_heads, head_dim).to(input.dtype) + + # V-norm: unweighted RMSNorm over head_dim (matches vllm's Gemma4 has_weight=False). + v_fp = v.float() + v_fp = v_fp * torch.rsqrt(v_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps_) + v = v_fp.to(input.dtype) + + # Per-layer RoPE + if self.is_sliding: + cos = infer_state.position_cos_sliding.to(q.dtype) + sin = infer_state.position_sin_sliding.to(q.dtype) + else: + cos = infer_state.position_cos_full.to(q.dtype) + sin = infer_state.position_sin_full.to(q.dtype) + rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=self.rotary_partial_factor_) + + # Gemma-4 uses scaling=1.0 in attention. The attention kernel hardcodes + # sm_scale = 1/sqrt(head_dim); pre-scale Q by sqrt(head_dim) so the + # kernel's division cancels out, yielding scores = Q @ K^T. + q = q * math.sqrt(head_dim) + + # Pack into the uniform mem-manager layout. + mm_heads = self.mm_kv_head_num_ + mm_dim = self.mm_head_dim_ + if self.is_sliding: + # (N, 2*mm_heads, mm_dim) with [:mm_heads]=K, [mm_heads:]=V + cache_kv = torch.cat([k, v], dim=1) + else: + # K,V shape (N, kv_heads, layer_head_dim) e.g. (N, 2, 512) on tp=2. + # Reshape each half to (N, kv_heads*layer_head_dim // mm_dim, mm_dim) e.g. (N, 4, 256) on tp=2. + # The mem-manager layout has (N, 2*mm_heads, mm_dim) = (N, 16, 256) on tp=2 for this + # checkpoint — pad to that shape with zeros on unused head slots. + N = k.shape[0] + k_packed = k.reshape(N, -1, mm_dim) # (N, kv_heads * layer_head_dim // mm_dim, mm_dim) + v_packed = v.reshape(N, -1, mm_dim) + cache_kv = self.alloc_tensor((N, 2 * mm_heads, mm_dim), dtype=k.dtype) + cache_kv.zero_() + k_slots = k_packed.shape[1] + cache_kv[:, :k_slots, :] = k_packed + cache_kv[:, mm_heads : mm_heads + k_slots, :] = v_packed + + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) + + return q, cache_kv + + def _get_o( + self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight + ) -> torch.Tensor: + if infer_state.need_dp_prefill_balance: + input = infer_state._all_to_all_balance_get(data=input) + input = input.view(-1, self.tp_o_head_num_ * self.layer_head_dim_) + o_tensor = layer_weight.o_proj.mm(input) + o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) + return o_tensor + + # ----- Attention kernels (sliding window + per-layer KV reshape) --- + + def _att_control(self): + # SWA is only safe with FA3 (it consumes window_size per-call). Triton + # backend asserts use_sliding_window is False; lightllm's flashinfer + # wrapper plans once and ignores per-call windows. The flag is set + # by Gemma4TpPartModel._init_att_backend after backend selection. + if self.is_sliding and self.sliding_window_ > 0 and self.network_config_.get("_gemma4_use_swa", False): + w = self.sliding_window_ - 1 + return AttControl(use_sliding_window=True, sliding_window=(w, w)) + return AttControl(use_sliding_window=False, sliding_window=(-1, -1)) + + def _get_layer_kv(self, infer_state: InferStateInfo): + _k_raw, _v_raw = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + # _k_raw / _v_raw shape (S, mm_heads, mm_dim) + if self.is_sliding: + # sliding K is stored in the full (mm_heads, mm_dim) slot; head count matches. + return _k_raw, _v_raw + # full layer: the real K/V live in the first `kv_heads * layer_head_dim // mm_dim` + # head slots. Reshape to (S, kv_heads, layer_head_dim). + kv_heads = self.tp_k_head_num_ + head_dim = self.layer_head_dim_ + mm_dim = self.mm_head_dim_ + k_slots = kv_heads * head_dim // mm_dim + _k = _k_raw[:, :k_slots, :].reshape(-1, kv_heads, head_dim) + _v = _v_raw[:, :k_slots, :].reshape(-1, kv_heads, head_dim) + return _k, _v + + def _context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: InferStateInfo, + layer_weight: Gemma4TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + _k, _v = self._get_layer_kv(infer_state) + _q = q.view(-1, self.tp_q_head_num_, self.layer_head_dim_) + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor + ) + return o_tensor.view(q.shape) + + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: InferStateInfo, + layer_weight: Gemma4TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + _k, _v = self._get_layer_kv(infer_state) + _q = q.view(-1, self.tp_q_head_num_, self.layer_head_dim_) + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor + ) + return o_tensor.view(q.shape) + + # ----- FFN (Gemma gelu-tanh, separate gate/up/down) ---------------- + + def _ffn( + self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + gate = layer_weight.gate_proj.mm(input) + up = layer_weight.up_proj.mm(input) + ffn1 = nn.functional.gelu(gate, approximate="tanh") * up + gate = None + up = None + ffn2 = layer_weight.down_proj.mm(ffn1) + ffn1 = None + ffn2 = self._tpsp_reduce(input=ffn2, infer_state=infer_state) + return ffn2 + + # ----- block-level forwards (add layer_scalar at the end) ---------- + + def _apply_layer_scalar(self, hidden_states, layer_weight): + hidden_states.mul_(layer_weight.layer_scalar_.weight) + return hidden_states + + def context_forward( + self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight + ): + input_embdings = input_embdings.to(torch.bfloat16) + + # attn sub-block + input1 = self._att_norm( + input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight + ).to(torch.bfloat16) + q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) + input1 = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + # ffn sub-block + input1 = layer_weight.pre_feedforward_layernorm_weight_( + input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + ffn_out = layer_weight.post_feedforward_layernorm_weight_( + input=ffn_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + + return self._apply_layer_scalar(input_embdings, layer_weight) + + def token_forward( + self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight + ): + input_embdings = input_embdings.to(torch.bfloat16) + + input1 = self._att_norm( + input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight + ).to(torch.bfloat16) + q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) + input1 = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._token_attention_kernel(q, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input1 = layer_weight.pre_feedforward_layernorm_weight_( + input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + ffn_out = self._ffn(input1, infer_state, layer_weight) + input1 = None + ffn_out = layer_weight.post_feedforward_layernorm_weight_( + input=ffn_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + + return self._apply_layer_scalar(input_embdings, layer_weight) diff --git a/lightllm/models/gemma4/layer_weights/__init__.py b/lightllm/models/gemma4/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..a767e70c10 --- /dev/null +++ b/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,36 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + RMSNormWeight, +) + + +class Gemma4PreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] + + self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.language_model.embed_tokens.weight", + data_type=self.data_type_, + ) + # lm_head is tied to input embedding for Gemma-4 (no separate lm_head.weight). + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_, + ) + + # Gemma-4 uses standard RMSNorm (not the gemma2/3 (1+w) variant). + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name="model.language_model.norm.weight", + data_type=self.data_type_, + ) + return diff --git a/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..f9a59f7424 --- /dev/null +++ b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py @@ -0,0 +1,182 @@ +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight, COLMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights import RMSNormWeight, ParameterWeight +from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight + + +class Gemma4TransformerLayerWeight(LlamaTransformerLayerWeight): + def __init__( + self, + layer_num, + data_type, + network_config, + quant_cfg=None, + ): + self._pre_parse_layer_shape(layer_num, network_config) + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _pre_parse_layer_shape(self, layer_num, network_config): + layer_type = network_config["layer_types"][layer_num] + self._is_sliding = layer_type == "sliding_attention" + if self._is_sliding: + self._layer_head_dim = network_config["head_dim"] + self._layer_kv_head_num = network_config["num_key_value_heads"] + self._layer_k_eq_v = False + else: + self._layer_head_dim = network_config["global_head_dim"] + self._layer_kv_head_num = network_config["num_global_key_value_heads"] + self._layer_k_eq_v = network_config.get("attention_k_eq_v", True) + + def _parse_config(self): + self.n_head = self.network_config_["num_attention_heads"] + self.q_head_num_ = self.network_config_["num_attention_heads"] + self.k_head_num_ = self._layer_kv_head_num + self.v_head_num_ = self._layer_kv_head_num + self.o_head_num_ = self.q_head_num_ + self.head_dim = self._layer_head_dim + self.n_embed = self.network_config_["hidden_size"] + self.n_inter = self.network_config_["intermediate_size"] + + def _init_weight_names(self): + prefix = f"model.language_model.layers.{self.layer_num_}" + self._q_weight_name = f"{prefix}.self_attn.q_proj.weight" + self._q_bias_name = None + self._k_weight_name = f"{prefix}.self_attn.k_proj.weight" + self._k_bias_name = None + self._v_weight_name = f"{prefix}.self_attn.v_proj.weight" + self._v_bias_name = None + self._o_weight_name = f"{prefix}.self_attn.o_proj.weight" + self._o_bias_name = None + + self._q_norm_weight_name = f"{prefix}.self_attn.q_norm.weight" + self._k_norm_weight_name = f"{prefix}.self_attn.k_norm.weight" + + self._gate_weight_name = f"{prefix}.mlp.gate_proj.weight" + self._up_weight_name = f"{prefix}.mlp.up_proj.weight" + self._down_weight_name = f"{prefix}.mlp.down_proj.weight" + + self._att_norm_weight_name = f"{prefix}.input_layernorm.weight" + self._ffn_norm_weight_name = f"{prefix}.post_attention_layernorm.weight" + self._pre_feedforward_layernorm_name = f"{prefix}.pre_feedforward_layernorm.weight" + self._post_feedforward_layernorm_name = f"{prefix}.post_feedforward_layernorm.weight" + + self._layer_scalar_name = f"{prefix}.layer_scalar" + + def _init_weight(self): + self._init_qkv() + self._init_o() + self._init_ffn() + self._init_norm() + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + kv_out_dim = self.k_head_num_ * self.head_dim + + self.q_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=self._q_weight_name, + data_type=self.data_type_, + bias_names=self._q_bias_name, + quant_method=self.get_quant_method("q_proj"), + ) + self.k_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[kv_out_dim], + weight_names=self._k_weight_name, + data_type=self.data_type_, + bias_names=self._k_bias_name, + quant_method=self.get_quant_method("k_proj"), + ) + if not self._layer_k_eq_v: + self.v_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[kv_out_dim], + weight_names=self._v_weight_name, + data_type=self.data_type_, + bias_names=self._v_bias_name, + quant_method=self.get_quant_method("v_proj"), + ) + # For k_eq_v layers HF checkpoint has no v_proj weight; the inference + # code aliases v = k at compute time, so no weight object is created. + + def _init_o(self): + in_dim = self.o_head_num_ * self.head_dim + out_dim = self.n_embed + self.o_proj = COLMMWeight( + in_dim=in_dim, + out_dims=[out_dim], + weight_names=self._o_weight_name, + data_type=self.data_type_, + bias_names=self._o_bias_name, + quant_method=self.get_quant_method("o_proj"), + ) + + def _init_ffn(self): + self.gate_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], + weight_names=self._gate_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("gate_proj"), + ) + self.up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter], + weight_names=self._up_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("up_proj"), + ) + self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], + weight_names=self._down_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("down_proj"), + ) + + def _init_norm(self): + hidden_size = self.network_config_["hidden_size"] + # Gemma-4 uses *standard* RMSNorm (x * rsqrt(var+eps) * w), NOT the + # gemma2/3 (1+w) variant. Using NoTpGEMMANormWeight here produces + # nothing but high-frequency-token gibberish ("de la de..."). + self.q_norm_weight_ = RMSNormWeight( + dim=self._layer_head_dim, + weight_name=self._q_norm_weight_name, + data_type=self.data_type_, + ) + self.k_norm_weight_ = RMSNormWeight( + dim=self._layer_head_dim, + weight_name=self._k_norm_weight_name, + data_type=self.data_type_, + ) + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + self.pre_feedforward_layernorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._pre_feedforward_layernorm_name, + data_type=self.data_type_, + ) + self.post_feedforward_layernorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_feedforward_layernorm_name, + data_type=self.data_type_, + ) + # scalar multiplier applied to the attention output + self.layer_scalar_ = ParameterWeight( + weight_name=self._layer_scalar_name, + data_type=self.data_type_, + weight_shape=(1,), + ) diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py new file mode 100644 index 0000000000..7b80d51218 --- /dev/null +++ b/lightllm/models/gemma4/model.py @@ -0,0 +1,235 @@ +import os +import json +import torch +from transformers import AutoConfig +from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.common.basemodel.attention.triton.fp import TritonAttBackend +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.build_utils import repair_config +from lightllm.models.llama.model import LlamaTpPartModel +from lightllm.models.gemma4.infer_struct import Gemma4InferStateInfo +from lightllm.models.gemma4.layer_infer.pre_layer_infer import Gemma4PreLayerInfer +from lightllm.models.gemma4.layer_infer.post_layer_infer import Gemma4PostLayerInfer +from lightllm.models.gemma4.layer_infer.transformer_layer_infer import Gemma4TransformerLayerInfer +from lightllm.models.gemma4.layer_weights.pre_and_post_layer_weight import Gemma4PreAndPostLayerWeight +from lightllm.models.gemma4.layer_weights.transformer_layer_weight import Gemma4TransformerLayerWeight +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num, get_env_start_args +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Gemma4Tokenizer(BaseMultiModalTokenizer): + """ + Thin wrapper; Phase-A milestone only exercises the text path. Multimodal + splice logic will be added alongside the vision tower port (Phase B). + """ + + def __init__(self, tokenizer, model_cfg): + super().__init__(tokenizer) + self.image_token_index = model_cfg.get("image_token_id", 258880) + self.boi_token_index = model_cfg.get("boi_token_id", 255999) + self.eoi_token_index = model_cfg.get("eoi_token_id", 258882) + self.image_length = model_cfg.get("vision_soft_tokens_per_image", 280) + # Gemma-4's tokenizer ships with `add_bos_token=False`, and even + # `add_special_tokens=True` doesn't prepend ``. The model generates + # garbage without it, so we always prepend manually. + self.bos_token_id = tokenizer.bos_token_id + + def init_imageitem_extral_params(self, img, multi_params, sampling_params): + return + + def init_audioitem_extral_params(self, audio, multi_params, sampling_params): + raise NotImplementedError + + def get_image_token_length(self, img): + return self.image_length + + def get_audio_token_length(self, audio): + raise NotImplementedError + + def encode(self, prompt, multimodal_params=None, add_special_tokens=False): + # Text-only path for Phase A — reject image/audio input loudly so users + # know multimodal isn't wired yet. + if multimodal_params is not None and ( + getattr(multimodal_params, "images", None) or getattr(multimodal_params, "audios", None) + ): + raise NotImplementedError( + "Gemma-4 multimodal (image/audio) inference is not yet implemented in LightLLM; " + "only text prompts are supported for now." + ) + input_ids = self.tokenizer(prompt).input_ids + # Auto-prepend for prompts (Gemma-4 generates garbage without it), + # but honour `add_special_tokens=False` so callers like stop-sequence + # encoding can opt out — otherwise stop strings get a leading BOS that + # never appears in generated output and never matches. + if ( + add_special_tokens + and self.bos_token_id is not None + and (len(input_ids) == 0 or input_ids[0] != self.bos_token_id) + ): + input_ids = [self.bos_token_id] + input_ids + return input_ids + + +@ModelRegistry("gemma4", is_multimodal=True) +class Gemma4TpPartModel(LlamaTpPartModel): + pre_and_post_weight_class = Gemma4PreAndPostLayerWeight + transformer_weight_class = Gemma4TransformerLayerWeight + + pre_layer_infer_class = Gemma4PreLayerInfer + transformer_layer_infer_class = Gemma4TransformerLayerInfer + post_layer_infer_class = Gemma4PostLayerInfer + + infer_state_class = Gemma4InferStateInfo + + def __init__(self, kvargs): + # head_dim_ is used by the default _init_to_get_rotary which we + # override; still set it to the sliding-layer head_dim for consistency + # with the mem manager and any generic helpers. + self.head_dim_ = 256 + super().__init__(kvargs) + return + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + self.config = json.load(json_file) + # The shipped checkpoint is a multimodal config wrapping a Gemma4TextConfig + # under text_config; flatten it so downstream code sees text-model fields + # at the top level (mirrors the gemma3 approach). + if "text_config" in self.config: + hf_config = AutoConfig.from_pretrained(self.weight_dir_, trust_remote_code=True) + self.config = hf_config.text_config.to_dict() + + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + return + + def _verify_params(self): + assert self.load_way == "HF", "Gemma-4 only supports HF format." + assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + assert self.config["num_key_value_heads"] % self.tp_world_size_ == 0 + num_global_kv = self.config.get("num_global_key_value_heads", self.config["num_key_value_heads"]) + assert num_global_kv % self.tp_world_size_ == 0, ( + f"num_global_key_value_heads={num_global_kv} must be divisible by tp={self.tp_world_size_}" + ) + return + + def _init_mem_manager(self): + # Uniform per-layer KV cache layout keyed to the *sliding* attention shape + # (num_kv_heads=16, head_dim=256). Full-attention layers (num_kv_heads=4, + # head_dim=512, k_eq_v) reuse the same byte budget at <=50% utilization; + # the transformer-layer infer code handles the reshape when reading back. + head_num_per_rank = self.config["num_key_value_heads"] // self.tp_world_size_ + head_dim = self.config["head_dim"] + self.mem_manager = select_mem_manager_class()( + self.max_total_token_num, + dtype=self.data_type, + head_num=head_num_per_rank, + head_dim=head_dim, + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), + mem_fraction=self.mem_fraction, + ) + return + + def _init_att_backend(self): + # Gemma-4 has per-layer heterogeneous attention shape (sliding layers + # use head_dim=256/16 KV heads, full-attn layers use head_dim=512/4). + # The flashinfer backend in this repo plans once per infer_state with + # a single (head_dim, num_kv_heads), so it crashes / silently produces + # wrong results on the layer where the shape doesn't match. FA3 reads + # head_dim and num_kv_heads from the per-call tensor shapes, so it + # supports the heterogeneous layout AND honours per-call sliding window + # — which is what we want on sliding layers. + from lightllm.common.basemodel.attention.fa3.fp import Fa3AttBackend + from lightllm.utils.sgl_utils import flash_attn_with_kvcache + + fa3_loadable = flash_attn_with_kvcache is not None + args = get_env_start_args() + backends = set(args.llm_prefill_att_backend + args.llm_decode_att_backend) + for backend_name in backends: + assert backend_name in ("auto", "triton", "fa3"), ( + "Gemma-4 requires triton or fa3 (per-layer dynamic head_dim / " + "num_kv_heads); flashinfer is not wired for the heterogeneous " + f"layout. Got --llm_*_att_backend={backend_name!r}." + ) + if "fa3" in backends: + assert fa3_loadable, ( + "Requested --llm_*_att_backend=fa3 but neither sgl_kernel nor " + "flash_attn_3 (flash_attn_interface) imported successfully. " + "Build flash-attention/hopper from source against the current torch." + ) + # Default policy: prefer FA3 if available (gets us real sliding-window + # attention on sliding layers); fall back to triton otherwise. + prefer_fa3 = fa3_loadable and (backends <= {"auto", "fa3"}) + if prefer_fa3: + self.prefill_att_backend = Fa3AttBackend(model=self) + self.decode_att_backend = Fa3AttBackend(model=self) + self.config["_gemma4_use_swa"] = True + else: + self.prefill_att_backend = TritonAttBackend(model=self) + self.decode_att_backend = TritonAttBackend(model=self) + self.config["_gemma4_use_swa"] = False + + def _init_custom(self): + self._init_to_get_rotary_gemma4() + + def _init_to_get_rotary_gemma4(self): + rope_params = self.config["rope_parameters"] + + # Cap the rotary table at something we can fit in memory — Gemma-4's + # advertised max_position_embeddings is 262144 which would require + # ~200MB per table in fp32. Rely on the server's max_seq_length instead. + max_seq_len = max(self.max_seq_length + 1024, 16384) + + t = torch.arange(max_seq_len, dtype=torch.float32, device="cpu") + + # Sliding layers: default RoPE, theta=10000, full rotation over head_dim=256. + sliding_params = rope_params["sliding_attention"] + sliding_head_dim = self.config["head_dim"] + sliding_theta = sliding_params["rope_theta"] + sliding_partial = sliding_params.get("partial_rotary_factor", 1.0) + sliding_rot_dim = int(sliding_head_dim * sliding_partial) + inv_freq_sliding = 1.0 / ( + sliding_theta ** (torch.arange(0, sliding_rot_dim, 2, dtype=torch.float32) / sliding_rot_dim) + ) + freqs_s = torch.outer(t, inv_freq_sliding) + self._cos_cached_sliding = torch.cos(freqs_s).to(self.data_type).cuda() + self._sin_cached_sliding = torch.sin(freqs_s).to(self.data_type).cuda() + + # Full-attention layers: proportional RoPE, theta=1_000_000, + # partial_rotary_factor=0.25 over global_head_dim=512. + # Proportional semantics (HF transformers): + # rope_angles = int(partial * head_dim // 2) -> 64 + # inv_freq[0:rope_angles] = 1 / base ** (arange(0, 2*rope_angles, 2) / head_dim) + # inv_freq[rope_angles:head_dim//2] = 0 (identity rotation for "no-pe" dims) + full_params = rope_params["full_attention"] + full_head_dim = self.config["global_head_dim"] + full_theta = full_params["rope_theta"] + full_partial = full_params.get("partial_rotary_factor", 1.0) + rope_type = full_params.get("rope_type", "default") + if rope_type == "proportional": + rope_angles = int(full_partial * full_head_dim // 2) + inv_freq_rot = 1.0 / ( + full_theta + ** (torch.arange(0, 2 * rope_angles, 2, dtype=torch.float32) / full_head_dim) + ) + nope_angles = full_head_dim // 2 - rope_angles + if nope_angles > 0: + inv_freq_full = torch.cat( + [inv_freq_rot, torch.zeros(nope_angles, dtype=torch.float32)] + ) + else: + inv_freq_full = inv_freq_rot + else: + full_rot_dim = int(full_head_dim * full_partial) + inv_freq_full = 1.0 / ( + full_theta ** (torch.arange(0, full_rot_dim, 2, dtype=torch.float32) / full_rot_dim) + ) + + freqs_f = torch.outer(t, inv_freq_full) + self._cos_cached_full = torch.cos(freqs_f).to(self.data_type).cuda() + self._sin_cached_full = torch.sin(freqs_f).to(self.data_type).cuda() + return diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 25726b2578..9c9f19e73f 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -31,6 +31,7 @@ from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer from ..models.gemma3.model import Gemma3Tokenizer +from ..models.gemma4.model import Gemma4Tokenizer from ..models.qwen3_omni_moe_thinker.model import QWen3OmniTokenizer # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. @@ -130,5 +131,7 @@ def get_tokenizer( tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) + elif model_type == "gemma4": + tokenizer = Gemma4Tokenizer(tokenizer, model_cfg) return tokenizer From 99b790c223742fee03ecdba7f64f9a6d46f0dea5 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Wed, 6 May 2026 05:06:25 +0000 Subject: [PATCH 02/20] fix --- .../common/basemodel/attention/triton/fp.py | 39 +++++++++-- .../gqa/flash_decoding/gqa_flash_decoding.py | 9 ++- .../gqa_flash_decoding_stage1.py | 34 +++++++--- .../context_flashattention_nopad.py | 32 +++++++++- .../layer_infer/transformer_layer_infer.py | 17 +++-- lightllm/models/gemma4/model.py | 64 ++++++++++++------- 6 files changed, 149 insertions(+), 46 deletions(-) diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index d29f15ec3b..a8f3c4414b 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -25,12 +25,13 @@ def prefill_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - assert att_control.use_sliding_window is False and att_control.use_att_sink is False + assert att_control.use_att_sink is False if att_control.use_alibi: + assert att_control.use_sliding_window is False, "alibi + sliding_window not supported" assert att_control.tp_alibi is not None return self._alibi_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) else: - return self._nomarl_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func) + return self._nomarl_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) def _alibi_prefill_att( self, @@ -59,9 +60,24 @@ def _alibi_prefill_att( ) return out - def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): + def _nomarl_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd + # Convert AttControl's (left, right) tuple to a single window length: + # we treat the window as `left + 1` (each query attends to itself plus + # the previous `left` tokens — same convention FA3 uses with causal=True). + if att_control.use_sliding_window: + sliding_window = int(att_control.sliding_window[0]) + 1 + else: + sliding_window = 0 + out = alloc_func(q.shape, q.dtype) context_attention_fwd( q, @@ -74,6 +90,7 @@ def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, self.infer_state.b_ready_cache_len, self.infer_state.max_q_seq_len, self.infer_state.req_manager.req_to_token_indexs, + sliding_window=sliding_window, ) return out @@ -94,17 +111,22 @@ def decode_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ): - assert att_control.use_sliding_window is False and att_control.use_att_sink is False + assert att_control.use_att_sink is False if att_control.use_alibi: + assert att_control.use_sliding_window is False, "alibi + sliding_window not supported" assert att_control.tp_alibi is not None return self._alibi_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) else: q_head_num = q.shape[1] k_head_num = k.shape[1] if q_head_num == k_head_num: + # MHA decode path: SWA not yet wired here (only GQA path used by Gemma-4). + assert att_control.use_sliding_window is False, "SWA in MHA triton decode not implemented" return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) elif q_head_num > k_head_num: - return self._normal_decode_gqa_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) + return self._normal_decode_gqa_flash_decoding_att( + q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func + ) else: raise NotImplementedError("error") @@ -163,12 +185,18 @@ def _normal_decode_gqa_flash_decoding_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + att_control: AttControl = AttControl(), alloc_func=torch.empty, ): from ...triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding import ( gqa_token_decode_attention_flash_decoding, ) + if att_control.use_sliding_window: + sliding_window = int(att_control.sliding_window[0]) + 1 + else: + sliding_window = 0 + out = alloc_func(q.shape, q.dtype) gqa_token_decode_attention_flash_decoding( @@ -178,6 +206,7 @@ def _normal_decode_gqa_flash_decoding_att( cache_v=v, out=out, alloc_tensor_func=alloc_func, + sliding_window=sliding_window, ) return out diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py index e549298e3b..979e54272c 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py @@ -2,7 +2,13 @@ def gqa_token_decode_attention_flash_decoding( - q: torch.Tensor, infer_state, cache_k: torch.Tensor, cache_v: torch.Tensor, out=None, alloc_tensor_func=torch.empty + q: torch.Tensor, + infer_state, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + out=None, + alloc_tensor_func=torch.empty, + sliding_window: int = 0, ): batch_size = infer_state.batch_size q_head_num, head_dim = q.shape[1], q.shape[2] @@ -39,6 +45,7 @@ def gqa_token_decode_attention_flash_decoding( mid_out=mid_o, mid_out_logsumexp=mid_o_logexpsum, block_seq=BLOCK_SEQ, + sliding_window=sliding_window, ) flash_decode_stage2( mid_out=mid_o, diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py index eab25f9757..c5ceb9100b 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py @@ -39,6 +39,7 @@ def _fwd_kernel_flash_decode_stage1( BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) @@ -50,6 +51,13 @@ def _fwd_kernel_flash_decode_stage1( if block_index >= req_total_block_num: return + # Decode: q is at position cur_batch_seq_len - 1; SWA keeps K at positions + # >= cur_batch_seq_len - SLIDING_WINDOW. win_threshold below. + if SLIDING_WINDOW > 0: + win_threshold = cur_batch_seq_len - SLIDING_WINDOW + else: + win_threshold = 0 + cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs @@ -76,6 +84,10 @@ def _fwd_kernel_flash_decode_stage1( for start_n in range(0, block_n_size, 1): offs_n_new = start_n * BLOCK_N + offs_n n_mask = offs_n_new < cur_batch_end_index + if SLIDING_WINDOW > 0: + # Drop K positions that fall outside the sliding window from + # the current query (last token in the sequence). + n_mask = n_mask & (offs_n_new >= win_threshold) k_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, mask=n_mask, @@ -110,14 +122,18 @@ def _fwd_kernel_flash_decode_stage1( + offs_d[None, :] ) off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + block_index - tl.store( - Mid_O + off_mid_o, - acc / sum_exp[:, None], - ) - tl.store( - Mid_O_LogExpSum + off_mid_o_logexpsum, - max_logic + tl.log(sum_exp), - ) + if SLIDING_WINDOW > 0: + # When SWA masks out every K this program saw, sum_exp stays 0 and + # acc/sum_exp would be NaN. Store zeros + log_exp_sum=-inf so stage2 + # naturally weights this slot to 0 in the final reduction. + safe_sum = tl.where(sum_exp > 0, sum_exp, 1.0) + out_acc = tl.where(sum_exp[:, None] > 0, acc / safe_sum[:, None], 0.0) + out_log = tl.where(sum_exp > 0, max_logic + tl.log(safe_sum), -float("inf")) + else: + out_acc = acc / sum_exp[:, None] + out_log = max_logic + tl.log(sum_exp) + tl.store(Mid_O + off_mid_o, out_acc) + tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, out_log) return @@ -170,6 +186,7 @@ def flash_decode_stage1( mid_out, mid_out_logsumexp, block_seq, + sliding_window: int = 0, run_config: Optional[dict] = None, ): """ """ @@ -225,6 +242,7 @@ def flash_decode_stage1( BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK_N, + SLIDING_WINDOW=int(sliding_window), num_warps=num_warps, num_stages=num_stages, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index dab01ddf18..d8396c3bd6 100644 --- a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -41,6 +41,7 @@ def _fwd_kernel( BLOCK_DMODEL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, ): start_m = tl.program_id(0) cur_bh = tl.program_id(1) @@ -76,8 +77,18 @@ def _fwd_kernel( block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = tl.minimum(block_start_loc + BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) - # causal mask - for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + # When SLIDING_WINDOW > 0, the earliest k_pos relevant to any q in this + # block is `(block_start_loc + prompt_cache_len) - (SLIDING_WINDOW - 1)`. + # Round down to BLOCK_N to avoid loading blocks fully outside the window. + if SLIDING_WINDOW > 0: + win_start = block_start_loc + prompt_cache_len - (SLIDING_WINDOW - 1) + win_start = tl.maximum(win_start, 0) + win_start = (win_start // BLOCK_N) * BLOCK_N + else: + win_start = 0 + + # causal (+ sliding-window) mask + for start_n in range(win_start, block_mask * block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- kv_loc = tl.load( @@ -89,7 +100,11 @@ def _fwd_kernel( k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) qk = tl.dot(q, k) + # causal: q_pos >= k_pos. q_pos = offs_m + prompt_cache_len, k_pos = start_n + offs_n. mask = offs_m[:, None] + prompt_cache_len >= (start_n + offs_n[None, :]) + if SLIDING_WINDOW > 0: + # SWA: q_pos - k_pos < SLIDING_WINDOW + mask = mask & ((offs_m[:, None] + prompt_cache_len) - (start_n + offs_n[None, :]) < SLIDING_WINDOW) qk = tl.where(mask, qk * sm_scale, -1.0e8) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -121,7 +136,17 @@ def _fwd_kernel( @torch.no_grad() def context_attention_fwd( - q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, + sliding_window: int = 0, ): BLOCK_M = 128 if not is_tesla() else 64 # shape constraints @@ -178,6 +203,7 @@ def context_attention_fwd( BLOCK_DMODEL=Lk, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + SLIDING_WINDOW=int(sliding_window), num_warps=num_warps, num_stages=num_stages, ) diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index 679e3e5a5d..22f6638101 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -169,10 +169,10 @@ def _get_o( # ----- Attention kernels (sliding window + per-layer KV reshape) --- def _att_control(self): - # SWA is only safe with FA3 (it consumes window_size per-call). Triton - # backend asserts use_sliding_window is False; lightllm's flashinfer - # wrapper plans once and ignores per-call windows. The flag is set - # by Gemma4TpPartModel._init_att_backend after backend selection. + # FA3 consumes window_size per-call; the triton prefill/decode kernels + # mask out-of-window positions when SLIDING_WINDOW > 0 (see + # context_flashattention_nopad.py / gqa_flash_decoding_stage1.py). + # `_gemma4_use_swa` is set by Gemma4TpPartModel._init_att_backend. if self.is_sliding and self.sliding_window_ > 0 and self.network_config_.get("_gemma4_use_swa", False): w = self.sliding_window_ - 1 return AttControl(use_sliding_window=True, sliding_window=(w, w)) @@ -204,7 +204,11 @@ def _context_attention_kernel( ) -> torch.Tensor: _k, _v = self._get_layer_kv(infer_state) _q = q.view(-1, self.tp_q_head_num_, self.layer_head_dim_) - o_tensor = infer_state.prefill_att_state.prefill_att( + # Sliding layers go through the secondary backend (FA3 with SWA when + # available, else triton-with-SWA from path B). Full-attn layers go + # through the primary triton backend (head_dim=512). + att_state = infer_state.prefill_att_state1 if self.is_sliding else infer_state.prefill_att_state + o_tensor = att_state.prefill_att( q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor ) return o_tensor.view(q.shape) @@ -218,7 +222,8 @@ def _token_attention_kernel( ) -> torch.Tensor: _k, _v = self._get_layer_kv(infer_state) _q = q.view(-1, self.tp_q_head_num_, self.layer_head_dim_) - o_tensor = infer_state.decode_att_state.decode_att( + att_state = infer_state.decode_att_state1 if self.is_sliding else infer_state.decode_att_state + o_tensor = att_state.decode_att( q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor ) return o_tensor.view(q.shape) diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py index 7b80d51218..8e4b1d1325 100644 --- a/lightllm/models/gemma4/model.py +++ b/lightllm/models/gemma4/model.py @@ -135,18 +135,19 @@ def _init_mem_manager(self): return def _init_att_backend(self): - # Gemma-4 has per-layer heterogeneous attention shape (sliding layers - # use head_dim=256/16 KV heads, full-attn layers use head_dim=512/4). - # The flashinfer backend in this repo plans once per infer_state with - # a single (head_dim, num_kv_heads), so it crashes / silently produces - # wrong results on the layer where the shape doesn't match. FA3 reads - # head_dim and num_kv_heads from the per-call tensor shapes, so it - # supports the heterogeneous layout AND honours per-call sliding window - # — which is what we want on sliding layers. - from lightllm.common.basemodel.attention.fa3.fp import Fa3AttBackend + # Gemma-4 has per-layer heterogeneous attention: sliding layers use + # (head_dim=256, kv_heads=16); full-attn layers use (head_dim=512, + # kv_heads=4, k_eq_v). No single backend covers both: + # - FA3 caps head_dim at 256 -> can't run full-attn layers. + # - Triton handles head_dim=512 (kernels widened to Lk=512) but + # historically refused sliding_window. + # - Flashinfer plans once per infer_state on a single shape -> can't + # accommodate heterogeneous layout at all. + # Strategy: run full-attn layers on triton (primary backend, this + # method) and sliding layers on a separate backend wired via + # _init_att_backend1 (FA3 when available; triton-with-SWA otherwise). from lightllm.utils.sgl_utils import flash_attn_with_kvcache - fa3_loadable = flash_attn_with_kvcache is not None args = get_env_start_args() backends = set(args.llm_prefill_att_backend + args.llm_decode_att_backend) for backend_name in backends: @@ -155,23 +156,40 @@ def _init_att_backend(self): "num_kv_heads); flashinfer is not wired for the heterogeneous " f"layout. Got --llm_*_att_backend={backend_name!r}." ) + fa3_loadable = flash_attn_with_kvcache is not None if "fa3" in backends: assert fa3_loadable, ( - "Requested --llm_*_att_backend=fa3 but neither sgl_kernel nor " - "flash_attn_3 (flash_attn_interface) imported successfully. " - "Build flash-attention/hopper from source against the current torch." + "Requested --llm_*_att_backend=fa3 but flash_attn_with_kvcache " + "did not import (sgl_kernel missing or wrong arch)." ) - # Default policy: prefer FA3 if available (gets us real sliding-window - # attention on sliding layers); fall back to triton otherwise. - prefer_fa3 = fa3_loadable and (backends <= {"auto", "fa3"}) - if prefer_fa3: - self.prefill_att_backend = Fa3AttBackend(model=self) - self.decode_att_backend = Fa3AttBackend(model=self) - self.config["_gemma4_use_swa"] = True + + # Full-attn layers always go through triton. + self.prefill_att_backend = TritonAttBackend(model=self) + self.decode_att_backend = TritonAttBackend(model=self) + + # Decide sliding-layer backend kind here so _init_att_backend1 can + # honour it. User can force triton with --llm_*_att_backend triton; + # otherwise prefer FA3 when loadable. + user_forced_triton = backends == {"triton"} + self._gemma4_sliding_backend_kind = ( + "fa3" if (fa3_loadable and not user_forced_triton) else "triton" + ) + # SWA is on regardless of which sliding backend was picked: FA3 + # honours window_size per call, and the triton kernels in + # context_flashattention_nopad.py / gqa_flash_decoding_stage1.py mask + # out-of-window positions when SLIDING_WINDOW > 0. + self.config["_gemma4_use_swa"] = True + + def _init_att_backend1(self): + # Sliding layers run on a dedicated backend so the head-dim/SWA + # mismatch with full-attn layers doesn't force a single compromise. + if self._gemma4_sliding_backend_kind == "fa3": + from lightllm.common.basemodel.attention.fa3.fp import Fa3AttBackend + self.prefill_att_backend1 = Fa3AttBackend(model=self) + self.decode_att_backend1 = Fa3AttBackend(model=self) else: - self.prefill_att_backend = TritonAttBackend(model=self) - self.decode_att_backend = TritonAttBackend(model=self) - self.config["_gemma4_use_swa"] = False + self.prefill_att_backend1 = TritonAttBackend(model=self) + self.decode_att_backend1 = TritonAttBackend(model=self) def _init_custom(self): self._init_to_get_rotary_gemma4() From 15a53793156529a80aee9d08ff285010f2cd670f Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 7 May 2026 08:50:54 +0000 Subject: [PATCH 03/20] support moe --- .../fused_moe/fused_moe_weight.py | 27 +++- .../gemma4_packed_fused_moe_weight.py | 34 +++++ .../meta_weights/fused_moe/impl/base_impl.py | 2 + .../fused_moe/impl/deepgemm_impl.py | 22 ++- .../fused_moe/impl/marlin_impl.py | 2 + .../fused_moe/impl/triton_impl.py | 9 ++ .../fused_moe/grouped_fused_moe.py | 13 +- .../fused_moe/grouped_fused_moe_ep.py | 12 +- .../fused_moe/moe_silu_and_mul.py | 19 ++- .../moe_silu_and_mul_mix_quant_ep.py | 17 ++- .../layer_infer/transformer_layer_infer.py | 135 +++++++++++------- .../layer_weights/transformer_layer_weight.py | 64 +++++++++ lightllm/models/gemma4/model.py | 32 +++-- 13 files changed, 303 insertions(+), 85 deletions(-) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gemma4_packed_fused_moe_weight.py diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 8f54e14a72..dd99616b6b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -33,12 +33,14 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + per_expert_scale_name: str = "", ) -> None: super().__init__(data_type=data_type) self.w1_weight_name = gate_proj_name self.w2_weight_name = down_proj_name self.w3_weight_name = up_proj_name self.e_score_correction_bias_name = e_score_correction_bias_name + self.per_expert_scale_name = per_expert_scale_name self.weight_prefix = weight_prefix self.layer_num_ = layer_num self.global_rank_ = get_global_rank() @@ -130,6 +132,8 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + per_expert_scale: Optional[torch.Tensor] = None, + use_gelu: bool = False, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -145,6 +149,8 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + per_expert_scale=per_expert_scale, + use_gelu=use_gelu, ) def low_latency_dispatch( @@ -263,16 +269,22 @@ def load_hf_weights(self, weights): # Load bias if self.e_score_correction_bias_name in weights: self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) + self._load_per_expert_scale(weights) self._load_weight(self.expert_idx_to_local_idx, weights) if self.redundancy_expert_num > 0: self._load_weight(self.redundancy_expert_idx_to_local_idx, weights) def verify_load(self): - return all(all(_weight_pack.load_ok) for _weight_pack in self.w1_list + self.w2_list + self.w3_list) + weight_load_ok = all(all(_weight_pack.load_ok) for _weight_pack in self.w1_list + self.w2_list + self.w3_list) + per_expert_scale_load_ok = ( + True if self.per_expert_scale is None else getattr(self.per_expert_scale, "load_ok", False) + ) + return weight_load_ok and per_expert_scale_load_ok def _create_weight(self): intermediate_size = self.split_inter_size self.e_score_correction_bias = None + self.per_expert_scale = None # Create e_score_correction_bias if self.e_score_correction_bias_name: self.e_score_correction_bias = torch.empty( @@ -280,6 +292,13 @@ def _create_weight(self): dtype=self.data_type_, device=f"cuda:{self.device_id_}", ) + if self.per_expert_scale_name: + self.per_expert_scale = torch.empty( + (self.n_routed_experts,), + dtype=torch.float32, + device=f"cuda:{self.device_id_}", + ) + self.per_expert_scale.load_ok = False self.w13, w13_param_list = self.quant_method.create_moe_weight( out_dims=[intermediate_size, intermediate_size], @@ -299,6 +318,11 @@ def _create_weight(self): self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1]) self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2) + def _load_per_expert_scale(self, weights: Dict[str, torch.Tensor]): + if self.per_expert_scale_name and self.per_expert_scale_name in weights: + self.per_expert_scale.copy_(weights[self.per_expert_scale_name].to(self.per_expert_scale.dtype)) + self.per_expert_scale.load_ok = True + def _get_expert_weight_list(self, weight_pack: WeightPack): weight_list = [] for idx in range(self.local_n_routed_experts): @@ -307,7 +331,6 @@ def _get_expert_weight_list(self, weight_pack: WeightPack): return weight_list def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]): - # Load each expert with TP slicing for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): with self.lock: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gemma4_packed_fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gemma4_packed_fused_moe_weight.py new file mode 100644 index 0000000000..1df39993ec --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gemma4_packed_fused_moe_weight.py @@ -0,0 +1,34 @@ +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight import FusedMoeWeight + + +class Gemma4PackedFusedMoeWeight(FusedMoeWeight): + def load_hf_weights(self, weights): + gate_up_name = f"{self.weight_prefix}.gate_up_proj" + down_name = f"{self.weight_prefix}.down_proj" + if gate_up_name not in weights and down_name not in weights and self.per_expert_scale_name not in weights: + return super().load_hf_weights(weights) + + assert self.quant_method.method_name == "none", "Gemma-4 packed MoE currently supports bf16/no-quant weights." + assert not self.enable_ep_moe, "Gemma-4 packed MoE currently supports TP mode only." + + start = self.split_inter_size * self.tp_rank_ + end = self.split_inter_size * (self.tp_rank_ + 1) + moe_intermediate_size = self.moe_intermediate_size + + if gate_up_name in weights: + gate_up_weight = weights[gate_up_name] + for expert_idx, local_expert_idx in self.expert_idx_to_local_idx.items(): + gate_weight = gate_up_weight[expert_idx, start:end, :].contiguous() + up_weight = gate_up_weight[ + expert_idx, moe_intermediate_size + start : moe_intermediate_size + end, : + ].contiguous() + self.quant_method.load_weight(gate_weight, self.w1_list[local_expert_idx]) + self.quant_method.load_weight(up_weight, self.w3_list[local_expert_idx]) + + if down_name in weights: + down_weight = weights[down_name] + for expert_idx, local_expert_idx in self.expert_idx_to_local_idx.items(): + down_weight_slice = down_weight[expert_idx, :, start:end].contiguous() + self.quant_method.load_weight(down_weight_slice, self.w2_list[local_expert_idx]) + + self._load_per_expert_scale(weights) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index 00587ac185..3e6ab8accf 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -62,5 +62,7 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + per_expert_scale: Optional[torch.Tensor] = None, + use_gelu: bool = False, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index bdd86eb51e..bf0c350138 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -31,6 +31,7 @@ def _select_experts( topk_group: int, num_expert_group: int, scoring_func: str, + per_expert_scale: Optional[torch.Tensor] = None, ): """Select experts and return topk weights and ids.""" from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts @@ -48,6 +49,8 @@ def _select_experts( ) if self.routed_scaling_factor != 1.0: topk_weights.mul_(self.routed_scaling_factor) + if per_expert_scale is not None: + topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype) if self.redundancy_expert_num > 0: redundancy_topk_ids_repair( topk_ids=topk_ids, @@ -68,8 +71,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + use_gelu: bool = False, ): - w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale use_fp8_w8a8 = self.quant_method.method_name != "none" @@ -88,6 +91,7 @@ def _fused_experts( w1_scale=w13_scale, w2_scale=w2_scale, previous_event=None, # for overlap + use_gelu=use_gelu, ) return output @@ -210,11 +214,20 @@ def masked_group_gemm( masked_m: torch.Tensor, dtype: torch.dtype, expected_m: int, + use_gelu: bool = False, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale return masked_group_gemm( - recv_x, masked_m, dtype, w13_weight, w13_scale, w2_weight, w2_scale, expected_m=expected_m + recv_x, + masked_m, + dtype, + w13_weight, + w13_scale, + w2_weight, + w2_scale, + expected_m=expected_m, + use_gelu=use_gelu, ) def prefilled_group_gemm( @@ -226,6 +239,7 @@ def prefilled_group_gemm( w13: WeightPack, w2: WeightPack, hidden_dtype=torch.bfloat16, + use_gelu: bool = False, ): device = recv_x[0].device w13_weight, w13_scale = w13.weight, w13.weight_scale @@ -278,7 +292,7 @@ def prefilled_group_gemm( # TODO fused kernel silu_out = torch.empty((all_tokens, N // 2), device=device, dtype=hidden_dtype) - silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out) + silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out, use_gelu=use_gelu) qsilu_out, qsilu_out_scale = per_token_group_quant_fp8( silu_out, block_size, dtype=w13_weight.dtype, column_major_scales=True, scale_tma_aligned=True ) @@ -298,7 +312,7 @@ def prefilled_group_gemm( if Autotuner.is_autotune_warmup(): _gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype) _silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype) - silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out) + silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out, use_gelu=use_gelu) _gemm_out_a, _silu_out = None, None return gather_out diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 6391a10800..67087d2151 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -29,7 +29,9 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + use_gelu: bool = False, ): + assert not use_gelu, "FuseMoeMarlin does not support GELU expert activation." w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index d6e923a115..c634ed59ad 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -42,6 +42,7 @@ def _select_experts( topk_group: int, num_expert_group: int, scoring_func: str, + per_expert_scale: Optional[torch.Tensor] = None, ): """Select experts and return topk weights and ids.""" from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts @@ -59,6 +60,8 @@ def _select_experts( ) if self.routed_scaling_factor != 1.0: topk_weights.mul_(self.routed_scaling_factor) + if per_expert_scale is not None: + topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype) if self.num_fused_shared_experts > 0: pad_topk_ids = ( torch.arange( @@ -91,6 +94,7 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: bool = False, + use_gelu: bool = False, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale @@ -108,6 +112,7 @@ def _fused_experts( use_fp8_w8a8=use_fp8_w8a8, w1_scale=w13_scale, w2_scale=w2_scale, + use_gelu=use_gelu, ) return input_tensor @@ -125,6 +130,8 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + per_expert_scale: Optional[torch.Tensor] = None, + use_gelu: bool = False, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -136,6 +143,7 @@ def __call__( topk_group=topk_group, num_expert_group=num_expert_group, scoring_func=scoring_func, + per_expert_scale=per_expert_scale, ) output = self._fused_experts( input_tensor=input_tensor, @@ -145,5 +153,6 @@ def __call__( topk_ids=topk_ids, router_logits=router_logits, is_prefill=is_prefill, + use_gelu=use_gelu, ) return output diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 638abbd6ca..bed3754960 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -114,7 +114,6 @@ def moe_align1_kernel( TOKEN_BLOCK_SIZE: tl.constexpr, NUM_STAGE: tl.constexpr, ): - expert_id = tl.program_id(axis=0) off_n = tl.arange(0, TOKEN_BLOCK_SIZE) @@ -308,7 +307,6 @@ def moe_align2_kernel( BLOCK_M: tl.constexpr, BLOCK_EXPERT: tl.constexpr, ): - expert_id = tl.program_id(axis=0) off_expert = tl.arange(0, BLOCK_EXPERT) expert_to_token_num = tl.load(experts_token_num_ptr + off_expert, mask=off_expert < expert_num, other=0) @@ -911,6 +909,7 @@ def fused_experts_impl( layout="blocked", limit=None, alpha=None, + use_gelu: bool = False, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -990,6 +989,7 @@ def fused_experts_impl( limit=limit, alpha=alpha, layout=layout, + use_gelu=use_gelu, ) grouped_matmul( @@ -1035,6 +1035,7 @@ def inplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + use_gelu: bool = False, ) -> None: fused_experts_impl( hidden_states, @@ -1054,6 +1055,7 @@ def inplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + use_gelu=use_gelu, ) @@ -1075,6 +1077,7 @@ def inplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + use_gelu: bool = False, ) -> None: pass @@ -1105,6 +1108,7 @@ def outplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + use_gelu: bool = False, ) -> None: return fused_experts_impl( hidden_states, @@ -1124,6 +1128,7 @@ def outplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + use_gelu=use_gelu, ) @@ -1145,6 +1150,7 @@ def outplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + use_gelu: bool = False, ) -> None: return torch.empty_like(hidden_states) @@ -1176,6 +1182,7 @@ def fused_experts( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + use_gelu: bool = False, ): if inplace: torch.ops.lightllm.inplace_fused_experts_impl( @@ -1195,6 +1202,7 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + use_gelu=use_gelu, ) return hidden_states else: @@ -1215,4 +1223,5 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + use_gelu=use_gelu, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 2c6d013bd5..8f31bde57a 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -40,6 +40,7 @@ def masked_group_gemm( w2: torch.Tensor, w2_scale: torch.Tensor, expected_m: int, + use_gelu: bool = False, ): padded_m = recv_x[0].shape[1] E, N, _ = w1.shape @@ -54,7 +55,7 @@ def masked_group_gemm( _deepgemm_grouped_fp8_nt_masked(recv_x, (w1, w1_scale), gemm_out_a, masked_m, expected_m) - silu_and_mul_masked_post_quant_fwd(gemm_out_a, qsilu_out, qsilu_out_scale, block_size, masked_m) + silu_and_mul_masked_post_quant_fwd(gemm_out_a, qsilu_out, qsilu_out_scale, block_size, masked_m, use_gelu=use_gelu) _deepgemm_grouped_fp8_nt_masked((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, masked_m, expected_m) return gemm_out_b @@ -74,6 +75,7 @@ def fused_experts_impl( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, previous_event: Optional["EventOverlap"] = None, + use_gelu: bool = False, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -175,7 +177,7 @@ def fused_experts_impl( # TODO fused kernel silu_out = torch.empty((all_tokens, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out) + silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out, use_gelu=use_gelu) qsilu_out, qsilu_out_scale = per_token_group_quant_fp8( silu_out, block_size_k, dtype=w1.dtype, column_major_scales=True, scale_tma_aligned=True ) @@ -194,7 +196,7 @@ def fused_experts_impl( if Autotuner.is_autotune_warmup(): _gemm_out_a = torch.zeros((1, N), device=hidden_states.device, dtype=hidden_states.dtype) _silu_out = torch.zeros((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out) + silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out, use_gelu=use_gelu) _gemm_out_a, _silu_out = None, None # normal combine @@ -220,7 +222,9 @@ def fused_experts_impl( return_recv_hook=False, ) # deepgemm - gemm_out_b = masked_group_gemm(recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m) + gemm_out_b = masked_group_gemm( + recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m, use_gelu=use_gelu + ) # low latency combine combined_x, event_overlap, hook = buffer.low_latency_combine( gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=False diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py index d7bcc17743..2b7a9d30b4 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py @@ -23,6 +23,7 @@ def _silu_and_mul_kernel_fast( NEED_MASK: tl.constexpr, layout: tl.constexpr = "blocked", # "blocked" or "interleaved" USE_LIMIT_AND_ALPHA: tl.constexpr = False, + USE_GELU: tl.constexpr = False, ): stride_input_m = tl.cast(stride_input_m, dtype=tl.int64) stride_output_m = tl.cast(stride_output_m, dtype=tl.int64) @@ -74,7 +75,14 @@ def _silu_and_mul_kernel_fast( mask=mask, ) else: - gate = gate / (1 + tl.exp(-gate)) + if USE_GELU: + # tanh-approx GELU, matching Gemma's gelu_pytorch_tanh MLP. + gate_cubed = gate * gate * gate + tanh_arg = 0.7978845608028654 * (gate + 0.044715 * gate_cubed) + tanh_val = 2.0 / (1.0 + tl.exp(-2.0 * tanh_arg)) - 1.0 + gate = 0.5 * gate * (1.0 + tanh_val) + else: + gate = gate / (1 + tl.exp(-gate)) gate = gate.to(input_ptr.dtype.element_ty) tl.store( @@ -106,7 +114,13 @@ def _get_silu_and_mul_static_key(input: torch.Tensor, output: torch.Tensor): mutates_args=["output"], ) def silu_and_mul_fwd( - input: torch.Tensor, output: torch.Tensor, layout="blocked", limit=None, alpha=None, run_config=None + input: torch.Tensor, + output: torch.Tensor, + layout="blocked", + limit=None, + alpha=None, + run_config=None, + use_gelu: bool = False, ): assert input.is_contiguous() assert output.is_contiguous() @@ -157,5 +171,6 @@ def silu_and_mul_fwd( num_warps=num_warps, layout=layout, USE_LIMIT_AND_ALPHA=USE_LIMIT_AND_ALPHA, + USE_GELU=use_gelu, ) return diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py index d2c44b2953..30124cc2b2 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py @@ -24,6 +24,7 @@ def _silu_and_mul_post_quant_kernel( fp8_min, BLOCK_N: tl.constexpr, NUM_STAGE: tl.constexpr, + USE_GELU: tl.constexpr = False, ): expert_id = tl.program_id(2) token_id = tl.program_id(1) @@ -48,7 +49,13 @@ def _silu_and_mul_post_quant_kernel( for token_index in tl.range(token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE): gate = tl.load(input_ptr_offs + token_index * stride_input_1, mask=offs_in_d < size_n, other=0.0).to(tl.float32) up = tl.load(input_ptr_offs + token_index * stride_input_1 + size_n, mask=offs_in_d < size_n, other=0.0) - gate = gate / (1 + tl.exp(-gate)) + if USE_GELU: + gate_cubed = gate * gate * gate + tanh_arg = 0.7978845608028654 * (gate + 0.044715 * gate_cubed) + tanh_val = 2.0 / (1.0 + tl.exp(-2.0 * tanh_arg)) - 1.0 + gate = 0.5 * gate * (1.0 + tanh_val) + else: + gate = gate / (1 + tl.exp(-gate)) gate = gate.to(input_ptr.dtype.element_ty) gate_up = up * gate _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) @@ -66,7 +73,12 @@ def _silu_and_mul_post_quant_kernel( def silu_and_mul_masked_post_quant_fwd( - input: torch.Tensor, output: torch.Tensor, output_scale: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + quant_group_size: int, + masked_m: torch.Tensor, + use_gelu: bool = False, ): """ input shape [expert_num, token_num_padded, hidden_dim] @@ -122,6 +134,7 @@ def silu_and_mul_masked_post_quant_fwd( fp8_min, BLOCK_N=BLOCK_N, NUM_STAGE=NUM_STAGES, + USE_GELU=use_gelu, num_warps=num_warps, ) return diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index 22f6638101..0d4ed932e1 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -23,6 +23,10 @@ def __init__(self, layer_num, network_config): super().__init__(layer_num, network_config) self.eps_ = 1e-6 self.embed_dim_ = network_config["hidden_size"] + self.is_moe = bool(network_config.get("enable_moe_block", False)) + self.num_experts_per_tok = network_config.get("num_experts_per_tok", network_config.get("top_k_experts", 0)) + self.norm_topk_prob = network_config.get("norm_topk_prob", True) + self.router_root_scale = self.embed_dim_ ** -0.5 layer_type = network_config["layer_types"][layer_num] self.is_sliding = layer_type == "sliding_attention" @@ -68,26 +72,16 @@ def _bind_func(self): # ----- norms --------------------------------------------------------- - def _att_norm( - self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight - ) -> torch.Tensor: - return layer_weight.att_norm_weight_( - input=input, eps=self.eps_, alloc_func=self.alloc_tensor - ) + def _att_norm(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: + return layer_weight.att_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) - def _ffn_norm( - self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight - ) -> torch.Tensor: + def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: # NOTE: gemma packs post_attention_layernorm under `ffn_norm_weight_` - return layer_weight.ffn_norm_weight_( - input=input, eps=self.eps_, alloc_func=self.alloc_tensor - ) + return layer_weight.ffn_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) # ----- QKV + attention --------------------------------------------- - def _get_qkv( - self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight - ) -> torch.Tensor: + def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: input = self._tpsp_allgather(input=input, infer_state=infer_state) head_dim = self.layer_head_dim_ @@ -156,9 +150,7 @@ def _get_qkv( return q, cache_kv - def _get_o( - self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight - ) -> torch.Tensor: + def _get_o(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: if infer_state.need_dp_prefill_balance: input = infer_state._all_to_all_balance_get(data=input) input = input.view(-1, self.tp_o_head_num_ * self.layer_head_dim_) @@ -223,16 +215,12 @@ def _token_attention_kernel( _k, _v = self._get_layer_kv(infer_state) _q = q.view(-1, self.tp_q_head_num_, self.layer_head_dim_) att_state = infer_state.decode_att_state1 if self.is_sliding else infer_state.decode_att_state - o_tensor = att_state.decode_att( - q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor - ) + o_tensor = att_state.decode_att(q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor) return o_tensor.view(q.shape) # ----- FFN (Gemma gelu-tanh, separate gate/up/down) ---------------- - def _ffn( - self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight - ) -> torch.Tensor: + def _ffn(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) gate = layer_weight.gate_proj.mm(input) @@ -245,21 +233,77 @@ def _ffn( ffn2 = self._tpsp_reduce(input=ffn2, infer_state=infer_state) return ffn2 + def _router_logits(self, residual, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: + router_input = residual.view(-1, self.embed_dim_).float() + router_input = router_input * torch.rsqrt(router_input.pow(2).mean(dim=-1, keepdim=True) + self.eps_) + router_input = router_input * self.router_root_scale + router_input = (router_input * layer_weight.router_input_scale_.weight.float()).to(torch.bfloat16) + return layer_weight.moe_gate.mm(router_input, use_custom_tensor_mananger=False).float() + + def _moe_ffn(self, input, router_logits, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + moe_out = layer_weight.experts.experts( + input, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + is_prefill=infer_state.is_prefill, + per_expert_scale=layer_weight.experts.per_expert_scale, + use_gelu=True, + ) + moe_out = self._tpsp_reduce(input=moe_out, infer_state=infer_state) + return moe_out + + def _ffn_block(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + residual = input_embdings + dense_input = layer_weight.pre_feedforward_layernorm_weight_( + input=residual.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + dense_out = self._ffn(dense_input, infer_state, layer_weight) + dense_input = None + + if self.is_moe: + dense_out = layer_weight.post_feedforward_layernorm_1_weight_( + input=dense_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + + router_logits = self._router_logits(residual, layer_weight) + moe_input = layer_weight.pre_feedforward_layernorm_2_weight_( + input=residual.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + moe_out = self._moe_ffn(moe_input, router_logits, infer_state, layer_weight) + moe_input = None + router_logits = None + moe_out = layer_weight.post_feedforward_layernorm_2_weight_( + input=moe_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + dense_out.add_(moe_out) + moe_out = None + + ffn_out = layer_weight.post_feedforward_layernorm_weight_( + input=dense_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor + ).to(torch.bfloat16) + dense_out = None + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + # ----- block-level forwards (add layer_scalar at the end) ---------- def _apply_layer_scalar(self, hidden_states, layer_weight): hidden_states.mul_(layer_weight.layer_scalar_.weight) return hidden_states - def context_forward( - self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight - ): + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): input_embdings = input_embdings.to(torch.bfloat16) # attn sub-block - input1 = self._att_norm( - input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight - ).to(torch.bfloat16) + input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight).to( + torch.bfloat16 + ) q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) input1 = None self._post_cache_kv(cache_kv, infer_state, layer_weight) @@ -270,27 +314,16 @@ def context_forward( input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - # ffn sub-block - input1 = layer_weight.pre_feedforward_layernorm_weight_( - input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor - ).to(torch.bfloat16) - ffn_out = self._ffn(input1, infer_state, layer_weight) - input1 = None - ffn_out = layer_weight.post_feedforward_layernorm_weight_( - input=ffn_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor - ).to(torch.bfloat16) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + input_embdings = self._ffn_block(input_embdings, infer_state, layer_weight) return self._apply_layer_scalar(input_embdings, layer_weight) - def token_forward( - self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight - ): + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): input_embdings = input_embdings.to(torch.bfloat16) - input1 = self._att_norm( - input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight - ).to(torch.bfloat16) + input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight).to( + torch.bfloat16 + ) q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) input1 = None self._post_cache_kv(cache_kv, infer_state, layer_weight) @@ -301,14 +334,6 @@ def token_forward( input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - input1 = layer_weight.pre_feedforward_layernorm_weight_( - input=input_embdings.float(), eps=self.eps_, alloc_func=self.alloc_tensor - ).to(torch.bfloat16) - ffn_out = self._ffn(input1, infer_state, layer_weight) - input1 = None - ffn_out = layer_weight.post_feedforward_layernorm_weight_( - input=ffn_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor - ).to(torch.bfloat16) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + input_embdings = self._ffn_block(input_embdings, infer_state, layer_weight) return self._apply_layer_scalar(input_embdings, layer_weight) diff --git a/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py index f9a59f7424..3ea706ef50 100644 --- a/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py @@ -1,6 +1,10 @@ from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight, COLMMWeight from lightllm.common.basemodel.layer_weights.meta_weights import RMSNormWeight, ParameterWeight +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.gemma4_packed_fused_moe_weight import ( + Gemma4PackedFusedMoeWeight, +) from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight +from lightllm.utils.envs_utils import get_env_start_args class Gemma4TransformerLayerWeight(LlamaTransformerLayerWeight): @@ -16,6 +20,7 @@ def __init__( return def _pre_parse_layer_shape(self, layer_num, network_config): + self._is_moe = bool(network_config.get("enable_moe_block", False)) layer_type = network_config["layer_types"][layer_num] self._is_sliding = layer_type == "sliding_attention" if self._is_sliding: @@ -59,6 +64,12 @@ def _init_weight_names(self): self._ffn_norm_weight_name = f"{prefix}.post_attention_layernorm.weight" self._pre_feedforward_layernorm_name = f"{prefix}.pre_feedforward_layernorm.weight" self._post_feedforward_layernorm_name = f"{prefix}.post_feedforward_layernorm.weight" + self._post_feedforward_layernorm_1_name = f"{prefix}.post_feedforward_layernorm_1.weight" + self._pre_feedforward_layernorm_2_name = f"{prefix}.pre_feedforward_layernorm_2.weight" + self._post_feedforward_layernorm_2_name = f"{prefix}.post_feedforward_layernorm_2.weight" + + self._router_input_scale_name = f"{prefix}.router.scale" + self._router_weight_name = f"{prefix}.router.proj.weight" self._layer_scalar_name = f"{prefix}.layer_scalar" @@ -66,6 +77,8 @@ def _init_weight(self): self._init_qkv() self._init_o() self._init_ffn() + if self._is_moe: + self._init_moe() self._init_norm() def _init_qkv(self): @@ -139,6 +152,41 @@ def _init_ffn(self): quant_method=self.get_quant_method("down_proj"), ) + def _init_moe(self): + enable_ep_moe = get_env_start_args().enable_ep_moe + assert not enable_ep_moe, "Gemma-4 MoE packed expert weights currently support TP mode only." + + self.router_input_scale_ = ParameterWeight( + weight_name=self._router_input_scale_name, + data_type=self.data_type_, + weight_shape=(self.n_embed,), + ) + self.moe_gate = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.network_config_["num_experts"]], + weight_names=self._router_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("moe_gate"), + tp_rank=0, + tp_world_size=1, + ) + self.experts = Gemma4PackedFusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"model.language_model.layers.{self.layer_num_}.experts", + n_routed_experts=self.network_config_["num_experts"], + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=self.network_config_["moe_intermediate_size"], + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + per_expert_scale_name=f"model.language_model.layers.{self.layer_num_}.router.per_expert_scale", + ) + def _init_norm(self): hidden_size = self.network_config_["hidden_size"] # Gemma-4 uses *standard* RMSNorm (x * rsqrt(var+eps) * w), NOT the @@ -174,6 +222,22 @@ def _init_norm(self): weight_name=self._post_feedforward_layernorm_name, data_type=self.data_type_, ) + if self._is_moe: + self.post_feedforward_layernorm_1_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_feedforward_layernorm_1_name, + data_type=self.data_type_, + ) + self.pre_feedforward_layernorm_2_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._pre_feedforward_layernorm_2_name, + data_type=self.data_type_, + ) + self.post_feedforward_layernorm_2_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_feedforward_layernorm_2_name, + data_type=self.data_type_, + ) # scalar multiplier applied to the attention output self.layer_scalar_ = ParameterWeight( weight_name=self._layer_scalar_name, diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py index 8e4b1d1325..e91cdccdd2 100644 --- a/lightllm/models/gemma4/model.py +++ b/lightllm/models/gemma4/model.py @@ -16,6 +16,7 @@ from lightllm.models.gemma4.layer_weights.transformer_layer_weight import Gemma4TransformerLayerWeight from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num, get_env_start_args from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager logger = init_logger(__name__) @@ -105,6 +106,13 @@ def _init_config(self): repair_config(self.config, same_names=["num_attention_heads", "n_head"]) repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.config.get("enable_moe_block", False): + # LightLLM's MoE helpers use Qwen/DeepSeek-style field names. + # Gemma-4 checkpoints expose equivalent values as top_k_experts + # and moe_intermediate_size. + self.config.setdefault("num_experts_per_tok", self.config["top_k_experts"]) + self.config.setdefault("norm_topk_prob", True) + self.config.setdefault("scoring_func", "softmax") return def _verify_params(self): @@ -112,9 +120,9 @@ def _verify_params(self): assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 assert self.config["num_key_value_heads"] % self.tp_world_size_ == 0 num_global_kv = self.config.get("num_global_key_value_heads", self.config["num_key_value_heads"]) - assert num_global_kv % self.tp_world_size_ == 0, ( - f"num_global_key_value_heads={num_global_kv} must be divisible by tp={self.tp_world_size_}" - ) + assert ( + num_global_kv % self.tp_world_size_ == 0 + ), f"num_global_key_value_heads={num_global_kv} must be divisible by tp={self.tp_world_size_}" return def _init_mem_manager(self): @@ -171,9 +179,7 @@ def _init_att_backend(self): # honour it. User can force triton with --llm_*_att_backend triton; # otherwise prefer FA3 when loadable. user_forced_triton = backends == {"triton"} - self._gemma4_sliding_backend_kind = ( - "fa3" if (fa3_loadable and not user_forced_triton) else "triton" - ) + self._gemma4_sliding_backend_kind = "fa3" if (fa3_loadable and not user_forced_triton) else "triton" # SWA is on regardless of which sliding backend was picked: FA3 # honours window_size per call, and the triton kernels in # context_flashattention_nopad.py / gqa_flash_decoding_stage1.py mask @@ -185,6 +191,7 @@ def _init_att_backend1(self): # mismatch with full-attn layers doesn't force a single compromise. if self._gemma4_sliding_backend_kind == "fa3": from lightllm.common.basemodel.attention.fa3.fp import Fa3AttBackend + self.prefill_att_backend1 = Fa3AttBackend(model=self) self.decode_att_backend1 = Fa3AttBackend(model=self) else: @@ -193,6 +200,8 @@ def _init_att_backend1(self): def _init_custom(self): self._init_to_get_rotary_gemma4() + if self.config.get("enable_moe_block", False): + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) def _init_to_get_rotary_gemma4(self): rope_params = self.config["rope_parameters"] @@ -231,21 +240,16 @@ def _init_to_get_rotary_gemma4(self): if rope_type == "proportional": rope_angles = int(full_partial * full_head_dim // 2) inv_freq_rot = 1.0 / ( - full_theta - ** (torch.arange(0, 2 * rope_angles, 2, dtype=torch.float32) / full_head_dim) + full_theta ** (torch.arange(0, 2 * rope_angles, 2, dtype=torch.float32) / full_head_dim) ) nope_angles = full_head_dim // 2 - rope_angles if nope_angles > 0: - inv_freq_full = torch.cat( - [inv_freq_rot, torch.zeros(nope_angles, dtype=torch.float32)] - ) + inv_freq_full = torch.cat([inv_freq_rot, torch.zeros(nope_angles, dtype=torch.float32)]) else: inv_freq_full = inv_freq_rot else: full_rot_dim = int(full_head_dim * full_partial) - inv_freq_full = 1.0 / ( - full_theta ** (torch.arange(0, full_rot_dim, 2, dtype=torch.float32) / full_rot_dim) - ) + inv_freq_full = 1.0 / (full_theta ** (torch.arange(0, full_rot_dim, 2, dtype=torch.float32) / full_rot_dim)) freqs_f = torch.outer(t, inv_freq_full) self._cos_cached_full = torch.cos(freqs_f).to(self.data_type).cuda() From 83f498309a0fdea305bdb916190dbe1ea4a1c538 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Sat, 9 May 2026 06:49:10 +0000 Subject: [PATCH 04/20] support e4b (PLE and shared_kv) --- lightllm/models/gemma4/infer_struct.py | 3 + .../gemma4/layer_infer/pre_layer_infer.py | 48 +++- .../layer_infer/transformer_layer_infer.py | 259 +++++++++++------- .../pre_and_post_layer_weight.py | 28 ++ .../layer_weights/transformer_layer_weight.py | 37 ++- lightllm/models/gemma4/model.py | 30 +- 6 files changed, 291 insertions(+), 114 deletions(-) diff --git a/lightllm/models/gemma4/infer_struct.py b/lightllm/models/gemma4/infer_struct.py index 686118346d..703ee6c68a 100644 --- a/lightllm/models/gemma4/infer_struct.py +++ b/lightllm/models/gemma4/infer_struct.py @@ -12,6 +12,9 @@ def __init__(self): self.position_sin_sliding = None self.position_cos_full = None self.position_sin_full = None + # E-series only: per-layer embeddings (PLE), shape (N, num_layers, hidden_size_per_layer_input). + # Computed once in Gemma4PreLayerInfer; sliced per layer in the transformer block. + self.per_layer_embeds = None def init_some_extra_state(self, model): super().init_some_extra_state(model) diff --git a/lightllm/models/gemma4/layer_infer/pre_layer_infer.py b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py index 4771e4b1e1..e5e1507c96 100644 --- a/lightllm/models/gemma4/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py @@ -1,24 +1,52 @@ +import math import torch +import torch.distributed as dist +from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy +from lightllm.distributed.communication_op import all_reduce from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.utils.envs_utils import get_env_start_args class Gemma4PreLayerInfer(LlamaPreLayerInfer): - """ - Text-only pre-layer for Gemma-4 (Phase A). Applies the Gemma embedding - scale (sqrt(hidden_size)) to the token embeddings. Multimodal embed-scatter - handling will be added alongside the vision tower port. - """ - def __init__(self, network_config): super().__init__(network_config) self.embed_scale = float(network_config["hidden_size"]) ** 0.5 + self.has_ple = bool(network_config.get("hidden_size_per_layer_input")) + if self.has_ple: + self.num_layers_ = network_config["num_hidden_layers"] + self.ple_dim_ = network_config["hidden_size_per_layer_input"] + self.ple_embed_scale_ = math.sqrt(self.ple_dim_) + self.ple_proj_scale_ = float(network_config["hidden_size"]) ** -0.5 + self.ple_combine_scale_ = 2.0 ** -0.5 + self.rms_norm_eps_ = network_config.get("rms_norm_eps", 1e-6) + + def _compute_per_layer_embeds(self, input_ids, input_embdings, infer_state, layer_weight): + ple_embeds = layer_weight.embed_tokens_per_layer_weight_(input_ids) + if self.tp_world_size_ > 1: + all_reduce(ple_embeds, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + ple_embeds = ple_embeds * self.ple_embed_scale_ + + ple_proj = layer_weight.per_layer_model_projection_weight_.mm(input_embdings) + ple_proj = ple_proj * self.ple_proj_scale_ + ple_proj = ple_proj.reshape(*ple_proj.shape[:-1], self.num_layers_, self.ple_dim_) + ple_proj = layer_weight.per_layer_projection_norm_weight_( + input=ple_proj, eps=self.rms_norm_eps_, alloc_func=self.alloc_tensor + ) + + ple_embeds = ple_embeds.reshape(*ple_embeds.shape[:-1], self.num_layers_, self.ple_dim_) + infer_state.per_layer_embeds = (ple_proj + ple_embeds) * self.ple_combine_scale_ + def context_forward(self, input_ids, infer_state, layer_weight): input_embdings = super().context_forward(input_ids, infer_state, layer_weight) - input_dtype = input_embdings.dtype - return (input_embdings.float() * self.embed_scale).to(input_dtype) + input_embdings = input_embdings * self.embed_scale + if self.has_ple: + self._compute_per_layer_embeds(input_ids, input_embdings, infer_state, layer_weight) + return input_embdings def token_forward(self, input_ids, infer_state, layer_weight): input_embdings = super().token_forward(input_ids, infer_state, layer_weight) - input_dtype = input_embdings.dtype - return (input_embdings.float() * self.embed_scale).to(input_dtype) + input_embdings = input_embdings * self.embed_scale + if self.has_ple: + self._compute_per_layer_embeds(input_ids, input_embdings, infer_state, layer_weight) + return input_embdings diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index 0d4ed932e1..3274c87b7f 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -31,13 +31,17 @@ def __init__(self, layer_num, network_config): layer_type = network_config["layer_types"][layer_num] self.is_sliding = layer_type == "sliding_attention" + # Some E-series checkpoints leave num_global_key_value_heads = null; + # HF treats that as "fall back to num_key_value_heads". + num_global_kv = network_config.get("num_global_key_value_heads") or network_config["num_key_value_heads"] + if self.is_sliding: self.layer_head_dim_ = network_config["head_dim"] total_kv_heads = network_config["num_key_value_heads"] self.k_eq_v = False else: self.layer_head_dim_ = network_config["global_head_dim"] - total_kv_heads = network_config["num_global_key_value_heads"] + total_kv_heads = num_global_kv self.k_eq_v = network_config.get("attention_k_eq_v", True) # TP shard counts for this layer @@ -46,9 +50,14 @@ def __init__(self, layer_num, network_config): self.tp_v_head_num_ = self.tp_k_head_num_ self.tp_o_head_num_ = self.tp_q_head_num_ - # Uniform mem-manager layout (sliding shape per rank) - self.mm_head_dim_ = network_config["head_dim"] - self.mm_kv_head_num_ = network_config["num_key_value_heads"] // self.tp_world_size_ + self.kv_cache_slot_dim_ = network_config["head_dim"] + sliding_total = network_config["num_key_value_heads"] * network_config["head_dim"] + full_total = num_global_kv * network_config["global_head_dim"] + per_token_k_width = max(sliding_total, full_total) + assert ( + per_token_k_width % self.kv_cache_slot_dim_ == 0 + ), f"per-token K width {per_token_k_width} not aligned to kv_cache_slot_dim {self.kv_cache_slot_dim_}" + self.kv_cache_slot_num_ = (per_token_k_width // self.kv_cache_slot_dim_) // self.tp_world_size_ # Sliding window (None on full-attn layers) if self.is_sliding: @@ -57,30 +66,43 @@ def __init__(self, layer_num, network_config): else: self.sliding_window_ = 0 - # Partial rotary factor for the RoPE kernel. The sliding table is sized - # (seq, head_dim/2) so full rotation over head_dim is the default. - # The full table is sized (seq, global_head_dim/2) with zero-padded - # frequencies (proportional RoPE) — we still pass partial_rotary_factor=1 - # to the kernel so it walks every pair, applying identity for the zeroed - # frequencies. - self.rotary_partial_factor_ = 1.0 - - def _bind_func(self): - # Skip LlamaTransformerLayerInfer._bind_norm (it rebinds to Llama _att_norm / _ffn_norm); - # we want our own gemma-style norm implementations below. - return - - # ----- norms --------------------------------------------------------- - - def _att_norm(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: - return layer_weight.att_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) - - def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: - # NOTE: gemma packs post_attention_layernorm under `ffn_norm_weight_` - return layer_weight.ffn_norm_weight_(input=input, eps=self.eps_, alloc_func=self.alloc_tensor) + # E-series Per-Layer Embeddings gate (HF: config.hidden_size_per_layer_input, + # absent or 0 on 31B). + self.has_ple_ = bool(network_config.get("hidden_size_per_layer_input")) + if self.has_ple_: + self.ple_dim_ = network_config["hidden_size_per_layer_input"] + + # HF: config.num_kv_shared_layers (may be missing or null on non-E + # checkpoints — treat as 0). + kv_shared_count = network_config.get("num_kv_shared_layers") or 0 + total_layers = network_config["num_hidden_layers"] + self.is_kv_shared_ = kv_shared_count > 0 and layer_num >= total_layers - kv_shared_count + self.kv_share_target_layer_ = None + if self.is_kv_shared_: + cutoff = total_layers - kv_shared_count + for j in range(layer_num - 1, -1, -1): + if j < cutoff and network_config["layer_types"][j] == layer_type: + self.kv_share_target_layer_ = j + break + assert self.kv_share_target_layer_ is not None, ( + f"layer {layer_num} ({layer_type}) is KV-shared but no earlier non-shared " + f"layer of the same type found below cutoff={cutoff}" + ) + + # Always 1.0: NoPE dims for full-attn layers are zero-padded into + # cos/sin (cos=1, sin=0 → identity), so the kernel walks the whole + # head_dim. Don't change to 0.25 — that double-counts with the table. + self.partial_rotary_factor_ = 1.0 # ----- QKV + attention --------------------------------------------- + def _rope_cos_sin(self, infer_state): + # Tables are built in the model dtype (Gemma4TpPartModel._init_to_get_rotary_gemma4), + # so they already match q/k dtype — no cast needed. + if self.is_sliding: + return infer_state.position_cos_sliding, infer_state.position_sin_sliding + return infer_state.position_cos_full, infer_state.position_sin_full + def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: input = self._tpsp_allgather(input=input, infer_state=infer_state) @@ -88,61 +110,65 @@ def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight: Gemma4Trans q_heads = self.tp_q_head_num_ kv_heads = self.tp_k_head_num_ + # Q is always computed (even on KV-shared layers). RMSNormWeight's + # Triton kernel accepts 3D input (it views to 2D internally) and + # promotes to fp32 for the variance reduction, so feed bf16 (N, heads, + # head_dim) straight in — no Python-side reshape or dtype round-trip. q = layer_weight.q_proj.mm(input).view(-1, q_heads, head_dim) + q = layer_weight.q_norm_weight_(input=q, eps=self.eps_, alloc_func=self.alloc_tensor) + + cos, sin = self._rope_cos_sin(infer_state) + + if self.is_kv_shared_: + # K/V come from target layer's already-rotated, already-normed cache. + # Only rotate Q here. rotary_emb_fwd writes to k in place, so pass + # a 1-head throwaway tensor we can discard. + dummy_k = torch.empty((q.shape[0], 1, head_dim), dtype=q.dtype, device=q.device) + rotary_emb_fwd(q, dummy_k, cos, sin, partial_rotary_factor=self.partial_rotary_factor_) + q = q * math.sqrt(head_dim) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + return q, None + + # ---- non-shared: full K/V path ---- k = layer_weight.k_proj.mm(input).view(-1, kv_heads, head_dim) if self.k_eq_v: - # Full-attn layers share K weights for V. + # Full-attn k_eq_v variant (e.g. 31B): K weights serve as V. v = k.clone() else: v = layer_weight.v_proj.mm(input).view(-1, kv_heads, head_dim) - # QK RMSNorm (learnable weight, Gemma-style `(1+w)` applied in fp32). - # Reshape to 2D (N*heads, head_dim) so NoTpGEMMANormWeight accepts it. - q_flat = q.reshape(-1, head_dim).float() - k_flat = k.reshape(-1, head_dim).float() - q_flat = layer_weight.q_norm_weight_(input=q_flat, eps=self.eps_, alloc_func=self.alloc_tensor) - k_flat = layer_weight.k_norm_weight_(input=k_flat, eps=self.eps_, alloc_func=self.alloc_tensor) - q = q_flat.view(-1, q_heads, head_dim).to(input.dtype) - k = k_flat.view(-1, kv_heads, head_dim).to(input.dtype) + k = layer_weight.k_norm_weight_(input=k, eps=self.eps_, alloc_func=self.alloc_tensor) # V-norm: unweighted RMSNorm over head_dim (matches vllm's Gemma4 has_weight=False). v_fp = v.float() v_fp = v_fp * torch.rsqrt(v_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps_) v = v_fp.to(input.dtype) - # Per-layer RoPE - if self.is_sliding: - cos = infer_state.position_cos_sliding.to(q.dtype) - sin = infer_state.position_sin_sliding.to(q.dtype) - else: - cos = infer_state.position_cos_full.to(q.dtype) - sin = infer_state.position_sin_full.to(q.dtype) - rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=self.rotary_partial_factor_) + rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=self.partial_rotary_factor_) # Gemma-4 uses scaling=1.0 in attention. The attention kernel hardcodes # sm_scale = 1/sqrt(head_dim); pre-scale Q by sqrt(head_dim) so the # kernel's division cancels out, yielding scores = Q @ K^T. q = q * math.sqrt(head_dim) - # Pack into the uniform mem-manager layout. - mm_heads = self.mm_kv_head_num_ - mm_dim = self.mm_head_dim_ - if self.is_sliding: - # (N, 2*mm_heads, mm_dim) with [:mm_heads]=K, [mm_heads:]=V - cache_kv = torch.cat([k, v], dim=1) + # Pack into the uniform KV-cache layout (N, 2*slot_num, slot_dim). + # K occupies slots [0, used_slots); V occupies + # [slot_num, slot_num + used_slots). If this layer's K/V width is + # smaller than the allocated cache slot width, pad with zeros. + cache_slot_num = self.kv_cache_slot_num_ + cache_slot_dim = self.kv_cache_slot_dim_ + N = k.shape[0] + k_packed = k.reshape(N, -1, cache_slot_dim) + v_packed = v.reshape(N, -1, cache_slot_dim) + used_cache_slots = k_packed.shape[1] + if used_cache_slots == cache_slot_num: + cache_kv = torch.cat([k_packed, v_packed], dim=1) else: - # K,V shape (N, kv_heads, layer_head_dim) e.g. (N, 2, 512) on tp=2. - # Reshape each half to (N, kv_heads*layer_head_dim // mm_dim, mm_dim) e.g. (N, 4, 256) on tp=2. - # The mem-manager layout has (N, 2*mm_heads, mm_dim) = (N, 16, 256) on tp=2 for this - # checkpoint — pad to that shape with zeros on unused head slots. - N = k.shape[0] - k_packed = k.reshape(N, -1, mm_dim) # (N, kv_heads * layer_head_dim // mm_dim, mm_dim) - v_packed = v.reshape(N, -1, mm_dim) - cache_kv = self.alloc_tensor((N, 2 * mm_heads, mm_dim), dtype=k.dtype) + cache_kv = self.alloc_tensor((N, 2 * cache_slot_num, cache_slot_dim), dtype=k.dtype) cache_kv.zero_() - k_slots = k_packed.shape[1] - cache_kv[:, :k_slots, :] = k_packed - cache_kv[:, mm_heads : mm_heads + k_slots, :] = v_packed + cache_kv[:, :used_cache_slots, :] = k_packed + cache_kv[:, cache_slot_num : cache_slot_num + used_cache_slots, :] = v_packed if infer_state.need_dp_prefill_balance: q = infer_state._all_to_all_unbalance_get(data=q) @@ -150,6 +176,11 @@ def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight: Gemma4Trans return q, cache_kv + def _post_cache_kv(self, cache_kv, infer_state, layer_weight): + if self.is_kv_shared_ or cache_kv is None: + return + return super()._post_cache_kv(cache_kv, infer_state, layer_weight) + def _get_o(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: if infer_state.need_dp_prefill_balance: input = infer_state._all_to_all_balance_get(data=input) @@ -171,19 +202,20 @@ def _att_control(self): return AttControl(use_sliding_window=False, sliding_window=(-1, -1)) def _get_layer_kv(self, infer_state: InferStateInfo): - _k_raw, _v_raw = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) - # _k_raw / _v_raw shape (S, mm_heads, mm_dim) - if self.is_sliding: - # sliding K is stored in the full (mm_heads, mm_dim) slot; head count matches. - return _k_raw, _v_raw - # full layer: the real K/V live in the first `kv_heads * layer_head_dim // mm_dim` - # head slots. Reshape to (S, kv_heads, layer_head_dim). + # KV-shared layers read from the target layer's cache slot. + layer_idx = self.kv_share_target_layer_ if self.is_kv_shared_ else self.layer_num_ + _k_raw, _v_raw = infer_state.mem_manager.get_att_input_params(layer_index=layer_idx) + # _k_raw / _v_raw shape (S, cache_slot_num, cache_slot_dim). kv_heads = self.tp_k_head_num_ head_dim = self.layer_head_dim_ - mm_dim = self.mm_head_dim_ - k_slots = kv_heads * head_dim // mm_dim - _k = _k_raw[:, :k_slots, :].reshape(-1, kv_heads, head_dim) - _v = _v_raw[:, :k_slots, :].reshape(-1, kv_heads, head_dim) + cache_slot_dim = self.kv_cache_slot_dim_ + used_cache_slots = kv_heads * head_dim // cache_slot_dim + if used_cache_slots == _k_raw.shape[1]: + # Layout already matches this layer's natural shape. + return _k_raw.reshape(-1, kv_heads, head_dim), _v_raw.reshape(-1, kv_heads, head_dim) + # Otherwise the K/V live in the first used_cache_slots; the rest is zero pad. + _k = _k_raw[:, :used_cache_slots, :].reshape(-1, kv_heads, head_dim) + _v = _v_raw[:, :used_cache_slots, :].reshape(-1, kv_heads, head_dim) return _k, _v def _context_attention_kernel( @@ -234,10 +266,16 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: Gemma4Transform return ffn2 def _router_logits(self, residual, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: + # Manual unweighted RMSNorm — lightllm's RMSNormWeight has no + # has_weight=False mode, and bf16 variance over hidden_size loses too + # much precision. Keep the fp32 accumulation explicit. router_input = residual.view(-1, self.embed_dim_).float() router_input = router_input * torch.rsqrt(router_input.pow(2).mean(dim=-1, keepdim=True) + self.eps_) router_input = router_input * self.router_root_scale - router_input = (router_input * layer_weight.router_input_scale_.weight.float()).to(torch.bfloat16) + # bf16 weight auto-promotes against fp32 router_input; cast back to + # bf16 to feed moe_gate.mm. + router_input = (router_input * layer_weight.router_input_scale_.weight).to(torch.bfloat16) + # gate logits stay fp32 for top-k / softmax precision. return layer_weight.moe_gate.mm(router_input, use_custom_tensor_mananger=False).float() def _moe_ffn(self, input, router_logits, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): @@ -261,79 +299,106 @@ def _moe_ffn(self, input, router_logits, infer_state: InferStateInfo, layer_weig def _ffn_block(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): residual = input_embdings dense_input = layer_weight.pre_feedforward_layernorm_weight_( - input=residual.float(), eps=self.eps_, alloc_func=self.alloc_tensor - ).to(torch.bfloat16) + input=residual, eps=self.eps_, alloc_func=self.alloc_tensor + ) dense_out = self._ffn(dense_input, infer_state, layer_weight) dense_input = None if self.is_moe: dense_out = layer_weight.post_feedforward_layernorm_1_weight_( - input=dense_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor - ).to(torch.bfloat16) + input=dense_out, eps=self.eps_, alloc_func=self.alloc_tensor + ) router_logits = self._router_logits(residual, layer_weight) moe_input = layer_weight.pre_feedforward_layernorm_2_weight_( - input=residual.float(), eps=self.eps_, alloc_func=self.alloc_tensor - ).to(torch.bfloat16) + input=residual, eps=self.eps_, alloc_func=self.alloc_tensor + ) moe_out = self._moe_ffn(moe_input, router_logits, infer_state, layer_weight) moe_input = None router_logits = None moe_out = layer_weight.post_feedforward_layernorm_2_weight_( - input=moe_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor - ).to(torch.bfloat16) + input=moe_out, eps=self.eps_, alloc_func=self.alloc_tensor + ) dense_out.add_(moe_out) moe_out = None ffn_out = layer_weight.post_feedforward_layernorm_weight_( - input=dense_out.float(), eps=self.eps_, alloc_func=self.alloc_tensor - ).to(torch.bfloat16) + input=dense_out, eps=self.eps_, alloc_func=self.alloc_tensor + ) dense_out = None input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - # ----- block-level forwards (add layer_scalar at the end) ---------- + # ----- block-level forwards (PLE fusion + layer_scalar at the end) ---- + + def _apply_per_layer_embed(self, hidden_states, infer_state, layer_weight): + """E-series: gate hidden_states through per_layer_embed slice and add + the projected contribution back as a residual. Matches HF + Gemma4TextDecoderLayer.forward (lines 1401–1408 in transformers 5.5.4) + and vllm Gemma4DecoderLayer.forward (gemma4.py:744–752) — bf16 the + whole way, RMSNorm Triton kernel handles fp32 promotion internally. + + gate / projection weights are ROWMMWeight(tp_world_size=1) — replicated + across TP ranks — so we drive them through `.mm()` and never need an + intra-block all-reduce. In TPSP mix mode, per_layer_embeds has already + been token-split alongside hidden_states by Gemma4PreLayerInfer's + _tpsp_sp_split override, so rows line up element-wise here. + """ + # per_layer_embeds is (N, num_layers, ple_dim); slice this layer. + ple_slice = infer_state.per_layer_embeds[..., self.layer_num_, :] + flat = hidden_states.view(-1, self.embed_dim_) + gate = layer_weight.per_layer_input_gate_.mm(flat) # (N, ple_dim) + gate = nn.functional.gelu(gate, approximate="tanh") + gated = gate * ple_slice.view(-1, self.ple_dim_) + contrib = layer_weight.per_layer_projection_.mm(gated) # (N, hidden_size) + contrib = layer_weight.post_per_layer_input_norm_weight_( + input=contrib, eps=self.eps_, alloc_func=self.alloc_tensor + ) + flat.add_(contrib) + return hidden_states def _apply_layer_scalar(self, hidden_states, layer_weight): hidden_states.mul_(layer_weight.layer_scalar_.weight) return hidden_states - def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): - input_embdings = input_embdings.to(torch.bfloat16) + def _block_epilogue(self, hidden_states, infer_state, layer_weight): + """Shared tail for prefill/decode: PLE fusion (E-series only) then + layer_scalar.""" + if self.has_ple_: + hidden_states = self._apply_per_layer_embed(hidden_states, infer_state, layer_weight) + return self._apply_layer_scalar(hidden_states, layer_weight) - # attn sub-block - input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight).to( - torch.bfloat16 - ) + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + # input_embdings is bf16 from the pre-layer / previous block; RMSNorm + # (att_norm, ffn_norm) handles fp32 promotion in its Triton kernel, + # so the entire residual stream stays in bf16. + input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_), infer_state, layer_weight) q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) input1 = None self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) + o = self._ffn_norm(o, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None input_embdings = self._ffn_block(input_embdings, infer_state, layer_weight) - return self._apply_layer_scalar(input_embdings, layer_weight) + return self._block_epilogue(input_embdings, infer_state, layer_weight) def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): - input_embdings = input_embdings.to(torch.bfloat16) - - input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_).float(), infer_state, layer_weight).to( - torch.bfloat16 - ) + input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_), infer_state, layer_weight) q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) input1 = None self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) - o = self._ffn_norm(o.float(), infer_state, layer_weight).to(torch.bfloat16) + o = self._ffn_norm(o, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None input_embdings = self._ffn_block(input_embdings, infer_state, layer_weight) - return self._apply_layer_scalar(input_embdings, layer_weight) + return self._block_epilogue(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py index a767e70c10..22a2fc4dc7 100644 --- a/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py @@ -2,6 +2,7 @@ from lightllm.common.basemodel.layer_weights.meta_weights import ( EmbeddingWeight, LMHeadWeight, + ROWMMWeight, RMSNormWeight, ) @@ -33,4 +34,31 @@ def __init__(self, data_type, network_config): weight_name="model.language_model.norm.weight", data_type=self.data_type_, ) + + if network_config.get("hidden_size_per_layer_input"): + num_layers = network_config["num_hidden_layers"] + ple_dim = network_config["hidden_size_per_layer_input"] + ple_vocab = network_config.get("vocab_size_per_layer_input", vocab_size) + self.embed_tokens_per_layer_weight_ = EmbeddingWeight( + dim=num_layers * ple_dim, + vocab_size=ple_vocab, + weight_name="model.language_model.embed_tokens_per_layer.weight", + data_type=self.data_type_, + ) + # nn.Linear(in=hidden_size, out=num_layers*ple_dim); HF storage is + # (out, in). Replicated across TP ranks. + self.per_layer_model_projection_weight_ = ROWMMWeight( + in_dim=hidden_size, + out_dims=[num_layers * ple_dim], + weight_names="model.language_model.per_layer_model_projection.weight", + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + # RMSNorm over the ple_dim of the projection output. + self.per_layer_projection_norm_weight_ = RMSNormWeight( + dim=ple_dim, + weight_name="model.language_model.per_layer_projection_norm.weight", + data_type=self.data_type_, + ) return diff --git a/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py index 3ea706ef50..0e1dc46b0a 100644 --- a/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py @@ -23,13 +23,16 @@ def _pre_parse_layer_shape(self, layer_num, network_config): self._is_moe = bool(network_config.get("enable_moe_block", False)) layer_type = network_config["layer_types"][layer_num] self._is_sliding = layer_type == "sliding_attention" + # Some E-series checkpoints leave num_global_key_value_heads = null; + # HF treats that as "fall back to num_key_value_heads". + num_global_kv = network_config.get("num_global_key_value_heads") or network_config["num_key_value_heads"] if self._is_sliding: self._layer_head_dim = network_config["head_dim"] self._layer_kv_head_num = network_config["num_key_value_heads"] self._layer_k_eq_v = False else: self._layer_head_dim = network_config["global_head_dim"] - self._layer_kv_head_num = network_config["num_global_key_value_heads"] + self._layer_kv_head_num = num_global_kv self._layer_k_eq_v = network_config.get("attention_k_eq_v", True) def _parse_config(self): @@ -73,6 +76,11 @@ def _init_weight_names(self): self._layer_scalar_name = f"{prefix}.layer_scalar" + # E-series Per-Layer Embeddings names (only loaded when PLE enabled). + self._per_layer_input_gate_name = f"{prefix}.per_layer_input_gate.weight" + self._per_layer_projection_name = f"{prefix}.per_layer_projection.weight" + self._post_per_layer_input_norm_name = f"{prefix}.post_per_layer_input_norm.weight" + def _init_weight(self): self._init_qkv() self._init_o() @@ -80,6 +88,33 @@ def _init_weight(self): if self._is_moe: self._init_moe() self._init_norm() + if self.network_config_.get("hidden_size_per_layer_input"): + self._init_ple() + + def _init_ple(self): + ple_dim = self.network_config_["hidden_size_per_layer_input"] + hidden_size = self.network_config_["hidden_size"] + self.per_layer_input_gate_ = ROWMMWeight( + in_dim=hidden_size, + out_dims=[ple_dim], + weight_names=self._per_layer_input_gate_name, + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + self.per_layer_projection_ = ROWMMWeight( + in_dim=ple_dim, + out_dims=[hidden_size], + weight_names=self._per_layer_projection_name, + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + self.post_per_layer_input_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_per_layer_input_norm_name, + data_type=self.data_type_, + ) def _init_qkv(self): in_dim = self.n_embed diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py index e91cdccdd2..c900261c35 100644 --- a/lightllm/models/gemma4/model.py +++ b/lightllm/models/gemma4/model.py @@ -106,6 +106,7 @@ def _init_config(self): repair_config(self.config, same_names=["num_attention_heads", "n_head"]) repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.config.get("enable_moe_block", False): # LightLLM's MoE helpers use Qwen/DeepSeek-style field names. # Gemma-4 checkpoints expose equivalent values as top_k_experts @@ -119,19 +120,36 @@ def _verify_params(self): assert self.load_way == "HF", "Gemma-4 only supports HF format." assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 assert self.config["num_key_value_heads"] % self.tp_world_size_ == 0 - num_global_kv = self.config.get("num_global_key_value_heads", self.config["num_key_value_heads"]) + # Use `or` rather than the dict.get default: E4B-style configs ship + # `num_global_key_value_heads: null`, which the default form would + # leave as None. + num_global_kv = self.config.get("num_global_key_value_heads") or self.config["num_key_value_heads"] assert ( num_global_kv % self.tp_world_size_ == 0 ), f"num_global_key_value_heads={num_global_kv} must be divisible by tp={self.tp_world_size_}" + kv_shared = self.config.get("num_kv_shared_layers") or 0 + assert 0 <= kv_shared < self.config["num_hidden_layers"], ( + f"num_kv_shared_layers={kv_shared} out of range for " + f"num_hidden_layers={self.config['num_hidden_layers']}" + ) return def _init_mem_manager(self): - # Uniform per-layer KV cache layout keyed to the *sliding* attention shape - # (num_kv_heads=16, head_dim=256). Full-attention layers (num_kv_heads=4, - # head_dim=512, k_eq_v) reuse the same byte budget at <=50% utilization; - # the transformer-layer infer code handles the reshape when reading back. - head_num_per_rank = self.config["num_key_value_heads"] // self.tp_world_size_ + # Uniform per-layer KV cache layout. The per-layer cache slot must fit + # whichever layer type has the largest per-token K/V width: sliding + # (num_key_value_heads * head_dim) or full + # (num_global_kv * global_head_dim). Keep cache_slot_dim = head_dim + # and pick cache_slot_num = max-width / head_dim. For 31B this + # collapses to num_key_value_heads; for E4B the full-attn shape wins + # (2*512 > 2*256), so it uses 4 storage slots of 256 dims. + # Gemma4TransformerLayerInfer.__init__ computes the same value and + # uses it to pack/unpack K/V at write/read time. head_dim = self.config["head_dim"] + num_global_kv = self.config.get("num_global_key_value_heads") or self.config["num_key_value_heads"] + sliding_total = self.config["num_key_value_heads"] * self.config["head_dim"] + full_total = num_global_kv * self.config["global_head_dim"] + per_token_k_width = max(sliding_total, full_total) + head_num_per_rank = (per_token_k_width // head_dim) // self.tp_world_size_ self.mem_manager = select_mem_manager_class()( self.max_total_token_num, dtype=self.data_type, From d969a5feaf7ad81fe2c016958f26fad24c6799cb Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 11 May 2026 08:52:32 +0000 Subject: [PATCH 05/20] support visual module --- .../basemodel/triton_kernel/multimodal_emb.py | 9 + lightllm/models/gemma4/gemma4_visual.py | 146 ++++++ lightllm/models/gemma4/infer_struct.py | 47 ++ .../gemma4/layer_infer/pre_layer_infer.py | 40 +- .../layer_infer/transformer_layer_infer.py | 29 ++ lightllm/models/gemma4/model.py | 84 +++- .../models/gemma4/triton_kernel/__init__.py | 0 .../triton_kernel/build_b_image_token_end.py | 172 +++++++ .../context_attention_fwd_gemma4_mm.py | 460 ++++++++++++++++++ .../qwen_vl/layer_infer/pre_layer_infer.py | 1 + lightllm/server/tokenizer.py | 8 +- .../visualserver/model_infer/model_rpc.py | 3 + lightllm/utils/config_utils.py | 3 + 13 files changed, 968 insertions(+), 34 deletions(-) create mode 100644 lightllm/models/gemma4/gemma4_visual.py create mode 100644 lightllm/models/gemma4/triton_kernel/__init__.py create mode 100644 lightllm/models/gemma4/triton_kernel/build_b_image_token_end.py create mode 100644 lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py diff --git a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py index e2d4aea587..05d678e41b 100644 --- a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py +++ b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py @@ -23,6 +23,8 @@ def _fwd_kernel( tp_text_end_token_id, hidden_size, tp_world_size, + APPLY_TEXT_EMBED_SCALE: tl.constexpr, + TEXT_EMBED_SCALE: tl.constexpr, BLOCK_HIDDEN_DIM: tl.constexpr, ): @@ -43,6 +45,8 @@ def _fwd_kernel( mask=off_d < hidden_size, other=0, ) + if APPLY_TEXT_EMBED_SCALE: + load_emb *= TEXT_EMBED_SCALE tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size) img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0) @@ -84,9 +88,12 @@ def multimodal_emb( tp_text_start_token_id: int, tp_text_end_token_id: int, tp_world_size: int, + text_embed_scale: float = 1.0, ): total_len = prompt_ids.shape[0] BLOCK = triton.next_power_of_2(out.shape[1]) + text_embed_scale = float(text_embed_scale) + apply_text_embed_scale = text_embed_scale != 1.0 # print(len(img_token_lens)) grid = (total_len, len(img_token_lens) + 1) num_warps = 1 @@ -109,6 +116,8 @@ def multimodal_emb( tp_text_end_token_id=tp_text_end_token_id, hidden_size=out.shape[1], tp_world_size=float(tp_world_size), + APPLY_TEXT_EMBED_SCALE=apply_text_embed_scale, + TEXT_EMBED_SCALE=text_embed_scale, BLOCK_HIDDEN_DIM=BLOCK, num_warps=num_warps, num_stages=1, diff --git a/lightllm/models/gemma4/gemma4_visual.py b/lightllm/models/gemma4/gemma4_visual.py new file mode 100644 index 0000000000..7ed64108b3 --- /dev/null +++ b/lightllm/models/gemma4/gemma4_visual.py @@ -0,0 +1,146 @@ +import json +import os +from io import BytesIO +from typing import List + +import torch +from PIL import Image +from safetensors import safe_open +from transformers import AutoConfig, AutoProcessor + +from lightllm.server.embed_cache.utils import get_shm_name_data, read_shm +from lightllm.server.multimodal_params import ImageItem +from lightllm.utils.log_utils import init_logger +from lightllm.utils.torch_dtype_utils import get_torch_dtype + + +logger = init_logger(__name__) + + +class Gemma4VisionModel: + def __init__(self, data_type="bfloat16"): + self.vision_tower = None + self.embed_vision = None + self.image_processor = None + self.data_type = data_type if isinstance(data_type, torch.dtype) else get_torch_dtype(data_type) + self.device = torch.device("cpu") + + def _weight_files(self, weight_dir): + index_path = os.path.join(weight_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + weight_map = json.load(f)["weight_map"] + return sorted(set(weight_map.values())) + return sorted(f for f in os.listdir(weight_dir) if f.endswith(".safetensors")) + + def _load_prefix_state_dict(self, weight_dir, prefix): + state_dict = {} + for file_name in self._weight_files(weight_dir): + file_path = os.path.join(weight_dir, file_name) + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith(prefix): + state_dict[key[len(prefix) :]] = f.get_tensor(key) + return state_dict + + def load_model(self, weight_dir): + try: + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4MultimodalEmbedder, + Gemma4VisionModel as HFGemma4VisionModel, + ) + except ImportError as e: + raise ImportError("Gemma-4 vision requires a transformers build with Gemma4 support.") from e + + config = AutoConfig.from_pretrained(weight_dir, trust_remote_code=True) + if config.vision_config is None: + raise ValueError("Gemma-4 checkpoint does not contain vision_config") + + processor = AutoProcessor.from_pretrained(weight_dir) + self.image_processor = processor.image_processor + self.vision_tower = HFGemma4VisionModel(config.vision_config).eval() + self.embed_vision = Gemma4MultimodalEmbedder(config.vision_config, config.text_config).eval() + + vision_state = self._load_prefix_state_dict(weight_dir, "model.vision_tower.") + embed_state = self._load_prefix_state_dict(weight_dir, "model.embed_vision.") + missing, unexpected = self.vision_tower.load_state_dict(vision_state, strict=False) + if missing or unexpected: + raise RuntimeError(f"Gemma-4 vision_tower weight mismatch: missing={missing}, unexpected={unexpected}") + missing, unexpected = self.embed_vision.load_state_dict(embed_state, strict=False) + if missing or unexpected: + raise RuntimeError(f"Gemma-4 embed_vision weight mismatch: missing={missing}, unexpected={unexpected}") + + return self + + def cuda(self): + self.device = torch.device("cuda") + self.vision_tower = self.vision_tower.cuda() + self.embed_vision = self.embed_vision.cuda() + return self + + def forward(self, pixel_values, image_position_ids): + pixel_values = pixel_values.to(self.device, non_blocking=True) + image_position_ids = image_position_ids.to(self.device, non_blocking=True) + pooling_k = self.vision_tower.config.pooling_kernel_size + pooling_k2 = pooling_k * pooling_k + + # Per-image vision-tower call. `output_length` MUST match the per-image + # num_soft_tokens the image processor declared; otherwise HF's pooler + # falls back to config.image_seq_length and silently emits a different + # token count than what `valid_ids` expects. + per_image_hidden = [] + for i in range(pixel_values.shape[0]): + pv = pixel_values[i : i + 1] + pp = image_position_ids[i : i + 1] + output_length = pv.shape[1] // pooling_k2 + per_image_hidden.append( + self.vision_tower( + pixel_values=pv, + pixel_position_ids=pp, + output_length=output_length, + ).last_hidden_state + ) + + # embed_vision is token-independent (RMSNorm + Linear); cat once and + # project once instead of looping like vllm — same numerics, fewer + # Python launches, lines up naturally with our flat embed-cache output. + flat_hidden = torch.cat(per_image_hidden, dim=0) + target_dtype = self.embed_vision.embedding_projection.weight.dtype + image_features = self.embed_vision(inputs_embeds=flat_hidden.unsqueeze(0).to(target_dtype)).squeeze(0) + return image_features.to(self.data_type) + + @torch.inference_mode() + def encode(self, images: List[ImageItem]): + pil_images = [] + uuids = [] + for img in images: + if not isinstance(img, ImageItem): + raise TypeError(f"Unsupported Gemma-4 image input type: {type(img)}") + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + with Image.open(BytesIO(image_data)) as image: + pil_images.append(image.convert("RGB")) + + if not pil_images: + return None + + image_inputs = self.image_processor(pil_images, return_tensors="pt") + token_nums = image_inputs.pop("num_soft_tokens_per_image") + pixel_values = image_inputs["pixel_values"] + image_position_ids = image_inputs["image_position_ids"] + + valid_ids = [] + valid_start = 0 + for img, token_num in zip(images, token_nums): + token_num = int(token_num) + if img.token_num != token_num: + raise ValueError(f"Gemma-4 image token mismatch: allocated={img.token_num}, encoded={token_num}") + valid_ids.append([valid_start, valid_start + token_num]) + valid_start += token_num + + all_img_embeds = self.forward(pixel_values, image_position_ids) + if all_img_embeds.shape[0] != valid_start: + raise ValueError( + f"Gemma-4 image embed length mismatch: embeds={all_img_embeds.shape[0]}, tokens={valid_start}" + ) + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/gemma4/infer_struct.py b/lightllm/models/gemma4/infer_struct.py index 703ee6c68a..81a0b236b4 100644 --- a/lightllm/models/gemma4/infer_struct.py +++ b/lightllm/models/gemma4/infer_struct.py @@ -1,5 +1,6 @@ import torch from lightllm.common.basemodel import InferStateInfo +from lightllm.models.gemma4.triton_kernel.build_b_image_token_end import build_b_image_token_end class Gemma4InferStateInfo(InferStateInfo): @@ -15,6 +16,11 @@ def __init__(self): # E-series only: per-layer embeddings (PLE), shape (N, num_layers, hidden_size_per_layer_input). # Computed once in Gemma4PreLayerInfer; sliced per layer in the transformer block. self.per_layer_embeds = None + # Per-Q image-bidi end markers, shape (sum_new_tokens,) int32. + # 0 for non-image Q tokens; image-span end (in absolute request position) + # for tokens inside an image span. Consumed by + # `context_attention_fwd_gemma4_mm` on sliding-window layers. + self.b_image_token_end = None def init_some_extra_state(self, model): super().init_some_extra_state(model) @@ -33,4 +39,45 @@ def init_some_extra_state(self, model): ) if self.is_prefill: self.max_seq_len = self.max_kv_seq_len + self._build_b_image_token_end() return + + def _build_b_image_token_end(self): + # Scatter per-image end markers into a flat (sum_q,) int32 tensor for + # consumption by the image-aware prefill attention kernel. Style mirrors + # neo_chat_moe.get_neo_position. Chunked-prefill clipping (image span + # straddling cache/new boundary) is handled inside the kernel. + if not self.multimodal_params: + self.b_image_token_end = None + return + + b_image_start_idx = [] + b_image_len = [] + b_image_nums = [] + b_image_start_num = [] + image_start_num = 0 + for params in self.multimodal_params: + b_image_start_num.append(image_start_num) + images = params.get("images", []) + b_image_nums.append(len(images)) + for img in images: + b_image_start_idx.append(img["start_idx"]) + b_image_len.append(img["token_num"]) + image_start_num += 1 + + if image_start_num == 0: + self.b_image_token_end = None + return + + device = self.position_ids.device + self.b_image_token_end = torch.zeros(self.position_ids.shape[0], dtype=torch.int32, device=device) + build_b_image_token_end( + b_image_start_idx=torch.tensor(b_image_start_idx, dtype=torch.int32).cuda(non_blocking=True), + b_image_len=torch.tensor(b_image_len, dtype=torch.int32).cuda(non_blocking=True), + b_image_nums=torch.tensor(b_image_nums, dtype=torch.int32).cuda(non_blocking=True), + b_image_start_num=torch.tensor(b_image_start_num, dtype=torch.int32).cuda(non_blocking=True), + b_q_start_loc=self.b_q_start_loc, + b_ready_cache_len=self.b_ready_cache_len, + b_q_seq_len=self.b_q_seq_len, + b_image_token_end=self.b_image_token_end, + ) diff --git a/lightllm/models/gemma4/layer_infer/pre_layer_infer.py b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py index e5e1507c96..98c29c1cc2 100644 --- a/lightllm/models/gemma4/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py @@ -4,13 +4,16 @@ from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy from lightllm.distributed.communication_op import all_reduce from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer from lightllm.utils.envs_utils import get_env_start_args -class Gemma4PreLayerInfer(LlamaPreLayerInfer): +class Gemma4PreLayerInfer(LlamaMultimodalPreLayerInfer): def __init__(self, network_config): super().__init__(network_config) self.embed_scale = float(network_config["hidden_size"]) ** 0.5 + self.multimodal_text_embed_scale_ = self.embed_scale + self.pad_token_id_ = network_config.get("pad_token_id", 0) self.has_ple = bool(network_config.get("hidden_size_per_layer_input")) if self.has_ple: @@ -21,8 +24,8 @@ def __init__(self, network_config): self.ple_combine_scale_ = 2.0 ** -0.5 self.rms_norm_eps_ = network_config.get("rms_norm_eps", 1e-6) - def _compute_per_layer_embeds(self, input_ids, input_embdings, infer_state, layer_weight): - ple_embeds = layer_weight.embed_tokens_per_layer_weight_(input_ids) + def _compute_per_layer_embeds(self, input_ids_for_ple, input_embdings, infer_state, layer_weight): + ple_embeds = layer_weight.embed_tokens_per_layer_weight_(input_ids_for_ple) if self.tp_world_size_ > 1: all_reduce(ple_embeds, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) ple_embeds = ple_embeds * self.ple_embed_scale_ @@ -38,15 +41,38 @@ def _compute_per_layer_embeds(self, input_ids, input_embdings, infer_state, laye infer_state.per_layer_embeds = (ple_proj + ple_embeds) * self.ple_combine_scale_ def context_forward(self, input_ids, infer_state, layer_weight): - input_embdings = super().context_forward(input_ids, infer_state, layer_weight) - input_embdings = input_embdings * self.embed_scale + input_embdings = LlamaMultimodalPreLayerInfer.context_forward(self, input_ids, infer_state, layer_weight) if self.has_ple: - self._compute_per_layer_embeds(input_ids, input_embdings, infer_state, layer_weight) + image_token_end = getattr(infer_state, "b_image_token_end", None) + input_ids_for_ple = ( + input_ids + if image_token_end is None + else input_ids.masked_fill(image_token_end != 0, self.pad_token_id_) + ) + self._compute_per_layer_embeds(input_ids_for_ple, input_embdings, infer_state, layer_weight) return input_embdings def token_forward(self, input_ids, infer_state, layer_weight): - input_embdings = super().token_forward(input_ids, infer_state, layer_weight) + input_embdings = LlamaPreLayerInfer.token_forward(self, input_ids, infer_state, layer_weight) input_embdings = input_embdings * self.embed_scale if self.has_ple: self._compute_per_layer_embeds(input_ids, input_embdings, infer_state, layer_weight) return input_embdings + + def _tpsp_sp_split(self, input: torch.Tensor, infer_state): + if self.tp_world_size_ > 1 and get_env_start_args().enable_tpsp_mix_mode: + input = super()._tpsp_sp_split(input=input, infer_state=infer_state) + if self.has_ple and infer_state.per_layer_embeds is not None: + ple_shape = infer_state.per_layer_embeds.shape + per_layer_embeds = infer_state.per_layer_embeds.reshape(ple_shape[0], -1) + per_layer_embeds = sp_pad_copy( + in_tensor=per_layer_embeds, + sp_rank_id=self.tp_rank_, + sp_world_size=self.tp_world_size_, + alloc_func=self.alloc_tensor, + ) + infer_state.per_layer_embeds = per_layer_embeds.reshape( + per_layer_embeds.shape[0], ple_shape[1], ple_shape[2] + ) + return input + return input diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index 3274c87b7f..b290f1ce84 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -5,6 +5,9 @@ from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.common.basemodel.infer_struct import InferStateInfo from lightllm.models.gemma4.layer_weights.transformer_layer_weight import Gemma4TransformerLayerWeight +from lightllm.models.gemma4.triton_kernel.context_attention_fwd_gemma4_mm import ( + context_attention_fwd_gemma4_mm, +) from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd @@ -228,6 +231,32 @@ def _context_attention_kernel( ) -> torch.Tensor: _k, _v = self._get_layer_kv(infer_state) _q = q.view(-1, self.tp_q_head_num_, self.layer_head_dim_) + # Image bidirectional attention only applies on sliding-window layers + # (matches HF/vllm `use_bidirectional_attention="vision"`). Full-attn + # layers stay on the standard causal triton path. + if ( + self.is_sliding + and self.network_config_.get("_gemma4_use_swa", False) + and getattr(infer_state, "b_image_token_end", None) is not None + ): + o_tensor = self.alloc_tensor(_q.shape, q.dtype) + sw = self.sliding_window_ if self.sliding_window_ > 0 else 0 + context_attention_fwd_gemma4_mm( + _q, + _k, + _v, + o_tensor, + infer_state.b_req_idx, + infer_state.b_q_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_q_seq_len, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_image_token_end, + sliding_window=sw, + ) + return o_tensor.view(q.shape) + # Sliding layers go through the secondary backend (FA3 with SWA when # available, else triton-with-SWA from path B). Full-attn layers go # through the primary triton backend (head_dim=512). diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py index c900261c35..a2d3661f28 100644 --- a/lightllm/models/gemma4/model.py +++ b/lightllm/models/gemma4/model.py @@ -1,3 +1,4 @@ +import math import os import json import torch @@ -22,20 +23,17 @@ class Gemma4Tokenizer(BaseMultiModalTokenizer): - """ - Thin wrapper; Phase-A milestone only exercises the text path. Multimodal - splice logic will be added alongside the vision tower port (Phase B). - """ - - def __init__(self, tokenizer, model_cfg): + def __init__(self, tokenizer, model_cfg, image_processor=None): super().__init__(tokenizer) self.image_token_index = model_cfg.get("image_token_id", 258880) self.boi_token_index = model_cfg.get("boi_token_id", 255999) self.eoi_token_index = model_cfg.get("eoi_token_id", 258882) + self.image_processor = image_processor self.image_length = model_cfg.get("vision_soft_tokens_per_image", 280) - # Gemma-4's tokenizer ships with `add_bos_token=False`, and even - # `add_special_tokens=True` doesn't prepend ``. The model generates - # garbage without it, so we always prepend manually. + self.patch_size = getattr(self.image_processor, "patch_size", 16) + self.pooling_kernel_size = getattr(self.image_processor, "pooling_kernel_size", 3) + self.max_soft_tokens = getattr(self.image_processor, "max_soft_tokens", self.image_length) + # HF Gemma-4 tokenizer does not prepend BOS even with add_special_tokens=True. self.bos_token_id = tokenizer.bos_token_id def init_imageitem_extral_params(self, img, multi_params, sampling_params): @@ -45,32 +43,66 @@ def init_audioitem_extral_params(self, audio, multi_params, sampling_params): raise NotImplementedError def get_image_token_length(self, img): - return self.image_length + if self.image_processor is None or img.image_w <= 0 or img.image_h <= 0: + return self.image_length + + patch, kernel = self.patch_size, self.pooling_kernel_size + unit = patch * kernel + num_patches_orig = (img.image_h / patch) * (img.image_w / patch) + scale = math.sqrt(self.max_soft_tokens * kernel ** 2 / num_patches_orig) + target_h = max(unit, int(math.floor(img.image_h * scale / unit)) * unit) + target_w = max(unit, int(math.floor(img.image_w * scale / unit)) * unit) + num_patches = (target_h // patch) * (target_w // patch) + return min(num_patches // kernel ** 2, self.max_soft_tokens) def get_audio_token_length(self, audio): raise NotImplementedError def encode(self, prompt, multimodal_params=None, add_special_tokens=False): - # Text-only path for Phase A — reject image/audio input loudly so users - # know multimodal isn't wired yet. - if multimodal_params is not None and ( - getattr(multimodal_params, "images", None) or getattr(multimodal_params, "audios", None) - ): - raise NotImplementedError( - "Gemma-4 multimodal (image/audio) inference is not yet implemented in LightLLM; " - "only text prompts are supported for now." - ) - input_ids = self.tokenizer(prompt).input_ids - # Auto-prepend for prompts (Gemma-4 generates garbage without it), - # but honour `add_special_tokens=False` so callers like stop-sequence - # encoding can opt out — otherwise stop strings get a leading BOS that - # never appears in generated output and never matches. + origin_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids if ( add_special_tokens and self.bos_token_id is not None - and (len(input_ids) == 0 or input_ids[0] != self.bos_token_id) + and (len(origin_ids) == 0 or origin_ids[0] != self.bos_token_id) ): - input_ids = [self.bos_token_id] + input_ids + origin_ids = [self.bos_token_id] + origin_ids + + images = [] if multimodal_params is None else getattr(multimodal_params, "images", []) + if not images: + return origin_ids + + input_ids = [] + image_id = 0 + start = 0 + while True: + try: + image_start = origin_ids.index(self.image_token_index, start) + except ValueError: + break + + input_ids.extend(origin_ids[start:image_start]) + image_end = image_start + 1 + while image_end < len(origin_ids) and origin_ids[image_end] == self.image_token_index: + image_end += 1 + if image_id >= len(images): + raise ValueError("image token error") + + img = images[image_id] + if not input_ids or input_ids[-1] != self.boi_token_index: + input_ids.append(self.boi_token_index) + img.start_idx = len(input_ids) + input_ids.extend(range(img.token_id, img.token_id + img.token_num)) + input_ids.append(self.eoi_token_index) + + if image_end < len(origin_ids) and origin_ids[image_end] == self.eoi_token_index: + image_end += 1 + start = image_end + image_id += 1 + + input_ids.extend(origin_ids[start:]) + image_cnt = len(images) + if image_cnt != image_id: + raise ValueError(f"invalid image tag num: {image_cnt} vs {image_id}!") return input_ids diff --git a/lightllm/models/gemma4/triton_kernel/__init__.py b/lightllm/models/gemma4/triton_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/gemma4/triton_kernel/build_b_image_token_end.py b/lightllm/models/gemma4/triton_kernel/build_b_image_token_end.py new file mode 100644 index 0000000000..bb5f383611 --- /dev/null +++ b/lightllm/models/gemma4/triton_kernel/build_b_image_token_end.py @@ -0,0 +1,172 @@ +"""GPU-resident builder for ``b_image_token_end``. + +Replaces a 3× D2H sync + Python per-batch-image slice-fill in CPU memory +with a single small H2D copy (image metadata) + one Triton kernel that +scatters the image-end markers into the flat-Q-token tensor on GPU. + +Adapted from neo_chat_moe's `get_neo_position_triton`. Same per-batch +program structure; we only emit the `b_image_token_end` scatter (no 3D +position_ids — gemma-4 uses 1D position ids). +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _build_b_image_token_end_kernel( + B_Image_Start_Idx, # (num_imgs,) int32, image span start in absolute request position + B_Image_Len, # (num_imgs,) int32, image token count + B_Image_Nums, # (batch,) int32, per-batch image count + B_Image_Start_Num, # (batch,) int32, prefix-sum offset into flat per-image arrays + B_Q_Start_Loc, # (batch,) int32, per-batch start in flat layout + B_Ready_Cache_Len, # (batch,) int32, per-batch prompt-cache length + B_Q_Seq_Len, # (batch,) int32, per-batch new-token count + B_Image_Token_End, # (sum_q,) int32, output scatter target + BLOCK_SIZE: tl.constexpr, +): + cur_batch = tl.program_id(0) + cache_len = tl.load(B_Ready_Cache_Len + cur_batch) + q_seq_len = tl.load(B_Q_Seq_Len + cur_batch) + image_num = tl.load(B_Image_Nums + cur_batch) + image_start_num = tl.load(B_Image_Start_Num + cur_batch) + flat_start = tl.load(B_Q_Start_Loc + cur_batch) + + for i in range(image_num): + image_start_idx = tl.load(B_Image_Start_Idx + image_start_num + i) + image_len = tl.load(B_Image_Len + image_start_num + i) + image_end_idx = image_start_idx + image_len + # Flat layout offset of the image's first token within this batch. + flat_image_start = flat_start + image_start_idx - cache_len + + for j in range(0, image_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + in_image = off < image_len + # Only fill positions that fall inside this batch's NEW-tokens range + # (i.e., the part of the image that hasn't already been processed + # in a previous chunked-prefill chunk and isn't past the chunk's end). + in_new_tokens = (image_start_idx - cache_len + off >= 0) & (image_start_idx - cache_len + off < q_seq_len) + tl.store( + B_Image_Token_End + flat_image_start + off, + image_end_idx, + mask=in_image & in_new_tokens, + ) + + +def build_b_image_token_end( + b_image_start_idx: torch.Tensor, + b_image_len: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_q_start_loc: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_image_token_end: torch.Tensor, +): + batch_size = b_q_start_loc.shape[0] + assert b_image_nums.shape[0] == batch_size + grid = (batch_size,) + BLOCK_SIZE = 64 + _build_b_image_token_end_kernel[grid]( + b_image_start_idx, + b_image_len, + b_image_nums, + b_image_start_num, + b_q_start_loc, + b_ready_cache_len, + b_q_seq_len, + b_image_token_end, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +# --------------------------------------------------------------------------- +# Standalone correctness check +# --------------------------------------------------------------------------- + + +def _reference( + multimodal_params, + b_q_start_loc_cpu, + b_ready_cache_len_cpu, + b_q_seq_len_cpu, + sum_q, +): + out = torch.zeros((sum_q,), dtype=torch.int32) + for batch_idx, params in enumerate(multimodal_params): + cache_len = b_ready_cache_len_cpu[batch_idx] + new_len = b_q_seq_len_cpu[batch_idx] + flat_start = b_q_start_loc_cpu[batch_idx] + for img in params.get("images", []): + image_start_idx = img["start_idx"] + image_end_idx = image_start_idx + img["token_num"] + for j in range(img["token_num"]): + req_off = image_start_idx - cache_len + j + if req_off < 0 or req_off >= new_len: + continue + out[flat_start + req_off] = image_end_idx + return out + + +def _check(): + device = "cuda" + # Two batches. b0 has 1 image overlapping new tokens; b1 has 2 images, one + # fully cached and one in the new-token range. + multimodal = [ + {"images": [{"start_idx": 5, "token_num": 4}]}, # b0: image at req[5..9) + { + "images": [ + {"start_idx": 0, "token_num": 3}, # fully cached + {"start_idx": 8, "token_num": 5}, # in new tokens + ] + }, + ] + b_q_start_loc = torch.tensor([0, 6], dtype=torch.int32) # b0 new=6, b1 new=10 + b_ready_cache_len = torch.tensor([2, 5], dtype=torch.int32) + b_q_seq_len = torch.tensor([6, 10], dtype=torch.int32) + sum_q = int(b_q_seq_len.sum().item()) + + ref = _reference( + multimodal, + b_q_start_loc.tolist(), + b_ready_cache_len.tolist(), + b_q_seq_len.tolist(), + sum_q, + ) + + b_image_start_idx = [] + b_image_len = [] + b_image_nums = [] + b_image_start_num = [] + image_start_num = 0 + for params in multimodal: + b_image_start_num.append(image_start_num) + b_image_nums.append(len(params["images"])) + for img in params["images"]: + b_image_start_idx.append(img["start_idx"]) + b_image_len.append(img["token_num"]) + image_start_num += 1 + + out_gpu = torch.zeros((sum_q,), dtype=torch.int32, device=device) + build_b_image_token_end( + b_image_start_idx=torch.tensor(b_image_start_idx, dtype=torch.int32, device=device), + b_image_len=torch.tensor(b_image_len, dtype=torch.int32, device=device), + b_image_nums=torch.tensor(b_image_nums, dtype=torch.int32, device=device), + b_image_start_num=torch.tensor(b_image_start_num, dtype=torch.int32, device=device), + b_q_start_loc=b_q_start_loc.to(device), + b_ready_cache_len=b_ready_cache_len.to(device), + b_q_seq_len=b_q_seq_len.to(device), + b_image_token_end=out_gpu, + ) + + out_cpu = out_gpu.cpu() + assert torch.equal(out_cpu, ref), f"\n got {out_cpu.tolist()}\n ref {ref.tolist()}" + print("ok", out_cpu.tolist()) + + +if __name__ == "__main__": + if torch.cuda.is_available(): + _check() + else: + print("No CUDA, skip.") diff --git a/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py b/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py new file mode 100644 index 0000000000..b8aa2199a3 --- /dev/null +++ b/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py @@ -0,0 +1,460 @@ +"""Gemma-4 prefill attention kernel with image bidirectional masking. + +Gemma-4 was trained with bidirectional attention inside each image span on its +sliding-window layers (matches HF/vllm `use_bidirectional_attention="vision"`). +Other lightllm multimodal models use causal attention on image tokens, so the +shared prefill kernel does not need this — keep the modification scoped to +this gemma4-private file rather than the common path. + +The kernel mirrors `context_flashattention_nopad._fwd_kernel` (paged KV via +req_to_token_indexs, prompt_cache_len for chunked prefill, sliding window +support, head_dim=256/512 with BLOCK_M reduction) and adds two ideas borrowed +from `lightllm-neo/.../context_attention_fwd_neo`: + +1. Per-Q `b_image_token_end` tensor of shape (sum_q,). For Q tokens inside an + image span it carries the span's end index; for text tokens it is 0. + The attention mask becomes `causal_mask | (k_pos < q_image_end)`. +2. K/V iteration upper bound is extended to `max(causal_end, block_image_end)` + so a Q tile in the middle of an image span actually loads K/V tiles past + its causal end. Without this, the bidi mask in the original diff was a + no-op on every tile but the last one of the image span. + +The standalone `reference_attention` and `check_once` are runnable as a script +for unit testing image bidi correctness. +""" + +import math +import torch +import triton +import triton.language as tl + +from lightllm.utils.device_utils import is_tesla + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + Out, + B_Start_Loc, + B_Seqlen, + Req_to_tokens, + B_req_idx, + B_Image_Token_End, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + kv_group_num, + b_prompt_cache_len, + H: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, +): + start_m = tl.program_id(0) + cur_bh = tl.program_id(1) + cur_batch = cur_bh // H + cur_head = cur_bh % H + + cur_kv_head = cur_head // kv_group_num + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + total_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = total_len - prompt_cache_len # new tokens this step + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + block_start_loc = BLOCK_M * start_m + if block_start_loc >= cur_batch_seq_len: + return + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = block_start_loc + tl.arange(0, BLOCK_M) + q_valid = offs_m < cur_batch_seq_len + + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) + + # Per-Q image_end. 0 for non-image tokens, image-span end for image tokens. + q_image_end = tl.load( + B_Image_Token_End + cur_batch_in_all_start_index + offs_m, + mask=q_valid, + other=0, + ).to(tl.int32) + + # Absolute position in the request (prompt_cache_len + offset within new tokens). + q_pos = prompt_cache_len + offs_m # [M] + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + causal_end = tl.minimum(prompt_cache_len + block_start_loc + BLOCK_M, total_len) + block_image_end = tl.minimum(tl.max(q_image_end, axis=0), total_len) + block_end_loc = tl.maximum(causal_end, block_image_end) + + if SLIDING_WINDOW > 0: + win_start = block_start_loc + prompt_cache_len - (SLIDING_WINDOW - 1) + win_start = tl.maximum(win_start, 0) + win_start = (win_start // BLOCK_N) * BLOCK_N + else: + win_start = 0 + + for start_n in range(win_start, block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_pos = start_n + offs_n # [N] + k_valid = k_pos < block_end_loc + + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_valid, + other=0, + ).to(tl.int64) + + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) + qk = tl.dot(q, k) + + causal_mask = q_pos[:, None] >= k_pos[None, :] + if SLIDING_WINDOW > 0: + causal_mask = causal_mask & ((q_pos[:, None] - k_pos[None, :]) < SLIDING_WINDOW) + # Image bidi: a Q in image span [_, e) attends to all K with k_pos < e. + # For text Q (q_image_end == 0) this is k_pos < 0 = always False, so + # the union with causal_mask leaves text-attention unchanged. + image_mask = k_pos[None, :] < q_image_end[:, None] + mask = (causal_mask | image_mask) & k_valid[None, :] + + qk = tl.where(mask, qk * sm_scale, -1.0e8) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + tl.store(Out + off_o, acc, mask=q_valid[:, None]) + + +@torch.no_grad() +def context_attention_fwd_gemma4_mm( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, + b_image_token_end, + sliding_window: int = 0, +): + """Prefill attention with image bidirectional masking on sliding layers. + + Args: + b_image_token_end: int32 tensor of shape (sum_q,). For each Q token + position (in the flattened new-token layout), value is the image + span's end index (in absolute request position) if the token is + inside an image span, else 0. + """ + BLOCK_M = 128 if not is_tesla() else 64 + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128, 256, 512} + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) + + sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + BLOCK_N = BLOCK_M + num_warps = 4 if Lk <= 64 else 8 + num_stages = 1 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + o, + b_start_loc, + b_seq_len, + req_to_token_indexs, + b_req_idx, + b_image_token_end, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, + H=head, + BLOCK_DMODEL=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + SLIDING_WINDOW=int(sliding_window), + num_warps=num_warps, + num_stages=num_stages, + ) + + +# --------------------------------------------------------------------------- +# Reference implementation + standalone test harness +# --------------------------------------------------------------------------- + + +def reference_attention( + q, + k, + v, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + b_image_token_end, + sliding_window=0, +): + """Slow torch reference for the gemma4 mm prefill kernel.""" + device = q.device + dtype = q.dtype + sum_q, Hq, D = q.shape + Hk = k.shape[1] + kv_group_num = Hq // Hk + + out = torch.empty_like(q) + scale = 1.0 / math.sqrt(D) + + batch = b_seq_len.shape[0] + for b in range(batch): + req = int(b_req_idx[b].item()) + total_len = int(b_seq_len[b].item()) + prompt_len = int(b_prompt_cache_len[b].item()) + new_len = total_len - prompt_len + q_start = int(b_start_loc[b].item()) + + q_blk = q[q_start : q_start + new_len] # [M, Hq, D] + q_image_end = b_image_token_end[q_start : q_start + new_len].to(torch.int64) # [M] + + token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) + k_blk = k[token_locs] + v_blk = v[token_locs] + + k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) + v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) + + q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) + k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) + + causal = k_pos[None, :] <= q_pos[:, None] + if sliding_window > 0: + causal = causal & ((q_pos[:, None] - k_pos[None, :]) < sliding_window) + image = k_pos[None, :] < q_image_end[:, None] + allow = causal | image + + q_t = q_blk.permute(1, 0, 2).to(torch.float32) + k_t = k_hq.permute(1, 2, 0).to(torch.float32) + scores = torch.matmul(q_t, k_t) * scale + + neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) + scores = torch.where(allow[None, :, :], scores, neg) + p = torch.softmax(scores, dim=-1) + v_t = v_hq.permute(1, 0, 2).to(torch.float32) + out_hq = torch.matmul(p, v_t) + out[q_start : q_start + new_len] = out_hq.permute(1, 0, 2).to(dtype) + + return out + + +def make_test_case( + device="cuda", + dtype=torch.bfloat16, + batch=3, + Hq=8, + Hk=4, + D=256, + seed=0, + base_index=50000, + sliding_window=0, +): + torch.manual_seed(seed) + + prompt_lens = torch.randint(low=0, high=8, size=(batch,), device=device) + new_lens = torch.randint(low=4, high=24, size=(batch,), device=device) + total_lens = (prompt_lens + new_lens).to(torch.int32) + max_total_len = int(total_lens.max().item()) + max_new_len = int(new_lens.max().item()) + + b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) + cur = 0 + for b in range(batch): + b_start_loc[b] = cur + cur += int(new_lens[b].item()) + sum_q = cur + + b_seq_len = total_lens + b_prompt_cache_len = prompt_lens.to(torch.int32) + b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) + + sum_kv = int(total_lens.sum().item()) + kv_size = base_index + sum_kv + 1024 + pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index + + req_to_token_indexs = torch.zeros((batch, max_total_len), device=device, dtype=torch.int32) + p = 0 + for r in range(batch): + L = int(total_lens[r].item()) + req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) + p += L + + # Inject one image span per batch into the new-token region with prob 0.7. + b_image_token_end = torch.zeros((sum_q,), device=device, dtype=torch.int32) + for b in range(batch): + M = int(new_lens[b].item()) + P = int(prompt_lens[b].item()) + start = int(b_start_loc[b].item()) + if M >= 4 and torch.rand((), device=device).item() > 0.3: + s = int(torch.randint(0, M - 2, (1,), device=device).item()) + span_len = int(torch.randint(2, max(3, M - s + 1), (1,), device=device).item()) + e = min(M, s + span_len) + # image_end is absolute (request-position) = prompt_len + new-offset + b_image_token_end[start + s : start + e] = P + e + + q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) + k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) + + return ( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + b_image_token_end, + sliding_window, + ) + + +def check_once(seed=0, dtype=torch.bfloat16, sliding_window=0, D=256): + case = make_test_case(seed=seed, dtype=dtype, sliding_window=sliding_window, D=D) + ( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + b_image_token_end, + sliding_window, + ) = case + + context_attention_fwd_gemma4_mm( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + b_image_token_end, + sliding_window=sliding_window, + ) + + ref = reference_attention( + q, + k, + v, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + b_image_token_end, + sliding_window=sliding_window, + ) + + diff = (o - ref).abs() + max_abs = diff.max().item() + denom = ref.abs().max().item() + 1e-6 + max_rel = max_abs / denom + has_image = (b_image_token_end > 0).any().item() + print( + f"seed={seed} dtype={dtype} D={D} sw={sliding_window} has_image={has_image} " + f"max_abs={max_abs:.4e} max_rel={max_rel:.4e}" + ) + assert max_abs < 5e-2, f"max_abs too large: {max_abs}" + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA, skip.") + else: + # Vary D, sliding window, and image presence. + for seed in (0, 1, 2): + check_once(seed=seed, D=128, sliding_window=0) + check_once(seed=seed, D=128, sliding_window=4096) + check_once(seed=seed, D=256, sliding_window=4096) + print("ok") diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 9b9fe2569c..ce09632d2c 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -81,6 +81,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei tp_text_start_token_id=layer_weight.wte_weight_.tp_vocab_start_id, tp_text_end_token_id=layer_weight.wte_weight_.tp_vocab_end_id, tp_world_size=self.tp_world_size_, + text_embed_scale=getattr(self, "multimodal_text_embed_scale_", 1.0), ) if self.tp_world_size_ > 1: all_reduce(out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 9c9f19e73f..c353ee6d35 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -132,6 +132,12 @@ def get_tokenizer( elif model_type == "gemma3": tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) elif model_type == "gemma4": - tokenizer = Gemma4Tokenizer(tokenizer, model_cfg) + image_processor = None + if "vision_config" in model_cfg and model_cfg["vision_config"] is not None: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(tokenizer_name) + image_processor = processor.image_processor + tokenizer = Gemma4Tokenizer(tokenizer, model_cfg, image_processor=image_processor) return tokenizer diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 55f4704a31..35021178b7 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -13,6 +13,7 @@ from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel +from lightllm.models.gemma4.gemma4_visual import Gemma4VisionModel from lightllm.models.vit.model import VisionTransformer from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel @@ -97,6 +98,8 @@ def exposed_init_model(self, kvargs): # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() + elif self.model_type == "gemma4": + self.model = Gemma4VisionModel(data_type=kvargs["data_type"]) elif ( model_cfg.get("thinker_config", {}).get("vision_config", {}).get("model_type") == "qwen3_omni_moe_vision_encoder" diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 0c64520b2a..3368bdafd6 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -195,6 +195,9 @@ def has_vision_module(model_path: str) -> bool: return True elif model_type == "gemma3": return True + elif model_type == "gemma4": + model_cfg["vision_config"] + return model_cfg["vision_config"] is not None elif ( model_cfg.get("thinker_config", {}).get("vision_config", {}).get("model_type") == "qwen3_omni_moe_vision_encoder" From 08f066dfd0e1512397acdb5df80ca0a2dc8ace6a Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Tue, 12 May 2026 04:45:25 +0000 Subject: [PATCH 06/20] optimize sliding window --- lightllm/common/basemodel/attention/fa3/fp.py | 16 ++++--- .../common/basemodel/attention/triton/fp.py | 15 ++----- .../gqa/flash_decoding/gqa_flash_decoding.py | 3 +- .../gqa_flash_decoding_stage1.py | 43 +++++++----------- .../gqa_flash_decoding_stage2.py | 11 ++++- .../context_flashattention_nopad.py | 44 ++++++++++--------- .../context_attention_fwd_gemma4_mm.py | 44 +++++++++++-------- 7 files changed, 91 insertions(+), 85 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..e1f6959795 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -79,10 +79,12 @@ def _nomarl_prefill_att( ) -> torch.Tensor: self.backend: Fa3AttBackend = self.backend # for typing + window_size = (-1, -1) if att_control.use_sliding_window: - window_size = att_control.sliding_window - else: - window_size = (-1, -1) + left, right = att_control.sliding_window + left = max(int(left) - 1, 0) if left >= 0 else -1 + right = max(int(right) - 1, 0) if right >= 0 else -1 + window_size = (left, right) if att_control.use_att_sink: sink_weight: torch.Tensor = att_control.sink_weight @@ -209,10 +211,12 @@ def _normal_decode_att( att_control: AttControl, alloc_func=torch.empty, ): + window_size = (-1, -1) if att_control.use_sliding_window: - window_size = att_control.sliding_window - else: - window_size = (-1, -1) + left, right = att_control.sliding_window + left = max(int(left) - 1, 0) if left >= 0 else -1 + right = max(int(right) - 1, 0) if right >= 0 else -1 + window_size = (left, right) if att_control.use_att_sink: sink_weight: torch.Tensor = att_control.sink_weight diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index a8f3c4414b..1902960769 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -25,7 +25,6 @@ def prefill_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - assert att_control.use_att_sink is False if att_control.use_alibi: assert att_control.use_sliding_window is False, "alibi + sliding_window not supported" assert att_control.tp_alibi is not None @@ -70,13 +69,10 @@ def _nomarl_prefill_att( ): from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd - # Convert AttControl's (left, right) tuple to a single window length: - # we treat the window as `left + 1` (each query attends to itself plus - # the previous `left` tokens — same convention FA3 uses with causal=True). if att_control.use_sliding_window: - sliding_window = int(att_control.sliding_window[0]) + 1 + sliding_window = int(att_control.sliding_window[0]) else: - sliding_window = 0 + sliding_window = -1 out = alloc_func(q.shape, q.dtype) context_attention_fwd( @@ -111,7 +107,6 @@ def decode_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ): - assert att_control.use_att_sink is False if att_control.use_alibi: assert att_control.use_sliding_window is False, "alibi + sliding_window not supported" assert att_control.tp_alibi is not None @@ -120,8 +115,6 @@ def decode_att( q_head_num = q.shape[1] k_head_num = k.shape[1] if q_head_num == k_head_num: - # MHA decode path: SWA not yet wired here (only GQA path used by Gemma-4). - assert att_control.use_sliding_window is False, "SWA in MHA triton decode not implemented" return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) elif q_head_num > k_head_num: return self._normal_decode_gqa_flash_decoding_att( @@ -193,9 +186,9 @@ def _normal_decode_gqa_flash_decoding_att( ) if att_control.use_sliding_window: - sliding_window = int(att_control.sliding_window[0]) + 1 + sliding_window = int(att_control.sliding_window[0]) else: - sliding_window = 0 + sliding_window = -1 out = alloc_func(q.shape, q.dtype) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py index 979e54272c..55180d7adb 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py @@ -8,7 +8,7 @@ def gqa_token_decode_attention_flash_decoding( cache_v: torch.Tensor, out=None, alloc_tensor_func=torch.empty, - sliding_window: int = 0, + sliding_window: int = -1, ): batch_size = infer_state.batch_size q_head_num, head_dim = q.shape[1], q.shape[2] @@ -53,5 +53,6 @@ def gqa_token_decode_attention_flash_decoding( B_Seqlen=infer_state.b_seq_len, out=o_tensor.view(calcu_shape1), block_seq=BLOCK_SEQ, + sliding_window=sliding_window, ) return o_tensor diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py index c5ceb9100b..85348324be 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py @@ -39,7 +39,8 @@ def _fwd_kernel_flash_decode_stage1( BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + SLIDING_WINDOW_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) @@ -47,17 +48,16 @@ def _fwd_kernel_flash_decode_stage1( grid_block_num = tl.num_programs(2) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + if USE_SLIDING_WINDOW: + kv_start_index = tl.maximum(cur_batch_seq_len - SLIDING_WINDOW_SIZE, 0) + cur_batch_seq_len = cur_batch_seq_len - kv_start_index + else: + kv_start_index = 0 + req_total_block_num = tl.cdiv(cur_batch_seq_len, BLOCK_SEQ) if block_index >= req_total_block_num: return - # Decode: q is at position cur_batch_seq_len - 1; SWA keeps K at positions - # >= cur_batch_seq_len - SLIDING_WINDOW. win_threshold below. - if SLIDING_WINDOW > 0: - win_threshold = cur_batch_seq_len - SLIDING_WINDOW - else: - win_threshold = 0 - cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs @@ -84,12 +84,8 @@ def _fwd_kernel_flash_decode_stage1( for start_n in range(0, block_n_size, 1): offs_n_new = start_n * BLOCK_N + offs_n n_mask = offs_n_new < cur_batch_end_index - if SLIDING_WINDOW > 0: - # Drop K positions that fall outside the sliding window from - # the current query (last token in the sequence). - n_mask = n_mask & (offs_n_new >= win_threshold) k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + kv_start_index + offs_n_new, mask=n_mask, other=0, ).to(tl.int64) @@ -122,18 +118,8 @@ def _fwd_kernel_flash_decode_stage1( + offs_d[None, :] ) off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + block_index - if SLIDING_WINDOW > 0: - # When SWA masks out every K this program saw, sum_exp stays 0 and - # acc/sum_exp would be NaN. Store zeros + log_exp_sum=-inf so stage2 - # naturally weights this slot to 0 in the final reduction. - safe_sum = tl.where(sum_exp > 0, sum_exp, 1.0) - out_acc = tl.where(sum_exp[:, None] > 0, acc / safe_sum[:, None], 0.0) - out_log = tl.where(sum_exp > 0, max_logic + tl.log(safe_sum), -float("inf")) - else: - out_acc = acc / sum_exp[:, None] - out_log = max_logic + tl.log(sum_exp) - tl.store(Mid_O + off_mid_o, out_acc) - tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, out_log) + tl.store(Mid_O + off_mid_o, acc / sum_exp[:, None]) + tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) return @@ -186,7 +172,7 @@ def flash_decode_stage1( mid_out, mid_out_logsumexp, block_seq, - sliding_window: int = 0, + sliding_window: int = -1, run_config: Optional[dict] = None, ): """ """ @@ -208,6 +194,8 @@ def flash_decode_stage1( block_num = mid_out.shape[2] grid = (batch, kv_head_num, block_num) gqa_group_size = q.shape[1] // k.shape[1] + use_sliding_window = sliding_window >= 0 + sliding_window_size = int(sliding_window) if use_sliding_window else 0 _fwd_kernel_flash_decode_stage1[grid]( q, @@ -242,7 +230,8 @@ def flash_decode_stage1( BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK_N, - SLIDING_WINDOW=int(sliding_window), + USE_SLIDING_WINDOW=use_sliding_window, + SLIDING_WINDOW_SIZE=sliding_window_size, num_warps=num_warps, num_stages=num_stages, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py index 3abc7dc93b..a7c0db19f4 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py @@ -22,12 +22,17 @@ def _fwd_kernel_flash_decode_stage2( block_num, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + SLIDING_WINDOW_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + if USE_SLIDING_WINDOW: + kv_start_index = tl.maximum(cur_batch_seq_len - SLIDING_WINDOW_SIZE, 0) + cur_batch_seq_len = cur_batch_seq_len - kv_start_index block_num = tl.minimum(tl.cdiv(cur_batch_seq_len, BLOCK_SEQ), block_num) @@ -54,12 +59,14 @@ def _fwd_kernel_flash_decode_stage2( @torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): +def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq, sliding_window: int = -1): Lk = mid_out.shape[-1] assert Lk in {16, 32, 64, 128, 256, 512} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) block_num = mid_out.shape[2] + use_sliding_window = sliding_window >= 0 + sliding_window_size = int(sliding_window) if use_sliding_window else 0 _fwd_kernel_flash_decode_stage2[grid]( B_Seqlen, @@ -79,6 +86,8 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): block_num, BLOCK_SEQ=block_seq, BLOCK_DMODEL=Lk, + USE_SLIDING_WINDOW=use_sliding_window, + SLIDING_WINDOW_SIZE=sliding_window_size, num_warps=4, num_stages=2, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index d8396c3bd6..f4c3c10ffe 100644 --- a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -41,7 +41,8 @@ def _fwd_kernel( BLOCK_DMODEL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + SLIDING_WINDOW_SIZE: tl.constexpr, ): start_m = tl.program_id(0) cur_bh = tl.program_id(1) @@ -61,6 +62,7 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = block_start_loc + tl.arange(0, BLOCK_M) + q_pos = offs_m + prompt_cache_len off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh @@ -77,34 +79,31 @@ def _fwd_kernel( block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = tl.minimum(block_start_loc + BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) - # When SLIDING_WINDOW > 0, the earliest k_pos relevant to any q in this - # block is `(block_start_loc + prompt_cache_len) - (SLIDING_WINDOW - 1)`. - # Round down to BLOCK_N to avoid loading blocks fully outside the window. - if SLIDING_WINDOW > 0: - win_start = block_start_loc + prompt_cache_len - (SLIDING_WINDOW - 1) - win_start = tl.maximum(win_start, 0) - win_start = (win_start // BLOCK_N) * BLOCK_N + if USE_SLIDING_WINDOW: + kv_start_index = block_start_loc + prompt_cache_len - SLIDING_WINDOW_SIZE + 1 + kv_start_index = tl.maximum(kv_start_index, 0) + block_kv_len = block_end_loc - kv_start_index else: - win_start = 0 + kv_start_index = 0 + block_kv_len = block_end_loc # causal (+ sliding-window) mask - for start_n in range(win_start, block_mask * block_end_loc, BLOCK_N): + for start_n in range(0, block_mask * block_kv_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) + k_pos = kv_start_index + start_n + offs_n # -- compute qk ---- kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), - mask=(start_n + offs_n) < block_end_loc, + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_pos < block_end_loc, other=0, ).to(tl.int64) off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) + k = tl.load(K + off_k, mask=k_pos[None, :] < block_end_loc, other=0.0) qk = tl.dot(q, k) - # causal: q_pos >= k_pos. q_pos = offs_m + prompt_cache_len, k_pos = start_n + offs_n. - mask = offs_m[:, None] + prompt_cache_len >= (start_n + offs_n[None, :]) - if SLIDING_WINDOW > 0: - # SWA: q_pos - k_pos < SLIDING_WINDOW - mask = mask & ((offs_m[:, None] + prompt_cache_len) - (start_n + offs_n[None, :]) < SLIDING_WINDOW) + mask = q_pos[:, None] >= k_pos[None, :] + if USE_SLIDING_WINDOW: + mask = mask & ((q_pos[:, None] - k_pos[None, :]) < SLIDING_WINDOW_SIZE) qk = tl.where(mask, qk * sm_scale, -1.0e8) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -118,7 +117,7 @@ def _fwd_kernel( acc = acc * alpha[:, None] # update acc off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0) + v = tl.load(V + off_v, mask=k_pos[:, None] < block_end_loc, other=0.0) p = p.to(v.dtype) acc = tl.dot(p, v, acc) # update m_i and l_i @@ -146,7 +145,7 @@ def context_attention_fwd( b_prompt_cache_len, max_input_len, req_to_token_indexs, - sliding_window: int = 0, + sliding_window: int = -1, ): BLOCK_M = 128 if not is_tesla() else 64 # shape constraints @@ -172,6 +171,8 @@ def context_attention_fwd( BLOCK_N = BLOCK_M num_warps = 4 if Lk <= 64 else 8 num_stages = 1 + use_sliding_window = sliding_window >= 0 + sliding_window_size = int(sliding_window) if use_sliding_window else 0 _fwd_kernel[grid]( q, @@ -203,7 +204,8 @@ def context_attention_fwd( BLOCK_DMODEL=Lk, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - SLIDING_WINDOW=int(sliding_window), + USE_SLIDING_WINDOW=use_sliding_window, + SLIDING_WINDOW_SIZE=sliding_window_size, num_warps=num_warps, num_stages=num_stages, ) diff --git a/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py b/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py index b8aa2199a3..9ecf3938fa 100644 --- a/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py +++ b/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py @@ -63,7 +63,8 @@ def _fwd_kernel( BLOCK_DMODEL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - SLIDING_WINDOW: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + SLIDING_WINDOW_SIZE: tl.constexpr, ): start_m = tl.program_id(0) cur_bh = tl.program_id(1) @@ -112,16 +113,17 @@ def _fwd_kernel( block_image_end = tl.minimum(tl.max(q_image_end, axis=0), total_len) block_end_loc = tl.maximum(causal_end, block_image_end) - if SLIDING_WINDOW > 0: - win_start = block_start_loc + prompt_cache_len - (SLIDING_WINDOW - 1) - win_start = tl.maximum(win_start, 0) - win_start = (win_start // BLOCK_N) * BLOCK_N + if USE_SLIDING_WINDOW: + kv_start_index = block_start_loc + prompt_cache_len - SLIDING_WINDOW_SIZE + 1 + kv_start_index = tl.maximum(kv_start_index, 0) + block_kv_len = block_end_loc - kv_start_index else: - win_start = 0 + kv_start_index = 0 + block_kv_len = block_end_loc - for start_n in range(win_start, block_end_loc, BLOCK_N): + for start_n in range(0, block_kv_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - k_pos = start_n + offs_n # [N] + k_pos = kv_start_index + start_n + offs_n # [N] k_valid = k_pos < block_end_loc kv_loc = tl.load( @@ -135,8 +137,8 @@ def _fwd_kernel( qk = tl.dot(q, k) causal_mask = q_pos[:, None] >= k_pos[None, :] - if SLIDING_WINDOW > 0: - causal_mask = causal_mask & ((q_pos[:, None] - k_pos[None, :]) < SLIDING_WINDOW) + if USE_SLIDING_WINDOW: + causal_mask = causal_mask & ((q_pos[:, None] - k_pos[None, :]) < SLIDING_WINDOW_SIZE) # Image bidi: a Q in image span [_, e) attends to all K with k_pos < e. # For text Q (q_image_end == 0) this is k_pos < 0 = always False, so # the union with causal_mask leaves text-attention unchanged. @@ -183,7 +185,7 @@ def context_attention_fwd_gemma4_mm( max_input_len, req_to_token_indexs, b_image_token_end, - sliding_window: int = 0, + sliding_window: int = -1, ): """Prefill attention with image bidirectional masking on sliding layers. @@ -210,6 +212,8 @@ def context_attention_fwd_gemma4_mm( BLOCK_N = BLOCK_M num_warps = 4 if Lk <= 64 else 8 num_stages = 1 + use_sliding_window = sliding_window >= 0 + sliding_window_size = int(sliding_window) if use_sliding_window else 0 _fwd_kernel[grid]( q, @@ -242,7 +246,8 @@ def context_attention_fwd_gemma4_mm( BLOCK_DMODEL=Lk, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - SLIDING_WINDOW=int(sliding_window), + USE_SLIDING_WINDOW=use_sliding_window, + SLIDING_WINDOW_SIZE=sliding_window_size, num_warps=num_warps, num_stages=num_stages, ) @@ -263,9 +268,12 @@ def reference_attention( b_prompt_cache_len, req_to_token_indexs, b_image_token_end, - sliding_window=0, + sliding_window=-1, ): - """Slow torch reference for the gemma4 mm prefill kernel.""" + """Slow torch reference for the gemma4 mm prefill kernel. + + `sliding_window` is the total window size including self. < 0 disables SWA. + """ device = q.device dtype = q.dtype sum_q, Hq, D = q.shape @@ -297,7 +305,7 @@ def reference_attention( k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) causal = k_pos[None, :] <= q_pos[:, None] - if sliding_window > 0: + if sliding_window >= 0: causal = causal & ((q_pos[:, None] - k_pos[None, :]) < sliding_window) image = k_pos[None, :] < q_image_end[:, None] allow = causal | image @@ -325,7 +333,7 @@ def make_test_case( D=256, seed=0, base_index=50000, - sliding_window=0, + sliding_window=-1, ): torch.manual_seed(seed) @@ -391,7 +399,7 @@ def make_test_case( ) -def check_once(seed=0, dtype=torch.bfloat16, sliding_window=0, D=256): +def check_once(seed=0, dtype=torch.bfloat16, sliding_window=-1, D=256): case = make_test_case(seed=seed, dtype=dtype, sliding_window=sliding_window, D=D) ( q, @@ -454,7 +462,7 @@ def check_once(seed=0, dtype=torch.bfloat16, sliding_window=0, D=256): else: # Vary D, sliding window, and image presence. for seed in (0, 1, 2): - check_once(seed=seed, D=128, sliding_window=0) + check_once(seed=seed, D=128, sliding_window=-1) check_once(seed=seed, D=128, sliding_window=4096) check_once(seed=seed, D=256, sliding_window=4096) print("ok") From 7678de8011c99afe7d34c17068beb8f39c8c559f Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Tue, 12 May 2026 04:47:16 +0000 Subject: [PATCH 07/20] fix --- lightllm/models/gemma4/infer_struct.py | 4 ---- lightllm/models/gemma4/layer_infer/post_layer_infer.py | 5 ++--- .../models/gemma4/layer_infer/transformer_layer_infer.py | 8 +++----- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/lightllm/models/gemma4/infer_struct.py b/lightllm/models/gemma4/infer_struct.py index 81a0b236b4..fd2dfbe918 100644 --- a/lightllm/models/gemma4/infer_struct.py +++ b/lightllm/models/gemma4/infer_struct.py @@ -16,10 +16,6 @@ def __init__(self): # E-series only: per-layer embeddings (PLE), shape (N, num_layers, hidden_size_per_layer_input). # Computed once in Gemma4PreLayerInfer; sliced per layer in the transformer block. self.per_layer_embeds = None - # Per-Q image-bidi end markers, shape (sum_new_tokens,) int32. - # 0 for non-image Q tokens; image-span end (in absolute request position) - # for tokens inside an image span. Consumed by - # `context_attention_fwd_gemma4_mm` on sliding-window layers. self.b_image_token_end = None def init_some_extra_state(self, model): diff --git a/lightllm/models/gemma4/layer_infer/post_layer_infer.py b/lightllm/models/gemma4/layer_infer/post_layer_infer.py index 3b3423d645..f80fb10183 100644 --- a/lightllm/models/gemma4/layer_infer/post_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/post_layer_infer.py @@ -11,12 +11,11 @@ class Gemma4PostLayerInfer(LlamaPostLayerInfer): def __init__(self, network_config): super().__init__(network_config) self.eps_ = 1e-6 - self.final_logit_softcapping = network_config.get("final_logit_softcapping", None) + self.final_logit_softcapping = float(network_config.get("final_logit_softcapping", None)) def token_forward(self, input_embdings, infer_state, layer_weight): logits = super().token_forward(input_embdings, infer_state, layer_weight) if self.final_logit_softcapping is not None and self.final_logit_softcapping > 0: - cap = float(self.final_logit_softcapping) - # logits are fp32 already (LlamaPostLayerInfer allocates the output in fp32) + cap = self.final_logit_softcapping logits = torch.tanh(logits / cap) * cap return logits diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index b290f1ce84..9dab9438c5 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -195,12 +195,10 @@ def _get_o(self, input, infer_state: InferStateInfo, layer_weight: Gemma4Transfo # ----- Attention kernels (sliding window + per-layer KV reshape) --- def _att_control(self): - # FA3 consumes window_size per-call; the triton prefill/decode kernels - # mask out-of-window positions when SLIDING_WINDOW > 0 (see - # context_flashattention_nopad.py / gqa_flash_decoding_stage1.py). + # `sliding_window_` is the total window size including self. # `_gemma4_use_swa` is set by Gemma4TpPartModel._init_att_backend. if self.is_sliding and self.sliding_window_ > 0 and self.network_config_.get("_gemma4_use_swa", False): - w = self.sliding_window_ - 1 + w = self.sliding_window_ return AttControl(use_sliding_window=True, sliding_window=(w, w)) return AttControl(use_sliding_window=False, sliding_window=(-1, -1)) @@ -240,7 +238,7 @@ def _context_attention_kernel( and getattr(infer_state, "b_image_token_end", None) is not None ): o_tensor = self.alloc_tensor(_q.shape, q.dtype) - sw = self.sliding_window_ if self.sliding_window_ > 0 else 0 + sw = self.sliding_window_ if self.sliding_window_ > 0 else -1 context_attention_fwd_gemma4_mm( _q, _k, From 63c658ab5a56e32c0cf552639cddde07e5dc2d59 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Wed, 13 May 2026 06:55:19 +0000 Subject: [PATCH 08/20] simplify --- .../gemma4/layer_infer/post_layer_infer.py | 3 +- .../layer_weights/transformer_layer_weight.py | 6 +- lightllm/models/gemma4/model.py | 81 ++++++++++--------- 3 files changed, 48 insertions(+), 42 deletions(-) diff --git a/lightllm/models/gemma4/layer_infer/post_layer_infer.py b/lightllm/models/gemma4/layer_infer/post_layer_infer.py index f80fb10183..22bcf0508d 100644 --- a/lightllm/models/gemma4/layer_infer/post_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/post_layer_infer.py @@ -10,8 +10,7 @@ class Gemma4PostLayerInfer(LlamaPostLayerInfer): def __init__(self, network_config): super().__init__(network_config) - self.eps_ = 1e-6 - self.final_logit_softcapping = float(network_config.get("final_logit_softcapping", None)) + self.final_logit_softcapping = float(network_config.get("final_logit_softcapping")) def token_forward(self, input_embdings, infer_state, layer_weight): logits = super().token_forward(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py index 0e1dc46b0a..aaf2e84acc 100644 --- a/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py @@ -224,9 +224,8 @@ def _init_moe(self): def _init_norm(self): hidden_size = self.network_config_["hidden_size"] - # Gemma-4 uses *standard* RMSNorm (x * rsqrt(var+eps) * w), NOT the - # gemma2/3 (1+w) variant. Using NoTpGEMMANormWeight here produces - # nothing but high-frequency-token gibberish ("de la de..."). + # Gemma-4 uses standard RMSNorm (x * rsqrt(var+eps) * w), NOT the + # gemma2/3 (1+w) variant - do not swap in NoTpGEMMANormWeight. self.q_norm_weight_ = RMSNormWeight( dim=self._layer_head_dim, weight_name=self._q_norm_weight_name, @@ -273,7 +272,6 @@ def _init_norm(self): weight_name=self._post_feedforward_layernorm_2_name, data_type=self.data_type_, ) - # scalar multiplier applied to the attention output self.layer_scalar_ = ParameterWeight( weight_name=self._layer_scalar_name, data_type=self.data_type_, diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py index a2d3661f28..e5ff33fc23 100644 --- a/lightllm/models/gemma4/model.py +++ b/lightllm/models/gemma4/model.py @@ -2,7 +2,6 @@ import os import json import torch -from transformers import AutoConfig from lightllm.models.registry import ModelRegistry from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer from lightllm.common.basemodel.attention.triton.fp import TritonAttBackend @@ -15,7 +14,7 @@ from lightllm.models.gemma4.layer_infer.transformer_layer_infer import Gemma4TransformerLayerInfer from lightllm.models.gemma4.layer_weights.pre_and_post_layer_weight import Gemma4PreAndPostLayerWeight from lightllm.models.gemma4.layer_weights.transformer_layer_weight import Gemma4TransformerLayerWeight -from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num, get_env_start_args +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num from lightllm.utils.log_utils import init_logger from lightllm.distributed.communication_op import dist_group_manager @@ -132,12 +131,15 @@ def _init_config(self): # under text_config; flatten it so downstream code sees text-model fields # at the top level (mirrors the gemma3 approach). if "text_config" in self.config: - hf_config = AutoConfig.from_pretrained(self.weight_dir_, trust_remote_code=True) - self.config = hf_config.text_config.to_dict() + self.config = self.config["text_config"].copy() repair_config(self.config, same_names=["num_attention_heads", "n_head"]) repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + self._reset_num_key_value_heads() + + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size if self.config.get("enable_moe_block", False): # LightLLM's MoE helpers use Qwen/DeepSeek-style field names. @@ -195,41 +197,25 @@ def _init_mem_manager(self): def _init_att_backend(self): # Gemma-4 has per-layer heterogeneous attention: sliding layers use # (head_dim=256, kv_heads=16); full-attn layers use (head_dim=512, - # kv_heads=4, k_eq_v). No single backend covers both: + # kv_heads=4, k_eq_v). No single generic backend setup covers both: # - FA3 caps head_dim at 256 -> can't run full-attn layers. - # - Triton handles head_dim=512 (kernels widened to Lk=512) but - # historically refused sliding_window. # - Flashinfer plans once per infer_state on a single shape -> can't # accommodate heterogeneous layout at all. # Strategy: run full-attn layers on triton (primary backend, this - # method) and sliding layers on a separate backend wired via - # _init_att_backend1 (FA3 when available; triton-with-SWA otherwise). - from lightllm.utils.sgl_utils import flash_attn_with_kvcache - - args = get_env_start_args() - backends = set(args.llm_prefill_att_backend + args.llm_decode_att_backend) - for backend_name in backends: - assert backend_name in ("auto", "triton", "fa3"), ( - "Gemma-4 requires triton or fa3 (per-layer dynamic head_dim / " - "num_kv_heads); flashinfer is not wired for the heterogeneous " - f"layout. Got --llm_*_att_backend={backend_name!r}." - ) - fa3_loadable = flash_attn_with_kvcache is not None - if "fa3" in backends: - assert fa3_loadable, ( - "Requested --llm_*_att_backend=fa3 but flash_attn_with_kvcache " - "did not import (sgl_kernel missing or wrong arch)." - ) + # method) and sliding layers on a separate backend wired in + # _init_att_backend1. + fa3_loadable = self._gemma4_fa3_loadable() # Full-attn layers always go through triton. self.prefill_att_backend = TritonAttBackend(model=self) self.decode_att_backend = TritonAttBackend(model=self) - # Decide sliding-layer backend kind here so _init_att_backend1 can - # honour it. User can force triton with --llm_*_att_backend triton; - # otherwise prefer FA3 when loadable. - user_forced_triton = backends == {"triton"} - self._gemma4_sliding_backend_kind = "fa3" if (fa3_loadable and not user_forced_triton) else "triton" + self._gemma4_sliding_prefill_backend_kind = self._resolve_gemma4_sliding_backend( + self.args.llm_prefill_att_backend[0], fa3_loadable + ) + self._gemma4_sliding_decode_backend_kind = self._resolve_gemma4_sliding_backend( + self.args.llm_decode_att_backend[0], fa3_loadable + ) # SWA is on regardless of which sliding backend was picked: FA3 # honours window_size per call, and the triton kernels in # context_flashattention_nopad.py / gqa_flash_decoding_stage1.py mask @@ -239,14 +225,37 @@ def _init_att_backend(self): def _init_att_backend1(self): # Sliding layers run on a dedicated backend so the head-dim/SWA # mismatch with full-attn layers doesn't force a single compromise. - if self._gemma4_sliding_backend_kind == "fa3": + self.prefill_att_backend1 = self._build_gemma4_sliding_backend(self._gemma4_sliding_prefill_backend_kind) + self.decode_att_backend1 = self._build_gemma4_sliding_backend(self._gemma4_sliding_decode_backend_kind) + + @staticmethod + def _gemma4_fa3_loadable(): + from lightllm.utils.sgl_utils import flash_attn_with_kvcache + + return flash_attn_with_kvcache is not None + + @staticmethod + def _resolve_gemma4_sliding_backend(backend_name, fa3_loadable): + assert backend_name in ("auto", "triton", "fa3"), ( + "Gemma-4 requires triton or fa3 for sliding layers; flashinfer is " + f"not wired for the heterogeneous layout. Got backend={backend_name!r}." + ) + if backend_name == "auto": + return "fa3" if fa3_loadable else "triton" + if backend_name == "fa3": + assert fa3_loadable, ( + "Requested --llm_*_att_backend=fa3 but flash_attn_with_kvcache " + "did not import (sgl_kernel missing or wrong arch)." + ) + return backend_name + + def _build_gemma4_sliding_backend(self, backend_kind): + if backend_kind == "fa3": from lightllm.common.basemodel.attention.fa3.fp import Fa3AttBackend - self.prefill_att_backend1 = Fa3AttBackend(model=self) - self.decode_att_backend1 = Fa3AttBackend(model=self) - else: - self.prefill_att_backend1 = TritonAttBackend(model=self) - self.decode_att_backend1 = TritonAttBackend(model=self) + return Fa3AttBackend(model=self) + assert backend_kind == "triton" + return TritonAttBackend(model=self) def _init_custom(self): self._init_to_get_rotary_gemma4() From 300e5778ae6d02d95e5457a82cadd6400a4d45d5 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Wed, 13 May 2026 07:00:47 +0000 Subject: [PATCH 09/20] minor improvements --- .../basemodel/triton_kernel/norm/rmsnorm.py | 15 ++- .../layer_infer/transformer_layer_infer.py | 36 ++++--- .../models/llama/triton_kernel/rotary_emb.py | 96 ++++++++++--------- 3 files changed, 78 insertions(+), 69 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py index ca8f9a1c81..8dc8558922 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py @@ -18,6 +18,7 @@ def _rms_norm_fwd_fused( y_stride1, N, # number of columns in X eps, # epsilon to avoid division by zero + HAS_WEIGHT: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. @@ -32,14 +33,17 @@ def _rms_norm_fwd_fused( _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) - # Normalize and apply linear transformation + # Normalize and optionally apply linear transformation for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) x_hat = x * rstd - y = x_hat * w + y = x_hat + if HAS_WEIGHT: + y = x_hat * w # Write output tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) @@ -50,7 +54,9 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) # reshape input data into 2D tensor x_arg = x.view(-1, x.shape[-1]) y_arg = y.view(-1, x.shape[-1]) - assert x_arg.shape[-1] == weight.shape[0] and x_arg.shape == y_arg.shape + assert x_arg.shape == y_arg.shape + if weight is not None: + assert x_arg.shape[-1] == weight.shape[0] assert y.data_ptr() == y_arg.data_ptr() M, N = x_arg.shape # Less than 64KB per feature: enqueue fused kernel @@ -73,6 +79,7 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) y_arg.stride(1), N, eps, + HAS_WEIGHT=weight is not None, BLOCK_SIZE=BLOCK_SIZE, num_warps=rmsnorm_num_warps, ) diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index 9dab9438c5..9cb00b76d0 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -4,6 +4,7 @@ from lightllm.common.basemodel.attention.base_att import AttControl from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.models.gemma4.layer_weights.transformer_layer_weight import Gemma4TransformerLayerWeight from lightllm.models.gemma4.triton_kernel.context_attention_fwd_gemma4_mm import ( context_attention_fwd_gemma4_mm, @@ -24,7 +25,7 @@ class Gemma4TransformerLayerInfer(LlamaTransformerLayerInfer): def __init__(self, layer_num, network_config): super().__init__(layer_num, network_config) - self.eps_ = 1e-6 + self.eps_ = network_config.get("rms_norm_eps", 1e-6) self.embed_dim_ = network_config["hidden_size"] self.is_moe = bool(network_config.get("enable_moe_block", False)) self.num_experts_per_tok = network_config.get("num_experts_per_tok", network_config.get("top_k_experts", 0)) @@ -83,8 +84,8 @@ def __init__(self, layer_num, network_config): self.kv_share_target_layer_ = None if self.is_kv_shared_: cutoff = total_layers - kv_shared_count - for j in range(layer_num - 1, -1, -1): - if j < cutoff and network_config["layer_types"][j] == layer_type: + for j in range(cutoff - 1, -1, -1): + if network_config["layer_types"][j] == layer_type: self.kv_share_target_layer_ = j break assert self.kv_share_target_layer_ is not None, ( @@ -97,8 +98,6 @@ def __init__(self, layer_num, network_config): # head_dim. Don't change to 0.25 — that double-counts with the table. self.partial_rotary_factor_ = 1.0 - # ----- QKV + attention --------------------------------------------- - def _rope_cos_sin(self, infer_state): # Tables are built in the model dtype (Gemma4TpPartModel._init_to_get_rotary_gemma4), # so they already match q/k dtype — no cast needed. @@ -113,10 +112,6 @@ def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight: Gemma4Trans q_heads = self.tp_q_head_num_ kv_heads = self.tp_k_head_num_ - # Q is always computed (even on KV-shared layers). RMSNormWeight's - # Triton kernel accepts 3D input (it views to 2D internally) and - # promotes to fp32 for the variance reduction, so feed bf16 (N, heads, - # head_dim) straight in — no Python-side reshape or dtype round-trip. q = layer_weight.q_proj.mm(input).view(-1, q_heads, head_dim) q = layer_weight.q_norm_weight_(input=q, eps=self.eps_, alloc_func=self.alloc_tensor) @@ -124,10 +119,7 @@ def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight: Gemma4Trans if self.is_kv_shared_: # K/V come from target layer's already-rotated, already-normed cache. - # Only rotate Q here. rotary_emb_fwd writes to k in place, so pass - # a 1-head throwaway tensor we can discard. - dummy_k = torch.empty((q.shape[0], 1, head_dim), dtype=q.dtype, device=q.device) - rotary_emb_fwd(q, dummy_k, cos, sin, partial_rotary_factor=self.partial_rotary_factor_) + rotary_emb_fwd(q, None, cos, sin, partial_rotary_factor=self.partial_rotary_factor_) q = q * math.sqrt(head_dim) if infer_state.need_dp_prefill_balance: q = infer_state._all_to_all_unbalance_get(data=q) @@ -137,16 +129,19 @@ def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight: Gemma4Trans k = layer_weight.k_proj.mm(input).view(-1, kv_heads, head_dim) if self.k_eq_v: # Full-attn k_eq_v variant (e.g. 31B): K weights serve as V. - v = k.clone() + v = k else: v = layer_weight.v_proj.mm(input).view(-1, kv_heads, head_dim) k = layer_weight.k_norm_weight_(input=k, eps=self.eps_, alloc_func=self.alloc_tensor) # V-norm: unweighted RMSNorm over head_dim (matches vllm's Gemma4 has_weight=False). - v_fp = v.float() - v_fp = v_fp * torch.rsqrt(v_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps_) - v = v_fp.to(input.dtype) + v = rmsnorm_forward( + x=v, + weight=None, + eps=self.eps_, + out=self.alloc_tensor(v.shape, dtype=v.dtype, device=v.device), + ) rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=self.partial_rotary_factor_) @@ -299,9 +294,10 @@ def _router_logits(self, residual, layer_weight: Gemma4TransformerLayerWeight) - router_input = residual.view(-1, self.embed_dim_).float() router_input = router_input * torch.rsqrt(router_input.pow(2).mean(dim=-1, keepdim=True) + self.eps_) router_input = router_input * self.router_root_scale - # bf16 weight auto-promotes against fp32 router_input; cast back to - # bf16 to feed moe_gate.mm. - router_input = (router_input * layer_weight.router_input_scale_.weight).to(torch.bfloat16) + # Match the gate weight dtype before matmul, consistent with the other + # MoE paths and compatible with fp16 / bf16 / fp32 runs. + moe_gate_dtype = layer_weight.moe_gate.data_type_ + router_input = (router_input * layer_weight.router_input_scale_.weight).to(moe_gate_dtype) # gate logits stay fp32 for top-k / softmax precision. return layer_weight.moe_gate.mm(router_input, use_custom_tensor_mananger=False).float() diff --git a/lightllm/models/llama/triton_kernel/rotary_emb.py b/lightllm/models/llama/triton_kernel/rotary_emb.py index c6d4f3010d..f87b9d9e02 100755 --- a/lightllm/models/llama/triton_kernel/rotary_emb.py +++ b/lightllm/models/llama/triton_kernel/rotary_emb.py @@ -23,6 +23,7 @@ def _rotary_kernel( max_total_len, HEAD_Q, HEAD_K, # N_CTX 代表要计算的上下文长度 + HAS_K: tl.constexpr, BLOCK_HEAD: tl.constexpr, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -73,55 +74,59 @@ def _rotary_kernel( Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) ) - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) - - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd - - k0 = tl.load( - K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - k1 = tl.load( - K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out_k0 = k0 * cos - k1 * sin - out_k1 = k0 * sin + k1 * cos - - tl.store( - K + off_k0, - out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - tl.store( - K + off_k1, - out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) + if HAS_K: + off_k0 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range0[None, None, :] * stride_kd + ) + off_k1 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range1[None, None, :] * stride_kd + ) + + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd + + k0 = tl.load( + K + off_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + k1 = tl.load( + K + off_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out_k0 = k0 * cos - k1 * sin + out_k1 = k0 * sin + k1 * cos + + tl.store( + K + off_k0, + out_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) + tl.store( + K + off_k1, + out_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) return @torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): +def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0): total_len = q.shape[0] - head_num_q, head_num_k = q.shape[1], k.shape[1] + has_k = k is not None + head_num_q = q.shape[1] + head_num_k = k.shape[1] if has_k else 0 head_dim = int(q.shape[2] * partial_rotary_factor) assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" + if has_k: + assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" BLOCK_SEQ = 16 BLOCK_HEAD = 4 @@ -139,9 +144,9 @@ def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): q.stride(0), q.stride(1), q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), + k.stride(0) if has_k else 0, + k.stride(1) if has_k else 0, + k.stride(2) if has_k else 0, cos.stride(0), cos.stride(1), sin.stride(0), @@ -149,6 +154,7 @@ def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): total_len, head_num_q, head_num_k, + HAS_K=has_k, BLOCK_HEAD=BLOCK_HEAD, BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=head_dim, From 50822f0258fb4325b412b6e6a1566e6b0f878dc7 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Wed, 13 May 2026 07:41:02 +0000 Subject: [PATCH 10/20] fix --- .../layer_infer/transformer_layer_infer.py | 50 +++++++++---------- lightllm/models/gemma4/model.py | 5 -- 2 files changed, 23 insertions(+), 32 deletions(-) diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index 9cb00b76d0..0d267b84b2 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -39,12 +39,14 @@ def __init__(self, layer_num, network_config): # HF treats that as "fall back to num_key_value_heads". num_global_kv = network_config.get("num_global_key_value_heads") or network_config["num_key_value_heads"] + # Override parent's head_dim_ (hidden_size/num_heads = 224 on 31B, wrong + # for Gemma-4 — actual is 256 sliding / 512 full). if self.is_sliding: - self.layer_head_dim_ = network_config["head_dim"] + self.head_dim_ = network_config["head_dim"] total_kv_heads = network_config["num_key_value_heads"] self.k_eq_v = False else: - self.layer_head_dim_ = network_config["global_head_dim"] + self.head_dim_ = network_config["global_head_dim"] total_kv_heads = num_global_kv self.k_eq_v = network_config.get("attention_k_eq_v", True) @@ -108,7 +110,7 @@ def _rope_cos_sin(self, infer_state): def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: input = self._tpsp_allgather(input=input, infer_state=infer_state) - head_dim = self.layer_head_dim_ + head_dim = self.head_dim_ q_heads = self.tp_q_head_num_ kv_heads = self.tp_k_head_num_ @@ -179,20 +181,13 @@ def _post_cache_kv(self, cache_kv, infer_state, layer_weight): return return super()._post_cache_kv(cache_kv, infer_state, layer_weight) - def _get_o(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: - if infer_state.need_dp_prefill_balance: - input = infer_state._all_to_all_balance_get(data=input) - input = input.view(-1, self.tp_o_head_num_ * self.layer_head_dim_) - o_tensor = layer_weight.o_proj.mm(input) - o_tensor = self._tpsp_reduce(input=o_tensor, infer_state=infer_state) - return o_tensor - # ----- Attention kernels (sliding window + per-layer KV reshape) --- def _att_control(self): - # `sliding_window_` is the total window size including self. - # `_gemma4_use_swa` is set by Gemma4TpPartModel._init_att_backend. - if self.is_sliding and self.sliding_window_ > 0 and self.network_config_.get("_gemma4_use_swa", False): + # `sliding_window_` is the total window size including self. Sliding + # layers always run on a backend that consumes SWA (FA3 or the patched + # triton kernels — see Gemma4TpPartModel._init_att_backend1). + if self.is_sliding and self.sliding_window_ > 0: w = self.sliding_window_ return AttControl(use_sliding_window=True, sliding_window=(w, w)) return AttControl(use_sliding_window=False, sliding_window=(-1, -1)) @@ -201,17 +196,21 @@ def _get_layer_kv(self, infer_state: InferStateInfo): # KV-shared layers read from the target layer's cache slot. layer_idx = self.kv_share_target_layer_ if self.is_kv_shared_ else self.layer_num_ _k_raw, _v_raw = infer_state.mem_manager.get_att_input_params(layer_index=layer_idx) - # _k_raw / _v_raw shape (S, cache_slot_num, cache_slot_dim). + # _k_raw / _v_raw shape (S, cache_slot_num, cache_slot_dim). Use .view + # (not .reshape) so any non-contiguous layout from a future mem_manager + # backend fails loudly instead of silently copying — slice + view is + # O(1) on the standard MemoryManager layout (inner (kv_heads, head_dim) + # span is contiguous). kv_heads = self.tp_k_head_num_ - head_dim = self.layer_head_dim_ + head_dim = self.head_dim_ cache_slot_dim = self.kv_cache_slot_dim_ used_cache_slots = kv_heads * head_dim // cache_slot_dim if used_cache_slots == _k_raw.shape[1]: # Layout already matches this layer's natural shape. - return _k_raw.reshape(-1, kv_heads, head_dim), _v_raw.reshape(-1, kv_heads, head_dim) + return _k_raw.view(-1, kv_heads, head_dim), _v_raw.view(-1, kv_heads, head_dim) # Otherwise the K/V live in the first used_cache_slots; the rest is zero pad. - _k = _k_raw[:, :used_cache_slots, :].reshape(-1, kv_heads, head_dim) - _v = _v_raw[:, :used_cache_slots, :].reshape(-1, kv_heads, head_dim) + _k = _k_raw[:, :used_cache_slots, :].view(-1, kv_heads, head_dim) + _v = _v_raw[:, :used_cache_slots, :].view(-1, kv_heads, head_dim) return _k, _v def _context_attention_kernel( @@ -223,15 +222,12 @@ def _context_attention_kernel( out=None, ) -> torch.Tensor: _k, _v = self._get_layer_kv(infer_state) - _q = q.view(-1, self.tp_q_head_num_, self.layer_head_dim_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) # Image bidirectional attention only applies on sliding-window layers # (matches HF/vllm `use_bidirectional_attention="vision"`). Full-attn - # layers stay on the standard causal triton path. - if ( - self.is_sliding - and self.network_config_.get("_gemma4_use_swa", False) - and getattr(infer_state, "b_image_token_end", None) is not None - ): + # layers stay on the standard causal triton path. b_image_token_end is + # only built for prefills that actually carry images. + if self.is_sliding and getattr(infer_state, "b_image_token_end", None) is not None: o_tensor = self.alloc_tensor(_q.shape, q.dtype) sw = self.sliding_window_ if self.sliding_window_ > 0 else -1 context_attention_fwd_gemma4_mm( @@ -267,7 +263,7 @@ def _token_attention_kernel( out=None, ) -> torch.Tensor: _k, _v = self._get_layer_kv(infer_state) - _q = q.view(-1, self.tp_q_head_num_, self.layer_head_dim_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) att_state = infer_state.decode_att_state1 if self.is_sliding else infer_state.decode_att_state o_tensor = att_state.decode_att(q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor) return o_tensor.view(q.shape) diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py index e5ff33fc23..c3f758ee80 100644 --- a/lightllm/models/gemma4/model.py +++ b/lightllm/models/gemma4/model.py @@ -216,11 +216,6 @@ def _init_att_backend(self): self._gemma4_sliding_decode_backend_kind = self._resolve_gemma4_sliding_backend( self.args.llm_decode_att_backend[0], fa3_loadable ) - # SWA is on regardless of which sliding backend was picked: FA3 - # honours window_size per call, and the triton kernels in - # context_flashattention_nopad.py / gqa_flash_decoding_stage1.py mask - # out-of-window positions when SLIDING_WINDOW > 0. - self.config["_gemma4_use_swa"] = True def _init_att_backend1(self): # Sliding layers run on a dedicated backend so the head-dim/SWA From b4b13cc58656bd3a355de1fb37fa2f9f99eb930e Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Wed, 13 May 2026 09:00:49 +0000 Subject: [PATCH 11/20] fix attention cuda graph --- lightllm/models/gemma4/infer_struct.py | 11 +++------- .../gemma4/layer_infer/pre_layer_infer.py | 7 +------ .../layer_infer/transformer_layer_infer.py | 16 ++++++--------- lightllm/models/gemma4/model.py | 20 ++++++++++--------- 4 files changed, 21 insertions(+), 33 deletions(-) diff --git a/lightllm/models/gemma4/infer_struct.py b/lightllm/models/gemma4/infer_struct.py index fd2dfbe918..867977e3ec 100644 --- a/lightllm/models/gemma4/infer_struct.py +++ b/lightllm/models/gemma4/infer_struct.py @@ -39,12 +39,10 @@ def init_some_extra_state(self, model): return def _build_b_image_token_end(self): - # Scatter per-image end markers into a flat (sum_q,) int32 tensor for - # consumption by the image-aware prefill attention kernel. Style mirrors - # neo_chat_moe.get_neo_position. Chunked-prefill clipping (image span - # straddling cache/new boundary) is handled inside the kernel. + device = self.position_ids.device + self.b_image_token_end = torch.zeros(self.position_ids.shape[0], dtype=torch.int32, device=device) + if not self.multimodal_params: - self.b_image_token_end = None return b_image_start_idx = [] @@ -62,11 +60,8 @@ def _build_b_image_token_end(self): image_start_num += 1 if image_start_num == 0: - self.b_image_token_end = None return - device = self.position_ids.device - self.b_image_token_end = torch.zeros(self.position_ids.shape[0], dtype=torch.int32, device=device) build_b_image_token_end( b_image_start_idx=torch.tensor(b_image_start_idx, dtype=torch.int32).cuda(non_blocking=True), b_image_len=torch.tensor(b_image_len, dtype=torch.int32).cuda(non_blocking=True), diff --git a/lightllm/models/gemma4/layer_infer/pre_layer_infer.py b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py index 98c29c1cc2..b76508e1da 100644 --- a/lightllm/models/gemma4/layer_infer/pre_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py @@ -43,12 +43,7 @@ def _compute_per_layer_embeds(self, input_ids_for_ple, input_embdings, infer_sta def context_forward(self, input_ids, infer_state, layer_weight): input_embdings = LlamaMultimodalPreLayerInfer.context_forward(self, input_ids, infer_state, layer_weight) if self.has_ple: - image_token_end = getattr(infer_state, "b_image_token_end", None) - input_ids_for_ple = ( - input_ids - if image_token_end is None - else input_ids.masked_fill(image_token_end != 0, self.pad_token_id_) - ) + input_ids_for_ple = input_ids.masked_fill(infer_state.b_image_token_end != 0, self.pad_token_id_) self._compute_per_layer_embeds(input_ids_for_ple, input_embdings, infer_state, layer_weight) return input_embdings diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index 0d267b84b2..84ca1084d8 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -223,11 +223,9 @@ def _context_attention_kernel( ) -> torch.Tensor: _k, _v = self._get_layer_kv(infer_state) _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) - # Image bidirectional attention only applies on sliding-window layers - # (matches HF/vllm `use_bidirectional_attention="vision"`). Full-attn - # layers stay on the standard causal triton path. b_image_token_end is - # only built for prefills that actually carry images. - if self.is_sliding and getattr(infer_state, "b_image_token_end", None) is not None: + if self.is_sliding: + # Sliding layers always go through the gemma4_mm Triton kernel: it + # handles SWA + image bidirectional masking in one pass. o_tensor = self.alloc_tensor(_q.shape, q.dtype) sw = self.sliding_window_ if self.sliding_window_ > 0 else -1 context_attention_fwd_gemma4_mm( @@ -246,11 +244,9 @@ def _context_attention_kernel( ) return o_tensor.view(q.shape) - # Sliding layers go through the secondary backend (FA3 with SWA when - # available, else triton-with-SWA from path B). Full-attn layers go - # through the primary triton backend (head_dim=512). - att_state = infer_state.prefill_att_state1 if self.is_sliding else infer_state.prefill_att_state - o_tensor = att_state.prefill_att( + # Full-attn layers: head_dim=512, no SWA, no image bidi — standard + # triton via the primary backend. + o_tensor = infer_state.prefill_att_state.prefill_att( q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor ) return o_tensor.view(q.shape) diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py index c3f758ee80..42093a3f29 100644 --- a/lightllm/models/gemma4/model.py +++ b/lightllm/models/gemma4/model.py @@ -201,26 +201,28 @@ def _init_att_backend(self): # - FA3 caps head_dim at 256 -> can't run full-attn layers. # - Flashinfer plans once per infer_state on a single shape -> can't # accommodate heterogeneous layout at all. - # Strategy: run full-attn layers on triton (primary backend, this - # method) and sliding layers on a separate backend wired in - # _init_att_backend1. + # Strategy: + # - Prefill: sliding layers go through the gemma4_mm Triton kernel + # directly (handles SWA + image bidi); full-attn layers use the + # primary triton backend below. No FA3 in prefill — its + # image_token_end build asserts incompatible with SWA. Revisit + # when fa3 supports both simultaneously. + # - Decode: full-attn layers on triton (primary); sliding layers on + # fa3 (with SWA) when available — secondary backend set in + # _init_att_backend1. fa3_loadable = self._gemma4_fa3_loadable() # Full-attn layers always go through triton. self.prefill_att_backend = TritonAttBackend(model=self) self.decode_att_backend = TritonAttBackend(model=self) - self._gemma4_sliding_prefill_backend_kind = self._resolve_gemma4_sliding_backend( - self.args.llm_prefill_att_backend[0], fa3_loadable - ) self._gemma4_sliding_decode_backend_kind = self._resolve_gemma4_sliding_backend( self.args.llm_decode_att_backend[0], fa3_loadable ) def _init_att_backend1(self): - # Sliding layers run on a dedicated backend so the head-dim/SWA - # mismatch with full-attn layers doesn't force a single compromise. - self.prefill_att_backend1 = self._build_gemma4_sliding_backend(self._gemma4_sliding_prefill_backend_kind) + # Only decode needs the sliding-layer backend; prefill sliding goes + # through gemma4_mm Triton directly in the layer. self.decode_att_backend1 = self._build_gemma4_sliding_backend(self._gemma4_sliding_decode_backend_kind) @staticmethod From f19074b9212b43ac6367eee4aa071bc416399d14 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 14 May 2026 01:45:36 +0000 Subject: [PATCH 12/20] fused gelu gate up --- .../layer_infer/transformer_layer_infer.py | 19 +++++++++++-------- .../layer_weights/transformer_layer_weight.py | 19 +++++++------------ 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index 84ca1084d8..4cadf87c84 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -9,6 +9,7 @@ from lightllm.models.gemma4.triton_kernel.context_attention_fwd_gemma4_mm import ( context_attention_fwd_gemma4_mm, ) +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd @@ -264,16 +265,18 @@ def _token_attention_kernel( o_tensor = att_state.decode_att(q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor) return o_tensor.view(q.shape) - # ----- FFN (Gemma gelu-tanh, separate gate/up/down) ---------------- + # ----- FFN (Gemma gelu-tanh, fused gate_up + down) ----------------- - def _ffn(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: + def _ffn_tp(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: + # Only override the inner core — the outer _ffn (tpsp_allgather + + # _ffn_tp + tpsp_reduce) is inherited from LlamaTransformerLayerInfer. + # Difference vs llama: gelu(tanh)+mul instead of silu+mul. input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) - gate = layer_weight.gate_proj.mm(input) - up = layer_weight.up_proj.mm(input) - ffn1 = nn.functional.gelu(gate, approximate="tanh") * up - gate = None - up = None + gate_up = layer_weight.gate_up_proj.mm(input) + ffn1 = self.alloc_tensor((input.size(0), gate_up.size(1) // 2), input.dtype) + silu_and_mul_fwd(gate_up, ffn1, use_gelu=True) + gate_up = None ffn2 = layer_weight.down_proj.mm(ffn1) ffn1 = None ffn2 = self._tpsp_reduce(input=ffn2, infer_state=infer_state) @@ -414,6 +417,6 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - input_embdings = self._ffn_block(input_embdings, infer_state, layer_weight) + input_embdings = self._ffn(input_embdings, infer_state, layer_weight) return self._block_epilogue(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py index aaf2e84acc..6d9a5c2613 100644 --- a/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py @@ -162,21 +162,16 @@ def _init_o(self): ) def _init_ffn(self): - self.gate_proj = ROWMMWeight( + # Packed gate+up: ROWMMWeight stitches `gate_proj` and `up_proj` weights + # along the output dim so the dense FFN runs one matmul + a fused + # gelu*mul kernel (mirrors llama's gate_up_proj path). + self.gate_up_proj = ROWMMWeight( in_dim=self.n_embed, - out_dims=[self.n_inter], - weight_names=self._gate_weight_name, + out_dims=[self.n_inter, self.n_inter], + weight_names=[self._gate_weight_name, self._up_weight_name], data_type=self.data_type_, bias_names=None, - quant_method=self.get_quant_method("gate_proj"), - ) - self.up_proj = ROWMMWeight( - in_dim=self.n_embed, - out_dims=[self.n_inter], - weight_names=self._up_weight_name, - data_type=self.data_type_, - bias_names=None, - quant_method=self.get_quant_method("up_proj"), + quant_method=self.get_quant_method("gate_up_proj"), ) self.down_proj = COLMMWeight( in_dim=self.n_inter, From 5b61450546d139111da2aa6569b36d54bb05737d Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 14 May 2026 04:45:39 +0000 Subject: [PATCH 13/20] add out_dtype --- .../meta_weights/mm_weight/mm_weight.py | 17 ++++++++- lightllm/common/quantization/awq.py | 5 +++ lightllm/common/quantization/deepgemm.py | 3 ++ lightllm/common/quantization/no_quant.py | 15 ++++++-- .../common/quantization/quantize_method.py | 1 + lightllm/common/quantization/w8a8.py | 7 ++++ lightllm/common/quantization/w8a8gx.py | 3 ++ .../layer_infer/transformer_layer_infer.py | 37 ++++++++----------- 8 files changed, 60 insertions(+), 28 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 5021699143..895482b491 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -54,10 +54,23 @@ def __init__( self.gen_weight_quant_param_names() def mm( - self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True + self, + input_tensor: torch.Tensor, + out: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + # out_dtype: optional override that asks the quant backend to produce + # an output of a specified dtype (e.g. fp32) directly from the GEMM + # accumulator. Only NoQuantization currently honors values that differ + # from input dtype; other quant impls will assert. return self.quant_method.apply( - input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger, bias=self.bias + input_tensor, + self.mm_param, + out, + use_custom_tensor_mananger=use_custom_tensor_mananger, + bias=self.bias, + out_dtype=out_dtype, ) def gen_weight_quant_param_names(self): diff --git a/lightllm/common/quantization/awq.py b/lightllm/common/quantization/awq.py index f3c7623975..96cd2b2926 100644 --- a/lightllm/common/quantization/awq.py +++ b/lightllm/common/quantization/awq.py @@ -58,6 +58,7 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: raise NotImplementedError("AWQ online quantization is not supported yet.") @@ -92,7 +93,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "awq quant does not support out_dtype" qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point @@ -167,7 +170,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "awq_marlin quant does not support out_dtype" qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 137455a821..901ec142e1 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -35,6 +35,7 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -75,7 +76,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "deepgemm-fp8w8a8-b128 quant does not support out_dtype" qweight = weight_pack.weight weight_scale = weight_pack.weight_scale input_scale = None diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py index fa926ad6f0..b0deaca9a4 100644 --- a/lightllm/common/quantization/no_quant.py +++ b/lightllm/common/quantization/no_quant.py @@ -18,20 +18,27 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager weight = weight_pack.weight.t() + target_dtype = out_dtype if out_dtype is not None else input_tensor.dtype if out is None: shape = (input_tensor.shape[0], weight.shape[1]) - dtype = input_tensor.dtype device = input_tensor.device if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor(shape, dtype, device=device) + out = g_cache_manager.alloc_tensor(shape, target_dtype, device=device) else: - out = torch.empty(shape, dtype=dtype, device=device) + out = torch.empty(shape, dtype=target_dtype, device=device) + else: + assert out.dtype == target_dtype, ( + f"NoQuantization.apply: pre-allocated out.dtype={out.dtype} does not match " + f"requested out_dtype={target_dtype}" + ) if bias is None: - return torch.mm(input_tensor, weight, out=out) + return torch.mm(input_tensor, weight, out=out, out_dtype=target_dtype) + assert out_dtype is None, "NoQuantization.apply: out_dtype not supported when bias is set" return torch.addmm(bias, input_tensor, weight, out=out) def _create_weight( diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 95d8d806f9..d3f251ec84 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -55,6 +55,7 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: pass diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index b3d29b0527..87142086fa 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -48,6 +48,7 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -85,7 +86,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "w8a8 quant does not support out_dtype" input_scale = None qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale @@ -147,7 +150,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "fp8w8a8 quant does not support out_dtype" qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) @@ -214,7 +219,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "fp8w8a8-b128 quant does not support out_dtype" qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale.t() input_scale = None # dynamic quantization for input tensor diff --git a/lightllm/common/quantization/w8a8gx.py b/lightllm/common/quantization/w8a8gx.py index c25136697d..a6a4065745 100644 --- a/lightllm/common/quantization/w8a8gx.py +++ b/lightllm/common/quantization/w8a8gx.py @@ -24,6 +24,7 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -62,7 +63,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "fp8w8a8gxx quant does not support out_dtype" qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index 4cadf87c84..af5635f3cc 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -267,10 +267,9 @@ def _token_attention_kernel( # ----- FFN (Gemma gelu-tanh, fused gate_up + down) ----------------- - def _ffn_tp(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: - # Only override the inner core — the outer _ffn (tpsp_allgather + - # _ffn_tp + tpsp_reduce) is inherited from LlamaTransformerLayerInfer. - # Difference vs llama: gelu(tanh)+mul instead of silu+mul. + def _ffn_dense( + self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight + ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) gate_up = layer_weight.gate_up_proj.mm(input) @@ -283,20 +282,14 @@ def _ffn_tp(self, input, infer_state: InferStateInfo, layer_weight: Gemma4Transf return ffn2 def _router_logits(self, residual, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: - # Manual unweighted RMSNorm — lightllm's RMSNormWeight has no - # has_weight=False mode, and bf16 variance over hidden_size loses too - # much precision. Keep the fp32 accumulation explicit. - router_input = residual.view(-1, self.embed_dim_).float() - router_input = router_input * torch.rsqrt(router_input.pow(2).mean(dim=-1, keepdim=True) + self.eps_) - router_input = router_input * self.router_root_scale - # Match the gate weight dtype before matmul, consistent with the other - # MoE paths and compatible with fp16 / bf16 / fp32 runs. - moe_gate_dtype = layer_weight.moe_gate.data_type_ - router_input = (router_input * layer_weight.router_input_scale_.weight).to(moe_gate_dtype) - # gate logits stay fp32 for top-k / softmax precision. - return layer_weight.moe_gate.mm(router_input, use_custom_tensor_mananger=False).float() - - def _moe_ffn(self, input, router_logits, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + # Mirrors vllm Gemma4Router: unweighted RMSNorm -> 1/sqrt(hidden) -> + # per-channel scale -> bf16xbf16 -> fp32 gate matmul for stable top-k. + x = residual.view(-1, self.embed_dim_) + x = rmsnorm_forward(x=x, weight=None, eps=self.eps_, out=self.alloc_tensor(x.shape, dtype=x.dtype)) + x = x * self.router_root_scale * layer_weight.router_input_scale_.weight + return layer_weight.moe_gate.mm(x, out_dtype=torch.float32) + + def _ffn_moe(self, input, router_logits, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): input = input.view(-1, self.embed_dim_) input = self._tpsp_allgather(input=input, infer_state=infer_state) moe_out = layer_weight.experts.experts( @@ -314,12 +307,12 @@ def _moe_ffn(self, input, router_logits, infer_state: InferStateInfo, layer_weig moe_out = self._tpsp_reduce(input=moe_out, infer_state=infer_state) return moe_out - def _ffn_block(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + def _ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): residual = input_embdings dense_input = layer_weight.pre_feedforward_layernorm_weight_( input=residual, eps=self.eps_, alloc_func=self.alloc_tensor ) - dense_out = self._ffn(dense_input, infer_state, layer_weight) + dense_out = self._ffn_dense(dense_input, infer_state, layer_weight) dense_input = None if self.is_moe: @@ -331,7 +324,7 @@ def _ffn_block(self, input_embdings, infer_state: InferStateInfo, layer_weight: moe_input = layer_weight.pre_feedforward_layernorm_2_weight_( input=residual, eps=self.eps_, alloc_func=self.alloc_tensor ) - moe_out = self._moe_ffn(moe_input, router_logits, infer_state, layer_weight) + moe_out = self._ffn_moe(moe_input, router_logits, infer_state, layer_weight) moe_input = None router_logits = None moe_out = layer_weight.post_feedforward_layernorm_2_weight_( @@ -401,7 +394,7 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input_embdings.add_(o.view(-1, self.embed_dim_)) o = None - input_embdings = self._ffn_block(input_embdings, infer_state, layer_weight) + input_embdings = self._ffn(input_embdings, infer_state, layer_weight) return self._block_epilogue(input_embdings, infer_state, layer_weight) From c0ca2127bede0ff4177e0e99d65cf3cc2f6fa7ba Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 14 May 2026 05:56:28 +0000 Subject: [PATCH 14/20] minor improvements --- .../layer_infer/transformer_layer_infer.py | 62 +++++-------------- 1 file changed, 15 insertions(+), 47 deletions(-) diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index af5635f3cc..10168e5b70 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -342,54 +342,26 @@ def _ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4 # ----- block-level forwards (PLE fusion + layer_scalar at the end) ---- - def _apply_per_layer_embed(self, hidden_states, infer_state, layer_weight): - """E-series: gate hidden_states through per_layer_embed slice and add - the projected contribution back as a residual. Matches HF - Gemma4TextDecoderLayer.forward (lines 1401–1408 in transformers 5.5.4) - and vllm Gemma4DecoderLayer.forward (gemma4.py:744–752) — bf16 the - whole way, RMSNorm Triton kernel handles fp32 promotion internally. - - gate / projection weights are ROWMMWeight(tp_world_size=1) — replicated - across TP ranks — so we drive them through `.mm()` and never need an - intra-block all-reduce. In TPSP mix mode, per_layer_embeds has already - been token-split alongside hidden_states by Gemma4PreLayerInfer's - _tpsp_sp_split override, so rows line up element-wise here. - """ - # per_layer_embeds is (N, num_layers, ple_dim); slice this layer. - ple_slice = infer_state.per_layer_embeds[..., self.layer_num_, :] - flat = hidden_states.view(-1, self.embed_dim_) - gate = layer_weight.per_layer_input_gate_.mm(flat) # (N, ple_dim) - gate = nn.functional.gelu(gate, approximate="tanh") - gated = gate * ple_slice.view(-1, self.ple_dim_) - contrib = layer_weight.per_layer_projection_.mm(gated) # (N, hidden_size) - contrib = layer_weight.post_per_layer_input_norm_weight_( - input=contrib, eps=self.eps_, alloc_func=self.alloc_tensor - ) - flat.add_(contrib) - return hidden_states - - def _apply_layer_scalar(self, hidden_states, layer_weight): - hidden_states.mul_(layer_weight.layer_scalar_.weight) - return hidden_states - def _block_epilogue(self, hidden_states, infer_state, layer_weight): - """Shared tail for prefill/decode: PLE fusion (E-series only) then - layer_scalar.""" if self.has_ple_: - hidden_states = self._apply_per_layer_embed(hidden_states, infer_state, layer_weight) - return self._apply_layer_scalar(hidden_states, layer_weight) + ple_slice = infer_state.per_layer_embeds[..., self.layer_num_, :] + flat = hidden_states.view(-1, self.embed_dim_) + gate = layer_weight.per_layer_input_gate_.mm(flat) + gated = nn.functional.gelu(gate, approximate="tanh") * ple_slice.view(-1, self.ple_dim_) + contrib = layer_weight.per_layer_projection_.mm(gated) + contrib = layer_weight.post_per_layer_input_norm_weight_( + input=contrib, eps=self.eps_, alloc_func=self.alloc_tensor + ) + flat.add_(contrib) + hidden_states.mul_(layer_weight.layer_scalar_.weight) + return hidden_states def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): - # input_embdings is bf16 from the pre-layer / previous block; RMSNorm - # (att_norm, ffn_norm) handles fp32 promotion in its Triton kernel, - # so the entire residual stream stays in bf16. input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_), infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) + o = self.context_attention_forward(input1, infer_state, layer_weight) input1 = None - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) + # Gemma sandwich norm: post_attention_layernorm on the attn branch + # before the residual add, not on the post-add residual stream. o = self._ffn_norm(o, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -400,12 +372,8 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_), infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) + o = self.token_attention_forward(input1, infer_state, layer_weight) input1 = None - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._token_attention_kernel(q, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) o = self._ffn_norm(o, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None From 9499a00e166e4609d1d5fcbca6f4c55a69713f8e Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 14 May 2026 06:42:35 +0000 Subject: [PATCH 15/20] fix eos_token_ids --- lightllm/utils/config_utils.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 3368bdafd6..8383c9b54c 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -104,12 +104,28 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]: eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"]) if isinstance(eos_token_id, int): - return [eos_token_id] - if isinstance(eos_token_id, list): - return eos_token_id + eos_token_ids = [eos_token_id] + elif isinstance(eos_token_id, list): + eos_token_ids = list(eos_token_id) + else: + raise ValueError("error eos_token_id format in config.json") - assert False, "error eos_token_id format in config.json" - return + generation_config_path = os.path.join(model_path, "generation_config.json") + if os.path.exists(generation_config_path): + try: + with open(generation_config_path, "r") as file: + generation_eos = json.load(file).get("eos_token_id") + except Exception as exc: + logger.warning(f"failed to load eos_token_id from generation_config.json: {exc}") + generation_eos = None + if isinstance(generation_eos, int): + generation_eos = [generation_eos] + if isinstance(generation_eos, list): + for token_id in generation_eos: + if isinstance(token_id, int) and token_id not in eos_token_ids: + eos_token_ids.append(token_id) + + return eos_token_ids def get_model_architectures(model_path: str): From de7e220a47aa2039fb1102d079e2af23fadb1c4e Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 14 May 2026 06:42:54 +0000 Subject: [PATCH 16/20] for HF format --- lightllm/server/build_prompt.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 84044fccce..be0c660b2e 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -79,11 +79,31 @@ def _alias_reasoning_to_reasoning_content(messages: list) -> None: msg["reasoning_content"] = reasoning +def _normalize_multimodal_content_types(messages: list) -> None: + # OpenAI requests use content part types like `image_url` and `audio_url`. + # Model chat templates generally render modality tokens from `image` and + # `audio` parts while the raw media payload is carried separately in + # MultimodalParams. Preserve the original fields and normalize only the + # template-facing type to keep prompt tags aligned with media counts. + for msg in messages: + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if not isinstance(part, dict): + continue + if part.get("type") == "image_url": + part["type"] = "image" + elif part.get("type") == "audio_url": + part["type"] = "audio" + + async def build_prompt(request, tools) -> str: # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] _normalize_tool_call_arguments(messages) _alias_reasoning_to_reasoning_content(messages) + _normalize_multimodal_content_types(messages) kwargs = {"conversation": messages} if request.character_settings: From 109d27c01355d4644d2fa1ade1a6773bc59c7e9d Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 14 May 2026 07:34:53 +0000 Subject: [PATCH 17/20] fix window_size --- lightllm/common/basemodel/attention/fa3/fp.py | 16 ++++++---------- lightllm/common/basemodel/attention/triton/fp.py | 4 ++-- .../layer_infer/transformer_layer_infer.py | 5 +---- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index e1f6959795..952bb39d91 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -79,12 +79,10 @@ def _nomarl_prefill_att( ) -> torch.Tensor: self.backend: Fa3AttBackend = self.backend # for typing - window_size = (-1, -1) if att_control.use_sliding_window: - left, right = att_control.sliding_window - left = max(int(left) - 1, 0) if left >= 0 else -1 - right = max(int(right) - 1, 0) if right >= 0 else -1 - window_size = (left, right) + window_size = att_control.sliding_window + else: + window_size = (-1, -1) if att_control.use_att_sink: sink_weight: torch.Tensor = att_control.sink_weight @@ -211,12 +209,10 @@ def _normal_decode_att( att_control: AttControl, alloc_func=torch.empty, ): - window_size = (-1, -1) if att_control.use_sliding_window: - left, right = att_control.sliding_window - left = max(int(left) - 1, 0) if left >= 0 else -1 - right = max(int(right) - 1, 0) if right >= 0 else -1 - window_size = (left, right) + window_size = att_control.sliding_window + else: + window_size = (-1, -1) if att_control.use_att_sink: sink_weight: torch.Tensor = att_control.sink_weight diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index 1902960769..ff5432f288 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -70,7 +70,7 @@ def _nomarl_prefill_att( from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd if att_control.use_sliding_window: - sliding_window = int(att_control.sliding_window[0]) + sliding_window = int(att_control.sliding_window[0]) + 1 else: sliding_window = -1 @@ -186,7 +186,7 @@ def _normal_decode_gqa_flash_decoding_att( ) if att_control.use_sliding_window: - sliding_window = int(att_control.sliding_window[0]) + sliding_window = int(att_control.sliding_window[0]) + 1 else: sliding_window = -1 diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index 10168e5b70..ffc8bb4608 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -185,11 +185,8 @@ def _post_cache_kv(self, cache_kv, infer_state, layer_weight): # ----- Attention kernels (sliding window + per-layer KV reshape) --- def _att_control(self): - # `sliding_window_` is the total window size including self. Sliding - # layers always run on a backend that consumes SWA (FA3 or the patched - # triton kernels — see Gemma4TpPartModel._init_att_backend1). if self.is_sliding and self.sliding_window_ > 0: - w = self.sliding_window_ + w = self.sliding_window_ - 1 return AttControl(use_sliding_window=True, sliding_window=(w, w)) return AttControl(use_sliding_window=False, sliding_window=(-1, -1)) From 2ea258e92fb11fda295dcbaa89225d17e545a0ef Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 14 May 2026 07:48:47 +0000 Subject: [PATCH 18/20] fix window_size --- lightllm/common/basemodel/attention/triton/fp.py | 4 ++-- .../gqa/flash_decoding/gqa_flash_decoding_stage1.py | 2 +- .../gqa/flash_decoding/gqa_flash_decoding_stage2.py | 2 +- .../att/prefill_att/context_flashattention_nopad.py | 4 ++-- .../gemma4/layer_infer/transformer_layer_infer.py | 2 +- .../triton_kernel/context_attention_fwd_gemma4_mm.py | 10 ++++++---- 6 files changed, 13 insertions(+), 11 deletions(-) diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index ff5432f288..1902960769 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -70,7 +70,7 @@ def _nomarl_prefill_att( from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd if att_control.use_sliding_window: - sliding_window = int(att_control.sliding_window[0]) + 1 + sliding_window = int(att_control.sliding_window[0]) else: sliding_window = -1 @@ -186,7 +186,7 @@ def _normal_decode_gqa_flash_decoding_att( ) if att_control.use_sliding_window: - sliding_window = int(att_control.sliding_window[0]) + 1 + sliding_window = int(att_control.sliding_window[0]) else: sliding_window = -1 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py index b800921ae7..d60e434627 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py @@ -49,7 +49,7 @@ def _fwd_kernel_flash_decode_stage1( cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) if USE_SLIDING_WINDOW: - kv_start_index = tl.maximum(cur_batch_seq_len - SLIDING_WINDOW_SIZE, 0) + kv_start_index = tl.maximum(cur_batch_seq_len - 1 - SLIDING_WINDOW_SIZE, 0) cur_batch_seq_len = cur_batch_seq_len - kv_start_index else: kv_start_index = 0 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py index a7c0db19f4..810abe1efa 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py @@ -31,7 +31,7 @@ def _fwd_kernel_flash_decode_stage2( offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) if USE_SLIDING_WINDOW: - kv_start_index = tl.maximum(cur_batch_seq_len - SLIDING_WINDOW_SIZE, 0) + kv_start_index = tl.maximum(cur_batch_seq_len - 1 - SLIDING_WINDOW_SIZE, 0) cur_batch_seq_len = cur_batch_seq_len - kv_start_index block_num = tl.minimum(tl.cdiv(cur_batch_seq_len, BLOCK_SEQ), block_num) diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index f4c3c10ffe..7daf3a12e8 100644 --- a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -80,7 +80,7 @@ def _fwd_kernel( block_end_loc = tl.minimum(block_start_loc + BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) if USE_SLIDING_WINDOW: - kv_start_index = block_start_loc + prompt_cache_len - SLIDING_WINDOW_SIZE + 1 + kv_start_index = block_start_loc + prompt_cache_len - SLIDING_WINDOW_SIZE kv_start_index = tl.maximum(kv_start_index, 0) block_kv_len = block_end_loc - kv_start_index else: @@ -103,7 +103,7 @@ def _fwd_kernel( mask = q_pos[:, None] >= k_pos[None, :] if USE_SLIDING_WINDOW: - mask = mask & ((q_pos[:, None] - k_pos[None, :]) < SLIDING_WINDOW_SIZE) + mask = mask & ((q_pos[:, None] - k_pos[None, :]) <= SLIDING_WINDOW_SIZE) qk = tl.where(mask, qk * sm_scale, -1.0e8) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py index ffc8bb4608..e6dc5cc359 100644 --- a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -225,7 +225,7 @@ def _context_attention_kernel( # Sliding layers always go through the gemma4_mm Triton kernel: it # handles SWA + image bidirectional masking in one pass. o_tensor = self.alloc_tensor(_q.shape, q.dtype) - sw = self.sliding_window_ if self.sliding_window_ > 0 else -1 + sw = self.sliding_window_ - 1 if self.sliding_window_ > 0 else -1 context_attention_fwd_gemma4_mm( _q, _k, diff --git a/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py b/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py index 9ecf3938fa..b0ab70d7c7 100644 --- a/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py +++ b/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py @@ -114,7 +114,7 @@ def _fwd_kernel( block_end_loc = tl.maximum(causal_end, block_image_end) if USE_SLIDING_WINDOW: - kv_start_index = block_start_loc + prompt_cache_len - SLIDING_WINDOW_SIZE + 1 + kv_start_index = block_start_loc + prompt_cache_len - SLIDING_WINDOW_SIZE kv_start_index = tl.maximum(kv_start_index, 0) block_kv_len = block_end_loc - kv_start_index else: @@ -138,7 +138,8 @@ def _fwd_kernel( causal_mask = q_pos[:, None] >= k_pos[None, :] if USE_SLIDING_WINDOW: - causal_mask = causal_mask & ((q_pos[:, None] - k_pos[None, :]) < SLIDING_WINDOW_SIZE) + # SLIDING_WINDOW_SIZE is the FA-style offset (window = offset + 1 tokens). + causal_mask = causal_mask & ((q_pos[:, None] - k_pos[None, :]) <= SLIDING_WINDOW_SIZE) # Image bidi: a Q in image span [_, e) attends to all K with k_pos < e. # For text Q (q_image_end == 0) this is k_pos < 0 = always False, so # the union with causal_mask leaves text-attention unchanged. @@ -272,7 +273,8 @@ def reference_attention( ): """Slow torch reference for the gemma4 mm prefill kernel. - `sliding_window` is the total window size including self. < 0 disables SWA. + `sliding_window` is the FA-style offset (window = sliding_window + 1 tokens). + < 0 disables SWA. """ device = q.device dtype = q.dtype @@ -306,7 +308,7 @@ def reference_attention( causal = k_pos[None, :] <= q_pos[:, None] if sliding_window >= 0: - causal = causal & ((q_pos[:, None] - k_pos[None, :]) < sliding_window) + causal = causal & ((q_pos[:, None] - k_pos[None, :]) <= sliding_window) image = k_pos[None, :] < q_image_end[:, None] allow = causal | image From b297af59da369b1a5dabd9f412c059f967a5050f Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 14 May 2026 08:54:36 +0000 Subject: [PATCH 19/20] fix --- .../template/transformer_layer_infer_template.py | 12 +++++++----- lightllm/models/gemma4/model.py | 1 + 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 276b5856f9..f0cc129c09 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -103,11 +103,13 @@ def _context_attention_wrapper_run( ) -> torch.Tensor: if torch.cuda.is_current_stream_capturing(): q = q.contiguous() - cache_kv = cache_kv.contiguous() - _q, _cache_kv = ( - tensor_to_no_ref_tensor(q), - tensor_to_no_ref_tensor(cache_kv), - ) + # cache_kv is None for layers that own no K/V slot (e.g. gemma4 + # KV-shared layers, which read K/V from a prior layer's cache and + # ignore this arg in _context_attention_kernel). Skip the + # graph-input plumbing for it instead of crashing on None. + cache_kv = cache_kv.contiguous() if cache_kv is not None else None + _q = tensor_to_no_ref_tensor(q) + _cache_kv = tensor_to_no_ref_tensor(cache_kv) if cache_kv is not None else None pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() pre_capture_graph.__exit__(None, None, None) diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py index 42093a3f29..a4a1a81b83 100644 --- a/lightllm/models/gemma4/model.py +++ b/lightllm/models/gemma4/model.py @@ -223,6 +223,7 @@ def _init_att_backend(self): def _init_att_backend1(self): # Only decode needs the sliding-layer backend; prefill sliding goes # through gemma4_mm Triton directly in the layer. + self.prefill_att_backend1 = None self.decode_att_backend1 = self._build_gemma4_sliding_backend(self._gemma4_sliding_decode_backend_kind) @staticmethod From 7a81e85c865a1a044892b5379b15491e78aa2415 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 15 May 2026 06:12:15 +0000 Subject: [PATCH 20/20] add reasoning_parser for gemma4 --- lightllm/server/api_cli.py | 1 + lightllm/server/api_openai.py | 14 ++++++++++-- lightllm/server/reasoning_parser.py | 34 +++++++++++++++++++++++++++++ lightllm/utils/config_utils.py | 4 ++++ 4 files changed, 51 insertions(+), 2 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4a345000b0..7af5aa6b89 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -193,6 +193,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "step3", "nano_v3", "interns1", + "gemma4", ], default=None, help="reasoning parser type", diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index fe4f3b50b0..2f79d730d7 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -165,8 +165,8 @@ def _is_force_thinking_mode(request: ChatCompletionRequest) -> bool: return False if reasoning_parser in ["deepseek-v3"]: return request.chat_template_kwargs is not None and request.chat_template_kwargs.get("thinking") is True - if reasoning_parser in ["qwen3", "glm45", "nano_v3", "interns1"]: - # qwen3, glm45, nano_v3, and interns1 are reasoning by default + if reasoning_parser in ["qwen3", "glm45", "nano_v3", "interns1", "gemma4"]: + # qwen3, glm45, nano_v3, interns1, and gemma4 are reasoning by default; return not request.chat_template_kwargs or request.chat_template_kwargs.get("enable_thinking", True) is True return True # default @@ -315,6 +315,16 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req "seed": request.seed, } + # Gemma-4's reasoning delimiters (<|channel>=100, =101) are + # special tokens. The default skip_special_tokens=True would drop them + # from the decoded stream and the Gemma4Detector would be unable to + # find the reasoning boundary. Mirrors vllm's + # Gemma4ReasoningParser.adjust_request behaviour. Only applied when no + # explicit value is supplied so callers can still opt back into the + # default if they want. + if get_env_start_args().reasoning_parser == "gemma4" and "skip_special_tokens" not in sampling_params_dict: + sampling_params_dict["skip_special_tokens"] = False + if request.max_completion_tokens is not None: sampling_params_dict["max_new_tokens"] = request.max_completion_tokens elif request.max_tokens is not None: diff --git a/lightllm/server/reasoning_parser.py b/lightllm/server/reasoning_parser.py index 024be4f769..fc80cb2fa6 100644 --- a/lightllm/server/reasoning_parser.py +++ b/lightllm/server/reasoning_parser.py @@ -862,6 +862,33 @@ def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False) ) +class Gemma4Detector(BaseReasoningFormatDetector): + """ + Detector for Google Gemma-4 thinking models. + + Format: ``<|channel>thought\\n...reasoning...\\nanswer``. + Role label ``thought\\n`` is baked into the start token (cf. + GptOssDetector) so the base class strips it for free. + + Note: ``<|channel>`` and ```` are special tokens (ids 100/101). + The API layer forces ``skip_special_tokens=False`` when this parser is + active so the delimiters survive decoding (see ``api_openai.py``). + """ + + THINK_START_TOKEN = "<|channel>thought\n" + THINK_END_TOKEN = "" + + def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False): + # force_reasoning ignored: Gemma-4's template never starts generation + # inside an open channel (ReasoningParser pins it to False too). + super().__init__( + self.THINK_START_TOKEN, + self.THINK_END_TOKEN, + force_reasoning=False, + stream_reasoning=stream_reasoning, + ) + + class ReasoningParser: """ Parser that handles both streaming and non-streaming scenarios for extracting @@ -887,6 +914,7 @@ class ReasoningParser: "step3": DeepSeekR1Detector, "nano_v3": NanoV3Detector, "interns1": Qwen3Detector, + "gemma4": Gemma4Detector, } def __init__( @@ -905,6 +933,12 @@ def __init__( # Special cases where we override force_reasoning if model_type.lower() in {"qwen3-thinking", "gpt-oss", "minimax"}: force_reasoning = True + elif model_type.lower() == "gemma4": + # Gemma-4's chat template never positions generation inside an open + # channel — see Gemma4Detector docstring. Pin to False so a + # request_enable_reasoning=True from the caller can't accidentally + # mark the parser as already inside reasoning. + force_reasoning = False # Only pass force_reasoning if explicitly set, let detectors use their defaults kwargs = {"stream_reasoning": stream_reasoning} diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 375f358b00..548a36aeb0 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -469,4 +469,8 @@ def get_reasoning_parser_for_model(model_path: str) -> Optional[str]: if model_type == "deepseek_r1": return "deepseek-r1" + # Gemma-4 (all variants share the same Harmony-like <|channel>... format) + if model_type == "gemma4": + return "gemma4" + return None