diff --git a/common/speculative.cpp b/common/speculative.cpp index 4d1b61a13ad..52d590e51e4 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -33,16 +33,15 @@ const std::map common_speculative_type_fro }; static std::string common_speculative_get_devices_str(const std::vector & devices) { - if (devices.empty()) { - return "default"; - } - std::string result; for (size_t i = 0; i < devices.size(); i++) { - if (i > 0) result += ", "; + if (devices[i] == nullptr) { + continue; + } + if (!result.empty()) result += ", "; result += ggml_backend_dev_name(devices[i]); } - return result; + return result.empty() ? "default" : result; } struct common_speculative_config { @@ -416,6 +415,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { int32_t n_embd = 0; + bool kv_shared_with_target = false; + // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. // The last h-row of one process() call needs the first token of the NEXT // call to pair with, so it's stashed here until that next call fires. @@ -442,7 +443,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { auto * ctx_dft = this->params.ctx_dft; GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); - n_embd = llama_model_n_embd(llama_get_model(ctx_dft)); + n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft)); + GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) && + "MTP input row width must match the target h_pre_norm width"); LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__); LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd); @@ -471,6 +474,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false); llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true); + llama_set_mtp_source(ctx_dft, ctx_tgt); + + kv_shared_with_target = llama_model_n_layer_kv(llama_get_model(ctx_dft)) == 0; pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); @@ -496,9 +502,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { if (N <= 0) { return; } + auto * ctx_dft = this->params.ctx_dft; const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); - if (pos_max < N - 1) { + if (pos_max < N - 1 && !kv_shared_with_target) { LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - " "process() hook may not have run on every prefill ubatch " "(need_embd / logits=1 on every prompt position?). " @@ -541,48 +548,42 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { const size_t row_bytes = (size_t) n_embd * sizeof(float); - common_batch_clear(batch); + // if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode + if (!kv_shared_with_target) { + common_batch_clear(batch); - for (int k = 0; k < n_tokens; ++k) { - common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0); - } + for (int k = 0; k < n_tokens; ++k) { + common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0); + } - // shift the tgt embeddings to the right by one position - // assumes that the tokens in the batch are sequential for each sequence - // i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1] - // ^--- this is a problem - // TODO:this is generally true, but would be nice to assert it - { - const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt); - std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1)); + // shift the tgt embeddings to the right by one position + // assumes that the tokens in the batch are sequential for each sequence + // i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1] + // ^--- this is a problem + // TODO:this is generally true, but would be nice to assert it + { + const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt); + std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1)); + } - //{ - // // string with seq_ids in the batch - // std::stringstream ss; - // for (int i = 0; i < n_tokens; ++i) { - // ss << batch_in.seq_id[i][0] << ","; - // } - // LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str()); - //} - } + // fill the pending embeddings from a previous run + auto set_h = [&](int idx, const float * h_row) { + std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes); + }; - // fill the pending embeddings from a previous run - auto set_h = [&](int idx, const float * h_row) { - std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes); - }; + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_beg[seq_id] < 0) { + continue; + } - for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { - if (i_batch_beg[seq_id] < 0) { - continue; + set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); } - set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); - } - - const int32_t rc = llama_decode(ctx_dft, batch); - if (rc != 0) { - LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); - return false; + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); + return false; + } } for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { diff --git a/conversion/__init__.py b/conversion/__init__.py index 2c38123dff8..e3d919b61ad 100644 --- a/conversion/__init__.py +++ b/conversion/__init__.py @@ -73,6 +73,7 @@ "Gemma3TextModel": "gemma", "Gemma3nForCausalLM": "gemma", "Gemma3nForConditionalGeneration": "gemma", + "Gemma4AssistantForCausalLM": "gemma", "Gemma4ForConditionalGeneration": "gemma", "GemmaForCausalLM": "gemma", "Glm4ForCausalLM": "glm", diff --git a/conversion/gemma.py b/conversion/gemma.py index a6e14fbcb98..e8f50ebe048 100644 --- a/conversion/gemma.py +++ b/conversion/gemma.py @@ -765,6 +765,16 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Gemma4AssistantForCausalLM") +class Gemma4AssistantModel(Gemma4Model): + model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_embedding_length_out(self.hparams["backbone_hidden_size"]) + self.gguf_writer.add_nextn_predict_layers(self.block_count) + + @ModelBase.register("Gemma4ForConditionalGeneration") class Gemma4VisionAudioModel(MmprojModel): has_audio_encoder = True diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c25f217f990..24c5c296ff1 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -430,6 +430,7 @@ class MODEL_ARCH(IntEnum): GEMMA3 = auto() GEMMA3N = auto() GEMMA4 = auto() + GEMMA4_ASSISTANT = auto() GEMMA_EMBEDDING = auto() STARCODER2 = auto() RWKV6 = auto() @@ -856,6 +857,8 @@ class MODEL_TENSOR(IntEnum): A_PER_DIM_K_SCALE = auto() # gemma4 A_PER_DIM_SCALE = auto() # gemma4 # nextn/mtp + NEXTN_PRE_PROJ = auto() + NEXTN_POST_PROJ = auto() NEXTN_EH_PROJ = auto() NEXTN_EMBED_TOKENS = auto() NEXTN_ENORM = auto() @@ -945,6 +948,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA3: "gemma3", MODEL_ARCH.GEMMA3N: "gemma3n", MODEL_ARCH.GEMMA4: "gemma4", + MODEL_ARCH.GEMMA4_ASSISTANT: "gemma4-assistant", MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.RWKV6: "rwkv6", @@ -1401,6 +1405,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.A_QF_FFN_DOWN: "a.proj_blk.{bid}.ffn_down", MODEL_TENSOR.A_QF_FFN_NORM: "a.proj_blk.{bid}.ffn_norm", # NextN/MTP + MODEL_TENSOR.NEXTN_PRE_PROJ: "nextn.pre_projection", + MODEL_TENSOR.NEXTN_POST_PROJ: "nextn.post_projection", MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj", MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens", MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm", @@ -2481,6 +2487,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.PER_LAYER_PROJ_NORM, MODEL_TENSOR.PER_LAYER_POST_NORM, ], + MODEL_ARCH.GEMMA4_ASSISTANT: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.NEXTN_PRE_PROJ, + MODEL_TENSOR.NEXTN_POST_PROJ, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_PRE_NORM, + MODEL_TENSOR.FFN_POST_NORM, + MODEL_TENSOR.LAYER_OUT_SCALE, + ], MODEL_ARCH.GEMMA_EMBEDDING: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index f40cb828201..4f0fe749405 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -2245,6 +2245,14 @@ class TensorNameMap: ), # NextN/MTP tensors + MODEL_TENSOR.NEXTN_PRE_PROJ: ( + "pre_projection", + ), + + MODEL_TENSOR.NEXTN_POST_PROJ: ( + "post_projection", + ), + MODEL_TENSOR.NEXTN_EH_PROJ: ( "model.layers.{bid}.eh_proj", ), diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index c9eead18aa3..aa76d20608b 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -57,6 +57,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA4, "gemma4" }, + { LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, @@ -445,6 +446,8 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_NEXTN_PRE_PROJ, "nextn.pre_projection" }, + { LLM_TENSOR_NEXTN_POST_PROJ, "nextn.post_projection" }, { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, @@ -757,6 +760,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PRE_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_POST_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so // the model loader doesn't fault on the block index. diff --git a/src/llama-arch.h b/src/llama-arch.h index 89cf16cc37c..cba0ed5d3d8 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -61,6 +61,7 @@ enum llm_arch { LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA4, + LLM_ARCH_GEMMA4_ASSISTANT, LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, @@ -549,6 +550,8 @@ enum llm_tensor { LLM_TENSOR_INDEXER_PROJ, LLM_TENSOR_INDEXER_ATTN_K, LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_NEXTN_PRE_PROJ, + LLM_TENSOR_NEXTN_POST_PROJ, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3cc8ffa6668..96f35e63e69 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -30,6 +30,70 @@ static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) { throw std::runtime_error("Unsupported ctx type"); } +static uint32_t ctx_type_to_embd_inp(const llama_hparams & hparams, llama_context_type ctx_type) { + switch (ctx_type) { + case LLAMA_CONTEXT_TYPE_DEFAULT: return hparams.n_embd_inp(); + case LLAMA_CONTEXT_TYPE_MTP : return hparams.n_embd_out(); + } + throw std::runtime_error("Unsupported ctx type"); +} + +namespace { +struct src_mctx_reset_on_exit { + llama_memory_context_ptr * slot; + ~src_mctx_reset_on_exit() { if (slot) slot->reset(); } +}; + +static void llama_assert_gemma4_mtp_source_placement( + const llama_context * ctx, + const llama_context * src) { + if (!ctx || !src) { + return; + } + + const auto & model_dft = ctx->get_model(); + const auto & model_tgt = src->get_model(); + + if (model_dft.arch != LLM_ARCH_GEMMA4_ASSISTANT || model_tgt.arch != LLM_ARCH_GEMMA4) { + return; + } + + if (model_tgt.split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + return; + } + + const auto & hparams_dft = model_dft.hparams; + const auto & hparams_tgt = model_tgt.hparams; + + const int32_t il_tgt_full = (int32_t) hparams_tgt.n_layer - 1; + const int32_t il_tgt_swa = (int32_t) hparams_tgt.n_layer - 2; + + ggml_backend_dev_t dev_cpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!dev_cpu) { + throw std::runtime_error("Gemma 4 assistant MTP placement check failed: no CPU backend found"); + } + + const bool kv_offload = src->get_cparams().offload_kqv; + + for (uint32_t il_dft = 0; il_dft < hparams_dft.n_layer; ++il_dft) { + const int32_t il_tgt = hparams_dft.is_swa(il_dft) ? il_tgt_swa : il_tgt_full; + + ggml_backend_dev_t dev_dft = model_dft.dev_layer(il_dft); + ggml_backend_dev_t dev_kv = kv_offload ? model_tgt.dev_layer(il_tgt) : dev_cpu; + + if (dev_dft != dev_kv) { + throw std::runtime_error(format( + "Gemma 4 assistant MTP placement mismatch: draft layer %d is on %s, " + "but shared target KV layer %d is on %s", + (int) il_dft, + ggml_backend_dev_name(dev_dft), + (int) il_tgt, + ggml_backend_dev_name(dev_kv))); + } + } +} +} + llama_context::llama_context( const llama_model & model, llama_context_params params) : @@ -368,7 +432,11 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); } - sched_reserve(); + // MTP draft contexts can't reserve until the source context is wired + // via llama_set_mtp_source — defer to the first decode. + if (cparams.ctx_type != LLAMA_CONTEXT_TYPE_MTP) { + sched_reserve(); + } if (!cparams.flash_attn) { if (ggml_is_quantized(params.type_v)) { @@ -442,6 +510,23 @@ void llama_context::sched_reserve() { } } + // When called from decode(), src_mctx_for_decode is already populated and + // we must not drop it on exit (process_ubatch still needs it). Snapshot + // only when sched_reserve runs standalone (e.g. lazy first-decode reserve + // when set_mtp_source flipped sched_need_reserve). + const bool owns_src_snapshot = src_ctx && !src_mctx_for_decode; + if (owns_src_snapshot) { + auto * src_memory = src_ctx->get_memory(); + if (!src_memory) { + throw std::runtime_error("MTP source context has no memory module"); + } + src_mctx_for_decode = src_memory->init_full(); + if (!src_mctx_for_decode) { + throw std::runtime_error("failed to initialize MTP source memory snapshot"); + } + } + src_mctx_reset_on_exit reserve_src_drop{owns_src_snapshot ? &src_mctx_for_decode : nullptr}; + // avoid reserving graphs with zero outputs - assume one output per sequence const int n_outputs = n_seqs; @@ -896,10 +981,9 @@ float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { throw std::runtime_error("no pre-norm embeddings"); } - const uint32_t n_embd = model.hparams.n_embd; + const uint32_t n_embd = model.hparams.n_embd_out(); if (!cparams.embeddings_pre_norm_masked) { - // unmasked: pre-norm rows are stored densely, indexed by raw token position. if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); } @@ -1105,6 +1189,18 @@ void llama_context::set_embeddings_pre_norm(bool value, bool masked) { cparams.embeddings_pre_norm_masked = masked; } +void llama_context::set_mtp_source(llama_context * src) { + if (src_ctx == src) { + return; + } + llama_assert_gemma4_mtp_source_placement(this, src); + src_ctx = src; + src_mctx_for_decode.reset(); + // worst-case compute buffers were reserved without knowing about the source + // memory; force a re-reserve so the next decode sees src views + sched_need_reserve = true; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1317,7 +1413,7 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd_inp(); + const int64_t n_embd = ctx_type_to_embd_inp(hparams, cparams.ctx_type); const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 @@ -1452,7 +1548,7 @@ int llama_context::encode(const llama_batch & batch_inp) { ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; + const uint32_t n_embd = hparams.n_embd_out(); GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); } @@ -1627,7 +1723,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; const int64_t n_vocab = vocab.n_tokens(); - const int64_t n_embd = hparams.n_embd_inp(); + const int64_t n_embd = ctx_type_to_embd_inp(hparams, cparams.ctx_type); // when computing embeddings, all tokens are output const bool output_all = cparams.embeddings; @@ -1689,6 +1785,20 @@ int llama_context::decode(const llama_batch & batch_inp) { embd_seq.clear(); output_swaps.clear(); + src_mctx_reset_on_exit decode_src_drop{&src_mctx_for_decode}; + if (src_ctx) { + auto * src_memory = src_ctx->get_memory(); + if (!src_memory) { + LLAMA_LOG_ERROR("%s: MTP source context has no memory module\n", __func__); + return -2; + } + src_mctx_for_decode = src_memory->init_full(); + if (!src_mctx_for_decode) { + LLAMA_LOG_ERROR("%s: failed to snapshot MTP source memory\n", __func__); + return -2; + } + } + sched_reserve(); bool did_optimize = false; @@ -1903,7 +2013,7 @@ int llama_context::decode(const llama_batch & batch_inp) { ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; + const uint32_t n_embd = hparams.n_embd_out(); float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); @@ -1996,7 +2106,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); - const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); bool has_logits = true; @@ -2015,12 +2124,10 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { logits.size = has_logits ? n_vocab*n_outputs_max : 0; embd.size = has_embd ? n_embd_out*n_outputs_max : 0; - embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + embd_pre_norm.size = has_embd_pre_norm ? n_embd_out*n_outputs_max : 0; if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { - // unmasked: pre-norm row exists for every token in the batch, not just - // those flagged via batch.logits[i] -> size by token count instead. - embd_pre_norm.size = (size_t) n_embd * n_batch; + embd_pre_norm.size = (size_t) n_embd_out * n_batch; } // Allocate backend sampling output buffers if there are backend samplers configured. @@ -2283,6 +2390,8 @@ llm_graph_params llama_context::graph_params( /*.cvec =*/ cvec.get(), /*.loras =*/ loras.get(), /*.mctx =*/ mctx, + /*.src_mctx =*/ src_mctx_for_decode.get(), + /*.src_model =*/ src_ctx ? &src_ctx->get_model() : nullptr, /*.cross =*/ &cross, /*.samplers =*/ sampling.samplers, /*.n_outputs =*/ n_outputs, @@ -3575,6 +3684,10 @@ void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) ctx->set_embeddings_pre_norm(value, masked); } +void llama_set_mtp_source(llama_context * ctx, llama_context * src) { + ctx->set_mtp_source(src); +} + float * llama_get_embeddings_pre_norm(llama_context * ctx) { ctx->synchronize(); diff --git a/src/llama-context.h b/src/llama-context.h index d03f681d4a1..cffd8a83a1c 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -6,6 +6,7 @@ #include "llama-graph.h" #include "llama-adapter.h" #include "llama-impl.h" +#include "llama-memory.h" #include "ggml-cpp.h" #include "ggml-opt.h" @@ -111,6 +112,7 @@ struct llama_context { void set_embeddings (bool value); void set_embeddings_pre_norm(bool value, bool masked); + void set_mtp_source(llama_context * src); void set_causal_attn(bool value); void set_warmup(bool value); @@ -275,6 +277,12 @@ struct llama_context { std::unique_ptr memory; + // external KV source used by MTP draft contexts. src_ctx is the target + // context whose memory we read; src_mctx_for_decode is a per-decode + // snapshot held for the duration of one decode/sched_reserve call. + llama_context * src_ctx = nullptr; + llama_memory_context_ptr src_mctx_for_decode; + // decode output (2-dimensional array: [n_outputs][n_vocab]) buffer_view logits = {nullptr, 0}; diff --git a/src/llama-ext.h b/src/llama-ext.h index edfa71c207c..9e1cf727996 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -85,6 +85,11 @@ using llama_memory_breakdown = std::mapget_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + src_mctx->get_swa() ->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); +} + +bool llm_graph_input_attn_src_kv_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.src_mctx); + + this->src_mctx = mctx; + + bool res = true; + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + return res; +} + void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { GGML_ASSERT(cross_kq_mask); @@ -953,6 +969,8 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : cvec (params.cvec), loras (params.loras), mctx (params.mctx), + src_mctx (params.src_mctx), + src_model (params.src_model), cross (params.cross), samplers (params.samplers), cb_func (params.cb), @@ -2441,6 +2459,82 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); } +llm_graph_input_attn_src_kv_iswa * llm_graph_context::build_attn_inp_src_kv_iswa() const { + GGML_ASSERT(src_mctx && "MTP draft graph requires src_mctx (set via llama_set_mtp_source)"); + + const auto * src_iswa = static_cast(src_mctx); + + auto inp = std::make_unique(hparams, cparams, src_iswa); + + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, src_iswa->get_base(), ubatch, cparams); + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + + inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, src_iswa->get_swa(), ubatch, cparams); + inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + + return (llm_graph_input_attn_src_kv_iswa *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_src_kv_iswa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + float kq_scale, + int il_assist, + int il_src) const { + const bool is_swa = hparams.is_swa(il_assist); + + const auto * src_iswa = inp->src_mctx; + const auto * src_cur = is_swa ? src_iswa->get_swa() : src_iswa->get_base(); + + const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); + + ggml_build_forward_expand(gf, q_cur); + + ggml_tensor * q = q_cur; + ggml_tensor * k = src_cur->get_k(ctx0, il_src); + ggml_tensor * v = src_cur->get_v(ctx0, il_src); + + // build_attn_mha splits q across k->ne[3] (the trunk's stream count). When the + // trunk runs kv_unified=false the assistant's ubatch only references a subset + // of streams (one per active draft seq); q->ne[2] is not divisible by the full + // n_stream and the view collapses tokens. Slice k/v down to exactly the streams + // referenced by this ubatch. Requires those streams to form a contiguous range. + if (k->ne[3] > 1 && (uint32_t) k->ne[3] != ubatch.n_seqs_unq) { + GGML_ASSERT(ubatch.n_seqs_unq > 0 && ubatch.seq_id_unq); + llama_seq_id min_s = ubatch.seq_id_unq[0]; + llama_seq_id max_s = ubatch.seq_id_unq[0]; + for (uint32_t s = 1; s < ubatch.n_seqs_unq; ++s) { + min_s = std::min(min_s, ubatch.seq_id_unq[s]); + max_s = std::max(max_s, ubatch.seq_id_unq[s]); + } + GGML_ASSERT((uint32_t)(max_s - min_s + 1) == ubatch.n_seqs_unq && + "MTP src-kv attn requires the active draft seq_ids to be contiguous"); + GGML_ASSERT((int64_t) max_s < k->ne[3] && "MTP assistant seq_id beyond trunk stream count"); + + k = ggml_view_4d(ctx0, k, k->ne[0], k->ne[1], k->ne[2], (int64_t) ubatch.n_seqs_unq, + k->nb[1], k->nb[2], k->nb[3], (size_t) min_s * k->nb[3]); + v = ggml_view_4d(ctx0, v, v->ne[0], v->ne[1], v->ne[2], (int64_t) ubatch.n_seqs_unq, + v->nb[1], v->nb[2], v->nb[3], (size_t) min_s * v->nb[3]); + } + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il_assist); + cb(cur, "kqv_out", il_assist); + + if (wo) { + cur = build_lora_mm(wo, cur, wo_s); + } + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + return cur; +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_cross * inp, ggml_tensor * wo, diff --git a/src/llama-graph.h b/src/llama-graph.h index bf6778237e6..b8d147ca1f1 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -402,6 +402,37 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { const llama_kv_cache_iswa_context * mctx; }; +// mask-only input for attention against an external (read-only) ISWA KV cache. +// used by MTP draft graphs that attend to the target's KV without owning any. +class llm_graph_input_attn_src_kv_iswa : public llm_graph_input_i { +public: + llm_graph_input_attn_src_kv_iswa( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_iswa_context * src_mctx) : + hparams(hparams), + cparams(cparams), + src_mctx(src_mctx) { + } + ~llm_graph_input_attn_src_kv_iswa() = default; + + void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } + + ggml_tensor * self_kq_mask = nullptr; + ggml_tensor * self_kq_mask_cnv = nullptr; + ggml_tensor * self_kq_mask_swa = nullptr; + ggml_tensor * self_kq_mask_swa_cnv = nullptr; + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_iswa_context * src_mctx; +}; + class llm_graph_input_attn_cross : public llm_graph_input_i { public: llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} @@ -544,6 +575,11 @@ struct llm_graph_params { const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; const llama_memory_context_i * mctx; + // per-decode snapshot of an external memory module the graph reads from + // (never writes) — e.g. ctx_dft reading target KV during MTP draft. + // nullptr for a main decode. Rebound inside reuse-aware input classes. + const llama_memory_context_i * src_mctx; + const llama_model * src_model; const llama_cross * cross; std::map samplers; @@ -761,6 +797,8 @@ struct llm_graph_context { const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; const llama_memory_context_i * mctx; + const llama_memory_context_i * src_mctx; + const llama_model * src_model; const llama_cross * cross; std::map samplers; @@ -973,6 +1011,24 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_src_kv_iswa * build_attn_inp_src_kv_iswa() const; + + // Q-only attention against an external ISWA KV cache (no K/V projections, + // no writes). il_assist labels the attention block in the local graph for + // logging; il_src indexes the source K/V layer to attend to. + ggml_tensor * build_attn( + llm_graph_input_attn_src_kv_iswa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il_assist, + int il_src) const; + llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a49a055a630..c460d7822cd 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -2433,6 +2433,10 @@ uint32_t llama_kv_cache_context::get_n_kv() const { return n_kv; } +llama_pos llama_kv_cache_context::seq_pos_max(llama_seq_id seq_id) const { + return kv->seq_pos_max(seq_id); +} + ggml_type llama_kv_cache_context::type_k() const { return kv->type_k(); } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 0b62dc7b232..b658d5e664b 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -350,6 +350,11 @@ class llama_kv_cache_context : public llama_memory_context_i { uint32_t get_n_kv() const; + // last position recorded in the cache for this sequence; -1 if absent. + // exposed for cross-context KV consumers (e.g. MTP draft) that need to + // anchor the source position without owning a memory module of their own. + llama_pos seq_pos_max(llama_seq_id seq_id) const; + ggml_type type_k() const; ggml_type type_v() const; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8bf20a716eb..6890048e340 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -134,6 +134,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_gemma3n(params); case LLM_ARCH_GEMMA4: return new llama_model_gemma4(params); + case LLM_ARCH_GEMMA4_ASSISTANT: + return new llama_model_gemma4_assistant(params); case LLM_ARCH_GEMMA_EMBEDDING: return new llama_model_gemma_embedding(params); case LLM_ARCH_STARCODER2: @@ -2311,6 +2313,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: case LLM_ARCH_GEMMA4: + case LLM_ARCH_GEMMA4_ASSISTANT: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: @@ -2503,6 +2506,10 @@ int32_t llama_model_n_devices(const struct llama_model * model) { return (int32_t)model->devices.size(); } +int32_t llama_model_n_layer_kv(const struct llama_model * model) { + return (int32_t) model->hparams.n_layer_kv(); +} + ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i) { if (i < 0 || i >= (int)model->devices.size()) { return nullptr; diff --git a/src/llama-model.h b/src/llama-model.h index 01c87a75271..a3903c006c8 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -538,6 +538,10 @@ struct llama_model { struct ggml_tensor * output_s = nullptr; struct ggml_tensor * output_in_s = nullptr; + // NextN/MTP model-level projections + struct ggml_tensor * nextn_pre_proj = nullptr; + struct ggml_tensor * nextn_post_proj = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; diff --git a/src/models/gemma4.cpp b/src/models/gemma4.cpp index 4f9d8b18bc7..c787a5490ff 100644 --- a/src/models/gemma4.cpp +++ b/src/models/gemma4.cpp @@ -135,6 +135,213 @@ std::unique_ptr llama_model_gemma4::build_arch_graph(const ll return std::make_unique(*this, params); } +void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t) n_kv_shared_layers; + hparams.f_attention_scale = 1.0f; + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + + if (hparams.n_layer == 4) { + type = LLM_TYPE_31B; + } +} + +void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + if (hparams.n_embd_out() == n_embd) { + throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + + const int64_t n_embd_backbone = hparams.n_embd_out(); + nextn_pre_proj = create_tensor(tn(LLM_TENSOR_NEXTN_PRE_PROJ, "weight"), { 2*n_embd_backbone, n_embd }, 0); + nextn_post_proj = create_tensor(tn(LLM_TENSOR_NEXTN_POST_PROJ, "weight"), { n_embd, n_embd_backbone }, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_ff = hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0); + + if (!hparams.is_swa(i)) { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0); + } +} + +std::unique_ptr llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + GGML_ASSERT(src_mctx && "Gemma 4 assistant graph requires an MTP source (llama_set_mtp_source)"); + GGML_ASSERT(src_model && "Gemma 4 assistant graph requires a source model"); + GGML_ASSERT(src_model->tok_embd && "source model missing tok_embd"); + + const auto & src_hparams = src_model->hparams; + + // By convention the MTP draft reads from the trunk's final SWA and full layers. + const int32_t src_layer_full = (int32_t) src_hparams.n_layer - 1; + const int32_t src_layer_swa = (int32_t) src_hparams.n_layer - 2; + GGML_ASSERT(!src_hparams.is_swa(src_layer_full) && "trunk's last layer must be full attention"); + GGML_ASSERT( src_hparams.is_swa(src_layer_swa) && "trunk's penultimate layer must be SWA"); + + const int64_t n_embd_backbone = hparams.n_embd_out(); + + ggml_tensor * inp_tokens; + ggml_tensor * inp_h; + { + auto inp = std::make_unique(n_embd_backbone); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + inp_tokens = inp->tokens; + res->t_inp_tokens = inp->tokens; + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens); + cb(inp->embd, "inp_h", -1); + ggml_set_input(inp->embd); + inp_h = inp->embd; + res->t_inp_embd = inp->embd; + + res->add_input(std::move(inp)); + } + + ggml_tensor * x = ggml_get_rows(ctx0, src_model->tok_embd, inp_tokens); + x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone)); + cb(x, "inp_embd_target", -1); + + ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0); + cb(xh, "inp_xh", -1); + + ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_pre_proj, xh); + cb(cur, "pre_proj", -1); + + auto * inp_attn = build_attn_inp_src_kv_iswa(); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const bool is_swa = hparams.is_swa(il); + const int32_t il_src = is_swa ? src_layer_swa : src_layer_full; + + const int64_t n_embd_head = hparams.n_embd_head_k(il); + const int64_t n_head = hparams.n_head(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur_norm, "attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs; + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr, + Qcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il, il_src); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, nullptr, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, attn_out); + + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + ggml_tensor * logits = build_lora_mm(model.output, cur); + cb(logits, "result_output", -1); + res->t_logits = logits; + + ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_post_proj, cur); + cb(h_next, "result_h_pre_norm", -1); + res->t_h_pre_norm = h_next; + + ggml_build_forward_expand(gf, logits); + ggml_build_forward_expand(gf, h_next); +} + // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { GGML_ASSERT(idx < (int) x->ne[2]); @@ -245,7 +452,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para } // TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } @@ -345,7 +552,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens] // TODO @ngxson : improve this - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids); } @@ -372,6 +579,12 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); diff --git a/src/models/models.h b/src/models/models.h index 7e551eb965b..bc20a82bc0e 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -797,6 +797,19 @@ struct llama_model_gemma4 : public llama_model_base { }; +struct llama_model_gemma4_assistant : public llama_model_base { + llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_gemma_embedding : public llama_model_base { llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index dc3189e1705..ca973a1d6ea 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -9,6 +9,7 @@ #include "build-info.h" #include "common.h" #include "llama.h" +#include "../../src/llama-ext.h" // staging API: llama_set_mtp_source #include "log.h" #include "sampling.h" #include "speculative.h" @@ -805,6 +806,11 @@ struct server_context_impl { cparams.n_rs_seq = 0; ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + if (spec_mtp) { + // MTP draft must know its target before the first decode + llama_set_mtp_source(ctx_dft.get(), ctx_tgt); + } + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); params_base.speculative.draft.ctx_tgt = ctx_tgt; @@ -824,6 +830,10 @@ struct server_context_impl { return false; } + // wire the source before any decode (the seq-rm probe below + // triggers sched_reserve which needs src for Gemma4-style MTP) + llama_set_mtp_source(ctx_dft.get(), ctx_tgt); + ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get()); params_base.speculative.draft.ctx_tgt = ctx_tgt;