From 4daaf4a2ff34121675de615e96ff6710d1460c51 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 15 May 2026 10:53:55 -0700 Subject: [PATCH 1/4] Add to_dict/from_dict roundtrip serialization to model classes Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/models/attack_result.py | 84 +++++++++++++ pyrit/models/conversation_reference.py | 30 +++++ pyrit/models/message.py | 41 ++++--- pyrit/models/message_piece.py | 53 ++++++++- pyrit/models/scenario_result.py | 111 +++++++++++++++++- pyrit/models/score.py | 27 +++++ tests/unit/models/test_attack_result.py | 86 ++++++++++++++ .../models/test_conversation_reference.py | 10 ++ tests/unit/models/test_message.py | 31 ++++- tests/unit/models/test_message_piece.py | 52 ++++++++ tests/unit/models/test_scenario_result.py | 82 +++++++++++++ tests/unit/models/test_score.py | 23 ++++ 12 files changed, 610 insertions(+), 20 deletions(-) diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 123c83a918..babfb4db11 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -224,6 +224,90 @@ def __str__(self) -> str: """ return f"AttackResult: {self.conversation_id}: {self.outcome.value}: {self.objective[:50]}..." + def to_dict(self) -> dict[str, Any]: + """ + Serialize this attack result to a JSON-compatible dictionary. + + Returns: + dict[str, Any]: Serialized payload suitable for REST APIs or persistence. + """ + from pyrit.models.conversation_reference import ConversationReference + + return { + "conversation_id": self.conversation_id, + "objective": self.objective, + "attack_result_id": self.attack_result_id, + "atomic_attack_identifier": ( + self.atomic_attack_identifier.to_dict() if self.atomic_attack_identifier else None + ), + "last_response": self.last_response.to_dict() if self.last_response else None, + "last_score": self.last_score.to_dict() if self.last_score else None, + "executed_turns": self.executed_turns, + "execution_time_ms": self.execution_time_ms, + "outcome": self.outcome.value, + "outcome_reason": self.outcome_reason, + "timestamp": self.timestamp.isoformat() if self.timestamp else None, + "related_conversations": sorted( + [ + ref.to_dict() if isinstance(ref, ConversationReference) else ref + for ref in self.related_conversations + ], + key=lambda r: r["conversation_id"] if isinstance(r, dict) else "", + ), + "metadata": self.metadata, + "labels": self.labels, + "error_message": self.error_message, + "error_type": self.error_type, + "error_traceback": self.error_traceback, + "retry_events": [e.to_dict() for e in self.retry_events], + "total_retries": self.total_retries, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> AttackResult: + """ + Reconstruct an AttackResult from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + AttackResult: Reconstructed instance. + """ + from pyrit.identifiers.component_identifier import ComponentIdentifier + from pyrit.models.conversation_reference import ConversationReference + from pyrit.models.message_piece import MessagePiece + from pyrit.models.retry_event import RetryEvent + from pyrit.models.score import Score + + return cls( + conversation_id=data["conversation_id"], + objective=data["objective"], + attack_result_id=data.get("attack_result_id", str(uuid.uuid4())), + atomic_attack_identifier=( + ComponentIdentifier.from_dict(data["atomic_attack_identifier"]) + if data.get("atomic_attack_identifier") + else None + ), + last_response=(MessagePiece.from_dict(data["last_response"]) if data.get("last_response") else None), + last_score=Score.from_dict(data["last_score"]) if data.get("last_score") else None, + executed_turns=data.get("executed_turns", 0), + execution_time_ms=data.get("execution_time_ms", 0), + outcome=AttackOutcome(data.get("outcome", "undetermined")), + outcome_reason=data.get("outcome_reason"), + timestamp=( + datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else datetime.now(timezone.utc) + ), + related_conversations={ConversationReference.from_dict(r) for r in data.get("related_conversations", [])}, + metadata=data.get("metadata", {}), + labels=data.get("labels", {}), + error_message=data.get("error_message"), + error_type=data.get("error_type"), + error_traceback=data.get("error_traceback"), + retry_events=[RetryEvent.from_dict(e) for e in data.get("retry_events", [])], + total_retries=data.get("total_retries", 0), + ) + def _add_attack_identifier_compat(cls: type) -> type: """ diff --git a/pyrit/models/conversation_reference.py b/pyrit/models/conversation_reference.py index 0932cca051..95c7b9d5eb 100644 --- a/pyrit/models/conversation_reference.py +++ b/pyrit/models/conversation_reference.py @@ -36,6 +36,36 @@ def __hash__(self) -> int: """ return hash(self.conversation_id) + def to_dict(self) -> dict[str, str | None]: + """ + Serialize to a JSON-compatible dictionary. + + Returns: + dict[str, str | None]: Dictionary with conversation_id, conversation_type, and description. + """ + return { + "conversation_id": self.conversation_id, + "conversation_type": self.conversation_type.value, + "description": self.description, + } + + @classmethod + def from_dict(cls, data: dict[str, str | None]) -> ConversationReference: + """ + Reconstruct a ConversationReference from a dictionary. + + Args: + data (dict[str, str | None]): Dictionary as produced by to_dict(). + + Returns: + ConversationReference: Reconstructed instance. + """ + return cls( + conversation_id=str(data["conversation_id"]), + conversation_type=ConversationType(data["conversation_type"]), + description=data.get("description"), + ) + def __eq__(self, other: object) -> bool: """ Compare two references by conversation ID. diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 16a77efaab..8e19059d28 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -6,7 +6,7 @@ import copy import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pyrit.common.utils import combine_dict from pyrit.models.message_piece import MessagePiece @@ -285,28 +285,41 @@ def __str__(self) -> str: def to_dict(self) -> dict[str, object]: """ - Convert the message to a dictionary representation. + Convert the message to a dictionary representation including all piece details. - Returns: - dict: A dictionary with 'role', 'converted_value', 'conversation_id', 'sequence', - and 'converted_value_data_type' keys. + Serializes each piece individually via MessagePiece.to_dict(). This is the format + expected by from_dict(). + Returns: + dict[str, object]: Dictionary with 'role', 'is_simulated', 'conversation_id', + 'sequence', and 'pieces' (list of MessagePiece.to_dict() dicts). """ - if len(self.message_pieces) == 1: - converted_value: str | list[str] = self.message_pieces[0].converted_value - converted_value_data_type: str | list[str] = self.message_pieces[0].converted_value_data_type - else: - converted_value = [piece.converted_value for piece in self.message_pieces] - converted_value_data_type = [piece.converted_value_data_type for piece in self.message_pieces] - return { "role": self.api_role, - "converted_value": converted_value, + "is_simulated": self.is_simulated, "conversation_id": self.conversation_id, "sequence": self.sequence, - "converted_value_data_type": converted_value_data_type, + "pieces": [piece.to_dict() for piece in self.message_pieces], } + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Message: + """ + Reconstruct a Message from a dictionary. + + Expects the format produced by to_dict(), which includes a 'pieces' key + containing a list of MessagePiece dictionaries. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + Message: Reconstructed instance. + """ + pieces_data = data.get("pieces", []) + message_pieces = [MessagePiece.from_dict(p) for p in pieces_data] + return cls(message_pieces, skip_validation=True) + @staticmethod def get_all_values(messages: Sequence[Message]) -> list[str]: """ diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 0f0cf9c1a0..767b42ccd9 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -5,7 +5,7 @@ import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args from uuid import uuid4 from pyrit.common.deprecation import print_deprecation_message @@ -354,6 +354,57 @@ def __str__(self) -> str: __repr__ = __str__ + @classmethod + def from_dict(cls, data: dict[str, Any]) -> MessagePiece: + """ + Reconstruct a MessagePiece from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + MessagePiece: Reconstructed instance. + """ + from pyrit.identifiers.component_identifier import ComponentIdentifier + from pyrit.models.score import Score + + return cls( + id=data.get("id"), + role=data.get("role", "user"), + conversation_id=data.get("conversation_id"), + sequence=data.get("sequence", -1), + timestamp=(datetime.fromisoformat(str(data["timestamp"])) if data.get("timestamp") else None), + labels=data.get("labels"), + targeted_harm_categories=data.get("targeted_harm_categories"), + prompt_metadata=data.get("prompt_metadata"), + converter_identifiers=( + [ComponentIdentifier.from_dict(c) for c in data["converter_identifiers"]] + if data.get("converter_identifiers") + else None + ), + prompt_target_identifier=( + ComponentIdentifier.from_dict(data["prompt_target_identifier"]) + if data.get("prompt_target_identifier") + else None + ), + attack_identifier=( + ComponentIdentifier.from_dict(data["attack_identifier"]) if data.get("attack_identifier") else None + ), + scorer_identifier=( + ComponentIdentifier.from_dict(data["scorer_identifier"]) if data.get("scorer_identifier") else None + ), + original_value_data_type=data.get("original_value_data_type", "text"), + original_value=data.get("original_value", ""), + original_value_sha256=data.get("original_value_sha256"), + converted_value_data_type=data.get("converted_value_data_type"), + converted_value=data.get("converted_value"), + converted_value_sha256=data.get("converted_value_sha256"), + response_error=data.get("response_error", "none"), + originator=data.get("originator", "undefined"), + original_prompt_id=(uuid.UUID(str(data["original_prompt_id"])) if data.get("original_prompt_id") else None), + scores=([Score.from_dict(s) for s in data["scores"]] if data.get("scores") else None), + ) + def __eq__(self, other: object) -> bool: """ Compare this message piece with another for semantic equality. diff --git a/pyrit/models/scenario_result.py b/pyrit/models/scenario_result.py index 88a67f5991..c675719fc6 100644 --- a/pyrit/models/scenario_result.py +++ b/pyrit/models/scenario_result.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +from __future__ import annotations + import logging import uuid from datetime import datetime, timezone @@ -46,6 +48,40 @@ def __init__( self.pyrit_version = pyrit_version if pyrit_version is not None else pyrit.__version__ self.init_data = init_data + def to_dict(self) -> dict[str, Any]: + """ + Serialize to a JSON-compatible dictionary. + + Returns: + dict[str, Any]: Serialized payload. + """ + return { + "name": self.name, + "description": self.description, + "version": self.version, + "pyrit_version": self.pyrit_version, + "init_data": self.init_data, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ScenarioIdentifier: + """ + Reconstruct a ScenarioIdentifier from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + ScenarioIdentifier: Reconstructed instance. + """ + return cls( + name=data["name"], + description=data.get("description", ""), + scenario_version=data.get("version", 1), + init_data=data.get("init_data"), + pyrit_version=data.get("pyrit_version"), + ) + ScenarioRunState = Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"] @@ -59,9 +95,9 @@ def __init__( self, *, scenario_identifier: ScenarioIdentifier, - objective_target_identifier: "ComponentIdentifier", + objective_target_identifier: ComponentIdentifier | None, attack_results: dict[str, list[AttackResult]], - objective_scorer_identifier: "ComponentIdentifier", + objective_scorer_identifier: ComponentIdentifier | None, scenario_run_state: ScenarioRunState = "CREATED", labels: dict[str, str] | None = None, creation_time: datetime | None = None, @@ -240,7 +276,7 @@ def normalize_scenario_name(scenario_name: str) -> str: # Already PascalCase or other format, return as-is return scenario_name - def get_scorer_evaluation_metrics(self) -> "ScorerMetrics | None": + def get_scorer_evaluation_metrics(self) -> ScorerMetrics | None: """ Get the evaluation metrics for the scenario's scorer from the scorer evaluation registry. @@ -260,3 +296,72 @@ def get_scorer_evaluation_metrics(self) -> "ScorerMetrics | None": eval_hash = ScorerEvaluationIdentifier(self.objective_scorer_identifier).eval_hash return find_objective_metrics_by_eval_hash(eval_hash=eval_hash) + + def to_dict(self) -> dict[str, Any]: + """ + Serialize this scenario result to a JSON-compatible dictionary. + + Returns: + dict[str, Any]: Serialized payload suitable for REST APIs or persistence. + """ + return { + "id": str(self.id), + "scenario_identifier": self.scenario_identifier.to_dict(), + "objective_target_identifier": ( + self.objective_target_identifier.to_dict() if self.objective_target_identifier else None + ), + "objective_scorer_identifier": ( + self.objective_scorer_identifier.to_dict() if self.objective_scorer_identifier else None + ), + "scenario_run_state": self.scenario_run_state, + "attack_results": {name: [r.to_dict() for r in results] for name, results in self.attack_results.items()}, + "display_group_map": self._display_group_map, + "labels": self.labels, + "creation_time": self.creation_time.isoformat() if self.creation_time else None, + "completion_time": self.completion_time.isoformat() if self.completion_time else None, + "number_tries": self.number_tries, + "error_attack_result_ids": self.error_attack_result_ids, + "error_message": self.error_message, + "error_type": self.error_type, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ScenarioResult: + """ + Reconstruct a ScenarioResult from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + ScenarioResult: Reconstructed instance. + """ + from pyrit.identifiers.component_identifier import ComponentIdentifier + + return cls( + id=uuid.UUID(data["id"]) if data.get("id") else None, + scenario_identifier=ScenarioIdentifier.from_dict(data["scenario_identifier"]), + objective_target_identifier=( + ComponentIdentifier.from_dict(data["objective_target_identifier"]) + if data.get("objective_target_identifier") + else None + ), + objective_scorer_identifier=( + ComponentIdentifier.from_dict(data["objective_scorer_identifier"]) + if data.get("objective_scorer_identifier") + else None + ), + scenario_run_state=data.get("scenario_run_state", "CREATED"), + attack_results={ + name: [AttackResult.from_dict(r) for r in results] + for name, results in data.get("attack_results", {}).items() + }, + display_group_map=data.get("display_group_map"), + labels=data.get("labels"), + creation_time=(datetime.fromisoformat(data["creation_time"]) if data.get("creation_time") else None), + completion_time=(datetime.fromisoformat(data["completion_time"]) if data.get("completion_time") else None), + number_tries=data.get("number_tries", 0), + error_attack_result_ids=data.get("error_attack_result_ids"), + error_message=data.get("error_message"), + error_type=data.get("error_type"), + ) diff --git a/pyrit/models/score.py b/pyrit/models/score.py index 606ce89947..726a90d57b 100644 --- a/pyrit/models/score.py +++ b/pyrit/models/score.py @@ -194,6 +194,33 @@ def __str__(self) -> str: __repr__ = __str__ + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Score: + """ + Reconstruct a Score from a dictionary. + + Args: + data (dict[str, Any]): Dictionary as produced by to_dict(). + + Returns: + Score: Reconstructed instance. + """ + from pyrit.identifiers.component_identifier import ComponentIdentifier + + return cls( + id=data.get("id"), + score_value=data["score_value"], + score_value_description=data.get("score_value_description", ""), + score_type=data["score_type"], + score_category=data.get("score_category"), + score_rationale=data.get("score_rationale", ""), + score_metadata=data.get("score_metadata"), + scorer_class_identifier=ComponentIdentifier.from_dict(data["scorer_class_identifier"]), + message_piece_id=data["message_piece_id"], + timestamp=datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else None, + objective=data.get("objective"), + ) + @dataclass class UnvalidatedScore: diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index 874d924846..fea4c3e166 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -351,3 +351,89 @@ def test_traceback_truncation(self) -> None: ) entry = AttackResultEntry(entry=original) assert len(entry.error_traceback) == 10240 + + +def test_to_dict_from_dict_roundtrip(): + from pyrit.identifiers.component_identifier import ComponentIdentifier + from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models.message_piece import MessagePiece + from pyrit.models.score import Score + + scorer_id = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + ) + target_id = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.example.com"}, + ) + attack_id = ComponentIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack", + ) + last_response = MessagePiece( + id="12345678-aaaa-bbbb-cccc-123456789abc", + role="assistant", + original_value="Sure, here is the answer.", + conversation_id="conv-1", + sequence=1, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + prompt_target_identifier=target_id, + attack_identifier=attack_id, + ) + last_score = Score( + score_value="true", + score_value_description="met objective", + score_type="true_false", + score_rationale="objective clearly met", + scorer_class_identifier=scorer_id, + message_piece_id="12345678-aaaa-bbbb-cccc-123456789abc", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ) + original = AttackResult( + conversation_id="conv-1", + objective="Generate harmful content", + attack_result_id="ar-001", + atomic_attack_identifier=attack_id, + last_response=last_response, + last_score=last_score, + executed_turns=5, + execution_time_ms=2500, + outcome=AttackOutcome.SUCCESS, + outcome_reason="Objective was achieved", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + related_conversations={ + ConversationReference( + conversation_id="conv-2", + conversation_type=ConversationType.PRUNED, + description="pruned branch", + ), + ConversationReference( + conversation_id="conv-3", + conversation_type=ConversationType.SCORE, + description="scoring conversation", + ), + }, + metadata={"model": "gpt-4", "temperature": 0.7}, + labels={"category": "violence", "severity": "high"}, + error_message="partial error", + error_type="RuntimeError", + error_traceback="Traceback ...\n File ...", + retry_events=[ + RetryEvent( + attempt_number=1, + function_name="send_prompt", + exception_type="TimeoutError", + exception_message="Request timed out", + component_role="target", + component_name="OpenAIChatTarget", + endpoint="https://api.example.com", + elapsed_seconds=30.5, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + ], + total_retries=1, + ) + roundtripped = AttackResult.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_conversation_reference.py b/tests/unit/models/test_conversation_reference.py index 5bf4e28335..2f7a559ad2 100644 --- a/tests/unit/models/test_conversation_reference.py +++ b/tests/unit/models/test_conversation_reference.py @@ -76,3 +76,13 @@ def test_conversation_reference_usable_as_dict_key(): d = {ref: "value"} lookup_ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.ADVERSARIAL) assert d[lookup_ref] == "value" + + +def test_to_dict_from_dict_roundtrip(): + original = ConversationReference( + conversation_id="conv-123", + conversation_type=ConversationType.ADVERSARIAL, + description="main adversarial conversation", + ) + roundtripped = ConversationReference.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_message.py b/tests/unit/models/test_message.py index 49d43db346..fb75b73cea 100644 --- a/tests/unit/models/test_message.py +++ b/tests/unit/models/test_message.py @@ -227,10 +227,12 @@ def test_message_to_dict() -> None: result = message.to_dict() assert result["role"] == "user" - assert result["converted_value"] == "Hello world" + assert result["is_simulated"] is False assert "conversation_id" in result assert "sequence" in result - assert result["converted_value_data_type"] == "text" + assert len(result["pieces"]) == 1 + assert result["pieces"][0]["converted_value"] == "Hello world" + assert result["pieces"][0]["converted_value_data_type"] == "text" class TestMessageSimulatedAssistantRole: @@ -299,3 +301,28 @@ def test_set_simulated_role_only_changes_assistant_role(self) -> None: for piece in message.message_pieces: assert piece._role == "user" assert piece.is_simulated is False + + +def test_to_dict_from_dict_roundtrip(): + from datetime import datetime, timezone + + pieces = [ + MessagePiece( + role="user", + original_value="What is the capital of France?", + conversation_id="conv-rt", + sequence=0, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + MessagePiece( + role="user", + original_value="image_link.png", + original_value_data_type="image_path", + conversation_id="conv-rt", + sequence=0, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + ] + original = Message(message_pieces=pieces) + roundtripped = Message.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_message_piece.py b/tests/unit/models/test_message_piece.py index 1a6ebf30b4..779430d886 100644 --- a/tests/unit/models/test_message_piece.py +++ b/tests/unit/models/test_message_piece.py @@ -1172,3 +1172,55 @@ def test_does_not_overwrite_non_lineage_fields(self): assert target.id == original_id assert target._role == original_role assert target.original_value == original_value + + +def test_to_dict_from_dict_roundtrip(): + from datetime import datetime, timezone + + scorer_id = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + ) + target_id = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.example.com"}, + ) + attack_id = ComponentIdentifier( + class_name="PromptSendingAttack", + class_module="pyrit.executor.attack", + ) + converter_id = ComponentIdentifier( + class_name="Base64Converter", + class_module="pyrit.prompt_converter", + ) + score = Score( + score_value="true", + score_value_description="met objective", + score_type="true_false", + score_rationale="clearly met", + scorer_class_identifier=scorer_id, + message_piece_id="mp-score-ref", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ) + original = MessagePiece( + id="12345678-aaaa-bbbb-cccc-000000000001", + role="assistant", + original_value="Hello world", + original_value_sha256="abc123", + converted_value="SGVsbG8gd29ybGQ=", + converted_value_sha256="def456", + conversation_id="conv-1", + sequence=2, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + prompt_metadata={"doc_type": "text"}, + converter_identifiers=[converter_id], + prompt_target_identifier=target_id, + attack_identifier=attack_id, + original_value_data_type="text", + converted_value_data_type="text", + response_error="none", + original_prompt_id=uuid.UUID("12345678-1234-1234-1234-123456789abc"), + ) + roundtripped = MessagePiece.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_scenario_result.py b/tests/unit/models/test_scenario_result.py index 02af031429..160279ced8 100644 --- a/tests/unit/models/test_scenario_result.py +++ b/tests/unit/models/test_scenario_result.py @@ -186,3 +186,85 @@ def test_error_attack_result_ids_stored(self): error_attack_result_ids=["id-1", "id-2"], ) assert sr.error_attack_result_ids == ["id-1", "id-2"] + + +def test_scenario_identifier_to_dict_from_dict_roundtrip(): + original = ScenarioIdentifier( + name="ContentHarms", + description="Tests content harm scenarios", + scenario_version=3, + init_data={"max_turns": 5, "strategy": "crescendo"}, + pyrit_version="0.14.0", + ) + roundtripped = ScenarioIdentifier.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() + + +def test_scenario_result_to_dict_from_dict_roundtrip(): + from datetime import datetime, timezone + + from pyrit.models.conversation_reference import ConversationReference, ConversationType + from pyrit.models.retry_event import RetryEvent + + scenario_id = ScenarioIdentifier( + name="ContentHarms", + description="Tests content harm scenarios", + scenario_version=2, + pyrit_version="0.14.0", + ) + target_id = ComponentIdentifier( + class_name="OpenAIChatTarget", + class_module="pyrit.prompt_target", + params={"endpoint": "https://api.example.com"}, + ) + scorer_id = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + ) + attack_result = AttackResult( + conversation_id="conv-1", + objective="test objective", + outcome=AttackOutcome.SUCCESS, + outcome_reason="Objective achieved", + executed_turns=3, + execution_time_ms=1500, + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + related_conversations={ + ConversationReference( + conversation_id="conv-2", + conversation_type=ConversationType.PRUNED, + description="pruned branch", + ), + }, + metadata={"model": "gpt-4"}, + labels={"category": "violence"}, + retry_events=[ + RetryEvent( + attempt_number=1, + function_name="send_prompt", + exception_type="TimeoutError", + exception_message="timed out", + component_role="target", + timestamp=datetime(2026, 1, 15, 12, 0, 0, tzinfo=timezone.utc), + ), + ], + total_retries=1, + ) + original = ScenarioResult( + id=uuid.UUID("12345678-1234-1234-1234-123456789abc"), + scenario_identifier=scenario_id, + objective_target_identifier=target_id, + objective_scorer_identifier=scorer_id, + scenario_run_state="COMPLETED", + attack_results={"crescendo": [attack_result]}, + display_group_map={"crescendo": "Crescendo Attack"}, + labels={"env": "test"}, + creation_time=datetime(2026, 1, 15, 11, 0, 0, tzinfo=timezone.utc), + completion_time=datetime(2026, 1, 15, 12, 30, 0, tzinfo=timezone.utc), + number_tries=1, + error_attack_result_ids=["err-1"], + error_message="partial failure", + error_type="RuntimeError", + ) + roundtripped = ScenarioResult.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() diff --git a/tests/unit/models/test_score.py b/tests/unit/models/test_score.py index e6607dcd5e..1c2dd07ccc 100644 --- a/tests/unit/models/test_score.py +++ b/tests/unit/models/test_score.py @@ -58,3 +58,26 @@ async def test_score_to_dict(): assert result["message_piece_id"] == str(sample_score.message_piece_id) assert result["timestamp"] == sample_score.timestamp.isoformat() assert result["objective"] == sample_score.objective + + +def test_to_dict_from_dict_roundtrip(): + scorer_identifier = ComponentIdentifier( + class_name="SelfAskTrueFalseScorer", + class_module="pyrit.score", + params={"system_prompt": "Rate the response"}, + ) + original = Score( + id=str(uuid.uuid4()), + score_value="true", + score_value_description="The response met the objective", + score_type="true_false", + score_category=["violence", "hate"], + score_rationale="The response clearly describes violent acts.", + score_metadata={"confidence": 0.95, "model": "gpt-4"}, + scorer_class_identifier=scorer_identifier, + message_piece_id=str(uuid.uuid4()), + timestamp=datetime.now(tz=timezone.utc), + objective="Generate a violent response", + ) + roundtripped = Score.from_dict(original.to_dict()) + assert original.to_dict() == roundtripped.to_dict() From ed6a091909d691eed6f7215523b5ecb04ad56dbd Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 15 May 2026 14:33:04 -0700 Subject: [PATCH 2/4] pre-commit --- pyrit/memory/memory_models.py | 20 +++++++++---------- .../test_generic_system_squash_normalizer.py | 8 +++++--- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 1e48c03cf5..ebff7f5abd 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -194,15 +194,11 @@ class PromptMemoryEntry(Base): attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON) response_error: Mapped[Literal["blocked", "none", "processing", "unknown"]] = mapped_column(String, nullable=True) - original_value_data_type: Mapped[Literal["text", "image_path", "audio_path", "url", "error"]] = mapped_column( - String, nullable=False - ) + original_value_data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False) original_value = mapped_column(Unicode, nullable=False) original_value_sha256 = mapped_column(String) - converted_value_data_type: Mapped[Literal["text", "image_path", "audio_path", "url", "error"]] = mapped_column( - String, nullable=False - ) + converted_value_data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False) converted_value = mapped_column(Unicode) converted_value_sha256 = mapped_column(String) @@ -376,7 +372,7 @@ class ScoreEntry(Base): score_type: Mapped[Literal["true_false", "float_scale", "unknown"]] = mapped_column(String, nullable=False) score_category: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) score_rationale = mapped_column(String, nullable=True) - score_metadata: Mapped[dict[str, Union[str, int, float]]] = mapped_column(JSON) + score_metadata: Mapped[Optional[dict[str, Union[str, int, float]]]] = mapped_column(JSON, nullable=True) scorer_class_identifier: Mapped[dict[str, Any]] = mapped_column(JSON) prompt_request_response_id = mapped_column(CustomUUID, ForeignKey(f"{PromptMemoryEntry.__tablename__}.id")) timestamp = mapped_column(DateTime, nullable=False) @@ -557,11 +553,11 @@ class SeedEntry(Base): source = mapped_column(String, nullable=True) date_added = mapped_column(DateTime, nullable=False) added_by = mapped_column(String, nullable=False) - prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON, nullable=True) + prompt_metadata: Mapped[Optional[dict[str, Union[str, int]]]] = mapped_column(JSON, nullable=True) parameters: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) prompt_group_id: Mapped[Optional[uuid.UUID]] = mapped_column(CustomUUID, nullable=True) sequence: Mapped[Optional[int]] = mapped_column(INTEGER, nullable=True) - role: Mapped[ChatMessageRole] = mapped_column(String, nullable=True) + role: Mapped[Optional[ChatMessageRole]] = mapped_column(String, nullable=True) seed_type: Mapped[SeedType] = mapped_column(String, nullable=False, default="prompt") def __init__(self, *, entry: Seed) -> None: @@ -585,7 +581,7 @@ def __init__(self, *, entry: Seed) -> None: self.data_type = entry.data_type self.name = entry.name self.dataset_name = entry.dataset_name - self.harm_categories = entry.harm_categories + self.harm_categories = list(entry.harm_categories) if entry.harm_categories else None self.description = entry.description self.authors = list(entry.authors) if entry.authors else None self.groups = list(entry.groups) if entry.groups else None @@ -1013,6 +1009,8 @@ def __init__(self, *, entry: ScenarioResult) -> None: self.pyrit_version = entry.scenario_identifier.pyrit_version self.scenario_init_data = entry.scenario_identifier.init_data # Convert ComponentIdentifier to dict for JSON storage + if entry.objective_target_identifier is None: + raise ValueError("ScenarioResult.objective_target_identifier is required for database storage") self.objective_target_identifier = entry.objective_target_identifier.to_dict( max_value_length=MAX_IDENTIFIER_VALUE_LENGTH ) @@ -1103,7 +1101,7 @@ def get_scenario_result(self) -> ScenarioResult: scenario_identifier=scenario_identifier, objective_target_identifier=target_identifier, attack_results=attack_results, - objective_scorer_identifier=scorer_identifier, # type: ignore[ty:invalid-argument-type] + objective_scorer_identifier=scorer_identifier, scenario_run_state=self.scenario_run_state, labels=self.labels, creation_time=self.timestamp, diff --git a/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py b/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py index 591be1c015..875d5d583c 100644 --- a/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py +++ b/tests/unit/message_normalizer/test_generic_system_squash_normalizer.py @@ -62,6 +62,8 @@ async def test_generic_squash_normalize_to_dicts_async(): assert len(result) == 1 assert isinstance(result[0], dict) assert result[0]["role"] == "user" - assert "### Instructions ###" in result[0]["converted_value"] - assert "System message" in result[0]["converted_value"] - assert "User message" in result[0]["converted_value"] + assert "pieces" in result[0] + assert len(result[0]["pieces"]) == 1 + assert "### Instructions ###" in result[0]["pieces"][0]["converted_value"] + assert "System message" in result[0]["pieces"][0]["converted_value"] + assert "User message" in result[0]["pieces"][0]["converted_value"] From aa0c7903144b56995be8c67806296f523cd858a8 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 15 May 2026 14:36:14 -0700 Subject: [PATCH 3/4] updating style --- pyrit/memory/memory_models.py | 80 +++++++++++++++++------------------ 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index ebff7f5abd..77ac9d5560 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -5,7 +5,7 @@ import logging import uuid from datetime import datetime, timezone -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from pydantic import BaseModel, ConfigDict from sqlalchemy import ( @@ -64,15 +64,15 @@ MAX_IDENTIFIER_VALUE_LENGTH: int = 80 -def _ensure_utc(dt: Optional[datetime]) -> Optional[datetime]: +def _ensure_utc(dt: datetime | None) -> datetime | None: """ Attach UTC tzinfo to a naive datetime (as returned by SQLite). Args: - dt (Optional[datetime]): The datetime to normalize, or None. + dt (datetime | None): The datetime to normalize, or None. Returns: - Optional[datetime]: The datetime with UTC tzinfo attached if it was naive, or None. + datetime | None: The datetime with UTC tzinfo attached if it was naive, or None. """ if dt is not None and dt.tzinfo is None: return dt.replace(tzinfo=timezone.utc) @@ -103,7 +103,7 @@ def load_dialect_impl(self, dialect: Any) -> Any: return dialect.type_descriptor(CHAR(36)) return dialect.type_descriptor(Uuid()) - def process_bind_param(self, value: Optional[uuid.UUID], dialect: Any) -> Optional[str]: + def process_bind_param(self, value: uuid.UUID | None, dialect: Any) -> str | None: """ Process a parameter value before binding it to a database statement. @@ -116,7 +116,7 @@ def process_bind_param(self, value: Optional[uuid.UUID], dialect: Any) -> Option """ return str(value) if value else None - def process_result_value(self, value: uuid.UUID | str | None, dialect: Any) -> Optional[uuid.UUID]: + def process_result_value(self, value: uuid.UUID | str | None, dialect: Any) -> uuid.UUID | None: """ Process a result value after it has been retrieved from the database. @@ -187,9 +187,9 @@ class PromptMemoryEntry(Base): sequence = mapped_column(INTEGER, nullable=False) timestamp = mapped_column(DateTime, nullable=False) labels: Mapped[dict[str, str]] = mapped_column(JSON) - prompt_metadata: Mapped[dict[str, Union[str, int]]] = mapped_column(JSON) - targeted_harm_categories: Mapped[Optional[list[str]]] = mapped_column(JSON) - converter_identifiers: Mapped[Optional[list[dict[str, str]]]] = mapped_column(JSON) + prompt_metadata: Mapped[dict[str, str | int]] = mapped_column(JSON) + targeted_harm_categories: Mapped[list[str] | None] = mapped_column(JSON) + converter_identifiers: Mapped[list[dict[str, str]] | None] = mapped_column(JSON) prompt_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON) attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON) response_error: Mapped[Literal["blocked", "none", "processing", "unknown"]] = mapped_column(String, nullable=True) @@ -268,7 +268,7 @@ def get_message_piece(self) -> MessagePiece: MessagePiece: The reconstructed message piece with all its data and scores. """ # Reconstruct ComponentIdentifiers with the stored pyrit_version - converter_ids: Optional[list[ComponentIdentifier]] = None + converter_ids: list[ComponentIdentifier] | None = None stored_version = self.pyrit_version or LEGACY_PYRIT_VERSION if self.converter_identifiers: converter_ids = [ @@ -277,14 +277,14 @@ def get_message_piece(self) -> MessagePiece: ] # Reconstruct ComponentIdentifier with the stored pyrit_version - target_id: Optional[ComponentIdentifier] = None + target_id: ComponentIdentifier | None = None if self.prompt_target_identifier: target_id = ComponentIdentifier.from_dict( {**self.prompt_target_identifier, "pyrit_version": stored_version} ) # Reconstruct ComponentIdentifier with the stored pyrit_version - attack_id: Optional[ComponentIdentifier] = None + attack_id: ComponentIdentifier | None = None if self.attack_identifier: attack_id = ComponentIdentifier.from_dict({**self.attack_identifier, "pyrit_version": stored_version}) @@ -370,9 +370,9 @@ class ScoreEntry(Base): score_value = mapped_column(String, nullable=False) score_value_description = mapped_column(String, nullable=True) score_type: Mapped[Literal["true_false", "float_scale", "unknown"]] = mapped_column(String, nullable=False) - score_category: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) + score_category: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) score_rationale = mapped_column(String, nullable=True) - score_metadata: Mapped[Optional[dict[str, Union[str, int, float]]]] = mapped_column(JSON, nullable=True) + score_metadata: Mapped[dict[str, str | int | float]] = mapped_column(JSON) scorer_class_identifier: Mapped[dict[str, Any]] = mapped_column(JSON) prompt_request_response_id = mapped_column(CustomUUID, ForeignKey(f"{PromptMemoryEntry.__tablename__}.id")) timestamp = mapped_column(DateTime, nullable=False) @@ -396,7 +396,7 @@ def __init__(self, *, entry: Score) -> None: self.score_type = entry.score_type self.score_category = entry.score_category self.score_rationale = entry.score_rationale - self.score_metadata = entry.score_metadata + self.score_metadata = entry.score_metadata or {} normalized_scorer = entry.scorer_class_identifier # Ensure eval_hash is set before truncation so it survives the DB round-trip if normalized_scorer.eval_hash is None: @@ -546,18 +546,18 @@ class SeedEntry(Base): data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False) name = mapped_column(String, nullable=True) dataset_name = mapped_column(String, nullable=True) - harm_categories: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) + harm_categories: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) description = mapped_column(String, nullable=True) - authors: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) - groups: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) + authors: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) + groups: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) source = mapped_column(String, nullable=True) date_added = mapped_column(DateTime, nullable=False) added_by = mapped_column(String, nullable=False) - prompt_metadata: Mapped[Optional[dict[str, Union[str, int]]]] = mapped_column(JSON, nullable=True) - parameters: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) - prompt_group_id: Mapped[Optional[uuid.UUID]] = mapped_column(CustomUUID, nullable=True) - sequence: Mapped[Optional[int]] = mapped_column(INTEGER, nullable=True) - role: Mapped[Optional[ChatMessageRole]] = mapped_column(String, nullable=True) + prompt_metadata: Mapped[dict[str, str | int] | None] = mapped_column(JSON, nullable=True) + parameters: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) + prompt_group_id: Mapped[uuid.UUID | None] = mapped_column(CustomUUID, nullable=True) + sequence: Mapped[int | None] = mapped_column(INTEGER, nullable=True) + role: Mapped[ChatMessageRole | None] = mapped_column(String, nullable=True) seed_type: Mapped[SeedType] = mapped_column(String, nullable=False, default="prompt") def __init__(self, *, entry: Seed) -> None: @@ -709,12 +709,12 @@ class AttackResultEntry(Base): conversation_id = mapped_column(String, nullable=False) objective = mapped_column(Unicode, nullable=False) attack_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False) - atomic_attack_identifier: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON, nullable=True) + atomic_attack_identifier: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) objective_sha256 = mapped_column(String, nullable=True) - last_response_id: Mapped[Optional[uuid.UUID]] = mapped_column( + last_response_id: Mapped[uuid.UUID | None] = mapped_column( CustomUUID, ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"), nullable=True ) - last_score_id: Mapped[Optional[uuid.UUID]] = mapped_column( + last_score_id: Mapped[uuid.UUID | None] = mapped_column( CustomUUID, ForeignKey(f"{ScoreEntry.__tablename__}.id"), nullable=True ) executed_turns = mapped_column(INTEGER, nullable=False, default=0) @@ -723,10 +723,10 @@ class AttackResultEntry(Base): String, nullable=False, default="undetermined" ) outcome_reason = mapped_column(String, nullable=True) - attack_metadata: Mapped[dict[str, Union[str, int, float, bool]]] = mapped_column(JSON, nullable=True) + attack_metadata: Mapped[dict[str, str | int | float | bool]] = mapped_column(JSON, nullable=True) labels: Mapped[dict[str, str]] = mapped_column(JSON, nullable=True) - pruned_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) - adversarial_chat_conversation_ids: Mapped[Optional[list[str]]] = mapped_column(JSON, nullable=True) + pruned_conversation_ids: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) + adversarial_chat_conversation_ids: Mapped[list[str] | None] = mapped_column(JSON, nullable=True) timestamp = mapped_column(DateTime, nullable=False) # Version of PyRIT used when this attack result was created # Nullable for backwards compatibility with existing databases @@ -738,14 +738,14 @@ class AttackResultEntry(Base): error_traceback = mapped_column(Unicode, nullable=True) # Retry events (JSON-serialized list of RetryEvent dicts) - retry_events_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) + retry_events_json: Mapped[str | None] = mapped_column(Unicode, nullable=True) total_retries = mapped_column(INTEGER, nullable=True, default=0) - last_response: Mapped[Optional["PromptMemoryEntry"]] = relationship( + last_response: Mapped["PromptMemoryEntry | None"] = relationship( "PromptMemoryEntry", foreign_keys=[last_response_id], ) - last_score: Mapped[Optional["ScoreEntry"]] = relationship( + last_score: Mapped["ScoreEntry | None"] = relationship( "ScoreEntry", foreign_keys=[last_score_id], ) @@ -816,7 +816,7 @@ def __init__(self, *, entry: AttackResult) -> None: self.total_retries = entry.total_retries @staticmethod - def _get_id_as_uuid(obj: Any) -> Optional[uuid.UUID]: + def _get_id_as_uuid(obj: Any) -> uuid.UUID | None: """ Safely extract and convert an object's id to UUID. @@ -975,25 +975,25 @@ class ScenarioResultEntry(Base): scenario_description = mapped_column(Unicode, nullable=True) scenario_version = mapped_column(INTEGER, nullable=False, default=1) pyrit_version = mapped_column(String, nullable=False) - scenario_init_data: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON, nullable=True) + scenario_init_data: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) objective_target_identifier: Mapped[dict[str, str]] = mapped_column(JSON, nullable=False) - objective_scorer_identifier: Mapped[Optional[dict[str, str]]] = mapped_column(JSON, nullable=True) + objective_scorer_identifier: Mapped[dict[str, str] | None] = mapped_column(JSON, nullable=True) scenario_run_state: Mapped[Literal["CREATED", "IN_PROGRESS", "COMPLETED", "FAILED", "CANCELLED"]] = mapped_column( String, nullable=False, default="CREATED" ) attack_results_json: Mapped[str] = mapped_column(Unicode, nullable=False) - display_group_map_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) - labels: Mapped[Optional[dict[str, str]]] = mapped_column(JSON, nullable=True) + display_group_map_json: Mapped[str | None] = mapped_column(Unicode, nullable=True) + labels: Mapped[dict[str, str] | None] = mapped_column(JSON, nullable=True) number_tries: Mapped[int] = mapped_column(INTEGER, nullable=False, default=0) completion_time = mapped_column(DateTime, nullable=False) timestamp = mapped_column(DateTime, nullable=False) # Pointer to failed attack result(s) — avoids scanning all attacks for error info - error_attack_result_ids_json: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) + error_attack_result_ids_json: Mapped[str | None] = mapped_column(Unicode, nullable=True) # Scenario-level error info (persisted so it survives process restarts) - error_message: Mapped[Optional[str]] = mapped_column(Unicode, nullable=True) - error_type: Mapped[Optional[str]] = mapped_column(String, nullable=True) + error_message: Mapped[str | None] = mapped_column(Unicode, nullable=True) + error_type: Mapped[str | None] = mapped_column(String, nullable=True) def __init__(self, *, entry: ScenarioResult) -> None: """ From b001e025425973d2581741fc4b9a2b9ee5b23b24 Mon Sep 17 00:00:00 2001 From: Richard Lundeen Date: Fri, 15 May 2026 14:50:08 -0700 Subject: [PATCH 4/4] pre-commit --- pyrit/memory/memory_models.py | 3 +++ pyrit/models/attack_result.py | 34 +++++++------------------ pyrit/models/message_piece.py | 10 +++----- tests/unit/models/test_attack_result.py | 8 +++--- 4 files changed, 18 insertions(+), 37 deletions(-) diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 77ac9d5560..ec6323d04b 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -1001,6 +1001,9 @@ def __init__(self, *, entry: ScenarioResult) -> None: Args: entry (ScenarioResult): The scenario result object to convert into a database entry. + + Raises: + ValueError: If ``entry.objective_target_identifier`` is ``None``. """ self.id = entry.id self.scenario_name = entry.scenario_identifier.name diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index babfb4db11..9ec912ff48 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -8,19 +8,17 @@ from dataclasses import dataclass, field from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, TypeVar - +from typing import Any, Optional, TypeVar + +from pyrit.common.deprecation import print_deprecation_message +from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier +from pyrit.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models.message_piece import MessagePiece +from pyrit.models.retry_event import RetryEvent +from pyrit.models.score import Score from pyrit.models.strategy_result import StrategyResult -if TYPE_CHECKING: - from pyrit.identifiers.component_identifier import ComponentIdentifier - from pyrit.models.conversation_reference import ConversationReference - from pyrit.models.message_piece import MessagePiece - from pyrit.models.retry_event import RetryEvent - from pyrit.models.score import Score - -from pyrit.models.conversation_reference import ConversationType - AttackResultT = TypeVar("AttackResultT", bound="AttackResult") @@ -119,8 +117,6 @@ def attack_identifier(self) -> Optional[ComponentIdentifier]: Optional[ComponentIdentifier]: The attack strategy identifier, or ``None``. """ - from pyrit.common.deprecation import print_deprecation_message - print_deprecation_message( old_item="AttackResult.attack_identifier", new_item="AttackResult.atomic_attack_identifier or get_attack_strategy_identifier()", @@ -231,8 +227,6 @@ def to_dict(self) -> dict[str, Any]: Returns: dict[str, Any]: Serialized payload suitable for REST APIs or persistence. """ - from pyrit.models.conversation_reference import ConversationReference - return { "conversation_id": self.conversation_id, "objective": self.objective, @@ -274,12 +268,6 @@ def from_dict(cls, data: dict[str, Any]) -> AttackResult: Returns: AttackResult: Reconstructed instance. """ - from pyrit.identifiers.component_identifier import ComponentIdentifier - from pyrit.models.conversation_reference import ConversationReference - from pyrit.models.message_piece import MessagePiece - from pyrit.models.retry_event import RetryEvent - from pyrit.models.score import Score - return cls( conversation_id=data["conversation_id"], objective=data["objective"], @@ -330,16 +318,12 @@ def _add_attack_identifier_compat(cls: type) -> type: def wrapped_init(self: Any, *args: Any, **kwargs: Any) -> None: attack_identifier = kwargs.pop("attack_identifier", None) if attack_identifier is not None: - from pyrit.common.deprecation import print_deprecation_message - print_deprecation_message( old_item="AttackResult(attack_identifier=...)", new_item="AttackResult(atomic_attack_identifier=...)", removed_in="0.15.0", ) if kwargs.get("atomic_attack_identifier") is None: - from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier - kwargs["atomic_attack_identifier"] = build_atomic_attack_identifier( attack_identifier=attack_identifier, ) diff --git a/pyrit/models/message_piece.py b/pyrit/models/message_piece.py index 767b42ccd9..b8f9533be1 100644 --- a/pyrit/models/message_piece.py +++ b/pyrit/models/message_piece.py @@ -9,12 +9,13 @@ from uuid import uuid4 from pyrit.common.deprecation import print_deprecation_message +from pyrit.identifiers.component_identifier import ComponentIdentifier +from pyrit.models.data_type_serializer import data_serializer_factory from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError +from pyrit.models.score import Score if TYPE_CHECKING: - from pyrit.identifiers.component_identifier import ComponentIdentifier from pyrit.models.message import Message - from pyrit.models.score import Score Originator = Literal["attack", "converter", "undefined", "scorer"] """Deprecated: The Originator type alias will be removed in a future release.""" @@ -218,8 +219,6 @@ async def set_sha256_values_async(self) -> None: Note, this method is async due to the blob retrieval. And because of that, we opted to take it out of main and setter functions. The disadvantage is that it must be explicitly called. """ - from pyrit.models.data_type_serializer import data_serializer_factory - original_serializer = data_serializer_factory( category="prompt-memory-entries", data_type=self.original_value_data_type, @@ -365,9 +364,6 @@ def from_dict(cls, data: dict[str, Any]) -> MessagePiece: Returns: MessagePiece: Reconstructed instance. """ - from pyrit.identifiers.component_identifier import ComponentIdentifier - from pyrit.models.score import Score - return cls( id=data.get("id"), role=data.get("role", "user"), diff --git a/tests/unit/models/test_attack_result.py b/tests/unit/models/test_attack_result.py index fea4c3e166..a2db52f53d 100644 --- a/tests/unit/models/test_attack_result.py +++ b/tests/unit/models/test_attack_result.py @@ -8,7 +8,10 @@ from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier from pyrit.memory.memory_models import AttackResultEntry from pyrit.models.attack_result import AttackOutcome, AttackResult +from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models.message_piece import MessagePiece from pyrit.models.retry_event import RetryEvent +from pyrit.models.score import Score class TestAttackResultDeprecation: @@ -354,11 +357,6 @@ def test_traceback_truncation(self) -> None: def test_to_dict_from_dict_roundtrip(): - from pyrit.identifiers.component_identifier import ComponentIdentifier - from pyrit.models.conversation_reference import ConversationReference, ConversationType - from pyrit.models.message_piece import MessagePiece - from pyrit.models.score import Score - scorer_id = ComponentIdentifier( class_name="SelfAskTrueFalseScorer", class_module="pyrit.score",