diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 1e48c03cf5..ec6323d04b 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -5,7 +5,7 @@ import logging import uuid from datetime import datetime, timezone -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from pydantic import BaseModel, ConfigDict from sqlalchemy import ( @@ -64,15 +64,15 @@ MAX_IDENTIFIER_VALUE_LENGTH: int = 80 -def _ensure_utc(dt: Optional[datetime]) -> Optional[datetime]: +def _ensure_utc(dt: datetime | None) -> datetime | None: """ Attach UTC tzinfo to a naive datetime (as returned by SQLite). Args: - dt (Optional[datetime]): The datetime to normalize, or None. + dt (datetime | None): The datetime to normalize, or None. Returns: - Optional[datetime]: The datetime with UTC tzinfo attached if it was naive, or None. + datetime | None: The datetime with UTC tzinfo attached if it was naive, or None. """ if dt is not None and dt.tzinfo is None: return dt.replace(tzinfo=timezone.utc) @@ -103,7 +103,7 @@ def load_dialect_impl(self, dialect: Any) -> Any: return dialect.type_descriptor(CHAR(36)) return dialect.type_descriptor(Uuid()) - def process_bind_param(self, value: Optional[uuid.UUID], dialect: Any) -> Optional[str]: + def process_bind_param(self, value: uuid.UUID | None, dialect: Any) -> str | None: """ Process a parameter value before binding it to a database statement. @@ -116,7 +116,7 @@ def process_bind_param(self, value: Optional[uuid.UUID], dialect: Any) -> Option """ return str(value) if value else None - def process_result_value(self, value: uuid.UUID | str | None, dialect: Any) -> Optional[uuid.UUID]: + def process_result_value(self, value: uuid.UUID | str | None, dialect: Any) -> uuid.UUID | None: """ Process a result value after it has been retrieved from the database. @@ -187,22 +187,18 @@ class PromptMemoryEntry(Base): sequence = mapped_column(INTEGER, nullable=False) timestamp = mapped_column(DateTime, nullable=False) labels: Mapped[dict[str, str]] = mapped_column(JSON) - prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON) - targeted_harm_categories: Mapped[Optional[list[str]]] = mapped_column(JSON) - converter_identifiers: Mapped[Optional[list[dict[str, str]]]] = mapped_column(JSON) + prompt_metadata: Mapped[dict[str, str | int]] = mapped_column(JSON) + targeted_harm_categories: Mapped[list[str] | None] = mapped_column(JSON) + converter_identifiers: Mapped[list[dict[str, str]] | None] = mapped_column(JSON) prompt_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON) attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON) response_error: Mapped[Literal["blocked", "none", "processing", "unknown"]] = mapped_column(String, nullable=True) - original_value_data_type: Mapped[Literal["text", "image_path", "audio_path", "url", "error"]] = mapped_column( - String, nullable=False - ) + original_value_data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False) original_value = mapped_column(Unicode, nullable=False) original_value_sha256 = mapped_column(String) - converted_value_data_type: Mapped[Literal["text", "image_path", "audio_path", "url", "error"]] = mapped_column( - String, nullable=False - ) + converted_value_data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False) converted_value = mapped_column(Unicode) converted_value_sha256 = mapped_column(String) @@ -272,7 +268,7 @@ def get_message_piece(self) -> MessagePiece: MessagePiece: The reconstructed message piece with all its data and scores. """ # Reconstruct ComponentIdentifiers with the stored pyrit_version - converter_ids: Optional[list[ComponentIdentifier]] = None + converter_ids: list[ComponentIdentifier] | None = None stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION if self.converter_identifiers: converter_ids = [ @@ -281,14 +277,14 @@ def get_message_piece(self) -> MessagePiece: ] # Reconstruct ComponentIdentifier with the stored pyrit_version - target_id: Optional[ComponentIdentifier] = None + target_id: ComponentIdentifier | None = None if self.prompt_target_identifier: target_id = ComponentIdentifier.from_dict( {**self.prompt_target_identifier, "pyrit_version": stored_version} ) # Reconstruct ComponentIdentifier with the stored pyrit_version - attack_id: Optional[ComponentIdentifier] = None + attack_id: ComponentIdentifier | None = None if self.attack_identifier: attack_id = ComponentIdentifier.from_dict({**self.attack_identifier, "pyrit_version": stored_version}) @@ -374,9 +370,9 @@ class ScoreEntry(Base): score_value = mapped_column(String, nullable=False) score_value_description = mapped_column(String, nullable=True) score_type: Mapped[Literal["true_false", "float_scale", "unknown"]] = mapped_column(String, nullable=False) - score_category: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) + score_category: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) score_rationale = mapped_column(String, nullable=True) - score_metadata: Mapped[dict[str, Union[str, int, float]]] = mapped_column(JSON) + score_metadata: Mapped[dict[str, str | int | float]] = mapped_column(JSON) scorer_class_identifier: Mapped[dict[str, Any]] = mapped_column(JSON) prompt_request_response_id = mapped_column(CustomUUID, ForeignKey(f"{PromptMemoryEntry.__tablename__}.id")) timestamp = mapped_column(DateTime, nullable=False) @@ -400,7 +396,7 @@ def __init__(self, *, entry: Score) -> None: self.score_type = entry.score_type self.score_category = entry.score_category self.score_rationale = entry.score_rationale - self.score_metadata = entry.score_metadata + self.score_metadata = entry.score_metadata or {} normalized_scorer = entry.scorer_class_identifier # Ensure eval_hash is set before truncation so it survives the DB round-trip if normalized_scorer.eval_hash is None: @@ -550,18 +546,18 @@ class SeedEntry(Base): data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False) name = mapped_column(String, nullable=True) dataset_name = mapped_column(String, nullable=True) - harm_categories: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) + harm_categories: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) description = mapped_column(String, nullable=True) - authors: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) - groups: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) + authors: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) + groups: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) source = mapped_column(String, nullable=True) date_added = mapped_column(DateTime, nullable=False) added_by = mapped_column(String, nullable=False) - prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON, nullable=True) - parameters: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) - prompt_group_id: Mapped[Optional[uuid.UUID]] = mapped_column(CustomUUID, nullable=True) - sequence: Mapped[Optional[int]] = mapped_column(INTEGER, nullable=True) - role: Mapped[ChatMessageRole] = mapped_column(String, nullable=True) + prompt_metadata: Mapped[dict[str, str | int] | None] = mapped_column(JSON, nullable=True) + parameters: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) + prompt_group_id: Mapped[uuid.UUID | None] = mapped_column(CustomUUID, nullable=True) + sequence: Mapped[int | None] = mapped_column(INTEGER, nullable=True) + role: Mapped[ChatMessageRole | None] = mapped_column(String, nullable=True) seed_type: Mapped[SeedType] = mapped_column(String, nullable=False, default="prompt") def __init__(self, *, entry: Seed) -> None: @@ -585,7 +581,7 @@ def __init__(self, *, entry: Seed) -> None: self.data_type = entry.data_type self.name = entry.name self.dataset_name = entry.dataset_name - self.harm_categories = entry.harm_categories + self.harm_categories = list(entry.harm_categories) if entry.harm_categories else None self.description = entry.description self.authors = list(entry.authors) if entry.authors else None self.groups = list(entry.groups) if entry.groups else None @@ -713,12 +709,12 @@ class AttackResultEntry(Base): conversation_id = mapped_column(String, nullable=False) objective = mapped_column(Unicode, nullable=False) attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False) - atomic_attack_identifier: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON, nullable=True) + atomic_attack_identifier: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) objective_sha256 = mapped_column(String, nullable=True) - last_response_id: Mapped[Optional[uuid.UUID]] = mapped_column( + last_response_id: Mapped[uuid.UUID | None] = mapped_column( CustomUUID, ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"), nullable=True ) - last_score_id: Mapped[Optional[uuid.UUID]] = mapped_column( + last_score_id: Mapped[uuid.UUID | None] = mapped_column( CustomUUID, ForeignKey(f"{ScoreEntry.__tablename__}.id"), nullable=True ) executed_turns = mapped_column(INTEGER, nullable=False, default=0) @@ -727,10 +723,10 @@ class AttackResultEntry(Base): String, nullable=False, default="undetermined" ) outcome_reason = mapped_column(String, nullable=True) - attack_metadata: Mapped[dict[str, Union[str, int, float, bool]]] = mapped_column(JSON, nullable=True) + attack_metadata: Mapped[dict[str, str | int | float | bool]] = mapped_column(JSON, nullable=True) labels: Mapped[dict[str, str]] = mapped_column(JSON, nullable=True) - pruned_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) - adversarial_chat_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) + pruned_conversation_ids: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) + adversarial_chat_conversation_ids: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) timestamp = mapped_column(DateTime, nullable=False) # Version of PyRIT used when this attack result was created # Nullable for backwards compatibility with existing databases @@ -742,14 +738,14 @@ class AttackResultEntry(Base): error_traceback = mapped_column(Unicode, nullable=True) # Retry events (JSON-serialized list of RetryEvent dicts) - retry_events_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) + retry_events_json: Mapped[str | None] = mapped_column(Unicode, nullable=True) total_retries = mapped_column(INTEGER, nullable=True, default=0) - last_response: Mapped[Optional["PromptMemoryEntry"]] = relationship( + last_response: Mapped["PromptMemoryEntry | None"] = relationship( "PromptMemoryEntry", foreign_keys=[last_response_id], ) - last_score: Mapped[Optional["ScoreEntry"]] = relationship( + last_score: Mapped["ScoreEntry | None"] = relationship( "ScoreEntry", foreign_keys=[last_score_id], ) @@ -820,7 +816,7 @@ def __init__(self, *, entry: AttackResult) -> None: self.total_retries = entry.total_retries @staticmethod - def _get_id_as_uuid(obj: Any) -> Optional[uuid.UUID]: + def _get_id_as_uuid(obj: Any) -> uuid.UUID | None: """ Safely extract and convert an object's id to UUID. @@ -979,25 +975,25 @@ class ScenarioResultEntry(Base): scenario_description = mapped_column(Unicode, nullable=True) scenario_version = mapped_column(INTEGER, nullable=False, default=1) pyrit_version = mapped_column(String, nullable=False) - scenario_init_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON, nullable=True) + scenario_init_data: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) objective_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False) - objective_scorer_identifier: Mapped[Optional[dict[str, str]]] = mapped_column(JSON, nullable=True) + objective_scorer_identifier: Mapped[dict[str, str] | None] = mapped_column(JSON, nullable=True) scenario_run_state: Mapped[Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"]] = mapped_column( String, nullable=False, default="CREATED" ) attack_results_json: Mapped[str] = mapped_column(Unicode, nullable=False) - display_group_map_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) - labels: Mapped[Optional[dict[str, str]]] = mapped_column(JSON, nullable=True) + display_group_map_json: Mapped[str | None] = mapped_column(Unicode, nullable=True) + labels: Mapped[dict[str, str] | None] = mapped_column(JSON, nullable=True) number_tries: Mapped[int] = mapped_column(INTEGER, nullable=False, default=0) completion_time = mapped_column(DateTime, nullable=False) timestamp = mapped_column(DateTime, nullable=False) # Pointer to failed attack result(s) — avoids scanning all attacks for error info - error_attack_result_ids_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) + error_attack_result_ids_json: Mapped[str | None] = mapped_column(Unicode, nullable=True) # Scenario-level error info (persisted so it survives process restarts) - error_message: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) - error_type: Mapped[Optional[str]] = mapped_column(String, nullable=True) + error_message: Mapped[str | None] = mapped_column(Unicode, nullable=True) + error_type: Mapped[str | None] = mapped_column(String, nullable=True) def __init__(self, *, entry: ScenarioResult) -> None: """ @@ -1005,6 +1001,9 @@ def __init__(self, *, entry: ScenarioResult) -> None: Args: entry (ScenarioResult): The scenario result object to convert into a database entry. + + Raises: + ValueError: If ``entry.objective_target_identifier`` is ``None``. """ self.id = entry.id self.scenario_name = entry.scenario_identifier.name @@ -1013,6 +1012,8 @@ def __init__(self, *, entry: ScenarioResult) -> None: self.pyrit_version = entry.scenario_identifier.pyrit_version self.scenario_init_data = entry.scenario_identifier.init_data # Convert ComponentIdentifier to dict for JSON storage + if entry.objective_target_identifier is None: + raise ValueError("ScenarioResult.objective_target_identifier is required for database storage") self.objective_target_identifier = entry.objective_target_identifier.to_dict( max_value_length=MAX_IDENTIFIER_VALUE_LENGTH ) @@ -1103,7 +1104,7 @@ def get_scenario_result(self) -> ScenarioResult: scenario_identifier=scenario_identifier, objective_target_identifier=target_identifier, attack_results=attack_results, - objective_scorer_identifier=scorer_identifier, # type: ignore[ty:invalid-argument-type] + objective_scorer_identifier=scorer_identifier, scenario_run_state=self.scenario_run_state, labels=self.labels, creation_time=self.timestamp, diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 123c83a918..9ec912ff48 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -8,19 +8,17 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, TypeVar - +from typing import Any, Optional, TypeVar + +from pyrit.common.deprecation import print_deprecation_message +from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier +from pyrit.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models.message_piece import MessagePiece +from pyrit.models.retry_event import RetryEvent +from pyrit.models.score import Score from pyrit.models.strategy_result import StrategyResult -if TYPE_CHECKING: - from pyrit.identifiers.component_identifier import ComponentIdentifier - from pyrit.models.conversation_reference import ConversationReference - from pyrit.models.message_piece import MessagePiece - from pyrit.models.retry_event import RetryEvent - from pyrit.models.score import Score - -from pyrit.models.conversation_reference import ConversationType - AttackResultT = TypeVar("AttackResultT", bound="AttackResult") @@ -119,8 +117,6 @@ def attack_identifier(self) -> Optional[ComponentIdentifier]: Optional[ComponentIdentifier]: The attack strategy identifier, or ``None``. """ - from pyrit.common.deprecation import print_deprecation_message - print_deprecation_message( old_item="AttackResult.attack_identifier", new_item="AttackResult.atomic_attack_identifier or get_attack_strategy_identifier()", @@ -224,6 +220,82 @@ def __str__(self) -> str: """ return f"AttackResult: {self.conversation_id}: {self.outcome.value}: {self.objective[:50]}..." + def to_dict(self) -> dict[str, Any]: + """ + Serialize this attack result to a JSON-compatible dictionary. + + Returns: + dict[str, Any]: Serialized payload suitable for REST APIs or persistence. + """ + return { + "conversation_id": self.conversation_id, + "objective": self.objective, + "attack_result_id": self.attack_result_id, + "atomic_attack_identifier": ( + self.atomic_attack_identifier.to_dict() if self.atomic_attack_identifier else None + ), + "last_response": self.last_response.to_dict() if self.last_response else None, + "last_score": self.last_score.to_dict() if self.last_score else None, + "executed_turns": self.executed_turns, + "execution_time_ms": self.execution_time_ms, + "outcome": self.outcome.value, + "outcome_reason": self.outcome_reason, + "timestamp": self.timestamp.isoformat() if self.timestamp else None, + "related_conversations": sorted( + [ + ref.to_dict() if isinstance(ref, ConversationReference) else ref + for ref in self.related_conversations + ], + key=lambda r: r["conversation_id"] if isinstance(r, dict) else "", + ), + "metadata": self.metadata, + "labels": self.labels, + "error_message": self.error_message, + "error_type": self.error_type, + "error_traceback": self.error_traceback, + "retry_events": [e.to_dict() for e in self.retry_events], + "total_retries": self.total_retries, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> AttackResult: + """ + Reconstruct an AttackResult from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + AttackResult: Reconstructed instance. + """ + return cls( + conversation_id=data["conversation_id"], + objective=data["objective"], + attack_result_id=data.get("attack_result_id", str(uuid.uuid4())), + atomic_attack_identifier=( + ComponentIdentifier.from_dict(data["atomic_attack_identifier"]) + if data.get("atomic_attack_identifier") + else None + ), + last_response=(MessagePiece.from_dict(data["last_response"]) if data.get("last_response") else None), + last_score=Score.from_dict(data["last_score"]) if data.get("last_score") else None, + executed_turns=data.get("executed_turns", 0), + execution_time_ms=data.get("execution_time_ms", 0), + outcome=AttackOutcome(data.get("outcome", "undetermined")), + outcome_reason=data.get("outcome_reason"), + timestamp=( + datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else datetime.now(timezone.utc) + ), + related_conversations={ConversationReference.from_dict(r) for r in data.get("related_conversations", [])}, + metadata=data.get("metadata", {}), + labels=data.get("labels", {}), + error_message=data.get("error_message"), + error_type=data.get("error_type"), + error_traceback=data.get("error_traceback"), + retry_events=[RetryEvent.from_dict(e) for e in data.get("retry_events", [])], + total_retries=data.get("total_retries", 0), + ) + def _add_attack_identifier_compat(cls: type) -> type: """ @@ -246,16 +318,12 @@ def _add_attack_identifier_compat(cls: type) -> type: def wrapped_init(self: Any, *args: Any, **kwargs: Any) -> None: attack_identifier = kwargs.pop("attack_identifier", None) if attack_identifier is not None: - from pyrit.common.deprecation import print_deprecation_message - print_deprecation_message( old_item="AttackResult(attack_identifier=...)", new_item="AttackResult(atomic_attack_identifier=...)", removed_in="0.15.0", ) if kwargs.get("atomic_attack_identifier") is None: - from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier - kwargs["atomic_attack_identifier"] = build_atomic_attack_identifier( attack_identifier=attack_identifier, ) diff --git a/pyrit/models/conversation_reference.py b/pyrit/models/conversation_reference.py index 0932cca051..95c7b9d5eb 100644 --- a/pyrit/models/conversation_reference.py +++ b/pyrit/models/conversation_reference.py @@ -36,6 +36,36 @@ def __hash__(self) -> int: """ return hash(self.conversation_id) + def to_dict(self) -> dict[str, str | None]: + """ + Serialize to a JSON-compatible dictionary. + + Returns: + dict[str, str | None]: Dictionary with conversation_id, conversation_type, and description. + """ + return { + "conversation_id": self.conversation_id, + "conversation_type": self.conversation_type.value, + "description": self.description, + } + + @classmethod + def from_dict(cls, data: dict[str, str | None]) -> ConversationReference: + """ + Reconstruct a ConversationReference from a dictionary. + + Args: + data (dict[str, str | None]): Dictionary as produced by to_dict(). + + Returns: + ConversationReference: Reconstructed instance. + """ + return cls( + conversation_id=str(data["conversation_id"]), + conversation_type=ConversationType(data["conversation_type"]), + description=data.get("description"), + ) + def __eq__(self, other: object) -> bool: """ Compare two references by conversation ID. diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 16a77efaab..8e19059d28 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -6,7 +6,7 @@ import copy import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pyrit.common.utils import combine_dict from pyrit.models.message_piece import MessagePiece @@ -285,28 +285,41 @@ def __str__(self) -> str: def to_dict(self) -> dict[str, object]: """ - Convert the message to a dictionary representation. + Convert the message to a dictionary representation including all piece details. - Returns: - dict: A dictionary with 'role', 'converted_value', 'conversation_id', 'sequence', - and 'converted_value_data_type' keys. + Serializes each piece individually via MessagePiece.to_dict(). This is the format + expected by from_dict(). + Returns: + dict[str, object]: Dictionary with 'role', 'is_simulated', 'conversation_id', + 'sequence', and 'pieces' (list of MessagePiece.to_dict() dicts). """ - if len(self.message_pieces) == 1: - converted_value: str | list[str] = self.message_pieces[0].converted_value - converted_value_data_type: str | list[str] = self.message_pieces[0].converted_value_data_type - else: - converted_value = [piece.converted_value for piece in self.message_pieces] - converted_value_data_type = [piece.converted_value_data_type for piece in self.message_pieces] - return { "role": self.api_role, - "converted_value": converted_value, + "is_simulated": self.is_simulated, "conversation_id": self.conversation_id, "sequence": self.sequence, - "converted_value_data_type": converted_value_data_type, + "pieces": [piece.to_dict() for piece in self.message_pieces], } + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Message: + """ + Reconstruct a Message from a dictionary. + + Expects the format produced by to_dict(), which includes a 'pieces' key + containing a list of MessagePiece dictionaries. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + Message: Reconstructed instance. + """ + pieces_data = data.get("pieces", []) + message_pieces = [MessagePiece.from_dict(p) for p in pieces_data] + return cls(message_pieces, skip_validation=True) + @staticmethod def get_all_values(messages: Sequence[Message]) -> list[str]: """ diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 0f0cf9c1a0..b8f9533be1 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -5,16 +5,17 @@ import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args from uuid import uuid4 from pyrit.common.deprecation import print_deprecation_message +from pyrit.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.data_type_serializer import data_serializer_factory from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError +from pyrit.models.score import Score if TYPE_CHECKING: - from pyrit.identifiers.component_identifier import ComponentIdentifier from pyrit.models.message import Message - from pyrit.models.score import Score Originator = Literal["attack", "converter", "undefined", "scorer"] """Deprecated: The Originator type alias will be removed in a future release.""" @@ -218,8 +219,6 @@ async def set_sha256_values_async(self) -> None: Note, this method is async due to the blob retrieval. And because of that, we opted to take it out of main and setter functions. The disadvantage is that it must be explicitly called. """ - from pyrit.models.data_type_serializer import data_serializer_factory - original_serializer = data_serializer_factory( category="prompt-memory-entries", data_type=self.original_value_data_type, @@ -354,6 +353,54 @@ def __str__(self) -> str: __repr__ = __str__ + @classmethod + def from_dict(cls, data: dict[str, Any]) -> MessagePiece: + """ + Reconstruct a MessagePiece from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + MessagePiece: Reconstructed instance. + """ + return cls( + id=data.get("id"), + role=data.get("role", "user"), + conversation_id=data.get("conversation_id"), + sequence=data.get("sequence", -1), + timestamp=(datetime.fromisoformat(str(data["timestamp"])) if data.get("timestamp") else None), + labels=data.get("labels"), + targeted_harm_categories=data.get("targeted_harm_categories"), + prompt_metadata=data.get("prompt_metadata"), + converter_identifiers=( + [ComponentIdentifier.from_dict(c) for c in data["converter_identifiers"]] + if data.get("converter_identifiers") + else None + ), + prompt_target_identifier=( + ComponentIdentifier.from_dict(data["prompt_target_identifier"]) + if data.get("prompt_target_identifier") + else None + ), + attack_identifier=( + ComponentIdentifier.from_dict(data["attack_identifier"]) if data.get("attack_identifier") else None + ), + scorer_identifier=( + ComponentIdentifier.from_dict(data["scorer_identifier"]) if data.get("scorer_identifier") else None + ), + original_value_data_type=data.get("original_value_data_type", "text"), + original_value=data.get("original_value", ""), + original_value_sha256=data.get("original_value_sha256"), + converted_value_data_type=data.get("converted_value_data_type"), + converted_value=data.get("converted_value"), + converted_value_sha256=data.get("converted_value_sha256"), + response_error=data.get("response_error", "none"), + originator=data.get("originator", "undefined"), + original_prompt_id=(uuid.UUID(str(data["original_prompt_id"])) if data.get("original_prompt_id") else None), + scores=([Score.from_dict(s) for s in data["scores"]] if data.get("scores") else None), + ) + def __eq__(self, other: object) -> bool: """ Compare this message piece with another for semantic equality. diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index 88a67f5991..c675719fc6 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import logging import uuid from datetime import datetime, timezone @@ -46,6 +48,40 @@ def __init__( self.pyrit_version = pyrit_version if pyrit_version is not None else pyrit.__version__ self.init_data = init_data + def to_dict(self) -> dict[str, Any]: + """ + Serialize to a JSON-compatible dictionary. + + Returns: + dict[str, Any]: Serialized payload. + """ + return { + "name": self.name, + "description": self.description, + "version": self.version, + "pyrit_version": self.pyrit_version, + "init_data": self.init_data, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ScenarioIdentifier: + """ + Reconstruct a ScenarioIdentifier from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + ScenarioIdentifier: Reconstructed instance. + """ + return cls( + name=data["name"], + description=data.get("description", ""), + scenario_version=data.get("version", 1), + init_data=data.get("init_data"), + pyrit_version=data.get("pyrit_version"), + ) + ScenarioRunState = Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"] @@ -59,9 +95,9 @@ def __init__( self, *, scenario_identifier: ScenarioIdentifier, - objective_target_identifier: "ComponentIdentifier", + objective_target_identifier: ComponentIdentifier | None, attack_results: dict[str, list[AttackResult]], - objective_scorer_identifier: "ComponentIdentifier", + objective_scorer_identifier: ComponentIdentifier | None, scenario_run_state: ScenarioRunState = "CREATED", labels: dict[str, str] | None = None, creation_time: datetime | None = None, @@ -240,7 +276,7 @@ def normalize_scenario_name(scenario_name: str) -> str: # Already PascalCase or other format, return as-is return scenario_name - def get_scorer_evaluation_metrics(self) -> "ScorerMetrics | None": + def get_scorer_evaluation_metrics(self) -> ScorerMetrics | None: """ Get the evaluation metrics for the scenario's scorer from the scorer evaluation registry. @@ -260,3 +296,72 @@ def get_scorer_evaluation_metrics(self) -> "ScorerMetrics | None": eval_hash = ScorerEvaluationIdentifier(self.objective_scorer_identifier).eval_hash return find_objective_metrics_by_eval_hash(eval_hash=eval_hash) + + def to_dict(self) -> dict[str, Any]: + """ + Serialize this scenario result to a JSON-compatible dictionary. + + Returns: + dict[str, Any]: Serialized payload suitable for REST APIs or persistence. + """ + return { + "id": str(self.id), + "scenario_identifier": self.scenario_identifier.to_dict(), + "objective_target_identifier": ( + self.objective_target_identifier.to_dict() if self.objective_target_identifier else None + ), + "objective_scorer_identifier": ( + self.objective_scorer_identifier.to_dict() if self.objective_scorer_identifier else None + ), + "scenario_run_state": self.scenario_run_state, + "attack_results": {name: [r.to_dict() for r in results] for name, results in self.attack_results.items()}, + "display_group_map": self._display_group_map, + "labels": self.labels, + "creation_time": self.creation_time.isoformat() if self.creation_time else None, + "completion_time": self.completion_time.isoformat() if self.completion_time else None, + "number_tries": self.number_tries, + "error_attack_result_ids": self.error_attack_result_ids, + "error_message": self.error_message, + "error_type": self.error_type, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ScenarioResult: + """ + Reconstruct a ScenarioResult from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + ScenarioResult: Reconstructed instance. + """ + from pyrit.identifiers.component_identifier import ComponentIdentifier + + return cls( + id=uuid.UUID(data["id"]) if data.get("id") else None, + scenario_identifier=ScenarioIdentifier.from_dict(data["scenario_identifier"]), + objective_target_identifier=( + ComponentIdentifier.from_dict(data["objective_target_identifier"]) + if data.get("objective_target_identifier") + else None + ), + objective_scorer_identifier=( + ComponentIdentifier.from_dict(data["objective_scorer_identifier"]) + if data.get("objective_scorer_identifier") + else None + ), + scenario_run_state=data.get("scenario_run_state", "CREATED"), + attack_results={ + name: [AttackResult.from_dict(r) for r in results] + for name, results in data.get("attack_results", {}).items() + }, + display_group_map=data.get("display_group_map"), + labels=data.get("labels"), + creation_time=(datetime.fromisoformat(data["creation_time"]) if data.get("creation_time") else None), + completion_time=(datetime.fromisoformat(data["completion_time"]) if data.get("completion_time") else None), + number_tries=data.get("number_tries", 0), + error_attack_result_ids=data.get("error_attack_result_ids"), + error_message=data.get("error_message"), + error_type=data.get("error_type"), + ) diff --git a/pyrit/models/score.py b/pyrit/models/score.py index 606ce89947..726a90d57b 100644 --- a/pyrit/models/score.py +++ b/pyrit/models/score.py @@ -194,6 +194,33 @@ def __str__(self) -> str: __repr__ = __str__ + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Score: + """ + Reconstruct a Score from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + Score: Reconstructed instance. + """ + from pyrit.identifiers.component_identifier import ComponentIdentifier + + return cls( + id=data.get("id"), + score_value=data["score_value"], + score_value_description=data.get("score_value_description", ""), + score_type=data["score_type"], + score_category=data.get("score_category"), + score_rationale=data.get("score_rationale", ""), + score_metadata=data.get("score_metadata"), + scorer_class_identifier=ComponentIdentifier.from_dict(data["scorer_class_identifier"]), + message_piece_id=data["message_piece_id"], + timestamp=datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else None, + objective=data.get("objective"), + ) + @dataclass class UnvalidatedScore: diff --git a/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py b/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py index 591be1c015..875d5d583c 100644 --- a/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py +++ b/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py @@ -62,6 +62,8 @@ async def test_generic_squash_normalize_to_dicts_async(): assert len(result) == 1 assert isinstance(result[0], dict) assert result[0]["role"] == "user" - assert "### Instructions ###" in result[0]["converted_value"] - assert "System message" in result[0]["converted_value"] - assert "User message" in result[0]["converted_value"] + assert "pieces" in result[0] + assert len(result[0]["pieces"]) == 1 + assert "### Instructions ###" in result[0]["pieces"][0]["converted_value"] + assert "System message" in result[0]["pieces"][0]["converted_value"] + assert "User message" in result[0]["pieces"][0]["converted_value"] diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index 874d924846..a2db52f53d 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -8,7 +8,10 @@ from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory.memory_models import AttackResultEntry from pyrit.models.attack_result import AttackOutcome, AttackResult +from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models.message_piece import MessagePiece from pyrit.models.retry_event import RetryEvent +from pyrit.models.score import Score class TestAttackResultDeprecation: @@ -351,3 +354,84 @@ def test_traceback_truncation(self) -> None: ) entry = AttackResultEntry(entry=original) assert len(entry.error_traceback) == 10240 + + +def test_to_dict_from_dict_roundtrip(): + scorer_id = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + ) + target_id = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.example.com"}, + ) + attack_id = ComponentIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack", + ) + last_response = MessagePiece( + id="12345678-aaaa-bbbb-cccc-123456789abc", + role="assistant", + original_value="Sure, here is the answer.", + conversation_id="conv-1", + sequence=1, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + prompt_target_identifier=target_id, + attack_identifier=attack_id, + ) + last_score = Score( + score_value="true", + score_value_description="met objective", + score_type="true_false", + score_rationale="objective clearly met", + scorer_class_identifier=scorer_id, + message_piece_id="12345678-aaaa-bbbb-cccc-123456789abc", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ) + original = AttackResult( + conversation_id="conv-1", + objective="Generate harmful content", + attack_result_id="ar-001", + atomic_attack_identifier=attack_id, + last_response=last_response, + last_score=last_score, + executed_turns=5, + execution_time_ms=2500, + outcome=AttackOutcome.SUCCESS, + outcome_reason="Objective was achieved", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + related_conversations={ + ConversationReference( + conversation_id="conv-2", + conversation_type=ConversationType.PRUNED, + description="pruned branch", + ), + ConversationReference( + conversation_id="conv-3", + conversation_type=ConversationType.SCORE, + description="scoring conversation", + ), + }, + metadata={"model": "gpt-4", "temperature": 0.7}, + labels={"category": "violence", "severity": "high"}, + error_message="partial error", + error_type="RuntimeError", + error_traceback="Traceback ...\n File ...", + retry_events=[ + RetryEvent( + attempt_number=1, + function_name="send_prompt", + exception_type="TimeoutError", + exception_message="Request timed out", + component_role="target", + component_name="OpenAIChatTarget", + endpoint="https://api.example.com", + elapsed_seconds=30.5, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + ], + total_retries=1, + ) + roundtripped = AttackResult.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_conversation_reference.py b/tests/unit/models/test_conversation_reference.py index 5bf4e28335..2f7a559ad2 100644 --- a/tests/unit/models/test_conversation_reference.py +++ b/tests/unit/models/test_conversation_reference.py @@ -76,3 +76,13 @@ def test_conversation_reference_usable_as_dict_key(): d = {ref: "value"} lookup_ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL) assert d[lookup_ref] == "value" + + +def test_to_dict_from_dict_roundtrip(): + original = ConversationReference( + conversation_id="conv-123", + conversation_type=ConversationType.ADVERSARIAL, + description="main adversarial conversation", + ) + roundtripped = ConversationReference.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_message.py b/tests/unit/models/test_message.py index 49d43db346..fb75b73cea 100644 --- a/tests/unit/models/test_message.py +++ b/tests/unit/models/test_message.py @@ -227,10 +227,12 @@ def test_message_to_dict() -> None: result = message.to_dict() assert result["role"] == "user" - assert result["converted_value"] == "Hello world" + assert result["is_simulated"] is False assert "conversation_id" in result assert "sequence" in result - assert result["converted_value_data_type"] == "text" + assert len(result["pieces"]) == 1 + assert result["pieces"][0]["converted_value"] == "Hello world" + assert result["pieces"][0]["converted_value_data_type"] == "text" class TestMessageSimulatedAssistantRole: @@ -299,3 +301,28 @@ def test_set_simulated_role_only_changes_assistant_role(self) -> None: for piece in message.message_pieces: assert piece._role == "user" assert piece.is_simulated is False + + +def test_to_dict_from_dict_roundtrip(): + from datetime import datetime, timezone + + pieces = [ + MessagePiece( + role="user", + original_value="What is the capital of France?", + conversation_id="conv-rt", + sequence=0, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + MessagePiece( + role="user", + original_value="image_link.png", + original_value_data_type="image_path", + conversation_id="conv-rt", + sequence=0, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + ] + original = Message(message_pieces=pieces) + roundtripped = Message.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index 1a6ebf30b4..779430d886 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -1172,3 +1172,55 @@ def test_does_not_overwrite_non_lineage_fields(self): assert target.id == original_id assert target._role == original_role assert target.original_value == original_value + + +def test_to_dict_from_dict_roundtrip(): + from datetime import datetime, timezone + + scorer_id = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + ) + target_id = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.example.com"}, + ) + attack_id = ComponentIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack", + ) + converter_id = ComponentIdentifier( + class_name="Base64Converter", + class_module="pyrit.prompt_converter", + ) + score = Score( + score_value="true", + score_value_description="met objective", + score_type="true_false", + score_rationale="clearly met", + scorer_class_identifier=scorer_id, + message_piece_id="mp-score-ref", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ) + original = MessagePiece( + id="12345678-aaaa-bbbb-cccc-000000000001", + role="assistant", + original_value="Hello world", + original_value_sha256="abc123", + converted_value="SGVsbG8gd29ybGQ=", + converted_value_sha256="def456", + conversation_id="conv-1", + sequence=2, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + prompt_metadata={"doc_type": "text"}, + converter_identifiers=[converter_id], + prompt_target_identifier=target_id, + attack_identifier=attack_id, + original_value_data_type="text", + converted_value_data_type="text", + response_error="none", + original_prompt_id=uuid.UUID("12345678-1234-1234-1234-123456789abc"), + ) + roundtripped = MessagePiece.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_scenario_result.py b/tests/unit/models/test_scenario_result.py index 02af031429..160279ced8 100644 --- a/tests/unit/models/test_scenario_result.py +++ b/tests/unit/models/test_scenario_result.py @@ -186,3 +186,85 @@ def test_error_attack_result_ids_stored(self): error_attack_result_ids=["id-1", "id-2"], ) assert sr.error_attack_result_ids == ["id-1", "id-2"] + + +def test_scenario_identifier_to_dict_from_dict_roundtrip(): + original = ScenarioIdentifier( + name="ContentHarms", + description="Tests content harm scenarios", + scenario_version=3, + init_data={"max_turns": 5, "strategy": "crescendo"}, + pyrit_version="0.14.0", + ) + roundtripped = ScenarioIdentifier.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() + + +def test_scenario_result_to_dict_from_dict_roundtrip(): + from datetime import datetime, timezone + + from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models.retry_event import RetryEvent + + scenario_id = ScenarioIdentifier( + name="ContentHarms", + description="Tests content harm scenarios", + scenario_version=2, + pyrit_version="0.14.0", + ) + target_id = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.example.com"}, + ) + scorer_id = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + ) + attack_result = AttackResult( + conversation_id="conv-1", + objective="test objective", + outcome=AttackOutcome.SUCCESS, + outcome_reason="Objective achieved", + executed_turns=3, + execution_time_ms=1500, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + related_conversations={ + ConversationReference( + conversation_id="conv-2", + conversation_type=ConversationType.PRUNED, + description="pruned branch", + ), + }, + metadata={"model": "gpt-4"}, + labels={"category": "violence"}, + retry_events=[ + RetryEvent( + attempt_number=1, + function_name="send_prompt", + exception_type="TimeoutError", + exception_message="timed out", + component_role="target", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + ], + total_retries=1, + ) + original = ScenarioResult( + id=uuid.UUID("12345678-1234-1234-1234-123456789abc"), + scenario_identifier=scenario_id, + objective_target_identifier=target_id, + objective_scorer_identifier=scorer_id, + scenario_run_state="COMPLETED", + attack_results={"crescendo": [attack_result]}, + display_group_map={"crescendo": "Crescendo Attack"}, + labels={"env": "test"}, + creation_time=datetime(2026, 1, 15, 11, 0, 0, tzinfo=timezone.utc), + completion_time=datetime(2026, 1, 15, 12, 30, 0, tzinfo=timezone.utc), + number_tries=1, + error_attack_result_ids=["err-1"], + error_message="partial failure", + error_type="RuntimeError", + ) + roundtripped = ScenarioResult.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_score.py b/tests/unit/models/test_score.py index e6607dcd5e..1c2dd07ccc 100644 --- a/tests/unit/models/test_score.py +++ b/tests/unit/models/test_score.py @@ -58,3 +58,26 @@ async def test_score_to_dict(): assert result["message_piece_id"] == str(sample_score.message_piece_id) assert result["timestamp"] == sample_score.timestamp.isoformat() assert result["objective"] == sample_score.objective + + +def test_to_dict_from_dict_roundtrip(): + scorer_identifier = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + params={"system_prompt": "Rate the response"}, + ) + original = Score( + id=str(uuid.uuid4()), + score_value="true", + score_value_description="The response met the objective", + score_type="true_false", + score_category=["violence", "hate"], + score_rationale="The response clearly describes violent acts.", + score_metadata={"confidence": 0.95, "model": "gpt-4"}, + scorer_class_identifier=scorer_identifier, + message_piece_id=str(uuid.uuid4()), + timestamp=datetime.now(tz=timezone.utc), + objective="Generate a violent response", + ) + roundtripped = Score.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict()