Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 49 additions & 48 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Literal, Optional, Union
from typing import Any, Literal

from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
Expand Down Expand Up @@ -64,15 +64,15 @@
MAX_IDENTIFIER_VALUE_LENGTH: int = 80


def _ensure_utc(dt: Optional[datetime]) -> Optional[datetime]:
def _ensure_utc(dt: datetime | None) -> datetime | None:
"""
Attach UTC tzinfo to a naive datetime (as returned by SQLite).

Args:
dt (Optional[datetime]): The datetime to normalize, or None.
dt (datetime | None): The datetime to normalize, or None.

Returns:
Optional[datetime]: The datetime with UTC tzinfo attached if it was naive, or None.
datetime | None: The datetime with UTC tzinfo attached if it was naive, or None.
"""
if dt is not None and dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
Expand Down Expand Up @@ -103,7 +103,7 @@ def load_dialect_impl(self, dialect: Any) -> Any:
return dialect.type_descriptor(CHAR(36))
return dialect.type_descriptor(Uuid())

def process_bind_param(self, value: Optional[uuid.UUID], dialect: Any) -> Optional[str]:
def process_bind_param(self, value: uuid.UUID | None, dialect: Any) -> str | None:
"""
Process a parameter value before binding it to a database statement.

Expand All @@ -116,7 +116,7 @@ def process_bind_param(self, value: Optional[uuid.UUID], dialect: Any) -> Option
"""
return str(value) if value else None

def process_result_value(self, value: uuid.UUID | str | None, dialect: Any) -> Optional[uuid.UUID]:
def process_result_value(self, value: uuid.UUID | str | None, dialect: Any) -> uuid.UUID | None:
"""
Process a result value after it has been retrieved from the database.

Expand Down Expand Up @@ -187,22 +187,18 @@ class PromptMemoryEntry(Base):
sequence = mapped_column(INTEGER, nullable=False)
timestamp = mapped_column(DateTime, nullable=False)
labels: Mapped[dict[str, str]] = mapped_column(JSON)
prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON)
targeted_harm_categories: Mapped[Optional[list[str]]] = mapped_column(JSON)
converter_identifiers: Mapped[Optional[list[dict[str, str]]]] = mapped_column(JSON)
prompt_metadata: Mapped[dict[str, str | int]] = mapped_column(JSON)
targeted_harm_categories: Mapped[list[str] | None] = mapped_column(JSON)
converter_identifiers: Mapped[list[dict[str, str]] | None] = mapped_column(JSON)
prompt_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON)
attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON)
response_error: Mapped[Literal["blocked", "none", "processing", "unknown"]] = mapped_column(String, nullable=True)

original_value_data_type: Mapped[Literal["text", "image_path", "audio_path", "url", "error"]] = mapped_column(
String, nullable=False
)
original_value_data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False)
original_value = mapped_column(Unicode, nullable=False)
original_value_sha256 = mapped_column(String)

converted_value_data_type: Mapped[Literal["text", "image_path", "audio_path", "url", "error"]] = mapped_column(
String, nullable=False
)
converted_value_data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False)
converted_value = mapped_column(Unicode)
converted_value_sha256 = mapped_column(String)

Expand Down Expand Up @@ -272,7 +268,7 @@ def get_message_piece(self) -> MessagePiece:
MessagePiece: The reconstructed message piece with all its data and scores.
"""
# Reconstruct ComponentIdentifiers with the stored pyrit_version
converter_ids: Optional[list[ComponentIdentifier]] = None
converter_ids: list[ComponentIdentifier] | None = None
stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION
if self.converter_identifiers:
converter_ids = [
Expand All @@ -281,14 +277,14 @@ def get_message_piece(self) -> MessagePiece:
]

# Reconstruct ComponentIdentifier with the stored pyrit_version
target_id: Optional[ComponentIdentifier] = None
target_id: ComponentIdentifier | None = None
if self.prompt_target_identifier:
target_id = ComponentIdentifier.from_dict(
{**self.prompt_target_identifier, "pyrit_version": stored_version}
)

# Reconstruct ComponentIdentifier with the stored pyrit_version
attack_id: Optional[ComponentIdentifier] = None
attack_id: ComponentIdentifier | None = None
if self.attack_identifier:
attack_id = ComponentIdentifier.from_dict({**self.attack_identifier, "pyrit_version": stored_version})

Expand Down Expand Up @@ -374,9 +370,9 @@ class ScoreEntry(Base):
score_value = mapped_column(String, nullable=False)
score_value_description = mapped_column(String, nullable=True)
score_type: Mapped[Literal["true_false", "float_scale", "unknown"]] = mapped_column(String, nullable=False)
score_category: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True)
score_category: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
score_rationale = mapped_column(String, nullable=True)
score_metadata: Mapped[dict[str, Union[str, int, float]]] = mapped_column(JSON)
score_metadata: Mapped[dict[str, str | int | float]] = mapped_column(JSON)
scorer_class_identifier: Mapped[dict[str, Any]] = mapped_column(JSON)
prompt_request_response_id = mapped_column(CustomUUID, ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"))
timestamp = mapped_column(DateTime, nullable=False)
Expand All @@ -400,7 +396,7 @@ def __init__(self, *, entry: Score) -> None:
self.score_type = entry.score_type
self.score_category = entry.score_category
self.score_rationale = entry.score_rationale
self.score_metadata = entry.score_metadata
self.score_metadata = entry.score_metadata or {}
normalized_scorer = entry.scorer_class_identifier
# Ensure eval_hash is set before truncation so it survives the DB round-trip
if normalized_scorer.eval_hash is None:
Expand Down Expand Up @@ -550,18 +546,18 @@ class SeedEntry(Base):
data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False)
name = mapped_column(String, nullable=True)
dataset_name = mapped_column(String, nullable=True)
harm_categories: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True)
harm_categories: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
description = mapped_column(String, nullable=True)
authors: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True)
groups: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True)
authors: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
groups: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
source = mapped_column(String, nullable=True)
date_added = mapped_column(DateTime, nullable=False)
added_by = mapped_column(String, nullable=False)
prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON, nullable=True)
parameters: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True)
prompt_group_id: Mapped[Optional[uuid.UUID]] = mapped_column(CustomUUID, nullable=True)
sequence: Mapped[Optional[int]] = mapped_column(INTEGER, nullable=True)
role: Mapped[ChatMessageRole] = mapped_column(String, nullable=True)
prompt_metadata: Mapped[dict[str, str | int] | None] = mapped_column(JSON, nullable=True)
parameters: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
prompt_group_id: Mapped[uuid.UUID | None] = mapped_column(CustomUUID, nullable=True)
sequence: Mapped[int | None] = mapped_column(INTEGER, nullable=True)
role: Mapped[ChatMessageRole | None] = mapped_column(String, nullable=True)
seed_type: Mapped[SeedType] = mapped_column(String, nullable=False, default="prompt")

def __init__(self, *, entry: Seed) -> None:
Expand All @@ -585,7 +581,7 @@ def __init__(self, *, entry: Seed) -> None:
self.data_type = entry.data_type
self.name = entry.name
self.dataset_name = entry.dataset_name
self.harm_categories = entry.harm_categories
self.harm_categories = list(entry.harm_categories) if entry.harm_categories else None
self.description = entry.description
self.authors = list(entry.authors) if entry.authors else None
self.groups = list(entry.groups) if entry.groups else None
Expand Down Expand Up @@ -713,12 +709,12 @@ class AttackResultEntry(Base):
conversation_id = mapped_column(String, nullable=False)
objective = mapped_column(Unicode, nullable=False)
attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False)
atomic_attack_identifier: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON, nullable=True)
atomic_attack_identifier: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
objective_sha256 = mapped_column(String, nullable=True)
last_response_id: Mapped[Optional[uuid.UUID]] = mapped_column(
last_response_id: Mapped[uuid.UUID | None] = mapped_column(
CustomUUID, ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"), nullable=True
)
last_score_id: Mapped[Optional[uuid.UUID]] = mapped_column(
last_score_id: Mapped[uuid.UUID | None] = mapped_column(
CustomUUID, ForeignKey(f"{ScoreEntry.__tablename__}.id"), nullable=True
)
executed_turns = mapped_column(INTEGER, nullable=False, default=0)
Expand All @@ -727,10 +723,10 @@ class AttackResultEntry(Base):
String, nullable=False, default="undetermined"
)
outcome_reason = mapped_column(String, nullable=True)
attack_metadata: Mapped[dict[str, Union[str, int, float, bool]]] = mapped_column(JSON, nullable=True)
attack_metadata: Mapped[dict[str, str | int | float | bool]] = mapped_column(JSON, nullable=True)
labels: Mapped[dict[str, str]] = mapped_column(JSON, nullable=True)
pruned_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True)
adversarial_chat_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True)
pruned_conversation_ids: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
adversarial_chat_conversation_ids: Mapped[list[str] | None] = mapped_column(JSON, nullable=True)
timestamp = mapped_column(DateTime, nullable=False)
# Version of PyRIT used when this attack result was created
# Nullable for backwards compatibility with existing databases
Expand All @@ -742,14 +738,14 @@ class AttackResultEntry(Base):
error_traceback = mapped_column(Unicode, nullable=True)

# Retry events (JSON-serialized list of RetryEvent dicts)
retry_events_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True)
retry_events_json: Mapped[str | None] = mapped_column(Unicode, nullable=True)
total_retries = mapped_column(INTEGER, nullable=True, default=0)

last_response: Mapped[Optional["PromptMemoryEntry"]] = relationship(
last_response: Mapped["PromptMemoryEntry | None"] = relationship(
"PromptMemoryEntry",
foreign_keys=[last_response_id],
)
last_score: Mapped[Optional["ScoreEntry"]] = relationship(
last_score: Mapped["ScoreEntry | None"] = relationship(
"ScoreEntry",
foreign_keys=[last_score_id],
)
Expand Down Expand Up @@ -820,7 +816,7 @@ def __init__(self, *, entry: AttackResult) -> None:
self.total_retries = entry.total_retries

@staticmethod
def _get_id_as_uuid(obj: Any) -> Optional[uuid.UUID]:
def _get_id_as_uuid(obj: Any) -> uuid.UUID | None:
"""
Safely extract and convert an object's id to UUID.

Expand Down Expand Up @@ -979,32 +975,35 @@ class ScenarioResultEntry(Base):
scenario_description = mapped_column(Unicode, nullable=True)
scenario_version = mapped_column(INTEGER, nullable=False, default=1)
pyrit_version = mapped_column(String, nullable=False)
scenario_init_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON, nullable=True)
scenario_init_data: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
objective_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False)
objective_scorer_identifier: Mapped[Optional[dict[str, str]]] = mapped_column(JSON, nullable=True)
objective_scorer_identifier: Mapped[dict[str, str] | None] = mapped_column(JSON, nullable=True)
scenario_run_state: Mapped[Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"]] = mapped_column(
String, nullable=False, default="CREATED"
)
attack_results_json: Mapped[str] = mapped_column(Unicode, nullable=False)
display_group_map_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True)
labels: Mapped[Optional[dict[str, str]]] = mapped_column(JSON, nullable=True)
display_group_map_json: Mapped[str | None] = mapped_column(Unicode, nullable=True)
labels: Mapped[dict[str, str] | None] = mapped_column(JSON, nullable=True)
number_tries: Mapped[int] = mapped_column(INTEGER, nullable=False, default=0)
completion_time = mapped_column(DateTime, nullable=False)
timestamp = mapped_column(DateTime, nullable=False)

# Pointer to failed attack result(s) — avoids scanning all attacks for error info
error_attack_result_ids_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True)
error_attack_result_ids_json: Mapped[str | None] = mapped_column(Unicode, nullable=True)

# Scenario-level error info (persisted so it survives process restarts)
error_message: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True)
error_type: Mapped[Optional[str]] = mapped_column(String, nullable=True)
error_message: Mapped[str | None] = mapped_column(Unicode, nullable=True)
error_type: Mapped[str | None] = mapped_column(String, nullable=True)

def __init__(self, *, entry: ScenarioResult) -> None:
"""
Initialize a ScenarioResultEntry from a ScenarioResult object.

Args:
entry (ScenarioResult): The scenario result object to convert into a database entry.

Raises:
ValueError: If ``entry.objective_target_identifier`` is ``None``.
"""
self.id = entry.id
self.scenario_name = entry.scenario_identifier.name
Expand All @@ -1013,6 +1012,8 @@ def __init__(self, *, entry: ScenarioResult) -> None:
self.pyrit_version = entry.scenario_identifier.pyrit_version
self.scenario_init_data = entry.scenario_identifier.init_data
# Convert ComponentIdentifier to dict for JSON storage
if entry.objective_target_identifier is None:
raise ValueError("ScenarioResult.objective_target_identifier is required for database storage")
self.objective_target_identifier = entry.objective_target_identifier.to_dict(
max_value_length=MAX_IDENTIFIER_VALUE_LENGTH
)
Expand Down Expand Up @@ -1103,7 +1104,7 @@ def get_scenario_result(self) -> ScenarioResult:
scenario_identifier=scenario_identifier,
objective_target_identifier=target_identifier,
attack_results=attack_results,
objective_scorer_identifier=scorer_identifier, # type: ignore[ty:invalid-argument-type]
objective_scorer_identifier=scorer_identifier,
scenario_run_state=self.scenario_run_state,
labels=self.labels,
creation_time=self.timestamp,
Expand Down
Loading
Loading