From aadf2f7d2b5e125e6365207347a66e41cbeba523 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 11 Jun 2026 14:02:26 +0200 Subject: [PATCH 1/2] Support GCS dataset sources --- changelog.d/1776.added.md | 1 + docs/book/usage/simulations.md | 13 ++ policyengine_uk/data/dataset_sources.py | 134 +++++++++++++ policyengine_uk/simulation.py | 10 +- policyengine_uk/tests/test_dataset_sources.py | 189 ++++++++++++++++++ .../tests/test_simulation_dataset_sources.py | 33 ++- 6 files changed, 375 insertions(+), 5 deletions(-) create mode 100644 changelog.d/1776.added.md create mode 100644 policyengine_uk/data/dataset_sources.py create mode 100644 policyengine_uk/tests/test_dataset_sources.py diff --git a/changelog.d/1776.added.md b/changelog.d/1776.added.md new file mode 100644 index 000000000..c9087b2a3 --- /dev/null +++ b/changelog.d/1776.added.md @@ -0,0 +1 @@ +- Added direct `gs://` dataset loading for UK simulations, including support for GCS generations and PolicyEngine data-version metadata. diff --git a/docs/book/usage/simulations.md b/docs/book/usage/simulations.md index 3bbc2741b..38332414d 100644 --- a/docs/book/usage/simulations.md +++ b/docs/book/usage/simulations.md @@ -278,6 +278,19 @@ sim = Simulation(dataset=dataset) print(sim.calculate("household_net_income", 2026)) ``` +`Simulation` and `Microsimulation` can also load H5 files from local paths, +Hugging Face URLs, or Google Cloud Storage URLs: + +```python +sim = Microsimulation( + dataset="gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.10" +) +``` + +For `gs://` URLs, a numeric suffix after `@` pins an exact GCS generation. A +non-numeric suffix pins the PolicyEngine data version stored in the object's +GCS metadata. + ### From survey datasets For population-level analysis, use survey data: diff --git a/policyengine_uk/data/dataset_sources.py b/policyengine_uk/data/dataset_sources.py new file mode 100644 index 000000000..2c1009a5f --- /dev/null +++ b/policyengine_uk/data/dataset_sources.py @@ -0,0 +1,134 @@ +import hashlib +import os +import tempfile +from pathlib import Path +from typing import Optional, Union + +from policyengine_core.tools.google_cloud import parse_gs_url + + +def materialize_gcs_dataset_url( + dataset_url: str, + *, + cache_dir: Optional[Union[str, os.PathLike]] = None, +) -> str: + """Download a GCS dataset URL to a local H5 path and return that path.""" + bucket_name, file_path, revision = parse_gs_url(dataset_url) + storage_client = _get_storage_client() + blob = _resolve_gcs_blob(storage_client, bucket_name, file_path, revision) + generation = _blob_generation(blob) + + local_path = _cached_dataset_path( + bucket_name=bucket_name, + file_path=file_path, + generation=generation, + cache_dir=cache_dir, + ) + if not local_path.exists(): + _download_blob(blob, local_path) + return str(local_path) + + +def _get_storage_client(): + try: + import google.auth + from google.auth import exceptions as auth_exceptions + from google.cloud import storage + except ImportError as exc: + raise ImportError( + "google-cloud-storage is required for gs:// dataset URLs. " + "Install it with: pip install google-cloud-storage" + ) from exc + + try: + credentials, project_id = google.auth.default() + except auth_exceptions.DefaultCredentialsError as exc: + raise RuntimeError( + "Google Cloud credentials are required for gs:// dataset URLs. " + "Set application default credentials or GOOGLE_APPLICATION_CREDENTIALS." + ) from exc + + return storage.Client(credentials=credentials, project=project_id) + + +def _resolve_gcs_blob( + storage_client, + bucket_name: str, + file_path: str, + revision: Optional[str], +): + bucket = storage_client.bucket(bucket_name) + + if revision is not None and revision.isdigit(): + blob = bucket.blob(file_path, generation=int(revision)) + blob.reload() + return blob + + current_blob = bucket.blob(file_path) + current_blob.reload() + if revision is None or _blob_metadata_version(current_blob) == revision: + return current_blob + + matching_blobs = [] + for blob in storage_client.list_blobs( + bucket_name, + prefix=file_path, + versions=True, + ): + if blob.name != file_path: + continue + if _blob_metadata_version(blob) == revision: + matching_blobs.append(blob) + + if not matching_blobs: + raise ValueError( + f"No GCS object version for gs://{bucket_name}/{file_path} has " + f"metadata version {revision!r}." + ) + + return max(matching_blobs, key=lambda blob: int(_blob_generation(blob))) + + +def _blob_metadata_version(blob) -> Optional[str]: + if getattr(blob, "metadata", None) is None: + blob.reload() + metadata = getattr(blob, "metadata", None) or {} + return metadata.get("version") + + +def _blob_generation(blob) -> str: + generation = getattr(blob, "generation", None) + if generation is None: + blob.reload() + generation = getattr(blob, "generation", None) + if generation is None: + raise ValueError(f"GCS object {blob.name!r} does not expose a generation.") + return str(generation) + + +def _cached_dataset_path( + *, + bucket_name: str, + file_path: str, + generation: str, + cache_dir: Optional[Union[str, os.PathLike]], +) -> Path: + if cache_dir is None: + cache_dir = Path(tempfile.gettempdir()) / "policyengine-uk-datasets" + else: + cache_dir = Path(cache_dir) + + cache_key = hashlib.sha256( + f"{bucket_name}\0{file_path}\0{generation}".encode() + ).hexdigest() + return cache_dir / cache_key / Path(file_path).name + + +def _download_blob(blob, local_path: Path) -> None: + local_path.parent.mkdir(parents=True, exist_ok=True) + temporary_path = local_path.with_name(f"{local_path.name}.tmp") + try: + blob.download_to_filename(str(temporary_path)) + os.replace(temporary_path, local_path) + finally: + temporary_path.unlink(missing_ok=True) diff --git a/policyengine_uk/simulation.py b/policyengine_uk/simulation.py index f68096e80..4179d1439 100644 --- a/policyengine_uk/simulation.py +++ b/policyengine_uk/simulation.py @@ -26,6 +26,7 @@ extend_single_year_dataset, reset_growthfactor_uprating, ) +from policyengine_uk.data.dataset_sources import materialize_gcs_dataset_url from policyengine_uk.utils.dependencies import get_variable_dependencies from policyengine_uk.reforms import create_structural_reforms_from_parameters from policyengine_uk.parameters.gov.simulation.labour_supply_responses.aliases import ( @@ -274,11 +275,14 @@ def build_from_dataset_source( if dataset_source.startswith("hf://"): self.build_from_url(dataset_source) return + if dataset_source.startswith("gs://"): + dataset_file = materialize_gcs_dataset_url(dataset_source) + self.build_from_file(dataset_file, cache_key=dataset_source) + return if "://" in dataset_source: raise ValueError( - "Only HuggingFace dataset URLs are supported directly by " - "policyengine-uk. Download or materialize other dataset " - "sources to a local file path before passing them to Simulation." + "Only HuggingFace, Google Cloud Storage, and local dataset " + "sources are supported by policyengine-uk." ) self.build_from_file(dataset_source) diff --git a/policyengine_uk/tests/test_dataset_sources.py b/policyengine_uk/tests/test_dataset_sources.py new file mode 100644 index 000000000..6e5206267 --- /dev/null +++ b/policyengine_uk/tests/test_dataset_sources.py @@ -0,0 +1,189 @@ +from pathlib import Path + +import pytest + +from policyengine_uk.data import dataset_sources +from policyengine_uk.data.dataset_sources import materialize_gcs_dataset_url + + +class FakeBlob: + def __init__(self, name, generation, metadata=None, contents=b"dataset"): + self.name = name + self.generation = str(generation) + self.metadata = metadata + self.contents = contents + self.download_count = 0 + self.reload_count = 0 + + def reload(self): + self.reload_count += 1 + + def download_to_filename(self, filename): + self.download_count += 1 + Path(filename).write_bytes(self.contents) + + +class FakeBucket: + def __init__(self, blobs, current_generations): + self.blobs = blobs + self.current_generations = current_generations + + def blob(self, file_path, generation=None): + if generation is None: + generation = self.current_generations[file_path] + return self.blobs[(file_path, str(generation))] + + +class FakeStorageClient: + def __init__(self, blobs, current_generations): + self.blobs = blobs + self.current_generations = current_generations + self.list_calls = [] + + def bucket(self, bucket_name): + return FakeBucket( + self.blobs[bucket_name], self.current_generations[bucket_name] + ) + + def list_blobs(self, bucket_name, prefix, versions): + self.list_calls.append((bucket_name, prefix, versions)) + return [ + blob + for (name, _generation), blob in self.blobs[bucket_name].items() + if name.startswith(prefix) + ] + + +def fake_client(*blobs, current_generation): + file_path = blobs[0].name + return FakeStorageClient( + blobs={"bucket": {(blob.name, blob.generation): blob for blob in blobs}}, + current_generations={"bucket": {file_path: str(current_generation)}}, + ) + + +def test_materialize_gcs_dataset_url_uses_numeric_generation(monkeypatch, tmp_path): + old_blob = FakeBlob( + "data/file.h5", + 123, + metadata={"version": "1.55.10"}, + contents=b"old", + ) + current_blob = FakeBlob( + "data/file.h5", + 456, + metadata={"version": "1.56.0"}, + contents=b"current", + ) + client = fake_client(old_blob, current_blob, current_generation=456) + monkeypatch.setattr(dataset_sources, "_get_storage_client", lambda: client) + + path = materialize_gcs_dataset_url( + "gs://bucket/data/file.h5@123", + cache_dir=tmp_path, + ) + + assert Path(path).read_bytes() == b"old" + assert old_blob.download_count == 1 + assert client.list_calls == [] + + +def test_materialize_gcs_dataset_url_uses_latest_matching_metadata_version( + monkeypatch, + tmp_path, +): + old_matching_blob = FakeBlob( + "data/file.h5", + 111, + metadata={"version": "1.55.10"}, + contents=b"old match", + ) + new_matching_blob = FakeBlob( + "data/file.h5", + 333, + metadata={"version": "1.55.10"}, + contents=b"new match", + ) + current_blob = FakeBlob( + "data/file.h5", + 444, + metadata={"version": "1.56.0"}, + contents=b"current", + ) + client = fake_client( + old_matching_blob, + new_matching_blob, + current_blob, + current_generation=444, + ) + monkeypatch.setattr(dataset_sources, "_get_storage_client", lambda: client) + + path = materialize_gcs_dataset_url( + "gs://bucket/data/file.h5@1.55.10", + cache_dir=tmp_path, + ) + + assert Path(path).read_bytes() == b"new match" + assert new_matching_blob.download_count == 1 + assert client.list_calls == [("bucket", "data/file.h5", True)] + + +def test_materialize_gcs_dataset_url_uses_current_blob_when_metadata_matches( + monkeypatch, + tmp_path, +): + current_blob = FakeBlob( + "data/file.h5", + 444, + metadata={"version": "1.55.10"}, + contents=b"current", + ) + client = fake_client(current_blob, current_generation=444) + monkeypatch.setattr(dataset_sources, "_get_storage_client", lambda: client) + + path = materialize_gcs_dataset_url( + "gs://bucket/data/file.h5@1.55.10", + cache_dir=tmp_path, + ) + + assert Path(path).read_bytes() == b"current" + assert current_blob.download_count == 1 + assert client.list_calls == [] + + +def test_materialize_gcs_dataset_url_errors_for_missing_metadata_version( + monkeypatch, + tmp_path, +): + current_blob = FakeBlob( + "data/file.h5", + 444, + metadata={"version": "1.56.0"}, + ) + client = fake_client(current_blob, current_generation=444) + monkeypatch.setattr(dataset_sources, "_get_storage_client", lambda: client) + + with pytest.raises(ValueError, match="metadata version '1.55.10'"): + materialize_gcs_dataset_url( + "gs://bucket/data/file.h5@1.55.10", + cache_dir=tmp_path, + ) + + +def test_materialize_gcs_dataset_url_reuses_cached_file(monkeypatch, tmp_path): + current_blob = FakeBlob( + "data/file.h5", + 444, + metadata={"version": "1.55.10"}, + contents=b"current", + ) + client = fake_client(current_blob, current_generation=444) + monkeypatch.setattr(dataset_sources, "_get_storage_client", lambda: client) + + dataset_url = "gs://bucket/data/file.h5@1.55.10" + first_path = materialize_gcs_dataset_url(dataset_url, cache_dir=tmp_path) + second_path = materialize_gcs_dataset_url(dataset_url, cache_dir=tmp_path) + + assert first_path == second_path + assert Path(second_path).read_bytes() == b"current" + assert current_blob.download_count == 1 diff --git a/policyengine_uk/tests/test_simulation_dataset_sources.py b/policyengine_uk/tests/test_simulation_dataset_sources.py index 5065fa69f..98834b68c 100644 --- a/policyengine_uk/tests/test_simulation_dataset_sources.py +++ b/policyengine_uk/tests/test_simulation_dataset_sources.py @@ -56,13 +56,42 @@ def fake_build_from_file(dataset_file, *, cache_key=None): assert captured == {"url": url} +def test_dataset_source_routes_gcs_urls_to_materialized_file(monkeypatch): + captured = {} + simulation = Simulation.__new__(Simulation) + url = "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.10" + + def fake_materialize_gcs_dataset_url(dataset_url): + captured["url"] = dataset_url + return "/tmp/enhanced_frs_2023_24.h5" + + def fake_build_from_file(dataset_file, *, cache_key=None): + captured["dataset_file"] = dataset_file + captured["cache_key"] = cache_key + + monkeypatch.setattr( + simulation_module, + "materialize_gcs_dataset_url", + fake_materialize_gcs_dataset_url, + ) + monkeypatch.setattr(simulation, "build_from_file", fake_build_from_file) + + Simulation.build_from_dataset_source(simulation, url) + + assert captured == { + "url": url, + "dataset_file": "/tmp/enhanced_frs_2023_24.h5", + "cache_key": url, + } + + def test_dataset_source_rejects_unsupported_remote_urls(): simulation = Simulation.__new__(Simulation) - with pytest.raises(ValueError, match="Only HuggingFace dataset URLs"): + with pytest.raises(ValueError, match="Only HuggingFace, Google Cloud Storage"): Simulation.build_from_dataset_source( simulation, - "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5", + "s3://policyengine-uk-data-private/enhanced_frs_2023_24.h5", ) From 691571831c946959479bee47456aaf8603fcb69c Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 11 Jun 2026 14:51:09 +0200 Subject: [PATCH 2/2] Harden GCS dataset loading --- policyengine_uk/data/dataset_sources.py | 8 +++- policyengine_uk/simulation.py | 5 +++ policyengine_uk/tests/test_dataset_sources.py | 33 ++++++++++++++++ .../tests/test_simulation_dataset_sources.py | 38 +++++++++++++++++++ 4 files changed, 83 insertions(+), 1 deletion(-) diff --git a/policyengine_uk/data/dataset_sources.py b/policyengine_uk/data/dataset_sources.py index 2c1009a5f..ff279dc21 100644 --- a/policyengine_uk/data/dataset_sources.py +++ b/policyengine_uk/data/dataset_sources.py @@ -126,7 +126,13 @@ def _cached_dataset_path( def _download_blob(blob, local_path: Path) -> None: local_path.parent.mkdir(parents=True, exist_ok=True) - temporary_path = local_path.with_name(f"{local_path.name}.tmp") + fd, temporary_path_name = tempfile.mkstemp( + prefix=f".{local_path.name}.", + suffix=".tmp", + dir=local_path.parent, + ) + os.close(fd) + temporary_path = Path(temporary_path_name) try: blob.download_to_filename(str(temporary_path)) os.replace(temporary_path, local_path) diff --git a/policyengine_uk/simulation.py b/policyengine_uk/simulation.py index 4179d1439..75bb1569b 100644 --- a/policyengine_uk/simulation.py +++ b/policyengine_uk/simulation.py @@ -276,6 +276,11 @@ def build_from_dataset_source( self.build_from_url(dataset_source) return if dataset_source.startswith("gs://"): + if dataset_source in _url_dataset_cache: + multi_year_dataset = _url_dataset_cache[dataset_source] + self.build_from_multi_year_dataset(multi_year_dataset) + self.dataset = multi_year_dataset + return dataset_file = materialize_gcs_dataset_url(dataset_source) self.build_from_file(dataset_file, cache_key=dataset_source) return diff --git a/policyengine_uk/tests/test_dataset_sources.py b/policyengine_uk/tests/test_dataset_sources.py index 6e5206267..49dd499da 100644 --- a/policyengine_uk/tests/test_dataset_sources.py +++ b/policyengine_uk/tests/test_dataset_sources.py @@ -1,3 +1,4 @@ +import os from pathlib import Path import pytest @@ -13,6 +14,7 @@ def __init__(self, name, generation, metadata=None, contents=b"dataset"): self.metadata = metadata self.contents = contents self.download_count = 0 + self.download_filenames = [] self.reload_count = 0 def reload(self): @@ -20,6 +22,7 @@ def reload(self): def download_to_filename(self, filename): self.download_count += 1 + self.download_filenames.append(filename) Path(filename).write_bytes(self.contents) @@ -187,3 +190,33 @@ def test_materialize_gcs_dataset_url_reuses_cached_file(monkeypatch, tmp_path): assert first_path == second_path assert Path(second_path).read_bytes() == b"current" assert current_blob.download_count == 1 + + +def test_download_blob_uses_unique_temp_path_for_each_download(monkeypatch, tmp_path): + local_path = tmp_path / "cache" / "file.h5" + created_temp_paths = [] + + def fake_mkstemp(*, prefix, suffix, dir): + temporary_path = Path(dir) / f"{prefix}{len(created_temp_paths)}{suffix}" + fd = os.open(temporary_path, os.O_CREAT | os.O_EXCL | os.O_RDWR, 0o600) + created_temp_paths.append(temporary_path) + return fd, str(temporary_path) + + monkeypatch.setattr(dataset_sources.tempfile, "mkstemp", fake_mkstemp) + blob = FakeBlob("data/file.h5", 444, contents=b"first") + + dataset_sources._download_blob(blob, local_path) + local_path.unlink() + blob.contents = b"second" + dataset_sources._download_blob(blob, local_path) + + assert [ + Path(filename) for filename in blob.download_filenames + ] == created_temp_paths + assert len(set(created_temp_paths)) == 2 + assert all( + temporary_path.parent == local_path.parent + for temporary_path in created_temp_paths + ) + assert all(not temporary_path.exists() for temporary_path in created_temp_paths) + assert local_path.read_bytes() == b"second" diff --git a/policyengine_uk/tests/test_simulation_dataset_sources.py b/policyengine_uk/tests/test_simulation_dataset_sources.py index 98834b68c..5ce9f846a 100644 --- a/policyengine_uk/tests/test_simulation_dataset_sources.py +++ b/policyengine_uk/tests/test_simulation_dataset_sources.py @@ -85,6 +85,44 @@ def fake_build_from_file(dataset_file, *, cache_key=None): } +def test_dataset_source_reuses_cached_gcs_dataset_before_materializing(monkeypatch): + captured = {} + cached_dataset = object() + simulation = Simulation.__new__(Simulation) + url = "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.55.10" + + def fake_materialize_gcs_dataset_url(dataset_url): + raise AssertionError("Cached gs:// datasets should not be materialized.") + + def fake_build_from_file(dataset_file, *, cache_key=None): + raise AssertionError("Cached gs:// datasets should not be read from disk.") + + def fake_build_from_multi_year_dataset(dataset): + captured["dataset"] = dataset + + simulation_module._url_dataset_cache.pop(url, None) + simulation_module._url_dataset_cache[url] = cached_dataset + monkeypatch.setattr( + simulation_module, + "materialize_gcs_dataset_url", + fake_materialize_gcs_dataset_url, + ) + monkeypatch.setattr(simulation, "build_from_file", fake_build_from_file) + monkeypatch.setattr( + simulation, + "build_from_multi_year_dataset", + fake_build_from_multi_year_dataset, + ) + + try: + Simulation.build_from_dataset_source(simulation, url) + + assert captured["dataset"] is cached_dataset + assert simulation.dataset is cached_dataset + finally: + simulation_module._url_dataset_cache.pop(url, None) + + def test_dataset_source_rejects_unsupported_remote_urls(): simulation = Simulation.__new__(Simulation)