Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "cordon"
version = "1.1.0"
version = "1.1.1"
description = "Semantic anomaly detection for system log files"
readme = "README.md"
requires-python = ">=3.10,<3.15"
Expand Down
2 changes: 1 addition & 1 deletion src/cordon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from cordon.core.types import AnalysisResult, MergedBlock, ScoredWindow, TextWindow
from cordon.pipeline import SemanticLogAnalyzer

__version__ = "1.1.0"
__version__ = "1.1.1"

__all__ = [
"SemanticLogAnalyzer",
Expand Down
50 changes: 27 additions & 23 deletions src/cordon/embedding/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy.typing as npt
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

from cordon.core.config import AnalysisConfig
from cordon.core.device import detect_device
Expand All @@ -25,6 +24,11 @@ class TransformerEmbedder:
def __init__(self, config: AnalysisConfig) -> None:
"""Initialize the embedder with a sentence-transformer model.

The model is initialized directly on the target device via the
``device`` constructor parameter. This ensures the internal
``_target_device`` attribute is set correctly so that ``encode()``
places input tensors on the same device as the model parameters.

Args:
config: Analysis configuration specifying model and device.

Expand All @@ -35,22 +39,26 @@ def __init__(self, config: AnalysisConfig) -> None:
self.device = detect_device(self.config.device)

try:
self.model = SentenceTransformer(config.model_name)
self.model = SentenceTransformer(config.model_name, device=str(self.device))
except Exception as error:
raise RuntimeError(
f"Failed to load sentence-transformer model '{config.model_name}'. "
f"Verify the model name is correct and you have network access "
f"for first-time downloads. Error: {error}"
) from error

self.model.to(self.device)
self._truncation_warned = False

def embed_windows(
self, windows: Iterable[TextWindow]
) -> Iterator[tuple[TextWindow, npt.NDArray[np.floating[Any]]]]:
"""Embed text windows into vector representations.

Encodes all windows in a single ``model.encode()`` call, delegating
batching, length-based sorting, and padding to sentence-transformers.
This avoids per-batch overhead from repeated DataLoader creation and
tokenization and allows optimal GPU utilization.

Args:
windows: Iterable of text windows to embed.

Expand All @@ -69,28 +77,24 @@ def embed_windows(
if torch.cuda.is_available():
torch.cuda.empty_cache()

batch_size = self.config.batch_size
total_batches = (len(window_list) + batch_size - 1) // batch_size

for batch_start_idx in tqdm(
range(0, len(window_list), batch_size),
desc="Generating embeddings",
total=total_batches,
unit="batch",
disable=not self.config.show_progress,
):
batch = window_list[batch_start_idx : batch_start_idx + batch_size]
texts = [window.content for window in batch]

embeddings = self.model.encode(
texts,
batch_size=len(batch),
show_progress_bar=False,
convert_to_numpy=True,
normalize_embeddings=True,
texts = [window.content for window in window_list]

all_embeddings: npt.NDArray[np.floating[Any]] = self.model.encode(
texts,
batch_size=self.config.batch_size,
show_progress_bar=self.config.show_progress,
convert_to_numpy=True,
normalize_embeddings=True,
)

if len(all_embeddings) != len(window_list):
raise ValueError(
f"model.encode() returned {len(all_embeddings)} embeddings "
f"for {len(window_list)} input windows. This indicates a "
f"sentence-transformers internal error."
)

yield from zip(batch, embeddings, strict=False)
yield from zip(window_list, all_embeddings, strict=True)

def _check_truncation_warning(self, windows: list[TextWindow]) -> None:
"""Check if windows are likely to be truncated and warn user.
Expand Down
9 changes: 7 additions & 2 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,13 @@ def embedder(self, mock_st: MagicMock) -> TransformerEmbedder:

mock_model = MagicMock()
rng = np.random.default_rng(0)
raw = rng.standard_normal((1, 384)).astype(np.float32)
mock_model.encode.return_value = raw / np.linalg.norm(raw, axis=1, keepdims=True)

def _fake_encode(texts: list[str], **kwargs: object) -> np.ndarray: # type: ignore[type-arg]
n = len(texts) if isinstance(texts, list) else 1
raw = rng.standard_normal((n, 384)).astype(np.float32)
return raw / np.linalg.norm(raw, axis=1, keepdims=True)

mock_model.encode.side_effect = _fake_encode
mock_st.return_value = mock_model

config = AnalysisConfig(device="cpu", batch_size=2)
Expand Down
Loading