diff --git a/src/art/utils/iterate_dataset.py b/src/art/utils/iterate_dataset.py index 6b546e79e..179318027 100644 --- a/src/art/utils/iterate_dataset.py +++ b/src/art/utils/iterate_dataset.py @@ -62,8 +62,7 @@ def iterate_dataset( for epoch in range(num_epochs): # Create indices and shuffle deterministically based on epoch indices = list(range(dataset_size)) - random.seed(epoch) # Ensure shuffling is the same for a given epoch - random.shuffle(indices) + random.Random(epoch).shuffle(indices) for i in range(0, dataset_size, groups_per_step): epoch_step = i // groups_per_step diff --git a/tests/unit/test_iterate_dataset.py b/tests/unit/test_iterate_dataset.py new file mode 100644 index 000000000..1909fcf21 --- /dev/null +++ b/tests/unit/test_iterate_dataset.py @@ -0,0 +1,36 @@ +import random + +from art.utils.iterate_dataset import iterate_dataset + + +def test_iterate_dataset_is_deterministic_across_runs() -> None: + dataset = list(range(10)) + + first = [ + batch.items + for batch in iterate_dataset( + dataset, groups_per_step=3, num_epochs=2, use_tqdm=False + ) + ] + second = [ + batch.items + for batch in iterate_dataset( + dataset, groups_per_step=3, num_epochs=2, use_tqdm=False + ) + ] + + assert first == second + + +def test_iterate_dataset_does_not_reset_global_random_state() -> None: + dataset = list(range(10)) + + random.seed(12345) + expected = [random.random() for _ in range(3)] + + random.seed(12345) + iterator = iterate_dataset(dataset, groups_per_step=2, num_epochs=2, use_tqdm=False) + next(iterator) + after_iteration = [random.random() for _ in range(3)] + + assert after_iteration == expected