Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2806,9 +2806,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_BENCH}));
add_opt(common_arg(
{"--embd-normalize"}, "N",
string_format("normalisation for embeddings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", params.embd_normalize),
string_format("normalisation for embeddings (default: %d) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)", (int) params.embd_normalize),
[](common_params & params, int value) {
params.embd_normalize = value;
params.embd_normalize = (common_embd_norm) value;
}
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}));
add_opt(common_arg(
Expand Down
17 changes: 11 additions & 6 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1719,32 +1719,37 @@ std::string common_detokenize(const struct llama_vocab * vocab, const std::vecto
// Embedding utils
//

void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) {
void common_embd_normalize(const float * inp, float * out, int n, common_embd_norm embd_norm) {
double sum = 0.0;

switch (embd_norm) {
case -1: // no normalisation
case COMMON_EMBD_NORM_NONE: // no normalisation
sum = 1.0;
break;
case 0: // max absolute
case COMMON_EMBD_NORM_MAX_ABS: // max absolute
for (int i = 0; i < n; i++) {
if (sum < std::abs(inp[i])) {
sum = std::abs(inp[i]);
}
}
sum /= 32760.0; // make an int16 range
break;
case 2: // euclidean
case COMMON_EMBD_NORM_TAXICAB: // taxicab
for (int i = 0; i < n; i++) {
sum += std::abs(inp[i]);
}
break;
case COMMON_EMBD_NORM_EUCLIDEAN: // euclidean
for (int i = 0; i < n; i++) {
sum += inp[i] * inp[i];
}
sum = std::sqrt(sum);
break;
default: // p-norm (euclidean is p-norm p=2)
for (int i = 0; i < n; i++) {
sum += std::pow(std::abs(inp[i]), embd_norm);
sum += std::pow(std::abs(inp[i]), (int) embd_norm);
}
sum = std::pow(sum, 1.0 / embd_norm);
sum = std::pow(sum, 1.0 / (int) embd_norm);
break;
}

Expand Down
16 changes: 11 additions & 5 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,13 @@ enum common_reasoning_format {
// see: https://github.com/ggml-org/llama.cpp/pull/15408
};

enum common_embd_norm {
COMMON_EMBD_NORM_NONE = -1,
COMMON_EMBD_NORM_MAX_ABS = 0,
COMMON_EMBD_NORM_TAXICAB = 1,
COMMON_EMBD_NORM_EUCLIDEAN = 2,
};


struct lr_opt {
float lr0 = 1e-5; // learning rate at first epoch
Expand Down Expand Up @@ -581,9 +588,9 @@ struct common_params {
float val_split = 0.05f; // fraction of the data used for the validation set

// embedding
bool embedding = false; // get only sentence embedding
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
bool embedding = false; // get only sentence embedding
common_embd_norm embd_normalize = COMMON_EMBD_NORM_EUCLIDEAN; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
std::string embd_out = ""; // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
std::string embd_sep = "\n"; // separator of embeddings
std::string cls_sep = "\t"; // separator of classification sequences

Expand Down Expand Up @@ -988,8 +995,7 @@ std::string common_detokenize(
// Embedding utils
//

// TODO: replace embd_norm with an enum
void common_embd_normalize(const float * inp, float * out, int n, int embd_norm);
void common_embd_normalize(const float * inp, float * out, int n, common_embd_norm embd_norm);

float common_embd_similarity_cos(const float * embd1, const float * embd2, int n);

Expand Down
2 changes: 1 addition & 1 deletion examples/debug/debug.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ struct output_data {
data_size = n_floats;
type_suffix = "-embeddings";

if (params.embd_normalize >= 0) {
if (params.embd_normalize != COMMON_EMBD_NORM_NONE) {
embd_norm.resize(n_floats);
for (int i = 0; i < n_embd_count; i++) {
common_embd_normalize(embd_raw+i*n_embd, embd_norm.data()+i*n_embd, n_embd, params.embd_normalize);
Expand Down
14 changes: 7 additions & 7 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}

static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) {
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, common_embd_norm embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);

// clear previous kv_cache values (irrelevant for embeddings)
Expand Down Expand Up @@ -77,14 +77,14 @@ static void print_raw_embeddings(const float * emb,
int n_embd,
const llama_model * model,
enum llama_pooling_type pooling_type,
int embd_normalize) {
common_embd_norm embd_normalize) {
const uint32_t n_cls_out = llama_model_n_cls_out(model);
const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK);
const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd;

for (int j = 0; j < n_embd_count; ++j) {
for (int i = 0; i < cols; ++i) {
if (embd_normalize == 0) {
if (embd_normalize == COMMON_EMBD_NORM_MAX_ABS) {
LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
} else {
LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
Expand Down Expand Up @@ -293,15 +293,15 @@ int main(int argc, char ** argv) {
for (int j = 0; j < n_embd_count; j++) {
LOG("embedding %d: ", j);
for (int i = 0; i < std::min(3, n_embd_out); i++) {
if (params.embd_normalize == 0) {
if (params.embd_normalize == COMMON_EMBD_NORM_MAX_ABS) {
LOG("%6.0f ", emb[j * n_embd_out + i]);
} else {
LOG("%9.6f ", emb[j * n_embd_out + i]);
}
}
LOG(" ... ");
for (int i = n_embd_out - 3; i < n_embd_out; i++) {
if (params.embd_normalize == 0) {
if (params.embd_normalize == COMMON_EMBD_NORM_MAX_ABS) {
LOG("%6.0f ", emb[j * n_embd_out + i]);
} else {
LOG("%9.6f ", emb[j * n_embd_out + i]);
Expand Down Expand Up @@ -334,7 +334,7 @@ int main(int argc, char ** argv) {
for (int j = 0; j < n_prompts; j++) {
LOG("embedding %d: ", j);
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) {
if (params.embd_normalize == 0) {
if (params.embd_normalize == COMMON_EMBD_NORM_MAX_ABS) {
LOG("%6.0f ", emb[j * n_embd_out + i]);
} else {
LOG("%9.6f ", emb[j * n_embd_out + i]);
Expand Down Expand Up @@ -371,7 +371,7 @@ int main(int argc, char ** argv) {
if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
LOG("[");
for (int i = 0;;) { // at least one iteration (n_embd > 0)
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]);
LOG(params.embd_normalize == COMMON_EMBD_NORM_MAX_ABS ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]);
i++;
if (i < n_embd_out) LOG(","); else break;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ static void batch_process(llama_context * ctx, llama_batch & batch, float * outp
}

float * out = output + batch.seq_id[i][0] * n_embd;
common_embd_normalize(embd, out, n_embd, 2);
common_embd_normalize(embd, out, n_embd, COMMON_EMBD_NORM_EUCLIDEAN);
}
}

Expand Down
4 changes: 2 additions & 2 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4877,9 +4877,9 @@ std::unique_ptr<server_res_generator> server_routes::handle_embeddings_impl(cons
}
}

int embd_normalize = params.embd_normalize;
common_embd_norm embd_normalize = params.embd_normalize;
if (body.count("embd_normalize") != 0) {
embd_normalize = body.at("embd_normalize");
embd_normalize = (common_embd_norm) body.at("embd_normalize").get<int>();
if (meta->pooling_type == LLAMA_POOLING_TYPE_NONE) {
SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", meta->pooling_type);
}
Expand Down
2 changes: 1 addition & 1 deletion tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ struct task_params {
common_chat_parser_params chat_parser_params;

// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
common_embd_norm embd_normalize = COMMON_EMBD_NORM_EUCLIDEAN; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)

json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
json to_json(bool only_metrics = false) const;
Expand Down