From 3834349b83e19ac25592b4eb83deff4a83b554fe Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Wed, 17 Jun 2026 15:06:05 -0400 Subject: [PATCH] Use sqrt fiscal target loss weights --- .../tests/test_us_fiscal_refresh_builder.py | 297 ++++++++++++++++-- tools/build_us_fiscal_refresh_release.py | 163 +++++----- 2 files changed, 348 insertions(+), 112 deletions(-) diff --git a/packages/populace-build/tests/test_us_fiscal_refresh_builder.py b/packages/populace-build/tests/test_us_fiscal_refresh_builder.py index f6b1f91..3fda2ac 100644 --- a/packages/populace-build/tests/test_us_fiscal_refresh_builder.py +++ b/packages/populace-build/tests/test_us_fiscal_refresh_builder.py @@ -403,21 +403,28 @@ def test_release_gate_failures_reject_missing_critical_targets() -> None: ] -def test_fiscal_target_loss_weights_prioritize_national_totals() -> None: +def test_fiscal_target_loss_weights_ignore_roles_and_geography() -> None: builder = _load_builder_module() registry = TargetRegistry( ( TargetSpec( - name="national_income_tax_total", + name="national_critical_role", entity="household", - value=10.0, + value=100.0, source="fixture", metadata={"target_role": "federal_income_tax_total"}, ), TargetSpec( - name="distribution_row", + name="state_role_row", entity="household", - value=10.0, + value=100.0, + source="fixture", + metadata={"state_fips": "06", "target_role": "tanf_total"}, + ), + TargetSpec( + name="ordinary_distribution_row", + entity="household", + value=100.0, source="fixture", ), ), @@ -426,27 +433,45 @@ def test_fiscal_target_loss_weights_prioritize_national_totals() -> None: weights = builder._fiscal_target_loss_weights(registry) - assert weights.shape == (2,) + assert weights.shape == (3,) assert weights.mean() == 1.0 - assert weights[0] == weights[1] * builder.US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER + assert np.array_equal(weights, np.ones(3)) -def test_fiscal_target_loss_weights_downweight_state_rows() -> None: +def test_fiscal_target_loss_weights_scale_by_sqrt_value_within_basis() -> None: builder = _load_builder_module() registry = TargetRegistry( ( TargetSpec( - name="state_role_row", - entity="tax_unit", + name="amount_small", + entity="household", value=100.0, source="fixture", - metadata={"state_fips": "06", "target_role": "tanf_total"}, + metadata={"source_measure_id": "payment_amount"}, ), TargetSpec( - name="national_row", - entity="tax_unit", - value=100.0, + name="amount_large", + entity="household", + value=300.0, source="fixture", + metadata={"source_measure_id": "payment_amount"}, + ), + TargetSpec( + name="returns_small", + entity="household", + value=10.0, + source="fixture", + metadata={ + "source_measure_id": "income_tax_liability_returns", + "count": "true", + }, + ), + TargetSpec( + name="returns_large", + entity="household", + value=30.0, + source="fixture", + metadata={"source_measure_id": "ctc_claims", "count": "true"}, ), ), country="us", @@ -455,10 +480,13 @@ def test_fiscal_target_loss_weights_downweight_state_rows() -> None: weights = builder._fiscal_target_loss_weights(registry) assert weights.mean() == 1.0 - assert weights[0] == weights[1] * builder.US_STATE_TARGET_LOSS_MULTIPLIER + assert np.isclose(weights[1] / weights[0], np.sqrt(3.0)) + assert np.isclose(weights[3] / weights[2], np.sqrt(3.0)) + assert weights[0] == weights[2] + assert weights[1] == weights[3] -def test_fiscal_target_loss_weights_scale_by_value_within_basis() -> None: +def test_fiscal_target_loss_weights_split_evenly_between_amount_and_count() -> None: builder = _load_builder_module() registry = TargetRegistry( ( @@ -477,21 +505,52 @@ def test_fiscal_target_loss_weights_scale_by_value_within_basis() -> None: metadata={"source_measure_id": "payment_amount"}, ), TargetSpec( - name="returns_small", + name="returns", entity="household", value=10.0, source="fixture", - metadata={ - "source_measure_id": "income_tax_liability_returns", - "count": "true", - }, + metadata={"source_measure_id": "ctc_claims", "count": "true"}, ), + ), + country="us", + ) + + weights = builder._fiscal_target_loss_weights(registry) + bases = np.asarray( + [builder._fiscal_target_value_basis(spec) for spec in registry.specs], + dtype=object, + ) + + assert weights.mean() == 1.0 + assert weights[bases == "amount"].sum() == weights[bases == "count"].sum() + assert np.isclose(weights[1] / weights[0], np.sqrt(3.0)) + + +def test_fiscal_target_loss_weights_floor_zero_subunit_and_abs_values() -> None: + builder = _load_builder_module() + registry = TargetRegistry( + ( TargetSpec( - name="returns_large", + name="zero", entity="household", - value=30.0, + value=0.0, source="fixture", - metadata={"source_measure_id": "ctc_claims", "count": "true"}, + metadata={"source_measure_id": "payment_amount"}, + ), + TargetSpec( + name="subunit", + entity="household", + value=0.25, + source="fixture", + metadata={"source_measure_id": "payment_amount"}, + ), + TargetSpec( + name="negative", + entity="household", + value=-9.0, + source="fixture", + signed=True, + metadata={"source_measure_id": "payment_amount"}, ), ), country="us", @@ -500,13 +559,11 @@ def test_fiscal_target_loss_weights_scale_by_value_within_basis() -> None: weights = builder._fiscal_target_loss_weights(registry) assert weights.mean() == 1.0 - assert weights[1] / weights[0] == 3.0 - assert weights[3] / weights[2] == 3.0 - assert weights[0] == weights[2] - assert weights[1] == weights[3] + assert weights[0] == weights[1] + assert np.isclose(weights[2] / weights[0], 3.0) -def test_fiscal_target_value_basis_keeps_person_counts_separate_from_amounts() -> None: +def test_fiscal_target_value_basis_uses_only_amount_and_count() -> None: builder = _load_builder_module() amount = TargetSpec( name="amount", @@ -515,6 +572,13 @@ def test_fiscal_target_value_basis_keeps_person_counts_separate_from_amounts() - source="fixture", metadata={"source_measure_id": "payment_amount"}, ) + return_count = TargetSpec( + name="return_count", + entity="household", + value=100.0, + source="fixture", + metadata={"source_measure_id": "ctc_claims", "count": "true"}, + ) person_count = TargetSpec( name="person_count", entity="household", @@ -529,7 +593,180 @@ def test_fiscal_target_value_basis_keeps_person_counts_separate_from_amounts() - ) assert builder._fiscal_target_value_basis(amount) == "amount" - assert builder._fiscal_target_value_basis(person_count) == "person_count" + assert builder._fiscal_target_value_basis(return_count) == "count" + assert builder._fiscal_target_value_basis(person_count) == "count" + + +def test_release_calibration_diagnostics_include_gate_failures( + monkeypatch, tmp_path +) -> None: + builder = _load_builder_module() + captured: dict[str, object] = {} + + def fake_write_calibration_diagnostics(result, path, *, target_registry, build): + captured["result"] = result + captured["path"] = path + captured["target_registry"] = target_registry + captured["build"] = build + return path + + monkeypatch.setattr(builder, "_sha256", lambda path: "base-sha") + monkeypatch.setattr( + builder, "write_calibration_diagnostics", fake_write_calibration_diagnostics + ) + result = SimpleNamespace() + registry = TargetRegistry((), country="us") + profile_gate = SimpleNamespace(passed=True, failures=(), details={"n": 1}) + health_gate = SimpleNamespace(passed=True, failures=(), details={"n": 2}) + + builder._write_release_calibration_diagnostics( + result=result, + release_dir=tmp_path, + registry=registry, + base_h5=tmp_path / "base.h5", + compilation={"dropped_target_names": []}, + target_profile_gate=profile_gate, + health_input_gate=health_gate, + audit_export_targets=False, + gate_failures=["ctc failed"], + ) + + assert captured["path"] == tmp_path / "calibration_diagnostics.json" + build = captured["build"] + assert build["base_dataset_sha256"] == "base-sha" + assert build["release_gates"] == { + "passed": False, + "failures": ["ctc failed"], + } + assert build["health_input_signal"] == { + "passed": True, + "failures": [], + "details": {"n": 2}, + } + + +def test_main_writes_diagnostics_before_post_calibration_gate_failure( + monkeypatch, tmp_path +) -> None: + builder = _load_builder_module() + release_id = "populace-us-2024-gate-failure-test" + base_h5 = tmp_path / "base.h5" + facts = tmp_path / "facts.jsonl" + out = tmp_path / "out" + base_h5.write_bytes(b"h5") + facts.write_text("{}\n") + target_spec = TargetSpec( + name="amount", + entity="household", + measure="income", + value=100.0, + source="fixture", + metadata={"source_measure_id": "payment_amount"}, + ) + registry = TargetRegistry((target_spec,), country="us") + result = SimpleNamespace( + skipped=(), + diagnostics=(), + initial_loss=2.0, + final_loss=1.0, + ) + captured: dict[str, object] = {} + + class FakeFrame: + pass + + monkeypatch.setattr( + sys, + "argv", + [ + "build_us_fiscal_refresh_release.py", + "--base-h5", + str(base_h5), + "--ledger-facts", + str(facts), + "--out", + str(out), + "--release-id", + release_id, + ], + ) + monkeypatch.setattr(builder, "_git_dirty", lambda: False) + monkeypatch.setattr(builder, "_sha256", lambda path: "base-sha") + monkeypatch.setattr(builder, "_git_output", lambda *args: "commit") + monkeypatch.setattr(builder, "_load_ledger_facts", lambda path: ({"fact": 1},)) + monkeypatch.setattr( + builder, + "compile_us_fiscal_target_registry", + lambda facts, *, target_period: registry, + ) + monkeypatch.setattr( + builder, + "target_profile_coverage_gate", + lambda specs, requirements: builder.GateResult( + name="target_profile_coverage", + passed=True, + details={"checked": True}, + ), + ) + monkeypatch.setattr(builder, "_load_frame", lambda path: FakeFrame()) + monkeypatch.setattr( + builder, + "_with_aca_marketplace_source_outputs", + lambda frame, specs, *, seed: frame, + ) + monkeypatch.setattr( + builder, + "_health_input_signal_gate", + lambda frame: builder.GateResult( + name="health_input_signal", + passed=True, + details={"checked": True}, + ), + ) + monkeypatch.setattr( + builder, + "_materialize_target_frame", + lambda frame, specs: ( + frame, + registry, + {"dropped_target_names": []}, + ), + ) + + def fake_calibrate(*args, **kwargs): + captured["target_loss_weights"] = kwargs["target_loss_weights"] + return result + + def fake_write_diagnostics(**kwargs): + captured["diagnostics"] = kwargs + release_dir = kwargs["release_dir"] + release_dir.mkdir(parents=True, exist_ok=True) + (release_dir / "calibration_diagnostics.json").write_text("{}") + return release_dir / "calibration_diagnostics.json" + + monkeypatch.setattr(builder, "calibrate", fake_calibrate) + monkeypatch.setattr( + builder, + "_release_gate_failures", + lambda *args: ["ctc failed"], + ) + monkeypatch.setattr( + builder, + "_write_release_calibration_diagnostics", + fake_write_diagnostics, + ) + + try: + builder.main() + except RuntimeError as exc: + assert str(exc) == "Release gates failed: ctc failed" + else: # pragma: no cover - defensive assertion + raise AssertionError("Expected post-calibration gate failure.") + + release_dir = out / "releases" / release_id + assert (release_dir / "calibration_diagnostics.json").exists() + assert captured["diagnostics"]["gate_failures"] == ["ctc failed"] + assert np.array_equal(captured["target_loss_weights"], np.asarray([1.0])) def test_release_gate_failures_reject_bad_ctc_fit() -> None: diff --git a/tools/build_us_fiscal_refresh_release.py b/tools/build_us_fiscal_refresh_release.py index f108888..9e36565 100644 --- a/tools/build_us_fiscal_refresh_release.py +++ b/tools/build_us_fiscal_refresh_release.py @@ -69,25 +69,10 @@ POST_EXPORT_ABSOLUTE_TOLERANCE = 1_000_000.0 POST_EXPORT_RELATIVE_TOLERANCE = 5e-4 US_FISCAL_TARGET_LOSS_WEIGHTING = ( - "semantic_value_weighted_mape_by_measure_basis_target_scale_cap_1000pct" + "sqrt_value_weighted_mape_50_50_amount_count_target_scale_cap_1000pct" ) +US_FISCAL_TARGET_VALUE_WEIGHT_POWER = 0.5 US_FISCAL_TARGET_LOSS_CAP = 10.0 -US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER = 25.0 -US_STATE_TARGET_LOSS_MULTIPLIER = 0.25 -US_FISCAL_TARGET_ROLE_LOSS_MULTIPLIERS = { - "aca_spending": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, - "ctc_total": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, - "eitc_total": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, - "federal_income_tax_total": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, - "income_tax_before_credits_total": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, - "medicare_part_b_premium_total": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, - "refundable_ctc_total": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, - "snap_total": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, - "social_security_total": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, - "ssi_total": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, - "tanf_total": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, - "unemployment_compensation_total": US_NATIONAL_TOTAL_TARGET_LOSS_MULTIPLIER, -} US_CRITICAL_TARGET_FIT_REQUIREMENTS = ( { "name": ( @@ -1301,31 +1286,20 @@ def _write_npz(path: Path, *, result, registry: TargetRegistry) -> None: def _fiscal_target_loss_weights(registry: TargetRegistry) -> np.ndarray: - basis_weights = _fiscal_target_value_basis_weights(registry) - state_multipliers = np.asarray( - [ - US_STATE_TARGET_LOSS_MULTIPLIER if spec.metadata.get("state_fips") else 1.0 - for spec in registry.specs - ], - dtype=np.float64, - ) - multipliers = np.asarray( - [ - ( - 1.0 - if spec.metadata.get("state_fips") - else US_FISCAL_TARGET_ROLE_LOSS_MULTIPLIERS.get( - spec.metadata.get("target_role", ""), - 1.0, - ) - ) - for spec in registry.specs - ], - dtype=np.float64, + weights = _fiscal_target_value_basis_weights(registry) + bases = np.asarray( + [_fiscal_target_value_basis(spec) for spec in registry.specs], + dtype=object, ) - weights = basis_weights - weights *= state_multipliers - weights *= multipliers + unique_bases = sorted(set(bases.tolist())) + if not unique_bases: + return weights + basis_total = len(weights) / len(unique_bases) + for basis in unique_bases: + mask = bases == basis + current_total = weights[mask].sum() + if current_total > 0: + weights[mask] *= basis_total / current_total return weights / weights.mean() @@ -1339,11 +1313,12 @@ def _fiscal_target_value_basis_weights(registry: TargetRegistry) -> np.ndarray: [max(abs(float(spec.value)), 1.0) for spec in registry.specs], dtype=np.float64, ) + raw_weights = values**US_FISCAL_TARGET_VALUE_WEIGHT_POWER for basis in sorted(set(bases.tolist())): mask = bases == basis - mean_value = values[mask].mean() + mean_value = raw_weights[mask].mean() if mean_value > 0: - weights[mask] = values[mask] / mean_value + weights[mask] = raw_weights[mask] / mean_value return weights @@ -1351,35 +1326,17 @@ def _fiscal_target_value_basis(spec) -> str: metadata = spec.metadata measure_mode = metadata.get("measure_mode", "") source_measure_id = metadata.get("source_measure_id", "") - target_role = metadata.get("target_role", "") if metadata.get("count") == "true": - return ( - "return_count" - if _fiscal_target_is_return_count_measure(source_measure_id) - else "count" - ) + return "count" if measure_mode in {"count", "positive_count"}: - if metadata.get("count_map_to") == "person" or target_role in { - "aca_enrollment", - "aca_ptc_recipients", - "medicaid_enrollment", - "medicaid_chip_enrollment", - }: - return "person_count" return "count" if "enrollment" in source_measure_id or "recipients" in source_measure_id: - return "person_count" + return "count" if "return" in source_measure_id and "count" in source_measure_id: - return "return_count" + return "count" return "amount" -def _fiscal_target_is_return_count_measure(source_measure_id: str) -> bool: - return source_measure_id == "return_count" or source_measure_id.endswith( - ("_returns", "_claims") - ) - - def _release_gate_failures( result, compilation: Mapping[str, object], @@ -1525,6 +1482,52 @@ def _assert_release_gates( raise RuntimeError("Release gates failed: " + "; ".join(failures)) +def _write_release_calibration_diagnostics( + *, + result, + release_dir: Path, + registry: TargetRegistry, + base_h5: Path, + compilation: Mapping[str, object], + target_profile_gate: GateResult, + health_input_gate: GateResult | None, + audit_export_targets: bool, + gate_failures: Iterable[str], +) -> None: + """Write calibration diagnostics even when hard release gates fail.""" + failures = list(gate_failures) + write_calibration_diagnostics( + result, + release_dir / "calibration_diagnostics.json", + target_registry=registry, + build={ + "base_dataset_sha256": _sha256(base_h5), + "target_compilation": compilation, + "target_loss_weighting": US_FISCAL_TARGET_LOSS_WEIGHTING, + "target_loss_cap": US_FISCAL_TARGET_LOSS_CAP, + "target_profile_coverage": { + "passed": target_profile_gate.passed, + "failures": list(target_profile_gate.failures), + "details": dict(target_profile_gate.details), + }, + "health_input_signal": ( + { + "passed": health_input_gate.passed, + "failures": list(health_input_gate.failures), + "details": dict(health_input_gate.details), + } + if health_input_gate is not None + else None + ), + "release_gates": { + "passed": not failures, + "failures": failures, + }, + "post_export_target_audit": bool(audit_export_targets), + }, + ) + + def _target_final_estimate(result, target_name: str) -> float: for diagnostic in result.diagnostics: if diagnostic.name == f"{target_name}@{PERIOD}": @@ -1990,12 +1993,25 @@ def main() -> None: target_loss_weights=_fiscal_target_loss_weights(registry), target_loss_cap=US_FISCAL_TARGET_LOSS_CAP, ) - _assert_release_gates( + gate_failures = _release_gate_failures( result, compilation, target_profile_gate, health_input_gate, ) + _write_release_calibration_diagnostics( + result=result, + release_dir=release_dir, + registry=registry, + base_h5=base_h5, + compilation=compilation, + target_profile_gate=target_profile_gate, + health_input_gate=health_input_gate, + audit_export_targets=bool(args.audit_export_targets), + gate_failures=gate_failures, + ) + if gate_failures: + raise RuntimeError("Release gates failed: " + "; ".join(gate_failures)) export_frame = _strip_calibration_columns(base_frame, result.weights) dataset_path = artifact_root / DATASET_FILENAME @@ -2005,23 +2021,6 @@ def main() -> None: calibration_path = artifact_root / CALIBRATION_FILENAME _write_npz(calibration_path, result=result, registry=registry) - write_calibration_diagnostics( - result, - release_dir / "calibration_diagnostics.json", - target_registry=registry, - build={ - "base_dataset_sha256": _sha256(base_h5), - "target_compilation": compilation, - "target_loss_weighting": US_FISCAL_TARGET_LOSS_WEIGHTING, - "target_loss_cap": US_FISCAL_TARGET_LOSS_CAP, - "target_profile_coverage": { - "passed": target_profile_gate.passed, - "failures": list(target_profile_gate.failures), - "details": dict(target_profile_gate.details), - }, - "post_export_target_audit": bool(args.audit_export_targets), - }, - ) if not args.skip_reform_validation: _write_reform_validation(