diff --git a/include/llama.h b/include/llama.h index 9f78aa9a056..e9dc1635093 100644 --- a/include/llama.h +++ b/include/llama.h @@ -769,6 +769,11 @@ extern "C" { // Check if the memory supports shifting LLAMA_API bool llama_memory_can_shift(llama_memory_t mem); + LLAMA_API bool llama_requantize_memory( + struct llama_context * ctx, + ggml_type ctk, + ggml_type ctv); + // // State / sessions // diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f59381a4d75..9b8f4b0a3b8 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,6 +6,8 @@ #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" +#include "llama-kv-cache.h" +#include "llama-memory-recurrent.h" #include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" @@ -71,7 +73,7 @@ llama_context::llama_context( cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; cparams.warmup = false; - + cparams.swa_full = params.swa_full; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; @@ -2005,6 +2007,66 @@ int llama_context::decode(const llama_batch & batch_inp) { return 0; } +bool llama_context::requantize_memory(ggml_type new_type_k, ggml_type new_type_v) { + if (!memory) { + return false; + } + + if (!cparams.flash_attn && ggml_is_quantized(new_type_v)) { + LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__); + return false; + } + + // Base-level recurrent cache does not support quantization; it's hardcoded + // to f32/f32. The other cache implementations can be quantized. + if (llm_arch_is_recurrent(model.arch)) { + LLAMA_LOG_ERROR("%s: requantize not supported for recurrent cache\n", __func__); + return false; + } + + // Read existing kvcache to host buffer + const size_t state_size = state_get_size(); + std::vector state_store(state_size); + + if (state_get_data(state_store.data(), state_size) != state_size) { + LLAMA_LOG_ERROR("%s: error reading existing memory\n", __func__); + return false; + } + + // Tear down existing kvcache + gf_res_reserve.reset(); + sched.reset(); + memory.reset(); + + llama_memory_params params_mem = { + /*.type_k =*/ new_type_k, + /*.type_v =*/ new_type_v, + /*.swa_full =*/ cparams.swa_full, + /*.ctx_type= */ cparams.ctx_type, + }; + + // Create new kvcache + memory.reset(model.create_memory(params_mem, cparams)); + if (!memory) { + // TODO: Yikes! Maybe more checks to ensure create_memory will succeed before we do this + // Alternatively, we could try to rebuild using the prior types? + LLAMA_LOG_ERROR("%s: error requantizing memory\n", __func__); + return false; + } + + // Reserve a new backend scheduler + sched_need_reserve = true; + sched_reserve(); + + // Restore kvcache + if (!state_set_data(state_store.data(), state_size)) { + LLAMA_LOG_ERROR("%s: error restoring kvcache\n", __func__); + return false; + } + + return true; +} + // // output // @@ -3797,6 +3859,19 @@ bool llama_memory_can_shift(llama_memory_t mem) { return mem->get_can_shift(); } +bool llama_requantize_memory(struct llama_context * ctx, ggml_type ctk, ggml_type ctv) { + if (!ctx) { + return false; + } + + try { + return ctx->requantize_memory(ctk, ctv); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error requantizing memory: %s\n", __func__, err.what()); + return false; + } +} + // llama state API // deprecated diff --git a/src/llama-context.h b/src/llama-context.h index 2af92b0f096..045ca978ec8 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -138,6 +138,8 @@ struct llama_context { int encode(const llama_batch & batch_inp); int decode(const llama_batch & batch_inp); + bool requantize_memory(ggml_type new_type_k, ggml_type new_type_v); + // // state save/load // diff --git a/src/llama-cparams.h b/src/llama-cparams.h index fd227ee5a23..d8a81dbd091 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -42,6 +42,7 @@ struct llama_cparams { bool warmup; // TODO: remove [TAG_LLAMA_GRAPH_NO_WARMUP] bool op_offload; bool kv_unified; + bool swa_full; bool pipeline_parallel; enum llama_context_type ctx_type; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 82da38e0b61..b6117cbfca6 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -12,6 +12,7 @@ #include #include #include +#include static bool ggml_is_power_of_2(int n) { return (n & (n - 1)) == 0; @@ -73,6 +74,38 @@ static ggml_tensor * ggml_mul_mat_aux( return res; } +// Convert n_rows x n_per_row elements from src_type to dst_type via an f32 staging buffer +static bool kv_convert_rows( + ggml_type src_type, + ggml_type dst_type, + const void * src_bytes, + void * dst_bytes, + int64_t n_per_row, + int64_t n_rows) { + const ggml_type_traits * src_traits = ggml_get_type_traits(src_type); + + if (src_type != GGML_TYPE_F32 && src_traits->to_float == nullptr) { + LLAMA_LOG_ERROR("%s: cannot dequantize source type %s\n", __func__, ggml_type_name(src_type)); + return false; + } + if (ggml_quantize_requires_imatrix(dst_type)) { + LLAMA_LOG_ERROR("%s: destination type %s requires an imatrix\n", __func__, ggml_type_name(dst_type)); + return false; + } + + const int64_t n = n_per_row * n_rows; + std::vector buf(n); + + if (src_type == GGML_TYPE_F32) { + std::memcpy(buf.data(), src_bytes, n * sizeof(float)); + } else { + src_traits->to_float(src_bytes, buf.data(), n); + } + + ggml_quantize_chunk(dst_type, buf.data(), dst_bytes, 0, n_rows, n_per_row, nullptr); + return true; +} + // // llama_kv_cache // @@ -2248,22 +2281,27 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of key int32_t k_type_i_ref; io.read(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) k->type; - if (k_type_i != k_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); - return false; + const ggml_type k_type_src = (ggml_type) k_type_i_ref; + const ggml_type k_type_dst = k->type; + if (k_type_src != k_type_dst) { + LLAMA_LOG_DEBUG("%s: mismatched key type (%s != %s, layer %d); attempting conversion\n", __func__, ggml_type_name(k_type_src), ggml_type_name(k_type_dst), il); } // Read row size of key uint64_t k_size_row_ref; io.read(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); - if (k_size_row != k_size_row_ref) { + const size_t k_size_row = ggml_row_size(k_type_src, n_embd_k_gqa); + if (k_size_row_ref != k_size_row) { + // Note: compute against src type so this also validates the conversion path LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; } - if (cell_count) { + if (!cell_count) { + continue; + } + + if (k_type_src == k_type_dst) { if (sinfo.is_contiguous()) { // Fast path: contiguous cells, single memcpy io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row); @@ -2274,6 +2312,30 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 io.read_tensor(k, dst_offset, k_size_row); } } + } else { + const size_t k_size_row_dst = ggml_row_size(k_type_dst, n_embd_k_gqa); + + std::vector src_buf(cell_count * k_size_row); + std::vector dst_buf(cell_count * k_size_row_dst); + + io.read(src_buf.data(), src_buf.size()); + + if (!kv_convert_rows(k_type_src, k_type_dst, src_buf.data(), dst_buf.data(), n_embd_k_gqa, cell_count)) { + LLAMA_LOG_ERROR("%s: unable to convert between key types (layer %d)\n", __func__, il); + return false; + } + + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells, single memcpy + ggml_backend_tensor_set(k, dst_buf.data(), sinfo.head() * k_size_row_dst, dst_buf.size()); + } else { + // Slow path: scatter to non-contiguous positions + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_start = i * k_size_row_dst; + const size_t dst_offset = sinfo.idxs[0][i] * k_size_row_dst; + ggml_backend_tensor_set(k, dst_buf.data() + dst_start, dst_offset, k_size_row_dst); + } + } } } @@ -2291,22 +2353,27 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of value int32_t v_type_i_ref; io.read(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t) v->type; - if (v_type_i != v_type_i_ref) { - LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); - return false; + const ggml_type v_type_src = (ggml_type) v_type_i_ref; + const ggml_type v_type_dst = v->type; + if (v_type_src != v_type_dst) { + LLAMA_LOG_DEBUG("%s: mismatched value type (%s != %s, layer %d); attempting conversion\n", __func__, ggml_type_name(v_type_src), ggml_type_name(v_type_dst), il); } // Read row size of value uint64_t v_size_row_ref; io.read(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); - if (v_size_row != v_size_row_ref) { + const size_t v_size_row = ggml_row_size(v_type_src, n_embd_v_gqa); + if (v_size_row_ref != v_size_row) { + // Note: compute against src type so this also validates the conversion path LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); return false; } + + if (!cell_count) { + continue; + } - if (cell_count) { + if (v_type_src == v_type_dst) { if (sinfo.is_contiguous()) { // Fast path: contiguous cells, single memcpy io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row); @@ -2317,6 +2384,30 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 io.read_tensor(v, dst_offset, v_size_row); } } + } else { + const size_t v_size_row_dst = ggml_row_size(v_type_dst, n_embd_v_gqa); + + std::vector src_buf(cell_count * v_size_row); + std::vector dst_buf(cell_count * v_size_row_dst); + + io.read(src_buf.data(), src_buf.size()); + + if (!kv_convert_rows(v_type_src, v_type_dst, src_buf.data(), dst_buf.data(), n_embd_v_gqa, cell_count)) { + LLAMA_LOG_ERROR("%s: unable to convert between value types (layer %d)\n", __func__, il); + return false; + } + + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells, single memcpy + ggml_backend_tensor_set(v, dst_buf.data(), sinfo.head() * v_size_row_dst, dst_buf.size()); + } else { + // Slow path: scatter to non-contiguous positions + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_start = i * v_size_row_dst; + const size_t dst_offset = sinfo.idxs[0][i] * v_size_row_dst; + ggml_backend_tensor_set(v, dst_buf.data() + dst_start, dst_offset, v_size_row_dst); + } + } } } } else { diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 28f738c3feb..7b69a73dc79 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1,5 +1,6 @@ #include "server-context.h" +#include "ggml.h" #include "server-chat.h" #include "server-common.h" #include "server-http.h" @@ -2328,6 +2329,35 @@ struct server_context_impl { res->n_erased = n_erased; queue_results.send(std::move(res)); } break; + case SERVER_TASK_TYPE_REQUANTIZE_KVCACHE: + { + if (!check_no_mtmd(task.id)) break; + + // If any slot is busy, defer this task for later + bool deferred = false; + for (auto & slot : slots) { + if (slot.is_processing()) { + SRV_DBG("slot %d is busy, defer task, id_task = %d\n", slot.id, task.id); + queue_tasks.defer(std::move(task)); + deferred = true; + break; + } + } + if (deferred) break; + + ggml_type ctk = task.kvcache_action.ctk; + ggml_type ctv = task.kvcache_action.ctv; + + // TODO - handle draft model + if (!llama_requantize_memory(ctx_tgt, ctk, ctv)) { + send_error(task, "Unable to quantize memory", ERROR_TYPE_INVALID_REQUEST); + break; + } + + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); + } break; case SERVER_TASK_TYPE_GET_LORA: { // TODO @ngxson : make lora_adapters a dedicated member of server_context @@ -3581,6 +3611,8 @@ server_context_meta server_context::get_meta() const { /* json_webui_settings */ impl->json_webui_settings, // Deprecated /* slot_n_ctx */ impl->get_slot_n_ctx(), /* pooling_type */ llama_pooling_type(impl->ctx_tgt), + /* cache_type_k */ impl->params_base.cache_type_k, + /* cache_type_v */ impl->params_base.cache_type_v, /* chat_params */ impl->chat_params, /* chat_template_caps */ common_chat_templates_get_caps(impl->chat_params.tmpls.get()), @@ -4149,6 +4181,79 @@ void server_routes::init_routes() { return res; }; + this->post_cache_requantize = [this](const server_http_req & req) { + auto res = create_response(); + + std::string ctk = req.get_param("ctk"); + std::string ctv = req.get_param("ctv"); + + // Supported KV cache types (from common/arg.cpp) + // TODO - might be better to just make the arg.cpp method public + const std::vector kv_cache_types = { + GGML_TYPE_F32, + GGML_TYPE_F16, + GGML_TYPE_BF16, + GGML_TYPE_Q8_0, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_IQ4_NL, + GGML_TYPE_Q5_0, + GGML_TYPE_Q5_1, + }; + + ggml_type k = GGML_TYPE_F16; + ggml_type v = GGML_TYPE_F16; + + // Convert string parameters to ggml_type + bool found_k = ctk.empty(); + bool found_v = ctv.empty(); + + for (const auto & type : kv_cache_types) { + const std::string type_name = ggml_type_name(type); + if (!found_k && type_name == ctk) { + k = type; + found_k = true; + } + if (!found_v && type_name == ctv) { + v = type; + found_v = true; + } + if (found_k && found_v) break; + } + + if (!found_k) { + res->error(format_error_response("Unsupported cache type: " + ctk, ERROR_TYPE_INVALID_REQUEST)); + return res; + } + if (!found_v) { + res->error(format_error_response("Unsupported cache type: " + ctv, ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + { + server_task task(SERVER_TASK_TYPE_REQUANTIZE_KVCACHE); + task.id = res->rd.get_new_id(); + task.kvcache_action.ctk = k; + task.kvcache_action.ctv = v; + res->rd.post_task(std::move(task), true); // high-priority task + } + + auto result = res->rd.next(req.should_stop); + if (!result) { + // connection was closed + GGML_ASSERT(req.should_stop()); + return res; + } + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + res->ok({{"status", "ok"}}); + return res; + }; + this->get_props = [this](const server_http_req &) { auto res = create_response(true); @@ -4176,6 +4281,8 @@ void server_routes::init_routes() { {"vision", meta->has_inp_image}, {"audio", meta->has_inp_audio}, } }, + { "cache_type_k", std::string(ggml_type_name(meta->cache_type_k)) }, + { "cache_type_v", std::string(ggml_type_name(meta->cache_type_v)) }, { "media_marker", get_media_marker() }, { "endpoint_slots", params.endpoint_slots }, { "endpoint_props", params.endpoint_props }, diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 73caff54a46..83aa33a7b43 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -25,6 +25,8 @@ struct server_context_meta { json json_webui_settings; // Deprecated: use json_ui_settings instead (kept for backward compat) int slot_n_ctx; enum llama_pooling_type pooling_type; + ggml_type cache_type_k; + ggml_type cache_type_v; // chat params server_chat_params & chat_params; @@ -119,6 +121,7 @@ struct server_routes { server_http_context::handler_t get_models; server_http_context::handler_t post_tokenize; server_http_context::handler_t post_detokenize; + server_http_context::handler_t post_cache_requantize; server_http_context::handler_t post_embeddings; server_http_context::handler_t post_embeddings_oai; server_http_context::handler_t post_rerank; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 33de2e4d9ca..ec0819fe866 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -1979,6 +1979,10 @@ json server_task_result_apply_lora::to_json() { return json {{ "success", true }}; } +json server_task_result_requantize::to_json() { + return json {{ "success", true }}; +} + // // server_prompt_cache // diff --git a/tools/server/server-task.h b/tools/server/server-task.h index bdadcff7652..d5ef9396f71 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -27,6 +27,7 @@ enum server_task_type { SERVER_TASK_TYPE_SLOT_ERASE, SERVER_TASK_TYPE_GET_LORA, SERVER_TASK_TYPE_SET_LORA, + SERVER_TASK_TYPE_REQUANTIZE_KVCACHE, }; // TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common @@ -166,6 +167,12 @@ struct server_task { }; slot_action slot_action; + struct kvcache_action { + ggml_type ctk; + ggml_type ctv; + }; + kvcache_action kvcache_action; + // used by SERVER_TASK_TYPE_METRICS bool metrics_reset_bucket = false; @@ -584,6 +591,10 @@ struct server_task_result_apply_lora : server_task_result { virtual json to_json() override; }; +struct server_task_result_requantize : server_task_result { + virtual json to_json() override; +}; + struct server_prompt_data { std::vector main; std::vector drft; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 769e80a802f..d9c6121244b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -210,6 +210,7 @@ int llama_server(int argc, char ** argv) { // Save & load slots ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); + ctx_http.post("/cache/requantize", ex_wrapper(routes.post_cache_requantize)); // Google Cloud Platform (Vertex AI) compat ctx_http.register_gcp_compat();