Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,11 @@ extern "C" {
// Check if the memory supports shifting
LLAMA_API bool llama_memory_can_shift(llama_memory_t mem);

LLAMA_API bool llama_requantize_memory(
struct llama_context * ctx,
ggml_type ctk,
ggml_type ctv);

//
// State / sessions
//
Expand Down
77 changes: 76 additions & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-io.h"
#include "llama-kv-cache.h"
#include "llama-memory-recurrent.h"
#include "llama-memory.h"
#include "llama-mmap.h"
#include "llama-model.h"
Expand Down Expand Up @@ -71,7 +73,7 @@ llama_context::llama_context(
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;

cparams.swa_full = params.swa_full;

cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
Expand Down Expand Up @@ -2005,6 +2007,66 @@ int llama_context::decode(const llama_batch & batch_inp) {
return 0;
}

bool llama_context::requantize_memory(ggml_type new_type_k, ggml_type new_type_v) {
if (!memory) {
return false;
}

if (!cparams.flash_attn && ggml_is_quantized(new_type_v)) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return false;
}

// Base-level recurrent cache does not support quantization; it's hardcoded
// to f32/f32. The other cache implementations can be quantized.
if (llm_arch_is_recurrent(model.arch)) {
LLAMA_LOG_ERROR("%s: requantize not supported for recurrent cache\n", __func__);
return false;
}

// Read existing kvcache to host buffer
const size_t state_size = state_get_size();
std::vector<uint8_t> state_store(state_size);

if (state_get_data(state_store.data(), state_size) != state_size) {
LLAMA_LOG_ERROR("%s: error reading existing memory\n", __func__);
return false;
}

// Tear down existing kvcache
gf_res_reserve.reset();
sched.reset();
memory.reset();

llama_memory_params params_mem = {
/*.type_k =*/ new_type_k,
/*.type_v =*/ new_type_v,
/*.swa_full =*/ cparams.swa_full,
/*.ctx_type= */ cparams.ctx_type,
};

// Create new kvcache
memory.reset(model.create_memory(params_mem, cparams));
if (!memory) {
// TODO: Yikes! Maybe more checks to ensure create_memory will succeed before we do this
// Alternatively, we could try to rebuild using the prior types?
LLAMA_LOG_ERROR("%s: error requantizing memory\n", __func__);
return false;
}

// Reserve a new backend scheduler
sched_need_reserve = true;
sched_reserve();

// Restore kvcache
if (!state_set_data(state_store.data(), state_size)) {
LLAMA_LOG_ERROR("%s: error restoring kvcache\n", __func__);
return false;
}

return true;
}

//
// output
//
Expand Down Expand Up @@ -3797,6 +3859,19 @@ bool llama_memory_can_shift(llama_memory_t mem) {
return mem->get_can_shift();
}

bool llama_requantize_memory(struct llama_context * ctx, ggml_type ctk, ggml_type ctv) {
if (!ctx) {
return false;
}

try {
return ctx->requantize_memory(ctk, ctv);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error requantizing memory: %s\n", __func__, err.what());
return false;
}
}

// llama state API

// deprecated
Expand Down
2 changes: 2 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ struct llama_context {
int encode(const llama_batch & batch_inp);
int decode(const llama_batch & batch_inp);

bool requantize_memory(ggml_type new_type_k, ggml_type new_type_v);

//
// state save/load
//
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ struct llama_cparams {
bool warmup; // TODO: remove [TAG_LLAMA_GRAPH_NO_WARMUP]
bool op_offload;
bool kv_unified;
bool swa_full;
bool pipeline_parallel;

enum llama_context_type ctx_type;
Expand Down
119 changes: 105 additions & 14 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <limits>
#include <map>
#include <stdexcept>
#include <vector>

static bool ggml_is_power_of_2(int n) {
return (n & (n - 1)) == 0;
Expand Down Expand Up @@ -73,6 +74,38 @@ static ggml_tensor * ggml_mul_mat_aux(
return res;
}

// Convert n_rows x n_per_row elements from src_type to dst_type via an f32 staging buffer
static bool kv_convert_rows(
ggml_type src_type,
ggml_type dst_type,
const void * src_bytes,
void * dst_bytes,
int64_t n_per_row,
int64_t n_rows) {
const ggml_type_traits * src_traits = ggml_get_type_traits(src_type);

if (src_type != GGML_TYPE_F32 && src_traits->to_float == nullptr) {
LLAMA_LOG_ERROR("%s: cannot dequantize source type %s\n", __func__, ggml_type_name(src_type));
return false;
}
if (ggml_quantize_requires_imatrix(dst_type)) {
LLAMA_LOG_ERROR("%s: destination type %s requires an imatrix\n", __func__, ggml_type_name(dst_type));
return false;
}

const int64_t n = n_per_row * n_rows;
std::vector<float> buf(n);

if (src_type == GGML_TYPE_F32) {
std::memcpy(buf.data(), src_bytes, n * sizeof(float));
} else {
src_traits->to_float(src_bytes, buf.data(), n);
}

ggml_quantize_chunk(dst_type, buf.data(), dst_bytes, 0, n_rows, n_per_row, nullptr);
return true;
}

//
// llama_kv_cache
//
Expand Down Expand Up @@ -2248,22 +2281,27 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read type of key
int32_t k_type_i_ref;
io.read(&k_type_i_ref, sizeof(k_type_i_ref));
const int32_t k_type_i = (int32_t) k->type;
if (k_type_i != k_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
return false;
const ggml_type k_type_src = (ggml_type) k_type_i_ref;
const ggml_type k_type_dst = k->type;
if (k_type_src != k_type_dst) {
LLAMA_LOG_DEBUG("%s: mismatched key type (%s != %s, layer %d); attempting conversion\n", __func__, ggml_type_name(k_type_src), ggml_type_name(k_type_dst), il);
}

// Read row size of key
uint64_t k_size_row_ref;
io.read(&k_size_row_ref, sizeof(k_size_row_ref));
const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa);
if (k_size_row != k_size_row_ref) {
const size_t k_size_row = ggml_row_size(k_type_src, n_embd_k_gqa);
if (k_size_row_ref != k_size_row) {
// Note: compute against src type so this also validates the conversion path
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
return false;
}

if (cell_count) {
if (!cell_count) {
continue;
}

if (k_type_src == k_type_dst) {
if (sinfo.is_contiguous()) {
// Fast path: contiguous cells, single memcpy
io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row);
Expand All @@ -2274,6 +2312,30 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
io.read_tensor(k, dst_offset, k_size_row);
}
}
} else {
const size_t k_size_row_dst = ggml_row_size(k_type_dst, n_embd_k_gqa);

std::vector<uint8_t> src_buf(cell_count * k_size_row);
std::vector<uint8_t> dst_buf(cell_count * k_size_row_dst);

io.read(src_buf.data(), src_buf.size());

if (!kv_convert_rows(k_type_src, k_type_dst, src_buf.data(), dst_buf.data(), n_embd_k_gqa, cell_count)) {
LLAMA_LOG_ERROR("%s: unable to convert between key types (layer %d)\n", __func__, il);
return false;
}

if (sinfo.is_contiguous()) {
// Fast path: contiguous cells, single memcpy
ggml_backend_tensor_set(k, dst_buf.data(), sinfo.head() * k_size_row_dst, dst_buf.size());
} else {
// Slow path: scatter to non-contiguous positions
for (uint32_t i = 0; i < cell_count; ++i) {
const size_t dst_start = i * k_size_row_dst;
const size_t dst_offset = sinfo.idxs[0][i] * k_size_row_dst;
ggml_backend_tensor_set(k, dst_buf.data() + dst_start, dst_offset, k_size_row_dst);
}
}
}
}

Expand All @@ -2291,22 +2353,27 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
// Read type of value
int32_t v_type_i_ref;
io.read(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t) v->type;
if (v_type_i != v_type_i_ref) {
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false;
const ggml_type v_type_src = (ggml_type) v_type_i_ref;
const ggml_type v_type_dst = v->type;
if (v_type_src != v_type_dst) {
LLAMA_LOG_DEBUG("%s: mismatched value type (%s != %s, layer %d); attempting conversion\n", __func__, ggml_type_name(v_type_src), ggml_type_name(v_type_dst), il);
}

// Read row size of value
uint64_t v_size_row_ref;
io.read(&v_size_row_ref, sizeof(v_size_row_ref));
const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa);
if (v_size_row != v_size_row_ref) {
const size_t v_size_row = ggml_row_size(v_type_src, n_embd_v_gqa);
if (v_size_row_ref != v_size_row) {
// Note: compute against src type so this also validates the conversion path
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
return false;
}

if (!cell_count) {
continue;
}

if (cell_count) {
if (v_type_src == v_type_dst) {
if (sinfo.is_contiguous()) {
// Fast path: contiguous cells, single memcpy
io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row);
Expand All @@ -2317,6 +2384,30 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
io.read_tensor(v, dst_offset, v_size_row);
}
}
} else {
const size_t v_size_row_dst = ggml_row_size(v_type_dst, n_embd_v_gqa);

std::vector<uint8_t> src_buf(cell_count * v_size_row);
std::vector<uint8_t> dst_buf(cell_count * v_size_row_dst);

io.read(src_buf.data(), src_buf.size());

if (!kv_convert_rows(v_type_src, v_type_dst, src_buf.data(), dst_buf.data(), n_embd_v_gqa, cell_count)) {
LLAMA_LOG_ERROR("%s: unable to convert between value types (layer %d)\n", __func__, il);
return false;
}

if (sinfo.is_contiguous()) {
// Fast path: contiguous cells, single memcpy
ggml_backend_tensor_set(v, dst_buf.data(), sinfo.head() * v_size_row_dst, dst_buf.size());
} else {
// Slow path: scatter to non-contiguous positions
for (uint32_t i = 0; i < cell_count; ++i) {
const size_t dst_start = i * v_size_row_dst;
const size_t dst_offset = sinfo.idxs[0][i] * v_size_row_dst;
ggml_backend_tensor_set(v, dst_buf.data() + dst_start, dst_offset, v_size_row_dst);
}
}
}
}
} else {
Expand Down
Loading