Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 44 additions & 43 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ const std::map<std::string, common_speculative_type> common_speculative_type_fro
};

static std::string common_speculative_get_devices_str(const std::vector<ggml_backend_dev_t> & devices) {
if (devices.empty()) {
return "default";
}

std::string result;
for (size_t i = 0; i < devices.size(); i++) {
if (i > 0) result += ", ";
if (devices[i] == nullptr) {
continue;
}
if (!result.empty()) result += ", ";
result += ggml_backend_dev_name(devices[i]);
}
return result;
return result.empty() ? "default" : result;
}

struct common_speculative_config {
Expand Down Expand Up @@ -416,6 +415,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {

int32_t n_embd = 0;

bool kv_shared_with_target = false;

// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
// call to pair with, so it's stashed here until that next call fires.
Expand All @@ -442,7 +443,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");

n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
"MTP input row width must match the target h_pre_norm width");

LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd);
Expand Down Expand Up @@ -471,6 +474,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {

llama_set_embeddings_pre_norm(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_pre_norm(ctx_dft, true, /*masked*/ true);
llama_set_mtp_source(ctx_dft, ctx_tgt);

kv_shared_with_target = llama_model_n_layer_kv(llama_get_model(ctx_dft)) == 0;

pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));

Expand All @@ -496,9 +502,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
if (N <= 0) {
return;
}

auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1) {
if (pos_max < N - 1 && !kv_shared_with_target) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
Expand Down Expand Up @@ -541,48 +548,42 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {

const size_t row_bytes = (size_t) n_embd * sizeof(float);

common_batch_clear(batch);
// if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode
if (!kv_shared_with_target) {
common_batch_clear(batch);

for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}

// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_pre_norm(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
}

//{
// // string with seq_ids in the batch
// std::stringstream ss;
// for (int i = 0; i < n_tokens; ++i) {
// ss << batch_in.seq_id[i][0] << ",";
// }
// LOG_WRN("%s: batch_in.seq_id = %s\n", __func__, ss.str().c_str());
//}
}
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};

// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}

for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}

set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}

const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
}
}

for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
Expand Down
1 change: 1 addition & 0 deletions conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"Gemma3TextModel": "gemma",
"Gemma3nForCausalLM": "gemma",
"Gemma3nForConditionalGeneration": "gemma",
"Gemma4AssistantForCausalLM": "gemma",
"Gemma4ForConditionalGeneration": "gemma",
"GemmaForCausalLM": "gemma",
"Glm4ForCausalLM": "glm",
Expand Down
10 changes: 10 additions & 0 deletions conversion/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,16 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Gemma4AssistantForCausalLM")
class Gemma4AssistantModel(Gemma4Model):
model_arch = gguf.MODEL_ARCH.GEMMA4_ASSISTANT

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_embedding_length_out(self.hparams["backbone_hidden_size"])
self.gguf_writer.add_nextn_predict_layers(self.block_count)


@ModelBase.register("Gemma4ForConditionalGeneration")
class Gemma4VisionAudioModel(MmprojModel):
has_audio_encoder = True
Expand Down
24 changes: 24 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ class MODEL_ARCH(IntEnum):
GEMMA3 = auto()
GEMMA3N = auto()
GEMMA4 = auto()
GEMMA4_ASSISTANT = auto()
GEMMA_EMBEDDING = auto()
STARCODER2 = auto()
RWKV6 = auto()
Expand Down Expand Up @@ -856,6 +857,8 @@ class MODEL_TENSOR(IntEnum):
A_PER_DIM_K_SCALE = auto() # gemma4
A_PER_DIM_SCALE = auto() # gemma4
# nextn/mtp
NEXTN_PRE_PROJ = auto()
NEXTN_POST_PROJ = auto()
NEXTN_EH_PROJ = auto()
NEXTN_EMBED_TOKENS = auto()
NEXTN_ENORM = auto()
Expand Down Expand Up @@ -945,6 +948,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GEMMA3: "gemma3",
MODEL_ARCH.GEMMA3N: "gemma3n",
MODEL_ARCH.GEMMA4: "gemma4",
MODEL_ARCH.GEMMA4_ASSISTANT: "gemma4-assistant",
MODEL_ARCH.GEMMA_EMBEDDING: "gemma-embedding",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.RWKV6: "rwkv6",
Expand Down Expand Up @@ -1401,6 +1405,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.A_QF_FFN_DOWN: "a.proj_blk.{bid}.ffn_down",
MODEL_TENSOR.A_QF_FFN_NORM: "a.proj_blk.{bid}.ffn_norm",
# NextN/MTP
MODEL_TENSOR.NEXTN_PRE_PROJ: "nextn.pre_projection",
MODEL_TENSOR.NEXTN_POST_PROJ: "nextn.post_projection",
MODEL_TENSOR.NEXTN_EH_PROJ: "blk.{bid}.nextn.eh_proj",
MODEL_TENSOR.NEXTN_EMBED_TOKENS: "blk.{bid}.nextn.embed_tokens",
MODEL_TENSOR.NEXTN_ENORM: "blk.{bid}.nextn.enorm",
Expand Down Expand Up @@ -2481,6 +2487,24 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.PER_LAYER_PROJ_NORM,
MODEL_TENSOR.PER_LAYER_POST_NORM,
],
MODEL_ARCH.GEMMA4_ASSISTANT: [
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.NEXTN_PRE_PROJ,
MODEL_TENSOR.NEXTN_POST_PROJ,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.LAYER_OUT_SCALE,
],
MODEL_ARCH.GEMMA_EMBEDDING: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT,
Expand Down
8 changes: 8 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2245,6 +2245,14 @@ class TensorNameMap:
),

# NextN/MTP tensors
MODEL_TENSOR.NEXTN_PRE_PROJ: (
"pre_projection",
),

MODEL_TENSOR.NEXTN_POST_PROJ: (
"post_projection",
),

MODEL_TENSOR.NEXTN_EH_PROJ: (
"model.layers.{bid}.eh_proj",
),
Expand Down
5 changes: 5 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_GEMMA3, "gemma3" },
{ LLM_ARCH_GEMMA3N, "gemma3n" },
{ LLM_ARCH_GEMMA4, "gemma4" },
{ LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" },
{ LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" },
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
Expand Down Expand Up @@ -445,6 +446,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
{ LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" },
{ LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" },
{ LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" },
{ LLM_TENSOR_NEXTN_PRE_PROJ, "nextn.pre_projection" },
{ LLM_TENSOR_NEXTN_POST_PROJ, "nextn.post_projection" },
{ LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" },
{ LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" },
{ LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" },
Expand Down Expand Up @@ -757,6 +760,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_PRE_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_NEXTN_POST_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
// the model loader doesn't fault on the block index.
Expand Down
3 changes: 3 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ enum llm_arch {
LLM_ARCH_GEMMA3,
LLM_ARCH_GEMMA3N,
LLM_ARCH_GEMMA4,
LLM_ARCH_GEMMA4_ASSISTANT,
LLM_ARCH_GEMMA_EMBEDDING,
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
Expand Down Expand Up @@ -549,6 +550,8 @@ enum llm_tensor {
LLM_TENSOR_INDEXER_PROJ,
LLM_TENSOR_INDEXER_ATTN_K,
LLM_TENSOR_INDEXER_ATTN_Q_B,
LLM_TENSOR_NEXTN_PRE_PROJ,
LLM_TENSOR_NEXTN_POST_PROJ,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,
Expand Down
Loading
Loading