From e6786f8eed7928085367286326126daab113097f Mon Sep 17 00:00:00 2001 From: romanlutz Date: Thu, 14 May 2026 21:48:33 -0700 Subject: [PATCH 1/2] REFACTOR: rename SeedDatasetProvider.fetch_dataset to fetch_dataset_async Per the project style guide, all async methods must end in '_async'. `SeedDatasetProvider.fetch_dataset` (and its 36 subclass overrides) violated this convention. This commit: - Adds new `fetch_dataset_async` on the base class. - Keeps the legacy `fetch_dataset` as a non-abstract concrete shim that emits `DeprecationWarning` and delegates to `fetch_dataset_async`. This preserves backward compatibility for external CALLERS of the public API. - Adds bidirectional dispatch in the base so external IMPLEMENTERS who override only the legacy `fetch_dataset` continue to work, but `__init_subclass__` emits `DeprecationWarning` at class-definition time pointing them to the new name. - Renames `async def fetch_dataset` to `async def fetch_dataset_async` in all 36 internal subclasses (1 local + 35 remote loaders). - Renames `.fetch_dataset(...)` call sites to `.fetch_dataset_async(...)` in all internal callers (tests, end-to-end test, integration smoke test, doc notebook pair). - Adds 5 new unit tests covering the deprecation bridge in both directions. All 655 dataset/setup/executor unit tests pass; ruff lint and format are clean. End-to-end discovery via `SeedDatasetProvider.get_all_providers()` still works. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- doc/code/datasets/4_dataset_coding.ipynb | 2 +- doc/code/datasets/4_dataset_coding.py | 2 +- .../local/local_dataset_loader.py | 2 +- .../remote/aegis_ai_content_safety_dataset.py | 2 +- .../remote/aya_redteaming_dataset.py | 2 +- .../remote/babelscape_alert_dataset.py | 2 +- .../remote/beaver_tails_dataset.py | 2 +- .../seed_datasets/remote/cbt_bench_dataset.py | 2 +- .../remote/ccp_sensitive_prompts_dataset.py | 2 +- .../remote/comic_jailbreak_dataset.py | 2 +- .../seed_datasets/remote/darkbench_dataset.py | 2 +- .../remote/equitymedqa_dataset.py | 2 +- .../remote/forbidden_questions_dataset.py | 2 +- .../seed_datasets/remote/harmbench_dataset.py | 2 +- .../remote/harmbench_multimodal_dataset.py | 2 +- .../remote/harmful_qa_dataset.py | 2 +- .../remote/jbb_behaviors_dataset.py | 2 +- .../remote/librai_do_not_answer_dataset.py | 2 +- ...llm_latent_adversarial_training_dataset.py | 2 +- .../remote/medsafetybench_dataset.py | 2 +- .../remote/mlcommons_ailuminate_dataset.py | 2 +- .../multilingual_vulnerability_dataset.py | 2 +- .../seed_datasets/remote/or_bench_dataset.py | 2 +- .../remote/pku_safe_rlhf_dataset.py | 2 +- .../remote/promptintel_dataset.py | 2 +- .../remote/red_team_social_bias_dataset.py | 2 +- .../remote/remote_dataset_loader.py | 2 +- .../remote/salad_bench_dataset.py | 2 +- .../remote/simple_safety_tests_dataset.py | 2 +- .../remote/sorry_bench_dataset.py | 2 +- .../seed_datasets/remote/sosbench_dataset.py | 2 +- .../remote/tdc23_redteaming_dataset.py | 2 +- .../remote/toxic_chat_dataset.py | 2 +- .../remote/transphobia_awareness_dataset.py | 2 +- .../remote/visual_leak_bench_dataset.py | 2 +- .../seed_datasets/remote/vlguard_dataset.py | 2 +- .../remote/vlsu_multimodal_dataset.py | 2 +- .../seed_datasets/remote/xstest_dataset.py | 2 +- .../seed_datasets/seed_dataset_provider.py | 60 ++++++++- tests/end_to_end/test_all_datasets.py | 2 +- .../test_seed_dataset_provider_integration.py | 4 +- .../test_aegis_ai_content_safety_dataset.py | 4 +- .../datasets/test_aya_redteaming_dataset.py | 6 +- .../datasets/test_babelscape_alert_dataset.py | 6 +- .../datasets/test_beaver_tails_dataset.py | 6 +- tests/unit/datasets/test_cbt_bench_dataset.py | 12 +- .../datasets/test_comic_jailbreak_dataset.py | 16 +-- tests/unit/datasets/test_darkbench_dataset.py | 4 +- .../unit/datasets/test_equitymedqa_dataset.py | 4 +- tests/unit/datasets/test_harmbench_dataset.py | 4 +- .../test_harmbench_multimodal_dataset.py | 8 +- .../unit/datasets/test_harmful_qa_dataset.py | 2 +- .../datasets/test_jbb_behaviors_dataset.py | 4 +- .../datasets/test_local_dataset_loader.py | 4 +- .../datasets/test_medsafetybench_dataset.py | 4 +- .../test_mlcommons_ailuminate_dataset.py | 2 +- tests/unit/datasets/test_or_bench_dataset.py | 6 +- .../datasets/test_pku_safe_rlhf_dataset.py | 6 +- .../unit/datasets/test_promptintel_dataset.py | 36 +++--- .../test_red_team_social_bias_dataset.py | 4 +- .../datasets/test_remote_dataset_loader.py | 2 +- .../unit/datasets/test_salad_bench_dataset.py | 4 +- .../datasets/test_seed_dataset_provider.py | 118 ++++++++++++++++-- .../datasets/test_simple_remote_datasets.py | 2 +- .../test_simple_safety_tests_dataset.py | 2 +- .../unit/datasets/test_sorry_bench_dataset.py | 6 +- .../unit/datasets/test_toxic_chat_dataset.py | 12 +- .../test_transphobia_awareness_dataset.py | 6 +- .../test_visual_leak_bench_dataset.py | 28 ++--- tests/unit/datasets/test_vlguard_dataset.py | 20 +-- .../datasets/test_vlsu_multimodal_dataset.py | 16 +-- 71 files changed, 319 insertions(+), 177 deletions(-) diff --git a/doc/code/datasets/4_dataset_coding.ipynb b/doc/code/datasets/4_dataset_coding.ipynb index f200670e3b..b95ea3e3fc 100644 --- a/doc/code/datasets/4_dataset_coding.ipynb +++ b/doc/code/datasets/4_dataset_coding.ipynb @@ -67,7 +67,7 @@ " def dataset_name(self) -> str:\n", " return \"dark_bench\"\n", "\n", - " async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset:\n", + " async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset:\n", " # Fetch from HuggingFace\n", " data = await self._fetch_from_huggingface(\n", " dataset_name=\"apart/darkbench\",\n", diff --git a/doc/code/datasets/4_dataset_coding.py b/doc/code/datasets/4_dataset_coding.py index 61b7d95ca5..3614bdff1b 100644 --- a/doc/code/datasets/4_dataset_coding.py +++ b/doc/code/datasets/4_dataset_coding.py @@ -64,7 +64,7 @@ class SimpleDarkBench(_RemoteDatasetLoader): def dataset_name(self) -> str: return "dark_bench" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: # Fetch from HuggingFace data = await self._fetch_from_huggingface( dataset_name="apart/darkbench", diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 643fcd9786..4f4dcf8af4 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -53,7 +53,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return self._dataset_name - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Load the dataset from the local YAML file. diff --git a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py index df4c412fff..895ce87902 100644 --- a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py @@ -120,7 +120,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "aegis_content_safety" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch NVIDIA Aegis AI Content Safety dataset with optional filtering and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py b/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py index a513b8bb34..3056c35e25 100644 --- a/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aya_redteaming_dataset.py @@ -84,7 +84,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "aya_redteaming" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch Aya Red-teaming dataset with optional filtering and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py b/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py index 7ab6b2b01e..386d4190e6 100644 --- a/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/babelscape_alert_dataset.py @@ -51,7 +51,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "babelscape_alert" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch Babelscape ALERT dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/beaver_tails_dataset.py b/pyrit/datasets/seed_datasets/remote/beaver_tails_dataset.py index 33b08af4e8..0405d85db2 100644 --- a/pyrit/datasets/seed_datasets/remote/beaver_tails_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/beaver_tails_dataset.py @@ -52,7 +52,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "beaver_tails" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch BeaverTails dataset from HuggingFace and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/cbt_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/cbt_bench_dataset.py index 2d8cdc1f62..eee6c2d0dd 100644 --- a/pyrit/datasets/seed_datasets/remote/cbt_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/cbt_bench_dataset.py @@ -52,7 +52,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "cbt_bench" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch CBT-Bench dataset from HuggingFace and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/ccp_sensitive_prompts_dataset.py b/pyrit/datasets/seed_datasets/remote/ccp_sensitive_prompts_dataset.py index 72ad749bb9..752a1f4c66 100644 --- a/pyrit/datasets/seed_datasets/remote/ccp_sensitive_prompts_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/ccp_sensitive_prompts_dataset.py @@ -39,7 +39,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "ccp_sensitive_prompts" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch CCP-sensitive prompts dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py b/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py index 296f287fb2..a51e63b574 100644 --- a/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/comic_jailbreak_dataset.py @@ -137,7 +137,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "comic_jailbreak" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch ComicJailbreak dataset and return as SeedDataset of image+text pairs. diff --git a/pyrit/datasets/seed_datasets/remote/darkbench_dataset.py b/pyrit/datasets/seed_datasets/remote/darkbench_dataset.py index ebd447be24..97c5bbf5da 100644 --- a/pyrit/datasets/seed_datasets/remote/darkbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/darkbench_dataset.py @@ -47,7 +47,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "dark_bench" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch DarkBench dataset from HuggingFace and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/equitymedqa_dataset.py b/pyrit/datasets/seed_datasets/remote/equitymedqa_dataset.py index 0bc3959fe4..1c3377286a 100644 --- a/pyrit/datasets/seed_datasets/remote/equitymedqa_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/equitymedqa_dataset.py @@ -112,7 +112,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "equitymedqa" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch EquityMedQA dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/forbidden_questions_dataset.py b/pyrit/datasets/seed_datasets/remote/forbidden_questions_dataset.py index a7a1a9955a..d884293eb1 100644 --- a/pyrit/datasets/seed_datasets/remote/forbidden_questions_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/forbidden_questions_dataset.py @@ -46,7 +46,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "forbidden_questions" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch Forbidden Questions dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py index f0de4ccc97..4afcdb2c15 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py @@ -49,7 +49,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "harmbench" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch HarmBench dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py index 960bb90e97..6e063c6242 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_multimodal_dataset.py @@ -75,7 +75,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "harmbench_multimodal" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch HarmBench multimodal examples and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/harmful_qa_dataset.py b/pyrit/datasets/seed_datasets/remote/harmful_qa_dataset.py index 9f8171ad10..506ebdda01 100644 --- a/pyrit/datasets/seed_datasets/remote/harmful_qa_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmful_qa_dataset.py @@ -46,7 +46,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "harmful_qa" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch HarmfulQA dataset from HuggingFace and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py index 4198a0cfd6..e638b5144f 100644 --- a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py @@ -45,7 +45,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "jbb_behaviors" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch JBB-Behaviors dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/librai_do_not_answer_dataset.py b/pyrit/datasets/seed_datasets/remote/librai_do_not_answer_dataset.py index 28cf707ed3..9f007aa349 100644 --- a/pyrit/datasets/seed_datasets/remote/librai_do_not_answer_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/librai_do_not_answer_dataset.py @@ -40,7 +40,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "librai_do_not_answer" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch LibrAI Do Not Answer dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/llm_latent_adversarial_training_dataset.py b/pyrit/datasets/seed_datasets/remote/llm_latent_adversarial_training_dataset.py index e1cffbf7c5..f881fe0cb9 100644 --- a/pyrit/datasets/seed_datasets/remote/llm_latent_adversarial_training_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/llm_latent_adversarial_training_dataset.py @@ -39,7 +39,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "llm_lat_harmful" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch LLM-LAT harmful dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/medsafetybench_dataset.py b/pyrit/datasets/seed_datasets/remote/medsafetybench_dataset.py index 34ea7334f6..ab23907230 100644 --- a/pyrit/datasets/seed_datasets/remote/medsafetybench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/medsafetybench_dataset.py @@ -73,7 +73,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "medsafetybench" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch MedSafetyBench dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/mlcommons_ailuminate_dataset.py b/pyrit/datasets/seed_datasets/remote/mlcommons_ailuminate_dataset.py index a78c7ef6d8..f954bc9948 100644 --- a/pyrit/datasets/seed_datasets/remote/mlcommons_ailuminate_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/mlcommons_ailuminate_dataset.py @@ -65,7 +65,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "mlcommons_ailuminate" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch AILuminate dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/multilingual_vulnerability_dataset.py b/pyrit/datasets/seed_datasets/remote/multilingual_vulnerability_dataset.py index 25a3986e41..4f4db77cd6 100644 --- a/pyrit/datasets/seed_datasets/remote/multilingual_vulnerability_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/multilingual_vulnerability_dataset.py @@ -40,7 +40,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "multilingual_vulnerability" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch Multilingual Vulnerability dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/or_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/or_bench_dataset.py index b91cb34964..b4aa647d49 100644 --- a/pyrit/datasets/seed_datasets/remote/or_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/or_bench_dataset.py @@ -39,7 +39,7 @@ def __init__(self, *, split: str = "train") -> None: """ self.split = split - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch OR-Bench dataset from HuggingFace and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py b/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py index 79bb0bef48..2921bb032e 100644 --- a/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/pku_safe_rlhf_dataset.py @@ -72,7 +72,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "pku_safe_rlhf" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch PKU-SafeRLHF dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py b/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py index 84d6c8f4a4..7af64c4fe4 100644 --- a/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/promptintel_dataset.py @@ -305,7 +305,7 @@ def _convert_record_to_seed_prompt(self, record: dict[str, Any]) -> Optional[See metadata=metadata, ) - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch prompts from the PromptIntel API and return as a SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/red_team_social_bias_dataset.py b/pyrit/datasets/seed_datasets/remote/red_team_social_bias_dataset.py index 55105e2c02..75b4fe1390 100644 --- a/pyrit/datasets/seed_datasets/remote/red_team_social_bias_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/red_team_social_bias_dataset.py @@ -42,7 +42,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "red_team_social_bias" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch Red Team Social Bias dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 626b9febbf..ce6cd1b39f 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -49,7 +49,7 @@ class _RemoteDatasetLoader(SeedDatasetProvider, ABC): - HuggingFace Hub Subclasses must implement: - - fetch_dataset(): Fetch and return the dataset as a SeedDataset + - fetch_dataset_async(): Fetch and return the dataset as a SeedDataset - dataset_name property: Human-readable name for the dataset """ diff --git a/pyrit/datasets/seed_datasets/remote/salad_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/salad_bench_dataset.py index 67b90c5e4b..413dbd9155 100644 --- a/pyrit/datasets/seed_datasets/remote/salad_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/salad_bench_dataset.py @@ -67,7 +67,7 @@ def _parse_category(category: str) -> str: """ return re.sub(r"^O\d+:\s*", "", category) - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch SALAD-Bench dataset from HuggingFace and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/simple_safety_tests_dataset.py b/pyrit/datasets/seed_datasets/remote/simple_safety_tests_dataset.py index e25c8215dd..d92007c521 100644 --- a/pyrit/datasets/seed_datasets/remote/simple_safety_tests_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/simple_safety_tests_dataset.py @@ -46,7 +46,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "simple_safety_tests" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch SimpleSafetyTests dataset from HuggingFace and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py index 0167d9ca39..407fc8810d 100644 --- a/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/sorry_bench_dataset.py @@ -140,7 +140,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "sorry_bench" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch Sorry-Bench dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/sosbench_dataset.py b/pyrit/datasets/seed_datasets/remote/sosbench_dataset.py index 1474b53882..61b215c4f5 100644 --- a/pyrit/datasets/seed_datasets/remote/sosbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/sosbench_dataset.py @@ -40,7 +40,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "sosbench" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch SOSBench dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/tdc23_redteaming_dataset.py b/pyrit/datasets/seed_datasets/remote/tdc23_redteaming_dataset.py index c1ffd21b8f..104be4d106 100644 --- a/pyrit/datasets/seed_datasets/remote/tdc23_redteaming_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/tdc23_redteaming_dataset.py @@ -40,7 +40,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "tdc23_redteaming" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch TDC23-RedTeaming dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/toxic_chat_dataset.py b/pyrit/datasets/seed_datasets/remote/toxic_chat_dataset.py index 5e23332a1f..fcfa682971 100644 --- a/pyrit/datasets/seed_datasets/remote/toxic_chat_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/toxic_chat_dataset.py @@ -83,7 +83,7 @@ def _extract_harm_categories(self, item: dict[str, Any]) -> list[str]: return categories - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch ToxicChat dataset from HuggingFace and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/transphobia_awareness_dataset.py b/pyrit/datasets/seed_datasets/remote/transphobia_awareness_dataset.py index 0be21d263f..73b8a8e5ef 100644 --- a/pyrit/datasets/seed_datasets/remote/transphobia_awareness_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/transphobia_awareness_dataset.py @@ -52,7 +52,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "transphobia_awareness" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch Transphobia-Awareness dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py index 2f767fe429..acaa566502 100644 --- a/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/visual_leak_bench_dataset.py @@ -119,7 +119,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "visual_leak_bench" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch VisualLeakBench examples and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py index 2d3518a33b..af0362251c 100644 --- a/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlguard_dataset.py @@ -135,7 +135,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "vlguard" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch VLGuard multimodal examples and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py index 773248b1a4..f78c71b6d5 100644 --- a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py @@ -93,7 +93,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "ml_vlsu" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch ML-VLSU multimodal examples and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/remote/xstest_dataset.py b/pyrit/datasets/seed_datasets/remote/xstest_dataset.py index 6c68248c4e..ea50eecf19 100644 --- a/pyrit/datasets/seed_datasets/remote/xstest_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/xstest_dataset.py @@ -42,7 +42,7 @@ def dataset_name(self) -> str: """Return the dataset name.""" return "xstest" - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch XSTest dataset and return as SeedDataset. diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 7ce5eb8c37..42472e68c8 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -4,6 +4,7 @@ import asyncio import inspect import logging +import warnings from abc import ABC, abstractmethod from dataclasses import fields as dc_fields from typing import Any, Optional @@ -25,7 +26,7 @@ class SeedDatasetProvider(ABC): both local and remote dataset providers. Subclasses must implement: - - fetch_dataset(): Fetch and return the dataset as a SeedDataset + - fetch_dataset_async(): Fetch and return the dataset as a SeedDataset - dataset_name property: Human-readable name for the dataset All subclasses also have a _metadata property that is optional to make @@ -40,10 +41,22 @@ def __init_subclass__(cls, **kwargs: Any) -> None: """ Automatically register non-abstract subclasses. - This is called when a class inherits from SeedDatasetProvider. + This is called when a class inherits from SeedDatasetProvider. A + deprecation warning is emitted for subclasses that still override the + legacy ``fetch_dataset`` instead of ``fetch_dataset_async``. """ super().__init_subclass__(**kwargs) - # Only register concrete (non-abstract) classes + if not inspect.isabstract(cls) and ( + cls.fetch_dataset is not SeedDatasetProvider.fetch_dataset + and cls.fetch_dataset_async is SeedDatasetProvider.fetch_dataset_async + ): + warnings.warn( + f"{cls.__name__} overrides the deprecated fetch_dataset method. " + "Rename the override to fetch_dataset_async; fetch_dataset will be " + "removed in a future release.", + DeprecationWarning, + stacklevel=2, + ) if not inspect.isabstract(cls) and getattr(cls, "should_register", True): SeedDatasetProvider._registry[cls.__name__] = cls logger.debug(f"Registered dataset provider: {cls.__name__}") @@ -58,11 +71,15 @@ def dataset_name(self) -> str: str: The dataset name (e.g., "HarmBench", "JailbreakBench JBB-Behaviors") """ - @abstractmethod - async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: """ Fetch the dataset and return as a SeedDataset. + Subclasses MUST override this method. The default implementation exists + only to provide a deprecation bridge for legacy subclasses that override + the old ``fetch_dataset`` name; in that case it dispatches to the legacy + method and emits a DeprecationWarning. + Args: cache: Whether to cache the fetched dataset. Defaults to True. Remote datasets will use DB_DATA_PATH for caching. @@ -71,8 +88,39 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: SeedDataset: The fetched dataset with prompts. Raises: + NotImplementedError: If the subclass overrides neither + ``fetch_dataset_async`` nor the legacy ``fetch_dataset``. Exception: If the dataset cannot be fetched or processed. """ + cls = type(self) + if cls.fetch_dataset is SeedDatasetProvider.fetch_dataset: + raise NotImplementedError(f"{cls.__name__} must implement fetch_dataset_async.") + warnings.warn( + f"{cls.__name__}.fetch_dataset is deprecated; rename the override to fetch_dataset_async.", + DeprecationWarning, + stacklevel=2, + ) + return await self.fetch_dataset(cache=cache) + + async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + """ + Fetch the dataset (deprecated alias of ``fetch_dataset_async``). + + Kept as a backward-compatibility shim for callers of the public API. + Emits a DeprecationWarning and delegates to ``fetch_dataset_async``. + + Args: + cache: Whether to cache the fetched dataset. Defaults to True. + + Returns: + SeedDataset: The fetched dataset with prompts. + """ + warnings.warn( + "SeedDatasetProvider.fetch_dataset is deprecated; use fetch_dataset_async instead.", + DeprecationWarning, + stacklevel=2, + ) + return await self.fetch_dataset_async(cache=cache) async def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: """ @@ -283,7 +331,7 @@ async def fetch_single_dataset( logger.debug(f"Skipping {provider_name} - not in filter list") return None - dataset = await provider.fetch_dataset(cache=cache) + dataset = await provider.fetch_dataset_async(cache=cache) return (provider.dataset_name, dataset) # Create semaphore to limit concurrency diff --git a/tests/end_to_end/test_all_datasets.py b/tests/end_to_end/test_all_datasets.py index a13977b994..fbc6090852 100644 --- a/tests/end_to_end/test_all_datasets.py +++ b/tests/end_to_end/test_all_datasets.py @@ -55,7 +55,7 @@ def get_dataset_providers(): ) async def _fetch_with_retry(provider) -> SeedDataset: """Fetch a dataset with retry on transient network errors.""" - return await provider.fetch_dataset(cache=False) + return await provider.fetch_dataset_async(cache=False) @pytest.fixture(scope="module", autouse=True) diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index 85a8a80235..d4f2f5cfc2 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -49,7 +49,7 @@ async def test_fetch_dataset_smoke(self, name, provider_cls): logger.info(f"Smoke testing provider: {name}") provider = provider_cls() - dataset = await provider.fetch_dataset(cache=False) + dataset = await provider.fetch_dataset_async(cache=False) assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset" assert len(dataset.seeds) > 0, f"{name} returned an empty dataset" @@ -102,7 +102,7 @@ async def _fetch_dataset(self, *, cache=True): "__module__": __name__, # Concrete implementations satisfy ABC requirements "dataset_name": property(lambda self: captured_name), - "fetch_dataset": _fetch_dataset, + "fetch_dataset_async": _fetch_dataset, "_fetch_from_url": lambda self, **kw: [], } diff --git a/tests/unit/datasets/test_aegis_ai_content_safety_dataset.py b/tests/unit/datasets/test_aegis_ai_content_safety_dataset.py index cb8f75693a..de709d0307 100644 --- a/tests/unit/datasets/test_aegis_ai_content_safety_dataset.py +++ b/tests/unit/datasets/test_aegis_ai_content_safety_dataset.py @@ -46,7 +46,7 @@ async def test_fetch_dataset_filters_unsafe_only(mock_aegis_data): "pyrit.datasets.seed_datasets.remote.aegis_ai_content_safety_dataset.load_dataset", return_value=mock_aegis_data, ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) # Only unsafe, non-REDACTED prompts: "How to make a weapon?" and "How to hack a system?" @@ -63,7 +63,7 @@ async def test_fetch_dataset_with_harm_category_filter(mock_aegis_data): "pyrit.datasets.seed_datasets.remote.aegis_ai_content_safety_dataset.load_dataset", return_value=mock_aegis_data, ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 1 assert dataset.seeds[0].value == "How to hack a system?" diff --git a/tests/unit/datasets/test_aya_redteaming_dataset.py b/tests/unit/datasets/test_aya_redteaming_dataset.py index 00156b9f6b..6cc4fa19d7 100644 --- a/tests/unit/datasets/test_aya_redteaming_dataset.py +++ b/tests/unit/datasets/test_aya_redteaming_dataset.py @@ -29,7 +29,7 @@ async def test_fetch_dataset(mock_aya_data): loader = _AyaRedteamingDataset() with patch.object(loader, "_fetch_from_url", return_value=mock_aya_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 @@ -42,7 +42,7 @@ async def test_fetch_dataset_filters_by_harm_category(mock_aya_data): loader = _AyaRedteamingDataset(harm_categories=["Hate Speech"]) with patch.object(loader, "_fetch_from_url", return_value=mock_aya_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 1 assert dataset.seeds[0].value == "How to spread hate speech?" @@ -52,7 +52,7 @@ async def test_fetch_dataset_filters_by_harm_scope(mock_aya_data): loader = _AyaRedteamingDataset(harm_scope="local") with patch.object(loader, "_fetch_from_url", return_value=mock_aya_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 1 assert dataset.seeds[0].value == "Tell me something violent" diff --git a/tests/unit/datasets/test_babelscape_alert_dataset.py b/tests/unit/datasets/test_babelscape_alert_dataset.py index 739a4cf8b0..4768677628 100644 --- a/tests/unit/datasets/test_babelscape_alert_dataset.py +++ b/tests/unit/datasets/test_babelscape_alert_dataset.py @@ -39,11 +39,11 @@ class TestBabelscapeAlertDataset: """Test the Babelscape ALERT dataset loader.""" async def test_fetch_dataset_returns_seed_dataset(self, mock_alert_data): - """Test that fetch_dataset returns a SeedDataset with correct prompts.""" + """Test that fetch_dataset_async returns a SeedDataset with correct prompts.""" loader = _BabelscapeAlertDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_alert_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 @@ -54,7 +54,7 @@ async def test_fetch_dataset_includes_harm_categories(self, mock_alert_data): loader = _BabelscapeAlertDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_alert_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() first_prompt = dataset.seeds[0] assert first_prompt.harm_categories == ["crime_injury"] diff --git a/tests/unit/datasets/test_beaver_tails_dataset.py b/tests/unit/datasets/test_beaver_tails_dataset.py index 5c43467d98..0b2b5a3e3c 100644 --- a/tests/unit/datasets/test_beaver_tails_dataset.py +++ b/tests/unit/datasets/test_beaver_tails_dataset.py @@ -70,7 +70,7 @@ async def test_fetch_dataset_unsafe_only(self, mock_beaver_tails_data): loader = _BeaverTailsDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_beaver_tails_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 # Only unsafe entries @@ -85,7 +85,7 @@ async def test_fetch_dataset_all_entries(self, mock_beaver_tails_data): loader = _BeaverTailsDataset(unsafe_only=False) with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_beaver_tails_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 3 # All entries including safe @@ -120,7 +120,7 @@ def __iter__(self): loader = _BeaverTailsDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=MockDataset())): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # Both prompts should be preserved — untrusted text is never passed through Jinja assert len(dataset.seeds) == 2 assert dataset.seeds[0].value == "This contains {% endraw %} which is Jinja2 syntax" diff --git a/tests/unit/datasets/test_cbt_bench_dataset.py b/tests/unit/datasets/test_cbt_bench_dataset.py index c99ffe9b3c..bb0bac62af 100644 --- a/tests/unit/datasets/test_cbt_bench_dataset.py +++ b/tests/unit/datasets/test_cbt_bench_dataset.py @@ -69,7 +69,7 @@ async def test_fetch_dataset(self, mock_cbt_bench_data): loader = _CBTBenchDataset() with patch.object(loader, "_fetch_from_huggingface", return_value=mock_cbt_bench_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 @@ -95,7 +95,7 @@ async def test_fetch_dataset_with_custom_config(self, mock_cbt_bench_data): ) with patch.object(loader, "_fetch_from_huggingface", return_value=mock_cbt_bench_data) as mock_fetch: - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) assert len(dataset.seeds) == 2 mock_fetch.assert_called_once() @@ -110,7 +110,7 @@ async def test_fetch_dataset_situation_only(self, mock_cbt_bench_data_missing_th loader = _CBTBenchDataset() with patch.object(loader, "_fetch_from_huggingface", return_value=mock_cbt_bench_data_missing_thoughts): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 1 assert dataset.seeds[0].value == "A situation without thoughts." @@ -121,14 +121,14 @@ async def test_fetch_dataset_empty_raises(self, mock_cbt_bench_data_empty): with patch.object(loader, "_fetch_from_huggingface", return_value=mock_cbt_bench_data_empty): with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() async def test_fetch_dataset_metadata_includes_config(self, mock_cbt_bench_data): """Test that metadata includes the config name.""" loader = _CBTBenchDataset(config="distortions_seed") with patch.object(loader, "_fetch_from_huggingface", return_value=mock_cbt_bench_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() for seed in dataset.seeds: assert seed.metadata["config"] == "distortions_seed" @@ -138,7 +138,7 @@ async def test_fetch_dataset_source_url(self, mock_cbt_bench_data): loader = _CBTBenchDataset() with patch.object(loader, "_fetch_from_huggingface", return_value=mock_cbt_bench_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() for seed in dataset.seeds: assert seed.source == "https://huggingface.co/datasets/Psychotherapy-LLM/CBT-Bench" diff --git a/tests/unit/datasets/test_comic_jailbreak_dataset.py b/tests/unit/datasets/test_comic_jailbreak_dataset.py index ca6d7cfa3b..e7c470809a 100644 --- a/tests/unit/datasets/test_comic_jailbreak_dataset.py +++ b/tests/unit/datasets/test_comic_jailbreak_dataset.py @@ -69,7 +69,7 @@ async def test_fetch_dataset_creates_image_text_pairs(self): patch.object(loader, "_fetch_template_async", new_callable=AsyncMock, return_value="/fake/template.png"), patch.object(loader, "_render_comic_async", new_callable=AsyncMock, return_value="/fake/rendered.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 3 # 1 objective + 1 image + 1 text @@ -96,7 +96,7 @@ async def test_fetch_dataset_skips_empty_template_text(self): patch.object(loader, "_fetch_template_async", new_callable=AsyncMock, return_value="/fake/template.png"), patch.object(loader, "_render_comic_async", new_callable=AsyncMock, return_value="/fake/rendered.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) # Only article group (instruction text is empty): 1 objective + 1 image + 1 text assert len(dataset.seeds) == 3 @@ -111,7 +111,7 @@ async def test_fetch_dataset_multiple_templates(self): patch.object(loader, "_fetch_template_async", new_callable=AsyncMock, return_value="/fake/template.png"), patch.object(loader, "_render_comic_async", new_callable=AsyncMock, return_value="/fake/rendered.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) # 3 templates with text × 1 goal = 3 groups × 3 seeds = 9 assert len(dataset.seeds) == 9 @@ -126,7 +126,7 @@ async def test_fetch_dataset_max_examples(self): patch.object(loader, "_fetch_template_async", new_callable=AsyncMock, return_value="/fake/template.png"), patch.object(loader, "_render_comic_async", new_callable=AsyncMock, return_value="/fake/rendered.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) # max_examples=2 → at most 2 groups × 3 seeds = 6 assert len(dataset.seeds) <= 6 @@ -141,7 +141,7 @@ async def test_fetch_dataset_metadata(self): patch.object(loader, "_fetch_template_async", new_callable=AsyncMock, return_value="/fake/template.png"), patch.object(loader, "_render_comic_async", new_callable=AsyncMock, return_value="/fake/rendered.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) for seed in dataset.seeds: if isinstance(seed, SeedPrompt): @@ -159,7 +159,7 @@ async def test_fetch_dataset_authors(self): patch.object(loader, "_fetch_template_async", new_callable=AsyncMock, return_value="/fake/template.png"), patch.object(loader, "_render_comic_async", new_callable=AsyncMock, return_value="/fake/rendered.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) for seed in dataset.seeds: assert "Zhiyuan Yu" in seed.authors @@ -171,7 +171,7 @@ async def test_fetch_dataset_missing_goal_raises(self): with patch.object(loader, "_fetch_from_url", return_value=mock_data): with pytest.raises(ValueError, match="Missing keys"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() async def test_fetch_dataset_empty_goal_skipped(self): mock_data = [_make_example(Goal=" ")] @@ -182,7 +182,7 @@ async def test_fetch_dataset_empty_goal_skipped(self): patch.object(loader, "_fetch_template_async", new_callable=AsyncMock, return_value="/fake/template.png"), ): with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() class TestComicJailbreakTemplates: diff --git a/tests/unit/datasets/test_darkbench_dataset.py b/tests/unit/datasets/test_darkbench_dataset.py index 25caf6e33a..085bc5e552 100644 --- a/tests/unit/datasets/test_darkbench_dataset.py +++ b/tests/unit/datasets/test_darkbench_dataset.py @@ -21,7 +21,7 @@ async def test_fetch_dataset(mock_darkbench_data): loader = _DarkBenchDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_darkbench_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 @@ -35,7 +35,7 @@ async def test_fetch_dataset_passes_config(mock_darkbench_data): loader = _DarkBenchDataset(config="custom", split="test") with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_darkbench_data)) as mock_fetch: - await loader.fetch_dataset() + await loader.fetch_dataset_async() mock_fetch.assert_called_once() call_kwargs = mock_fetch.call_args.kwargs diff --git a/tests/unit/datasets/test_equitymedqa_dataset.py b/tests/unit/datasets/test_equitymedqa_dataset.py index be48493d23..733e0f9c10 100644 --- a/tests/unit/datasets/test_equitymedqa_dataset.py +++ b/tests/unit/datasets/test_equitymedqa_dataset.py @@ -23,7 +23,7 @@ async def test_fetch_dataset_single_subset(mock_equitymedqa_data): loader = _EquityMedQADataset(subset_name="cc_manual") with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_equitymedqa_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) > 0 @@ -47,7 +47,7 @@ async def test_fetch_dataset_multiple_subsets(): with patch.object( loader, "_fetch_from_huggingface", new=AsyncMock(side_effect=[mock_cc_manual_data, mock_multimedqa_data]) ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) > 0 diff --git a/tests/unit/datasets/test_harmbench_dataset.py b/tests/unit/datasets/test_harmbench_dataset.py index 53d5e0050d..d9111eb972 100644 --- a/tests/unit/datasets/test_harmbench_dataset.py +++ b/tests/unit/datasets/test_harmbench_dataset.py @@ -21,7 +21,7 @@ async def test_fetch_dataset(mock_harmbench_data): loader = _HarmBenchDataset() with patch.object(loader, "_fetch_from_url", return_value=mock_harmbench_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 @@ -36,7 +36,7 @@ async def test_fetch_dataset_missing_keys_raises(): with patch.object(loader, "_fetch_from_url", return_value=bad_data): with pytest.raises(ValueError, match="Missing keys"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() def test_dataset_name(): diff --git a/tests/unit/datasets/test_harmbench_multimodal_dataset.py b/tests/unit/datasets/test_harmbench_multimodal_dataset.py index 8f5a1655a7..b0ad4af8c9 100644 --- a/tests/unit/datasets/test_harmbench_multimodal_dataset.py +++ b/tests/unit/datasets/test_harmbench_multimodal_dataset.py @@ -41,7 +41,7 @@ async def test_fetch_dataset(mock_harmbench_mm_data): patch.object(loader, "_fetch_from_url", return_value=mock_harmbench_mm_data), patch.object(loader, "_fetch_and_save_image_async", new=AsyncMock(return_value="/path/to/image.png")), ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) # Only multimodal entry => 2 prompts (image + text) @@ -80,7 +80,7 @@ async def test_fetch_dataset_skips_failed_images(): new=AsyncMock(side_effect=[Exception("download failed"), "/path/to/image.png"]), ), ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # First image failed, second succeeded => 2 prompts (image + text for second) assert len(dataset.seeds) == 2 @@ -109,7 +109,7 @@ async def test_fetch_dataset_filters_by_category(): patch.object(loader, "_fetch_from_url", return_value=data), patch.object(loader, "_fetch_and_save_image_async", new=AsyncMock(return_value="/path/to/image.png")), ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # Only "illegal" category matched => 2 prompts (image + text) assert len(dataset.seeds) == 2 @@ -122,7 +122,7 @@ async def test_fetch_dataset_missing_keys_raises(): with patch.object(loader, "_fetch_from_url", return_value=bad_data): with pytest.raises(ValueError, match="Missing keys"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() def test_dataset_name(): diff --git a/tests/unit/datasets/test_harmful_qa_dataset.py b/tests/unit/datasets/test_harmful_qa_dataset.py index a0b762fbc8..911bd46fd0 100644 --- a/tests/unit/datasets/test_harmful_qa_dataset.py +++ b/tests/unit/datasets/test_harmful_qa_dataset.py @@ -40,7 +40,7 @@ async def test_fetch_dataset(self, mock_harmful_qa_data): loader = _HarmfulQADataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_harmful_qa_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 diff --git a/tests/unit/datasets/test_jbb_behaviors_dataset.py b/tests/unit/datasets/test_jbb_behaviors_dataset.py index a5df7a0d65..6bfe56a4ac 100644 --- a/tests/unit/datasets/test_jbb_behaviors_dataset.py +++ b/tests/unit/datasets/test_jbb_behaviors_dataset.py @@ -22,7 +22,7 @@ async def test_fetch_dataset(mock_jbb_data): loader = _JBBBehaviorsDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_jbb_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 # Empty behavior is skipped @@ -37,7 +37,7 @@ async def test_fetch_dataset_empty_raises(): with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=empty_data)): # Source wraps ValueError in generic Exception (see jbb_behaviors_dataset.py:122-124) with pytest.raises(Exception, match="Error loading JBB-Behaviors dataset"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() def test_dataset_name(): diff --git a/tests/unit/datasets/test_local_dataset_loader.py b/tests/unit/datasets/test_local_dataset_loader.py index d2f3052bcc..1331559475 100644 --- a/tests/unit/datasets/test_local_dataset_loader.py +++ b/tests/unit/datasets/test_local_dataset_loader.py @@ -42,7 +42,7 @@ async def test_fetch_dataset(self, tmp_path, valid_yaml_content): file_path.write_text(valid_yaml_content, encoding="utf-8") loader = _LocalDatasetLoader(file_path=file_path) - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert dataset.dataset_name == "test_dataset" @@ -52,4 +52,4 @@ async def test_fetch_dataset(self, tmp_path, valid_yaml_content): async def test_fetch_dataset_file_not_found(self): loader = _LocalDatasetLoader(file_path=Path("non_existent.yaml")) with pytest.raises(Exception): # noqa: B017 - await loader.fetch_dataset() + await loader.fetch_dataset_async() diff --git a/tests/unit/datasets/test_medsafetybench_dataset.py b/tests/unit/datasets/test_medsafetybench_dataset.py index 2409956a71..b718e0f8c1 100644 --- a/tests/unit/datasets/test_medsafetybench_dataset.py +++ b/tests/unit/datasets/test_medsafetybench_dataset.py @@ -21,7 +21,7 @@ async def test_fetch_dataset_generated_subset(mock_medsafety_data): loader = _MedSafetyBenchDataset(subset_name="generated") with patch.object(loader, "_fetch_from_url", return_value=mock_medsafety_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 * len(loader.sources) @@ -36,7 +36,7 @@ async def test_fetch_dataset_missing_keys_raises(): with patch.object(loader, "_fetch_from_url", return_value=bad_data): with pytest.raises(KeyError, match="No 'harmful_medical_request' or 'prompt' found"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() def test_dataset_name(): diff --git a/tests/unit/datasets/test_mlcommons_ailuminate_dataset.py b/tests/unit/datasets/test_mlcommons_ailuminate_dataset.py index 281e0a3b42..acc1a69787 100644 --- a/tests/unit/datasets/test_mlcommons_ailuminate_dataset.py +++ b/tests/unit/datasets/test_mlcommons_ailuminate_dataset.py @@ -21,7 +21,7 @@ async def test_fetch_dataset(mock_ailuminate_data): loader = _MLCommonsAILuminateDataset() with patch.object(loader, "_fetch_from_url", return_value=mock_ailuminate_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 diff --git a/tests/unit/datasets/test_or_bench_dataset.py b/tests/unit/datasets/test_or_bench_dataset.py index 499f8f5e23..3c3a40c4fd 100644 --- a/tests/unit/datasets/test_or_bench_dataset.py +++ b/tests/unit/datasets/test_or_bench_dataset.py @@ -36,7 +36,7 @@ async def test_fetch_dataset(self, mock_or_bench_data): loader = _ORBench80KDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_or_bench_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 @@ -62,7 +62,7 @@ async def test_fetch_dataset(self, mock_or_bench_data): with patch.object( loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_or_bench_data) ) as mock_fetch: - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 2 mock_fetch.assert_called_once() @@ -84,7 +84,7 @@ async def test_fetch_dataset(self, mock_or_bench_data): with patch.object( loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_or_bench_data) ) as mock_fetch: - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 2 mock_fetch.assert_called_once() diff --git a/tests/unit/datasets/test_pku_safe_rlhf_dataset.py b/tests/unit/datasets/test_pku_safe_rlhf_dataset.py index e21dd96e4c..9f38b430f4 100644 --- a/tests/unit/datasets/test_pku_safe_rlhf_dataset.py +++ b/tests/unit/datasets/test_pku_safe_rlhf_dataset.py @@ -40,7 +40,7 @@ async def test_fetch_dataset_includes_all_prompts(mock_pku_data): loader = _PKUSafeRLHFDataset(include_safe_prompts=True) with patch.object(loader, "_fetch_from_huggingface", new_callable=AsyncMock, return_value=mock_pku_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 3 @@ -51,7 +51,7 @@ async def test_fetch_dataset_excludes_safe_prompts(mock_pku_data): loader = _PKUSafeRLHFDataset(include_safe_prompts=False) with patch.object(loader, "_fetch_from_huggingface", new_callable=AsyncMock, return_value=mock_pku_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 2 values = [s.value for s in dataset.seeds] @@ -62,7 +62,7 @@ async def test_fetch_dataset_filters_by_harm_category(mock_pku_data): loader = _PKUSafeRLHFDataset(include_safe_prompts=True, filter_harm_categories=["Cybercrime"]) with patch.object(loader, "_fetch_from_huggingface", new_callable=AsyncMock, return_value=mock_pku_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # Only the first item has Cybercrime=True; safe item has no harm categories so it's excluded assert len(dataset.seeds) == 1 diff --git a/tests/unit/datasets/test_promptintel_dataset.py b/tests/unit/datasets/test_promptintel_dataset.py index 1f17bbec78..8c43bd56e2 100644 --- a/tests/unit/datasets/test_promptintel_dataset.py +++ b/tests/unit/datasets/test_promptintel_dataset.py @@ -134,20 +134,20 @@ def test_dataset_name(self, api_key): class TestPromptIntelDatasetFetch: - """Test fetch_dataset and data transformation.""" + """Test fetch_dataset_async and data transformation.""" async def test_fetch_no_api_key_raises(self): with patch.dict("os.environ", {}, clear=True): loader = _PromptIntelDataset() with pytest.raises(ValueError, match="API key is required"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() async def test_fetch_dataset_returns_seed_dataset(self, api_key, mock_promptintel_response): loader = _PromptIntelDataset(api_key=api_key) mock_resp = _make_mock_response(json_data=mock_promptintel_response) with patch("requests.get", return_value=mock_resp): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) # 2 prompts = 2 SeedPrompts @@ -158,7 +158,7 @@ async def test_seed_prompt_fields(self, api_key, mock_promptintel_response): mock_resp = _make_mock_response(json_data=mock_promptintel_response) with patch("requests.get", return_value=mock_resp): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # Find the first SeedPrompt prompts = [s for s in dataset.seeds if isinstance(s, SeedPrompt)] @@ -178,7 +178,7 @@ async def test_seed_prompt_metadata(self, api_key, mock_promptintel_response): mock_resp = _make_mock_response(json_data=mock_promptintel_response) with patch("requests.get", return_value=mock_resp): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() prompts = [s for s in dataset.seeds if isinstance(s, SeedPrompt)] first = prompts[0] @@ -195,7 +195,7 @@ async def test_prompt_value_matches_original(self, api_key, mock_promptintel_res mock_resp = _make_mock_response(json_data=mock_promptintel_response) with patch("requests.get", return_value=mock_resp): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() prompts = [s for s in dataset.seeds if isinstance(s, SeedPrompt)] # After Jinja2 rendering, {% raw %}...{% endraw %} preserves the original text @@ -208,7 +208,7 @@ async def test_fetch_empty_dataset_raises(self, api_key, mock_empty_response): with patch("requests.get", return_value=mock_resp): with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() async def test_fetch_skips_records_without_prompt(self, api_key): data = { @@ -230,7 +230,7 @@ async def test_fetch_skips_records_without_prompt(self, api_key): with patch("requests.get", return_value=mock_resp): # All records skipped -> empty seeds -> SeedDataset raises ValueError with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() async def test_fetch_skips_records_without_title(self, api_key): data = { @@ -252,7 +252,7 @@ async def test_fetch_skips_records_without_title(self, api_key): with patch("requests.get", return_value=mock_resp): # All records skipped -> empty seeds -> SeedDataset raises ValueError with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() class TestPromptIntelDatasetPagination: @@ -290,7 +290,7 @@ async def test_pagination_fetches_all_pages(self, api_key): responses = [_make_mock_response(json_data=page1), _make_mock_response(json_data=page2)] with patch("requests.get", side_effect=responses): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 2 # 1 prompt from page1 + 1 from page2 = 2 SeedPrompts @@ -299,7 +299,7 @@ async def test_max_prompts_limits_results(self, api_key, mock_promptintel_respon mock_resp = _make_mock_response(json_data=mock_promptintel_response) with patch("requests.get", return_value=mock_resp): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # max_prompts=1 should limit to 1 SeedPrompt assert len(dataset.seeds) == 1 @@ -317,7 +317,7 @@ async def test_api_401_raises_connection_error(self, api_key): with patch("requests.get", return_value=mock_resp): with pytest.raises(ConnectionError, match="status 401"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() async def test_api_500_raises_connection_error(self, api_key): loader = _PromptIntelDataset(api_key=api_key) @@ -328,7 +328,7 @@ async def test_api_500_raises_connection_error(self, api_key): with patch("requests.get", return_value=mock_resp): with pytest.raises(ConnectionError, match="status 500"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() class TestPromptIntelDatasetFilters: @@ -339,7 +339,7 @@ async def test_severity_filter_passed_to_api(self, api_key, mock_promptintel_res mock_resp = _make_mock_response(json_data=mock_promptintel_response) with patch("requests.get", return_value=mock_resp) as mock_get: - await loader.fetch_dataset() + await loader.fetch_dataset_async() call_kwargs = mock_get.call_args assert call_kwargs.kwargs["params"]["severity"] == "critical" @@ -349,7 +349,7 @@ async def test_category_filter_passed_to_api(self, api_key, mock_promptintel_res mock_resp = _make_mock_response(json_data=mock_promptintel_response) with patch("requests.get", return_value=mock_resp) as mock_get: - await loader.fetch_dataset() + await loader.fetch_dataset_async() call_kwargs = mock_get.call_args assert call_kwargs.kwargs["params"]["category"] == "manipulation" @@ -392,7 +392,7 @@ async def test_multiple_categories_make_separate_api_calls(self, api_key): ] with patch("requests.get", side_effect=responses) as mock_get: - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # Two separate API calls should be made assert mock_get.call_count == 2 @@ -425,7 +425,7 @@ async def test_multiple_categories_deduplicates_results(self, api_key): mock_resp = _make_mock_response(json_data=response_data) with patch("requests.get", return_value=mock_resp): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # Should deduplicate by ID — only 1 seed even though 2 API calls assert len(dataset.seeds) == 1 @@ -435,7 +435,7 @@ async def test_search_filter_passed_to_api(self, api_key, mock_promptintel_respo mock_resp = _make_mock_response(json_data=mock_promptintel_response) with patch("requests.get", return_value=mock_resp) as mock_get: - await loader.fetch_dataset() + await loader.fetch_dataset_async() call_kwargs = mock_get.call_args assert call_kwargs.kwargs["params"]["search"] == "jailbreak" diff --git a/tests/unit/datasets/test_red_team_social_bias_dataset.py b/tests/unit/datasets/test_red_team_social_bias_dataset.py index 07760a1239..943ea5d5af 100644 --- a/tests/unit/datasets/test_red_team_social_bias_dataset.py +++ b/tests/unit/datasets/test_red_team_social_bias_dataset.py @@ -44,7 +44,7 @@ async def test_fetch_dataset_parses_single_and_multi_turn_and_skips_invalid_rows loader = _RedTeamSocialBiasDataset() with patch.object(loader, "_fetch_from_huggingface", new_callable=AsyncMock, return_value=mock_social_bias_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) # Single Prompt with content + Multi Turn (2 user turns) = 3 prompts @@ -58,7 +58,7 @@ async def test_fetch_dataset_multi_turn_linked(mock_social_bias_data): loader = _RedTeamSocialBiasDataset() with patch.object(loader, "_fetch_from_huggingface", new_callable=AsyncMock, return_value=mock_social_bias_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # Multi-turn prompts should share a prompt_group_id multi_turn_prompts = [s for s in dataset.seeds if s.prompt_group_id is not None] diff --git a/tests/unit/datasets/test_remote_dataset_loader.py b/tests/unit/datasets/test_remote_dataset_loader.py index 0d6c6ceafc..7c3e912261 100644 --- a/tests/unit/datasets/test_remote_dataset_loader.py +++ b/tests/unit/datasets/test_remote_dataset_loader.py @@ -18,7 +18,7 @@ class ConcreteRemoteLoader(_RemoteDatasetLoader): def dataset_name(self): return "test_remote" - async def fetch_dataset(self): + async def fetch_dataset_async(self): return SeedDataset(prompts=[]) diff --git a/tests/unit/datasets/test_salad_bench_dataset.py b/tests/unit/datasets/test_salad_bench_dataset.py index 84c2940d77..62924dbe63 100644 --- a/tests/unit/datasets/test_salad_bench_dataset.py +++ b/tests/unit/datasets/test_salad_bench_dataset.py @@ -34,7 +34,7 @@ async def test_fetch_dataset(self, mock_salad_bench_data): loader = _SaladBenchDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_salad_bench_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 @@ -68,7 +68,7 @@ async def test_fetch_dataset_with_custom_config(self, mock_salad_bench_data): with patch.object( loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_salad_bench_data) ) as mock_fetch: - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 2 mock_fetch.assert_called_once() diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index 011b5d385b..4eb058f6ad 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import textwrap +import warnings from dataclasses import fields as dc_fields from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch @@ -63,7 +64,7 @@ class DynamicTestProvider(SeedDatasetProvider): def dataset_name(self): return "dynamic_test" - async def fetch_dataset(self): + async def fetch_dataset_async(self): return SeedDataset(seeds=[]) providers = SeedDatasetProvider.get_all_providers() @@ -88,14 +89,14 @@ async def test_fetch_datasets_async(self): mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) - mock_provider1.return_value.fetch_dataset = AsyncMock( + mock_provider1.return_value.fetch_dataset_async = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) - mock_provider2.return_value.fetch_dataset = AsyncMock( + mock_provider2.return_value.fetch_dataset_async = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) @@ -108,14 +109,14 @@ async def test_fetch_datasets_async_with_filter(self): mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) - mock_provider1.return_value.fetch_dataset = AsyncMock( + mock_provider1.return_value.fetch_dataset_async = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) - mock_provider2.return_value.fetch_dataset = AsyncMock(side_effect=Exception("Should not be called")) + mock_provider2.return_value.fetch_dataset_async = AsyncMock(side_effect=Exception("Should not be called")) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=["d1"]) @@ -127,14 +128,14 @@ async def test_fetch_datasets_async_invalid_dataset_name(self): mock_provider1 = MagicMock(__name__="P1") mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value._parse_metadata = AsyncMock(return_value=None) - mock_provider1.return_value.fetch_dataset = AsyncMock( + mock_provider1.return_value.fetch_dataset_async = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock(__name__="P2") mock_provider2.return_value.dataset_name = "d2" mock_provider2.return_value._parse_metadata = AsyncMock(return_value=None) - mock_provider2.return_value.fetch_dataset = AsyncMock( + mock_provider2.return_value.fetch_dataset_async = AsyncMock( return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) @@ -148,6 +149,99 @@ async def test_fetch_datasets_async_invalid_dataset_name(self): await SeedDatasetProvider.fetch_datasets_async(dataset_names=["d1", "invalid1", "invalid2"]) +class TestFetchDatasetDeprecation: + """Tests for the fetch_dataset -> fetch_dataset_async deprecation bridge.""" + + async def test_legacy_caller_warns_and_dispatches_to_new_override(self): + """Calling deprecated fetch_dataset on a new-style subclass warns and works.""" + + class NewStyleProvider(SeedDatasetProvider): + should_register = False + + @property + def dataset_name(self) -> str: + return "new_style" + + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: + return SeedDataset(seeds=[SeedPrompt(value="x", data_type="text")], dataset_name="new_style") + + provider = NewStyleProvider() + with pytest.warns(DeprecationWarning, match="fetch_dataset is deprecated"): + dataset = await provider.fetch_dataset() + assert isinstance(dataset, SeedDataset) + assert dataset.dataset_name == "new_style" + + async def test_new_caller_does_not_warn_for_new_override(self): + """Calling fetch_dataset_async on a new-style subclass does not warn.""" + + class NewStyleProvider(SeedDatasetProvider): + should_register = False + + @property + def dataset_name(self) -> str: + return "new_style" + + async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: + return SeedDataset(seeds=[SeedPrompt(value="x", data_type="text")], dataset_name="new_style") + + provider = NewStyleProvider() + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + dataset = await provider.fetch_dataset_async() + assert isinstance(dataset, SeedDataset) + + async def test_legacy_subclass_emits_class_definition_warning(self): + """Defining a subclass that overrides only fetch_dataset emits a DeprecationWarning.""" + + with pytest.warns(DeprecationWarning, match="overrides the deprecated fetch_dataset method"): + + class LegacyProvider(SeedDatasetProvider): + should_register = False + + @property + def dataset_name(self) -> str: + return "legacy" + + async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + return SeedDataset(seeds=[SeedPrompt(value="x", data_type="text")], dataset_name="legacy") + + async def test_new_caller_dispatches_to_legacy_override_with_warning(self): + """Calling fetch_dataset_async on a legacy-style subclass warns and delegates.""" + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + + class LegacyProvider(SeedDatasetProvider): + should_register = False + + @property + def dataset_name(self) -> str: + return "legacy" + + async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + return SeedDataset(seeds=[SeedPrompt(value="x", data_type="text")], dataset_name="legacy") + + provider = LegacyProvider() + with pytest.warns(DeprecationWarning, match="rename the override to fetch_dataset_async"): + dataset = await provider.fetch_dataset_async() + assert isinstance(dataset, SeedDataset) + assert dataset.dataset_name == "legacy" + + async def test_no_override_raises_not_implemented(self): + """Subclass that overrides neither method raises NotImplementedError on fetch.""" + + class NoOverrideProvider(SeedDatasetProvider): + should_register = False + + @property + def dataset_name(self) -> str: + return "no_override" + + provider = NoOverrideProvider() + with pytest.raises(NotImplementedError, match="must implement fetch_dataset_async"): + await provider.fetch_dataset_async() + + class TestHarmBenchDataset: """Test the HarmBench dataset loader.""" @@ -156,7 +250,7 @@ async def test_fetch_dataset(self, mock_harmbench_data): loader = _HarmBenchDataset() with patch.object(loader, "_fetch_from_url", return_value=mock_harmbench_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 @@ -182,7 +276,7 @@ async def test_fetch_dataset_missing_keys(self): with patch.object(loader, "_fetch_from_url", return_value=invalid_data): with pytest.raises(ValueError, match="Missing keys in example"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() async def test_fetch_dataset_with_custom_source(self, mock_harmbench_data): """Test fetching with custom source URL.""" @@ -192,7 +286,7 @@ async def test_fetch_dataset_with_custom_source(self, mock_harmbench_data): ) with patch.object(loader, "_fetch_from_url", return_value=mock_harmbench_data) as mock_fetch: - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) assert len(dataset.seeds) == 2 mock_fetch.assert_called_once() @@ -210,7 +304,7 @@ async def test_fetch_dataset(self, mock_darkbench_data): loader = _DarkBenchDataset() with patch.object(loader, "_fetch_from_huggingface", return_value=mock_darkbench_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 @@ -237,7 +331,7 @@ async def test_fetch_dataset_with_custom_config(self, mock_darkbench_data): ) with patch.object(loader, "_fetch_from_huggingface", return_value=mock_darkbench_data) as mock_fetch: - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 2 mock_fetch.assert_called_once() diff --git a/tests/unit/datasets/test_simple_remote_datasets.py b/tests/unit/datasets/test_simple_remote_datasets.py index e8ab2ee09a..9ef0e84280 100644 --- a/tests/unit/datasets/test_simple_remote_datasets.py +++ b/tests/unit/datasets/test_simple_remote_datasets.py @@ -138,7 +138,7 @@ async def test_fetch_dataset(loader_class): mock_kwargs["new_callable"] = AsyncMock with patch.object(loader, config["fetch_method"], **mock_kwargs): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 diff --git a/tests/unit/datasets/test_simple_safety_tests_dataset.py b/tests/unit/datasets/test_simple_safety_tests_dataset.py index f338111ab7..339f27f832 100644 --- a/tests/unit/datasets/test_simple_safety_tests_dataset.py +++ b/tests/unit/datasets/test_simple_safety_tests_dataset.py @@ -38,7 +38,7 @@ async def test_fetch_dataset(self, mock_simple_safety_tests_data): loader = _SimpleSafetyTestsDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_simple_safety_tests_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 diff --git a/tests/unit/datasets/test_sorry_bench_dataset.py b/tests/unit/datasets/test_sorry_bench_dataset.py index 6df843b111..e7d275e65d 100644 --- a/tests/unit/datasets/test_sorry_bench_dataset.py +++ b/tests/unit/datasets/test_sorry_bench_dataset.py @@ -37,7 +37,7 @@ async def test_fetch_dataset(mock_sorry_bench_data): loader = _SorryBenchDataset() with patch.object(loader, "_fetch_from_huggingface", new_callable=AsyncMock, return_value=mock_sorry_bench_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) # Only "base" style by default: items 1 and 3 @@ -55,7 +55,7 @@ async def test_fetch_dataset_with_category_filter(mock_sorry_bench_data): loader = _SorryBenchDataset(categories=["Fraud"]) with patch.object(loader, "_fetch_from_huggingface", new_callable=AsyncMock, return_value=mock_sorry_bench_data): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 1 assert dataset.seeds[0].value == "Write phishing email" @@ -66,7 +66,7 @@ async def test_fetch_dataset_empty_raises(mock_sorry_bench_data): with patch.object(loader, "_fetch_from_huggingface", new_callable=AsyncMock, return_value=mock_sorry_bench_data): with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() def test_dataset_name(): diff --git a/tests/unit/datasets/test_toxic_chat_dataset.py b/tests/unit/datasets/test_toxic_chat_dataset.py index d982c0c068..90de84e2f3 100644 --- a/tests/unit/datasets/test_toxic_chat_dataset.py +++ b/tests/unit/datasets/test_toxic_chat_dataset.py @@ -42,7 +42,7 @@ async def test_fetch_dataset(self, mock_toxic_chat_data): loader = _ToxicChatDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_toxic_chat_data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 @@ -83,7 +83,7 @@ async def test_fetch_dataset_preserves_jinja2_content(self): loader = _ToxicChatDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=data_with_html)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 2 assert dataset.seeds[0].value == "Normal question" @@ -123,7 +123,7 @@ async def test_fetch_dataset_preserves_jinja2_syntax_in_entries(self): loader = _ToxicChatDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=data_with_endraw)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # All entries are preserved — untrusted text is never passed through Jinja assert len(dataset.seeds) == 3 @@ -147,7 +147,7 @@ async def test_fetch_dataset_preserves_for_loop_content(self): loader = _ToxicChatDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=data_with_for)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 1 assert dataset.seeds[0].value == "Use {% for x in items %}{{ x }}{% endfor %} in your code" @@ -177,7 +177,7 @@ async def test_fetch_dataset_sets_harm_categories_from_openai_moderation(self): loader = _ToxicChatDataset() with patch.object(loader, "_fetch_from_huggingface", new=AsyncMock(return_value=data)): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 1 categories = dataset.seeds[0].harm_categories @@ -203,7 +203,7 @@ async def test_fetch_dataset_with_custom_config(self, mock_toxic_chat_data): with patch.object( loader, "_fetch_from_huggingface", new=AsyncMock(return_value=mock_toxic_chat_data) ) as mock_fetch: - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 2 mock_fetch.assert_called_once() diff --git a/tests/unit/datasets/test_transphobia_awareness_dataset.py b/tests/unit/datasets/test_transphobia_awareness_dataset.py index 06c1f795a9..1c9ae0767a 100644 --- a/tests/unit/datasets/test_transphobia_awareness_dataset.py +++ b/tests/unit/datasets/test_transphobia_awareness_dataset.py @@ -48,7 +48,7 @@ async def test_fetch_dataset_with_mock_data(self): dataset_loader = _TransphobiaAwarenessDataset() with patch("pandas.read_excel", return_value=mock_df): - dataset = await dataset_loader.fetch_dataset(cache=False) + dataset = await dataset_loader.fetch_dataset_async(cache=False) assert isinstance(dataset, SeedDataset) assert dataset.dataset_name == "transphobia_awareness" @@ -79,7 +79,7 @@ async def test_fetch_dataset_keyword_mapping(self): dataset_loader = _TransphobiaAwarenessDataset() with patch("pandas.read_excel", return_value=mock_df): - dataset = await dataset_loader.fetch_dataset(cache=False) + dataset = await dataset_loader.fetch_dataset_async(cache=False) # All Trans and Transgender should be mapped to "transgender" assert dataset.seeds[0].metadata["keyword"] == "transgender" @@ -104,7 +104,7 @@ async def test_fetch_dataset_handles_missing_sentiment(self): dataset_loader = _TransphobiaAwarenessDataset() with patch("pandas.read_excel", return_value=mock_df): - dataset = await dataset_loader.fetch_dataset(cache=False) + dataset = await dataset_loader.fetch_dataset_async(cache=False) # First seed should have sentiment assert "question_sentiment" in dataset.seeds[0].metadata diff --git a/tests/unit/datasets/test_visual_leak_bench_dataset.py b/tests/unit/datasets/test_visual_leak_bench_dataset.py index 11ecc5b87f..ee5ce83eba 100644 --- a/tests/unit/datasets/test_visual_leak_bench_dataset.py +++ b/tests/unit/datasets/test_visual_leak_bench_dataset.py @@ -97,7 +97,7 @@ async def test_fetch_dataset_ocr_creates_pair(self): patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/ocr.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 @@ -120,7 +120,7 @@ async def test_fetch_dataset_pii_creates_pair(self): patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/pii.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) assert len(dataset.seeds) == 2 text_prompt = next(s for s in dataset.seeds if s.data_type == "text") @@ -135,7 +135,7 @@ async def test_fetch_dataset_harm_categories_ocr(self): patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/img.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) for seed in dataset.seeds: assert seed.harm_categories == ["ocr_injection"] @@ -149,7 +149,7 @@ async def test_fetch_dataset_harm_categories_pii(self): patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/img.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) for seed in dataset.seeds: assert "pii_leakage" in seed.harm_categories @@ -164,7 +164,7 @@ async def test_category_filter_ocr_only(self): patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/img.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) assert len(dataset.seeds) == 2 for seed in dataset.seeds: @@ -179,7 +179,7 @@ async def test_category_filter_pii_only(self): patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/img.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) assert len(dataset.seeds) == 2 for seed in dataset.seeds: @@ -197,7 +197,7 @@ async def test_pii_type_filter(self): patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/img.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) assert len(dataset.seeds) == 2 for seed in dataset.seeds: @@ -212,7 +212,7 @@ async def test_pii_type_filter_does_not_affect_ocr(self): patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/img.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) # OCR example passes through; SSN PII example is filtered out assert len(dataset.seeds) == 2 @@ -232,7 +232,7 @@ async def test_max_examples_limits_output(self): patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/img.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) # max_examples=2 → at most 4 prompts (2 pairs) assert len(dataset.seeds) <= 4 @@ -248,7 +248,7 @@ async def test_all_images_fail_produces_empty_dataset(self): ): # SeedDataset raises because the loader produces zero prompts with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await loader.fetch_dataset(cache=False) + await loader.fetch_dataset_async(cache=False) async def test_failed_image_skipped_but_others_succeed(self): """Test that a failed image is skipped while other examples continue.""" @@ -270,7 +270,7 @@ async def fail_first_call(url: str, example_id: str) -> str: patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", side_effect=fail_first_call), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) # Only the second example (which succeeded) should be in the dataset assert len(dataset.seeds) == 2 @@ -282,7 +282,7 @@ async def test_missing_required_key_raises(self): with patch.object(loader, "_fetch_from_url", return_value=mock_data): with pytest.raises(ValueError, match="Missing keys in example"): - await loader.fetch_dataset(cache=False) + await loader.fetch_dataset_async(cache=False) async def test_prompts_share_group_id_and_dataset_name(self): """Test that both prompts in a pair share group_id and dataset_name.""" @@ -293,7 +293,7 @@ async def test_prompts_share_group_id_and_dataset_name(self): patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/img.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) assert len(dataset.seeds) == 2 image_p = next(s for s in dataset.seeds if s.data_type == "image_path") @@ -312,7 +312,7 @@ async def test_metadata_stored_on_prompts(self): patch.object(loader, "_fetch_from_url", return_value=mock_data), patch.object(loader, "_fetch_and_save_image_async", return_value="/fake/img.png"), ): - dataset = await loader.fetch_dataset(cache=False) + dataset = await loader.fetch_dataset_async(cache=False) for seed in dataset.seeds: assert seed.metadata["category"] == "PII Leakage" diff --git a/tests/unit/datasets/test_vlguard_dataset.py b/tests/unit/datasets/test_vlguard_dataset.py index 4e4be85a50..15165c8753 100644 --- a/tests/unit/datasets/test_vlguard_dataset.py +++ b/tests/unit/datasets/test_vlguard_dataset.py @@ -107,7 +107,7 @@ async def test_fetch_unsafes_subset(self, mock_vlguard_metadata, tmp_path): "_download_dataset_files_async", new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert isinstance(dataset, SeedDataset) # 2 unsafe examples × 2 prompts each = 4 prompts @@ -133,7 +133,7 @@ async def test_fetch_safe_unsafes_subset(self, mock_vlguard_metadata, tmp_path): "_download_dataset_files_async", new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 2 # 1 example × 2 prompts text_prompts = [p for p in dataset.seeds if p.data_type == "text"] @@ -153,7 +153,7 @@ async def test_fetch_safe_safes_subset(self, mock_vlguard_metadata, tmp_path): "_download_dataset_files_async", new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 2 # 1 example × 2 prompts text_prompts = [p for p in dataset.seeds if p.data_type == "text"] @@ -175,7 +175,7 @@ async def test_category_filtering(self, mock_vlguard_metadata, tmp_path): "_download_dataset_files_async", new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() assert len(dataset.seeds) == 2 # Only the Privacy example text_prompts = [p for p in dataset.seeds if p.data_type == "text"] @@ -195,7 +195,7 @@ async def test_max_examples(self, mock_vlguard_metadata, tmp_path): "_download_dataset_files_async", new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # max_examples=1 → 1 example × 2 prompts = 2 prompts assert len(dataset.seeds) == 2 @@ -214,7 +214,7 @@ async def test_prompt_group_id_links_text_and_image(self, mock_vlguard_metadata, "_download_dataset_files_async", new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # Each pair should share a group_id text_prompt = dataset.seeds[0] @@ -239,7 +239,7 @@ async def test_missing_image_skipped(self, mock_vlguard_metadata, tmp_path): "_download_dataset_files_async", new=AsyncMock(return_value=(mock_vlguard_metadata, image_dir)), ): - dataset = await loader.fetch_dataset() + dataset = await loader.fetch_dataset_async() # Only 1 example should be included (the one with the existing image) assert len(dataset.seeds) == 2 @@ -297,7 +297,7 @@ async def test_examples_with_invalid_instr_resp_skipped(self, tmp_path): new=AsyncMock(return_value=(metadata, image_dir)), ): with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() async def test_examples_with_missing_image_field_skipped(self, tmp_path): """Test that examples with no image field are skipped.""" @@ -319,7 +319,7 @@ async def test_examples_with_missing_image_field_skipped(self, tmp_path): new=AsyncMock(return_value=(metadata, image_dir)), ): with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() async def test_examples_with_no_extractable_instruction_skipped(self, tmp_path): """Test that examples where _extract_instruction returns None are skipped.""" @@ -343,7 +343,7 @@ async def test_examples_with_no_extractable_instruction_skipped(self, tmp_path): new=AsyncMock(return_value=(metadata, image_dir)), ): with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await loader.fetch_dataset() + await loader.fetch_dataset_async() async def test_download_dataset_files_uses_cache(self, tmp_path): """Test that _download_dataset_files_async returns cached data when available.""" diff --git a/tests/unit/datasets/test_vlsu_multimodal_dataset.py b/tests/unit/datasets/test_vlsu_multimodal_dataset.py index bb466f5b92..2c847f0b4a 100644 --- a/tests/unit/datasets/test_vlsu_multimodal_dataset.py +++ b/tests/unit/datasets/test_vlsu_multimodal_dataset.py @@ -79,7 +79,7 @@ async def test_fetch_dataset_combined_unsafe_creates_pair(self): return_value="/fake/path/image.png", ), ): - dataset = await dataset_loader.fetch_dataset(cache=False) + dataset = await dataset_loader.fetch_dataset_async(cache=False) assert isinstance(dataset, SeedDataset) assert len(dataset.seeds) == 2 # Text + Image pair @@ -135,7 +135,7 @@ async def test_fetch_dataset_combined_borderline_creates_pair(self): return_value="/fake/path/image.png", ), ): - dataset = await dataset_loader.fetch_dataset(cache=False) + dataset = await dataset_loader.fetch_dataset_async(cache=False) assert len(dataset.seeds) == 2 # Text + Image pair @@ -165,7 +165,7 @@ async def test_fetch_dataset_combined_safe_no_prompts(self): with patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data): with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await dataset_loader.fetch_dataset(cache=False) + await dataset_loader.fetch_dataset_async(cache=False) async def test_fetch_dataset_multiple_pairs(self): """Test that multiple text+image pairs are created correctly.""" @@ -204,7 +204,7 @@ async def test_fetch_dataset_multiple_pairs(self): return_value="/fake/path/image.png", ), ): - dataset = await dataset_loader.fetch_dataset(cache=False) + dataset = await dataset_loader.fetch_dataset_async(cache=False) assert len(dataset.seeds) == 4 # 2 pairs of text + image @@ -256,7 +256,7 @@ async def test_category_filtering(self): return_value="/fake/path/image.png", ), ): - dataset = await dataset_loader.fetch_dataset(cache=False) + dataset = await dataset_loader.fetch_dataset_async(cache=False) # Only the slur category should be included (1 pair = 2 prompts) assert len(dataset.seeds) == 2 @@ -291,7 +291,7 @@ async def test_handles_failed_image_downloads(self): ): # Both text and image should be skipped when image fails with pytest.raises(ValueError, match="SeedDataset cannot be empty"): - await dataset_loader.fetch_dataset(cache=False) + await dataset_loader.fetch_dataset_async(cache=False) async def test_custom_unsafe_grades(self): """Test that custom unsafe_grades parameter works correctly.""" @@ -331,7 +331,7 @@ async def test_custom_unsafe_grades(self): return_value="/fake/path/image.png", ), ): - dataset = await dataset_loader.fetch_dataset(cache=False) + dataset = await dataset_loader.fetch_dataset_async(cache=False) # Only the "unsafe" pair should be included assert len(dataset.seeds) == 2 @@ -364,7 +364,7 @@ async def test_both_prompts_use_combined_category(self): return_value="/fake/path/image.png", ), ): - dataset = await dataset_loader.fetch_dataset(cache=False) + dataset = await dataset_loader.fetch_dataset_async(cache=False) # Both should use combined_category, not their individual categories for seed in dataset.seeds: From e930765efc8703e1bf0f11b06ebfb9b25ac3636a Mon Sep 17 00:00:00 2001 From: romanlutz Date: Fri, 15 May 2026 15:11:09 -0700 Subject: [PATCH 2/2] Address review: add v0.16.0 to fetch_dataset deprecation messages Per @hannahwestra25's review feedback on #1735, specify the planned removal version in all three DeprecationWarning messages, matching the convention used elsewhere in the codebase (e.g., prompt_target.py, audio_transcript_scorer.py). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/datasets/seed_datasets/seed_dataset_provider.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 42472e68c8..bf19ecb5c2 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -53,7 +53,7 @@ def __init_subclass__(cls, **kwargs: Any) -> None: warnings.warn( f"{cls.__name__} overrides the deprecated fetch_dataset method. " "Rename the override to fetch_dataset_async; fetch_dataset will be " - "removed in a future release.", + "removed in v0.16.0.", DeprecationWarning, stacklevel=2, ) @@ -96,7 +96,8 @@ async def fetch_dataset_async(self, *, cache: bool = True) -> SeedDataset: if cls.fetch_dataset is SeedDatasetProvider.fetch_dataset: raise NotImplementedError(f"{cls.__name__} must implement fetch_dataset_async.") warnings.warn( - f"{cls.__name__}.fetch_dataset is deprecated; rename the override to fetch_dataset_async.", + f"{cls.__name__}.fetch_dataset is deprecated and will be removed in v0.16.0; " + "rename the override to fetch_dataset_async.", DeprecationWarning, stacklevel=2, ) @@ -116,7 +117,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: SeedDataset: The fetched dataset with prompts. """ warnings.warn( - "SeedDatasetProvider.fetch_dataset is deprecated; use fetch_dataset_async instead.", + "SeedDatasetProvider.fetch_dataset is deprecated and will be removed in v0.16.0; " + "use fetch_dataset_async instead.", DeprecationWarning, stacklevel=2, )