Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 267 additions & 30 deletions packages/populace-build/tests/test_us_fiscal_refresh_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,21 +403,28 @@ def test_release_gate_failures_reject_missing_critical_targets() -> None:
]


def test_fiscal_target_loss_weights_prioritize_national_totals() -> None:
def test_fiscal_target_loss_weights_ignore_roles_and_geography() -> None:
builder = _load_builder_module()
registry = TargetRegistry(
(
TargetSpec(
name="national_income_tax_total",
name="national_critical_role",
entity="household",
value=10.0,
value=100.0,
source="fixture",
metadata={"target_role": "federal_income_tax_total"},
),
TargetSpec(
name="distribution_row",
name="state_role_row",
entity="household",
value=10.0,
value=100.0,
source="fixture",
metadata={"state_fips": "06", "target_role": "tanf_total"},
),
TargetSpec(
name="ordinary_distribution_row",
entity="household",
value=100.0,
source="fixture",
),
),
Expand All @@ -426,27 +433,45 @@ def test_fiscal_target_loss_weights_prioritize_national_totals() -> None:

weights = builder._fiscal_target_loss_weights(registry)

assert weights.shape == (2,)
assert weights.shape == (3,)
assert weights.mean() == 1.0
assert weights[0] == weights[1] * builder.US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER
assert np.array_equal(weights, np.ones(3))


def test_fiscal_target_loss_weights_downweight_state_rows() -> None:
def test_fiscal_target_loss_weights_scale_by_sqrt_value_within_basis() -> None:
builder = _load_builder_module()
registry = TargetRegistry(
(
TargetSpec(
name="state_role_row",
entity="tax_unit",
name="amount_small",
entity="household",
value=100.0,
source="fixture",
metadata={"state_fips": "06", "target_role": "tanf_total"},
metadata={"source_measure_id": "payment_amount"},
),
TargetSpec(
name="national_row",
entity="tax_unit",
value=100.0,
name="amount_large",
entity="household",
value=300.0,
source="fixture",
metadata={"source_measure_id": "payment_amount"},
),
TargetSpec(
name="returns_small",
entity="household",
value=10.0,
source="fixture",
metadata={
"source_measure_id": "income_tax_liability_returns",
"count": "true",
},
),
TargetSpec(
name="returns_large",
entity="household",
value=30.0,
source="fixture",
metadata={"source_measure_id": "ctc_claims", "count": "true"},
),
),
country="us",
Expand All @@ -455,10 +480,13 @@ def test_fiscal_target_loss_weights_downweight_state_rows() -> None:
weights = builder._fiscal_target_loss_weights(registry)

assert weights.mean() == 1.0
assert weights[0] == weights[1] * builder.US_STATE_TARGET_LOSS_MULTIPLIER
assert np.isclose(weights[1] / weights[0], np.sqrt(3.0))
assert np.isclose(weights[3] / weights[2], np.sqrt(3.0))
assert weights[0] == weights[2]
assert weights[1] == weights[3]


def test_fiscal_target_loss_weights_scale_by_value_within_basis() -> None:
def test_fiscal_target_loss_weights_split_evenly_between_amount_and_count() -> None:
builder = _load_builder_module()
registry = TargetRegistry(
(
Expand All @@ -477,21 +505,52 @@ def test_fiscal_target_loss_weights_scale_by_value_within_basis() -> None:
metadata={"source_measure_id": "payment_amount"},
),
TargetSpec(
name="returns_small",
name="returns",
entity="household",
value=10.0,
source="fixture",
metadata={
"source_measure_id": "income_tax_liability_returns",
"count": "true",
},
metadata={"source_measure_id": "ctc_claims", "count": "true"},
),
),
country="us",
)

weights = builder._fiscal_target_loss_weights(registry)
bases = np.asarray(
[builder._fiscal_target_value_basis(spec) for spec in registry.specs],
dtype=object,
)

assert weights.mean() == 1.0
assert weights[bases == "amount"].sum() == weights[bases == "count"].sum()
assert np.isclose(weights[1] / weights[0], np.sqrt(3.0))


def test_fiscal_target_loss_weights_floor_zero_subunit_and_abs_values() -> None:
builder = _load_builder_module()
registry = TargetRegistry(
(
TargetSpec(
name="returns_large",
name="zero",
entity="household",
value=30.0,
value=0.0,
source="fixture",
metadata={"source_measure_id": "ctc_claims", "count": "true"},
metadata={"source_measure_id": "payment_amount"},
),
TargetSpec(
name="subunit",
entity="household",
value=0.25,
source="fixture",
metadata={"source_measure_id": "payment_amount"},
),
TargetSpec(
name="negative",
entity="household",
value=-9.0,
source="fixture",
signed=True,
metadata={"source_measure_id": "payment_amount"},
),
),
country="us",
Expand All @@ -500,13 +559,11 @@ def test_fiscal_target_loss_weights_scale_by_value_within_basis() -> None:
weights = builder._fiscal_target_loss_weights(registry)

assert weights.mean() == 1.0
assert weights[1] / weights[0] == 3.0
assert weights[3] / weights[2] == 3.0
assert weights[0] == weights[2]
assert weights[1] == weights[3]
assert weights[0] == weights[1]
assert np.isclose(weights[2] / weights[0], 3.0)


def test_fiscal_target_value_basis_keeps_person_counts_separate_from_amounts() -> None:
def test_fiscal_target_value_basis_uses_only_amount_and_count() -> None:
builder = _load_builder_module()
amount = TargetSpec(
name="amount",
Expand All @@ -515,6 +572,13 @@ def test_fiscal_target_value_basis_keeps_person_counts_separate_from_amounts() -
source="fixture",
metadata={"source_measure_id": "payment_amount"},
)
return_count = TargetSpec(
name="return_count",
entity="household",
value=100.0,
source="fixture",
metadata={"source_measure_id": "ctc_claims", "count": "true"},
)
person_count = TargetSpec(
name="person_count",
entity="household",
Expand All @@ -529,7 +593,180 @@ def test_fiscal_target_value_basis_keeps_person_counts_separate_from_amounts() -
)

assert builder._fiscal_target_value_basis(amount) == "amount"
assert builder._fiscal_target_value_basis(person_count) == "person_count"
assert builder._fiscal_target_value_basis(return_count) == "count"
assert builder._fiscal_target_value_basis(person_count) == "count"


def test_release_calibration_diagnostics_include_gate_failures(
monkeypatch, tmp_path
) -> None:
builder = _load_builder_module()
captured: dict[str, object] = {}

def fake_write_calibration_diagnostics(result, path, *, target_registry, build):
captured["result"] = result
captured["path"] = path
captured["target_registry"] = target_registry
captured["build"] = build
return path

monkeypatch.setattr(builder, "_sha256", lambda path: "base-sha")
monkeypatch.setattr(
builder, "write_calibration_diagnostics", fake_write_calibration_diagnostics
)
result = SimpleNamespace()
registry = TargetRegistry((), country="us")
profile_gate = SimpleNamespace(passed=True, failures=(), details={"n": 1})
health_gate = SimpleNamespace(passed=True, failures=(), details={"n": 2})

builder._write_release_calibration_diagnostics(
result=result,
release_dir=tmp_path,
registry=registry,
base_h5=tmp_path / "base.h5",
compilation={"dropped_target_names": []},
target_profile_gate=profile_gate,
health_input_gate=health_gate,
audit_export_targets=False,
gate_failures=["ctc failed"],
)

assert captured["path"] == tmp_path / "calibration_diagnostics.json"
build = captured["build"]
assert build["base_dataset_sha256"] == "base-sha"
assert build["release_gates"] == {
"passed": False,
"failures": ["ctc failed"],
}
assert build["health_input_signal"] == {
"passed": True,
"failures": [],
"details": {"n": 2},
}


def test_main_writes_diagnostics_before_post_calibration_gate_failure(
monkeypatch, tmp_path
) -> None:
builder = _load_builder_module()
release_id = "populace-us-2024-gate-failure-test"
base_h5 = tmp_path / "base.h5"
facts = tmp_path / "facts.jsonl"
out = tmp_path / "out"
base_h5.write_bytes(b"h5")
facts.write_text("{}\n")
target_spec = TargetSpec(
name="amount",
entity="household",
measure="income",
value=100.0,
source="fixture",
metadata={"source_measure_id": "payment_amount"},
)
registry = TargetRegistry((target_spec,), country="us")
result = SimpleNamespace(
skipped=(),
diagnostics=(),
initial_loss=2.0,
final_loss=1.0,
)
captured: dict[str, object] = {}

class FakeFrame:
pass

monkeypatch.setattr(
sys,
"argv",
[
"build_us_fiscal_refresh_release.py",
"--base-h5",
str(base_h5),
"--ledger-facts",
str(facts),
"--out",
str(out),
"--release-id",
release_id,
],
)
monkeypatch.setattr(builder, "_git_dirty", lambda: False)
monkeypatch.setattr(builder, "_sha256", lambda path: "base-sha")
monkeypatch.setattr(builder, "_git_output", lambda *args: "commit")
monkeypatch.setattr(builder, "_load_ledger_facts", lambda path: ({"fact": 1},))
monkeypatch.setattr(
builder,
"compile_us_fiscal_target_registry",
lambda facts, *, target_period: registry,
)
monkeypatch.setattr(
builder,
"target_profile_coverage_gate",
lambda specs, requirements: builder.GateResult(
name="target_profile_coverage",
passed=True,
details={"checked": True},
),
)
monkeypatch.setattr(builder, "_load_frame", lambda path: FakeFrame())
monkeypatch.setattr(
builder,
"_with_aca_marketplace_source_outputs",
lambda frame, specs, *, seed: frame,
)
monkeypatch.setattr(
builder,
"_health_input_signal_gate",
lambda frame: builder.GateResult(
name="health_input_signal",
passed=True,
details={"checked": True},
),
)
monkeypatch.setattr(
builder,
"_materialize_target_frame",
lambda frame, specs: (
frame,
registry,
{"dropped_target_names": []},
),
)

def fake_calibrate(*args, **kwargs):
captured["target_loss_weights"] = kwargs["target_loss_weights"]
return result

def fake_write_diagnostics(**kwargs):
captured["diagnostics"] = kwargs
release_dir = kwargs["release_dir"]
release_dir.mkdir(parents=True, exist_ok=True)
(release_dir / "calibration_diagnostics.json").write_text("{}")
return release_dir / "calibration_diagnostics.json"

monkeypatch.setattr(builder, "calibrate", fake_calibrate)
monkeypatch.setattr(
builder,
"_release_gate_failures",
lambda *args: ["ctc failed"],
)
monkeypatch.setattr(
builder,
"_write_release_calibration_diagnostics",
fake_write_diagnostics,
)

try:
builder.main()
except RuntimeError as exc:
assert str(exc) == "Release gates failed: ctc failed"
else: # pragma: no cover - defensive assertion
raise AssertionError("Expected post-calibration gate failure.")

release_dir = out / "releases" / release_id
assert (release_dir / "calibration_diagnostics.json").exists()
assert captured["diagnostics"]["gate_failures"] == ["ctc failed"]
assert np.array_equal(captured["target_loss_weights"], np.asarray([1.0]))


def test_release_gate_failures_reject_bad_ctc_fit() -> None:
Expand Down
Loading
Loading