From 70c2b1597f6f635c9a071cdab464e7470ca9acd8 Mon Sep 17 00:00:00 2001 From: pragnyanramtha Date: Sun, 17 May 2026 02:13:23 +0000 Subject: [PATCH] Preserve RNG state in dataset iteration --- src/art/utils/iterate_dataset.py | 3 +-- tests/unit/test_iterate_dataset.py | 36 ++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_iterate_dataset.py diff --git a/src/art/utils/iterate_dataset.py b/src/art/utils/iterate_dataset.py index 6b546e79e..179318027 100644 --- a/src/art/utils/iterate_dataset.py +++ b/src/art/utils/iterate_dataset.py @@ -62,8 +62,7 @@ def iterate_dataset( for epoch in range(num_epochs): # Create indices and shuffle deterministically based on epoch indices = list(range(dataset_size)) - random.seed(epoch) # Ensure shuffling is the same for a given epoch - random.shuffle(indices) + random.Random(epoch).shuffle(indices) for i in range(0, dataset_size, groups_per_step): epoch_step = i // groups_per_step diff --git a/tests/unit/test_iterate_dataset.py b/tests/unit/test_iterate_dataset.py new file mode 100644 index 000000000..1909fcf21 --- /dev/null +++ b/tests/unit/test_iterate_dataset.py @@ -0,0 +1,36 @@ +import random + +from art.utils.iterate_dataset import iterate_dataset + + +def test_iterate_dataset_is_deterministic_across_runs() -> None: + dataset = list(range(10)) + + first = [ + batch.items + for batch in iterate_dataset( + dataset, groups_per_step=3, num_epochs=2, use_tqdm=False + ) + ] + second = [ + batch.items + for batch in iterate_dataset( + dataset, groups_per_step=3, num_epochs=2, use_tqdm=False + ) + ] + + assert first == second + + +def test_iterate_dataset_does_not_reset_global_random_state() -> None: + dataset = list(range(10)) + + random.seed(12345) + expected = [random.random() for _ in range(3)] + + random.seed(12345) + iterator = iterate_dataset(dataset, groups_per_step=2, num_epochs=2, use_tqdm=False) + next(iterator) + after_iteration = [random.random() for _ in range(3)] + + assert after_iteration == expected