diff --git a/common/arg.cpp b/common/arg.cpp index 1ffaf704858..4d2d1091a57 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2806,9 +2806,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_BENCH})); add_opt(common_arg( {"--embd-normalize"}, "N", - string_format("normalisation for embeddings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize), + string_format("normalisation for embeddings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", (int) params.embd_normalize), [](common_params & params, int value) { - params.embd_normalize = value; + params.embd_normalize = (common_embd_norm) value; } ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( diff --git a/common/common.cpp b/common/common.cpp index b6a7626f2a1..a2457fc505a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1719,14 +1719,14 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto // Embedding utils // -void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) { +void common_embd_normalize(const float * inp, float * out, int n, common_embd_norm embd_norm) { double sum = 0.0; switch (embd_norm) { - case -1: // no normalisation + case COMMON_EMBD_NORM_NONE: // no normalisation sum = 1.0; break; - case 0: // max absolute + case COMMON_EMBD_NORM_MAX_ABS: // max absolute for (int i = 0; i < n; i++) { if (sum < std::abs(inp[i])) { sum = std::abs(inp[i]); @@ -1734,7 +1734,12 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) } sum /= 32760.0; // make an int16 range break; - case 2: // euclidean + case COMMON_EMBD_NORM_TAXICAB: // taxicab + for (int i = 0; i < n; i++) { + sum += std::abs(inp[i]); + } + break; + case COMMON_EMBD_NORM_EUCLIDEAN: // euclidean for (int i = 0; i < n; i++) { sum += inp[i] * inp[i]; } @@ -1742,9 +1747,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) break; default: // p-norm (euclidean is p-norm p=2) for (int i = 0; i < n; i++) { - sum += std::pow(std::abs(inp[i]), embd_norm); + sum += std::pow(std::abs(inp[i]), (int) embd_norm); } - sum = std::pow(sum, 1.0 / embd_norm); + sum = std::pow(sum, 1.0 / (int) embd_norm); break; } diff --git a/common/common.h b/common/common.h index 13f387271d8..e2281836597 100644 --- a/common/common.h +++ b/common/common.h @@ -404,6 +404,13 @@ enum common_reasoning_format { // see: https://github.com/ggml-org/llama.cpp/pull/15408 }; +enum common_embd_norm { + COMMON_EMBD_NORM_NONE = -1, + COMMON_EMBD_NORM_MAX_ABS = 0, + COMMON_EMBD_NORM_TAXICAB = 1, + COMMON_EMBD_NORM_EUCLIDEAN = 2, +}; + struct lr_opt { float lr0 = 1e-5; // learning rate at first epoch @@ -581,9 +588,9 @@ struct common_params { float val_split = 0.05f; // fraction of the data used for the validation set // embedding - bool embedding = false; // get only sentence embedding - int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) - std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix + bool embedding = false; // get only sentence embedding + common_embd_norm embd_normalize = COMMON_EMBD_NORM_EUCLIDEAN; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) + std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix std::string embd_sep = "\n"; // separator of embeddings std::string cls_sep = "\t"; // separator of classification sequences @@ -988,8 +995,7 @@ std::string common_detokenize( // Embedding utils // -// TODO: replace embd_norm with an enum -void common_embd_normalize(const float * inp, float * out, int n, int embd_norm); +void common_embd_normalize(const float * inp, float * out, int n, common_embd_norm embd_norm); float common_embd_similarity_cos(const float * embd1, const float * embd2, int n); diff --git a/examples/debug/debug.cpp b/examples/debug/debug.cpp index 761e7a2db54..27d31919cc4 100644 --- a/examples/debug/debug.cpp +++ b/examples/debug/debug.cpp @@ -79,7 +79,7 @@ struct output_data { data_size = n_floats; type_suffix = "-embeddings"; - if (params.embd_normalize >= 0) { + if (params.embd_normalize != COMMON_EMBD_NORM_NONE) { embd_norm.resize(n_floats); for (int i = 0; i < n_embd_count; i++) { common_embd_normalize(embd_raw+i*n_embd, embd_norm.data()+i*n_embd, n_embd, params.embd_normalize); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index f6a20ef9d07..b348920e77e 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -34,7 +34,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } } -static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) { +static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, common_embd_norm embd_norm) { const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); // clear previous kv_cache values (irrelevant for embeddings) @@ -77,14 +77,14 @@ static void print_raw_embeddings(const float * emb, int n_embd, const llama_model * model, enum llama_pooling_type pooling_type, - int embd_normalize) { + common_embd_norm embd_normalize) { const uint32_t n_cls_out = llama_model_n_cls_out(model); const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK); const int cols = is_rank ? std::min(n_embd, (int) n_cls_out) : n_embd; for (int j = 0; j < n_embd_count; ++j) { for (int i = 0; i < cols; ++i) { - if (embd_normalize == 0) { + if (embd_normalize == COMMON_EMBD_NORM_MAX_ABS) { LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); } else { LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); @@ -293,7 +293,7 @@ int main(int argc, char ** argv) { for (int j = 0; j < n_embd_count; j++) { LOG("embedding %d: ", j); for (int i = 0; i < std::min(3, n_embd_out); i++) { - if (params.embd_normalize == 0) { + if (params.embd_normalize == COMMON_EMBD_NORM_MAX_ABS) { LOG("%6.0f ", emb[j * n_embd_out + i]); } else { LOG("%9.6f ", emb[j * n_embd_out + i]); @@ -301,7 +301,7 @@ int main(int argc, char ** argv) { } LOG(" ... "); for (int i = n_embd_out - 3; i < n_embd_out; i++) { - if (params.embd_normalize == 0) { + if (params.embd_normalize == COMMON_EMBD_NORM_MAX_ABS) { LOG("%6.0f ", emb[j * n_embd_out + i]); } else { LOG("%9.6f ", emb[j * n_embd_out + i]); @@ -334,7 +334,7 @@ int main(int argc, char ** argv) { for (int j = 0; j < n_prompts; j++) { LOG("embedding %d: ", j); for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) { - if (params.embd_normalize == 0) { + if (params.embd_normalize == COMMON_EMBD_NORM_MAX_ABS) { LOG("%6.0f ", emb[j * n_embd_out + i]); } else { LOG("%9.6f ", emb[j * n_embd_out + i]); @@ -371,7 +371,7 @@ int main(int argc, char ** argv) { if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j); LOG("["); for (int i = 0;;) { // at least one iteration (n_embd > 0) - LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]); + LOG(params.embd_normalize == COMMON_EMBD_NORM_MAX_ABS ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]); i++; if (i < n_embd_out) LOG(","); else break; } diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 7d93ab1172c..c0096c253b1 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -108,7 +108,7 @@ static void batch_process(llama_context * ctx, llama_batch & batch, float * outp } float * out = output + batch.seq_id[i][0] * n_embd; - common_embd_normalize(embd, out, n_embd, 2); + common_embd_normalize(embd, out, n_embd, COMMON_EMBD_NORM_EUCLIDEAN); } } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index ab0d5944763..588c12ad36b 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -4877,9 +4877,9 @@ std::unique_ptr server_routes::handle_embeddings_impl(cons } } - int embd_normalize = params.embd_normalize; + common_embd_norm embd_normalize = params.embd_normalize; if (body.count("embd_normalize") != 0) { - embd_normalize = body.at("embd_normalize"); + embd_normalize = (common_embd_norm) body.at("embd_normalize").get(); if (meta->pooling_type == LLAMA_POOLING_TYPE_NONE) { SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", meta->pooling_type); } diff --git a/tools/server/server-task.h b/tools/server/server-task.h index bdadcff7652..3c69d71d596 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -93,7 +93,7 @@ struct task_params { common_chat_parser_params chat_parser_params; // Embeddings - int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) + common_embd_norm embd_normalize = COMMON_EMBD_NORM_EUCLIDEAN; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) json format_logit_bias(const std::vector & logit_bias) const; json to_json(bool only_metrics = false) const;