Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/assets/openapi.json

Large diffs are not rendered by default.

25 changes: 22 additions & 3 deletions libs/infinity_emb/infinity_emb/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,20 @@ class EngineArgs:
pooling_method, PoolingMethod or str: pooling method to use. Defaults to PoolingMethod.auto.
lengths_via_tokenize, bool: schedule by token usage. Defaults to False.
served_model_name, str: Defaults to readable name of model_name_or_path.
max_query_tokens, Optional[int]: rerank ceiling, head-truncate the query to at most
N tokens. A client may request fewer. Defaults to None (no limit).
max_tokens_per_doc, Optional[int]: rerank ceiling, head-truncate each document to at
most N tokens. A client may request fewer. Defaults to None (no limit).
max_pair_tokens, Optional[int]: rerank ceiling on the joined (query, document) pair.
A client may request fewer. Defaults to None (no limit).
"""

model_name_or_path: str = MANAGER.model_id[0]
batch_size: int = MANAGER.batch_size[0]
revision: Optional[str] = MANAGER.revision[0]
max_query_tokens: Optional[int] = MANAGER.max_query_tokens[0]
max_tokens_per_doc: Optional[int] = MANAGER.max_tokens_per_doc[0]
max_pair_tokens: Optional[int] = MANAGER.max_pair_tokens[0]
trust_remote_code: bool = MANAGER.trust_remote_code[0]
engine: InferenceEngine = InferenceEngine[MANAGER.engine[0]]
model_warmup: bool = MANAGER.model_warmup[0]
Expand Down Expand Up @@ -99,6 +108,10 @@ def __post_init__(self):
)
if self.revision is not None and self.revision == "":
object.__setattr__(self, "revision", None)
for limit_name in ("max_query_tokens", "max_tokens_per_doc", "max_pair_tokens"):
limit = getattr(self, limit_name)
if limit is not None and limit <= 0:
raise ValueError(f"{limit_name} must be a positive integer or None, got {limit}")
if isinstance(self.vector_disk_cache_path, bool):
object.__setattr__(
self,
Expand Down Expand Up @@ -163,9 +176,12 @@ def from_env(cls) -> list["EngineArgs"]:
embedding_dtype=embedding_dtype,
served_model_name=served_model_name,
onnx_disable_optimize=onnx_disable_optimize,
onnx_do_not_prefer_quantized=onnx_do_not_prefer_quantized
onnx_do_not_prefer_quantized=onnx_do_not_prefer_quantized,
max_query_tokens=max_query_tokens,
max_tokens_per_doc=max_tokens_per_doc,
max_pair_tokens=max_pair_tokens,
)
for model_name_or_path, batch_size, revision, trust_remote_code, engine, model_warmup, device, compile, bettertransformer, dtype, pooling_method, lengths_via_tokenize, embedding_dtype, served_model_name,onnx_disable_optimize,onnx_do_not_prefer_quantized in zip_longest(
for model_name_or_path, batch_size, revision, trust_remote_code, engine, model_warmup, device, compile, bettertransformer, dtype, pooling_method, lengths_via_tokenize, embedding_dtype, served_model_name,onnx_disable_optimize,onnx_do_not_prefer_quantized, max_query_tokens, max_tokens_per_doc, max_pair_tokens in zip_longest(
MANAGER.model_id,
MANAGER.batch_size,
MANAGER.revision,
Expand All @@ -181,6 +197,9 @@ def from_env(cls) -> list["EngineArgs"]:
MANAGER.embedding_dtype,
MANAGER.served_model_name,
MANAGER.onnx_disable_optimize,
MANAGER.onnx_do_not_prefer_quantized
MANAGER.onnx_do_not_prefer_quantized,
MANAGER.max_query_tokens,
MANAGER.max_tokens_per_doc,
MANAGER.max_pair_tokens,
)
]
17 changes: 16 additions & 1 deletion libs/infinity_emb/infinity_emb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,18 @@ def v2(
**_construct("onnx_do_not_prefer_quantized"),
help="Do not use quantized onnx models by default if available",
),
max_query_tokens: list[int] = typer.Option(
**_construct("max_query_tokens"),
help="Rerank ceiling: head-truncate the query to at most N tokens before scoring. A client may request fewer. Unset disables the limit.",
),
max_tokens_per_doc: list[int] = typer.Option(
**_construct("max_tokens_per_doc"),
help="Rerank ceiling: head-truncate each document to at most N tokens before scoring. A client may request fewer. Unset disables the limit.",
),
max_pair_tokens: list[int] = typer.Option(
**_construct("max_pair_tokens"),
help="Rerank ceiling on the joined (query, document) pair, in tokens. A client may request fewer. Unset disables the limit.",
),
):
"""Infinity API ♾️ cli v2. MIT License. Copyright (c) 2023-now Michael Feil \n
\n
Expand Down Expand Up @@ -341,7 +353,10 @@ def v2(
bettertransformer=bettertransformer,
served_model_name=served_model_name,
onnx_disable_optimize=onnx_disable_optimize,
onnx_do_not_prefer_quantized=onnx_do_not_prefer_quantized
onnx_do_not_prefer_quantized=onnx_do_not_prefer_quantized,
max_query_tokens=max_query_tokens,
max_tokens_per_doc=max_tokens_per_doc,
max_pair_tokens=max_pair_tokens,
)

engine_args = []
Expand Down
52 changes: 51 additions & 1 deletion libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@
)


def _clamp_to_ceiling(requested: Optional[int], ceiling: Optional[int]) -> Optional[int]:
"""Clamp a client-requested rerank token budget to the model's startup ceiling.

The startup ceiling guards backend stability; a client may only lower a limit (to
trade quality for speed), never raise it. ``None`` means "no constraint from that
side", so the result is the smaller of the two set values, or ``None`` if neither is.
"""
candidates = [value for value in (requested, ceiling) if value is not None]
return min(candidates) if candidates else None


class AsyncEmbeddingEngine:
"""
An LLM engine that receives requests and embeds them asynchronously.
Expand Down Expand Up @@ -164,6 +175,9 @@ async def rerank(
docs: list[str],
raw_scores: bool = False,
top_n: Optional[int] = None,
max_query_tokens: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
max_pair_tokens: Optional[int] = None,
) -> tuple[list["RerankReturnType"], int]:
"""rerank multiple sentences

Expand All @@ -173,6 +187,16 @@ async def rerank(
raw_scores (bool): return raw scores instead of sigmoid
top_n (Optional[int]): number of top scores to return after reranking
if top_n is None, <= 0 or out of range, all scores are returned
max_query_tokens (Optional[int]): head-truncate the query to at most N tokens.
max_tokens_per_doc (Optional[int]): head-truncate each document to at most N
tokens.
max_pair_tokens (Optional[int]): cap the joined (query, document) pair to at most
N tokens, trimming the longest side first.

Each token budget is clamped to the model's startup ceiling
(``EngineArgs.max_*``): a client may lower a limit to trade quality for speed but
cannot raise it above the configured ceiling. None on both the request and the
ceiling disables that limit.

Raises:
ValueError: raised if engine is not started yet
Expand All @@ -189,6 +213,15 @@ async def rerank(
docs=docs,
raw_scores=raw_scores,
top_n=top_n,
max_query_tokens=_clamp_to_ceiling(
max_query_tokens, self._engine_args.max_query_tokens
),
max_tokens_per_doc=_clamp_to_ceiling(
max_tokens_per_doc, self._engine_args.max_tokens_per_doc
),
max_pair_tokens=_clamp_to_ceiling(
max_pair_tokens, self._engine_args.max_pair_tokens
),
)

return scores, usage
Expand Down Expand Up @@ -351,6 +384,9 @@ async def rerank(
docs: list[str],
raw_scores: bool = False,
top_n: Optional[int] = None,
max_query_tokens: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
max_pair_tokens: Optional[int] = None,
) -> tuple[list["RerankReturnType"], int]:
"""rerank multiple sentences

Expand All @@ -360,6 +396,12 @@ async def rerank(
docs (list[str]): docs to be reranked
raw_scores (bool): return raw scores instead of sigmoid
top_n (Optional[int]): number of top scores to return after reranking
max_query_tokens (Optional[int]): head-truncate the query to at most N tokens.
max_tokens_per_doc (Optional[int]): head-truncate each document to at most N
tokens.
max_pair_tokens (Optional[int]): cap the joined (query, document) pair to at
most N tokens, trimming the longest side first. Each budget is clamped to
the model's startup ceiling; a client may lower a limit but not raise it.

Raises:
ValueError: raised if engine is not started yet
Expand All @@ -370,7 +412,15 @@ async def rerank(
list[float]: list of scores
int: token usage
"""
return await self[model].rerank(query=query, docs=docs, raw_scores=raw_scores, top_n=top_n)
return await self[model].rerank(
query=query,
docs=docs,
raw_scores=raw_scores,
top_n=top_n,
max_query_tokens=max_query_tokens,
max_tokens_per_doc=max_tokens_per_doc,
max_pair_tokens=max_pair_tokens,
)

async def classify(
self, *, model: str, sentences: list[str], raw_scores: bool = False
Expand Down
35 changes: 34 additions & 1 deletion libs/infinity_emb/infinity_emb/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from functools import cached_property
from pathlib import Path
from typing import TypeVar
from typing import Optional, TypeVar

from infinity_emb.log_handler import logger
from infinity_emb.primitives import (
Expand Down Expand Up @@ -87,6 +87,21 @@ def _to_bool_multiple(value: list[str]) -> list[bool]:
def _to_int_multiple(value: list[str]) -> list[int]:
return [int(v) for v in value]

@staticmethod
def _to_optional_int_multiple(value: list[str]) -> list[Optional[int]]:
"""Parse a per-model list where an empty token or `none`/`null` disables the limit."""
parsed: list[Optional[int]] = []
for v in value:
v = v.strip()
if v == "" or v.lower() in {"none", "null"}:
parsed.append(None)
continue
number = int(v)
if number <= 0:
raise ValueError(f"token limit must be a positive integer, got `{v}`")
parsed.append(number)
return parsed

@cached_property
def api_key(self):
return self._optional_infinity_var("api_key", default="")
Expand All @@ -107,6 +122,24 @@ def batch_size(self):
self._optional_infinity_var_multiple("batch_size", default=["32"])
)

@cached_property
def max_query_tokens(self):
return self._to_optional_int_multiple(
self._optional_infinity_var_multiple("max_query_tokens", default=[""])
)

@cached_property
def max_tokens_per_doc(self):
return self._to_optional_int_multiple(
self._optional_infinity_var_multiple("max_tokens_per_doc", default=[""])
)

@cached_property
def max_pair_tokens(self):
return self._to_optional_int_multiple(
self._optional_infinity_var_multiple("max_pair_tokens", default=[""])
)

@cached_property
def revision(self):
return self._optional_infinity_var_multiple("revision", default=[""])
Expand Down
32 changes: 31 additions & 1 deletion libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@


from infinity_emb._optional_imports import CHECK_PYDANTIC
from infinity_emb.primitives import EmbeddingEncodingFormat, Modality
from infinity_emb.primitives import (
EmbeddingEncodingFormat,
Modality,
)

CHECK_PYDANTIC.mark_required()
# pydantic 2.x is strictly needed starting v0.0.70
Expand Down Expand Up @@ -230,6 +233,33 @@ class RerankInput(BaseModel):
raw_scores: bool = False
model: str = "default/not-specified"
top_n: Optional[int] = Field(default=None, gt=0)
max_query_tokens: Optional[int] = Field(
default=None,
gt=0,
description=(
"Head-truncate the query to at most N tokens before scoring. Clamped to the "
"model's server-side ceiling: a request may lower this but not raise it above "
"the configured limit. Omit or null to use the server ceiling."
),
)
max_tokens_per_doc: Optional[int] = Field(
default=None,
gt=0,
description=(
"Head-truncate each document to at most N tokens before scoring (Cohere v2 "
"compatible). Clamped to the model's server-side ceiling: a request may lower "
"this but not raise it. Omit or null to use the server ceiling."
),
)
max_pair_tokens: Optional[int] = Field(
default=None,
gt=0,
description=(
"Cap the joined (query, document) pair to at most N tokens, trimming the longest "
"side first. Clamped to the model's server-side ceiling: a request may lower "
"this but not raise it. Omit or null to use the server ceiling."
),
)


class _ReRankObject(BaseModel):
Expand Down
18 changes: 17 additions & 1 deletion libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OverloadStatus,
PredictSingle,
PrioritizedQueueItem,
RerankLimits,
RerankReturnType,
ReRankSingle,
get_inner_item,
Expand Down Expand Up @@ -180,15 +181,25 @@ async def rerank(
docs: list[str],
raw_scores: bool = False,
top_n: Optional[int] = None,
max_query_tokens: Optional[int] = None,
max_tokens_per_doc: Optional[int] = None,
max_pair_tokens: Optional[int] = None,
) -> tuple[list[RerankReturnType], int]:
"""Schedule a query to be reranked with documents. Awaits until reranked.

The token budgets here are already resolved (the engine has clamped them to the
model's startup ceilings); ``None`` disables the corresponding limit.

Args:
query (str): query for reranking
docs (list[str]): documents to be reranked
raw_scores (bool): return raw scores instead of sigmoid
top_n (Optional[int]): number of top scores to return after reranking
if top_n is None, <= 0 or out of range, all scores are returned
max_query_tokens (Optional[int]): head-truncate the query to N tokens.
max_tokens_per_doc (Optional[int]): head-truncate each document to N tokens.
max_pair_tokens (Optional[int]): cap the joined (query, document) pair to N
tokens, trimming the longest side first.

Raises:
ModelNotDeployedError: If loaded model does not expose `rerank`
Expand All @@ -202,7 +213,12 @@ async def rerank(
raise ModelNotDeployedError(
"the loaded moded cannot fullyfill `rerank`. " f"Options are {self.capabilities}."
)
rerankables = [ReRankSingle(query=query, document=doc) for doc in docs]
limits = RerankLimits(
max_query_tokens=max_query_tokens,
max_tokens_per_doc=max_tokens_per_doc,
max_pair_tokens=max_pair_tokens,
)
rerankables = [ReRankSingle(query=query, document=doc, limits=limits) for doc in docs]
scores, usage = await self._schedule(rerankables)

if not raw_scores:
Expand Down
3 changes: 3 additions & 0 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,9 @@ async def _rerank(data: RerankInput):
docs=data.documents,
raw_scores=data.raw_scores,
top_n=data.top_n,
max_query_tokens=data.max_query_tokens,
max_tokens_per_doc=data.max_tokens_per_doc,
max_pair_tokens=data.max_pair_tokens,
)

duration = (time.perf_counter() - start) * 1000
Expand Down
Loading