From ff1fdfcda0c112962fb91ce3bad1a74bf32113c5 Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 6 Jun 2026 07:24:25 +0100 Subject: [PATCH] Chain numeric zero-inflated imputations --- .../zero-inflated-numeric-chain.fixed.md | 1 + microimpute/models/zero_inflated.py | 105 +++++++++++------- tests/test_models/test_zero_inflated.py | 52 +++++++++ 3 files changed, 118 insertions(+), 40 deletions(-) create mode 100644 changelog.d/zero-inflated-numeric-chain.fixed.md diff --git a/changelog.d/zero-inflated-numeric-chain.fixed.md b/changelog.d/zero-inflated-numeric-chain.fixed.md new file mode 100644 index 0000000..1741d0a --- /dev/null +++ b/changelog.d/zero-inflated-numeric-chain.fixed.md @@ -0,0 +1 @@ +Preserve sequential numeric-target chaining inside `ZeroInflatedImputer`. diff --git a/microimpute/models/zero_inflated.py b/microimpute/models/zero_inflated.py index 8858446..f643bc6 100644 --- a/microimpute/models/zero_inflated.py +++ b/microimpute/models/zero_inflated.py @@ -234,20 +234,26 @@ def fit( for v in imputed_variables if v in self.numeric_targets or v in constant_numeric_targets ] - for var in numeric_targets: + numeric_target_set = set(numeric_targets) + previous_numeric_targets: List[str] = [] + for var in imputed_variables: + if var not in numeric_target_set: + continue y = X_train[var].to_numpy(dtype=float, copy=False) regime = _detect_regime( y, zero_atol=self.zero_atol, ) + sequential_predictors = list(predictors) + previous_numeric_targets self._regimes[var] = regime self._per_variable[var] = self._fit_single_numeric( X_train=X_train, - predictors=predictors, + predictors=sequential_predictors, variable=var, regime=regime, y=y, ) + previous_numeric_targets.append(var) # Non-numeric (categorical / boolean / constant) targets are # handled by a single auxiliary base imputer over their union. @@ -298,17 +304,24 @@ def _fit_single_numeric( Returns a bundle dict with the regime, the gate classifier (or None), and the base imputer(s) keyed by their role. """ + + def with_predictors(bundle: Dict[str, Any]) -> Dict[str, Any]: + bundle["predictors"] = list(predictors) + return bundle + X_pred = X_train[predictors].to_numpy(dtype=float, copy=False) if regime == REGIME_DEGENERATE_ZERO: - return {"kind": "constant", "value": 0.0} + return with_predictors({"kind": "constant", "value": 0.0}) if regime in (REGIME_POSITIVE_ONLY, REGIME_NEGATIVE_ONLY): # No gate; single base imputer on the full training set. - return { - "kind": "single", - "base": self._fit_base_single(X_train, predictors, variable), - } + return with_predictors( + { + "kind": "single", + "base": self._fit_base_single(X_train, predictors, variable), + } + ) if regime == REGIME_ZI_POSITIVE: labels = (y > self.zero_atol).astype(int) @@ -318,11 +331,13 @@ def _fit_single_numeric( pos_base = self._fit_base_single( X_train.loc[pos_mask], predictors, variable ) - return { - "kind": "zi_positive", - "classifier": clf, - "positive_base": pos_base, - } + return with_predictors( + { + "kind": "zi_positive", + "classifier": clf, + "positive_base": pos_base, + } + ) if regime == REGIME_ZI_NEGATIVE: labels = (y < -self.zero_atol).astype(int) @@ -332,11 +347,13 @@ def _fit_single_numeric( neg_base = self._fit_base_single( X_train.loc[neg_mask], predictors, variable ) - return { - "kind": "zi_negative", - "classifier": clf, - "negative_base": neg_base, - } + return with_predictors( + { + "kind": "zi_negative", + "classifier": clf, + "negative_base": neg_base, + } + ) if regime == REGIME_SIGN_ONLY: # No zero class, but both signs present. Binary sign gate @@ -346,16 +363,18 @@ def _fit_single_numeric( clf.fit(X_pred, labels) pos_mask = y > 0 neg_mask = ~pos_mask - return { - "kind": "sign_only", - "classifier": clf, - "positive_base": self._fit_base_single( - X_train.loc[pos_mask], predictors, variable - ), - "negative_base": self._fit_base_single( - X_train.loc[neg_mask], predictors, variable - ), - } + return with_predictors( + { + "kind": "sign_only", + "classifier": clf, + "positive_base": self._fit_base_single( + X_train.loc[pos_mask], predictors, variable + ), + "negative_base": self._fit_base_single( + X_train.loc[neg_mask], predictors, variable + ), + } + ) if regime == REGIME_THREE_SIGN: # 0 / neg / pos three-way gate + two base imputers. @@ -368,16 +387,18 @@ def _fit_single_numeric( clf.fit(X_pred, labels) pos_mask = y > self.zero_atol neg_mask = y < -self.zero_atol - return { - "kind": "three_sign", - "classifier": clf, - "positive_base": self._fit_base_single( - X_train.loc[pos_mask], predictors, variable - ), - "negative_base": self._fit_base_single( - X_train.loc[neg_mask], predictors, variable - ), - } + return with_predictors( + { + "kind": "three_sign", + "classifier": clf, + "positive_base": self._fit_base_single( + X_train.loc[pos_mask], predictors, variable + ), + "negative_base": self._fit_base_single( + X_train.loc[neg_mask], predictors, variable + ), + } + ) raise ValueError(f"Unhandled regime {regime!r}") @@ -465,6 +486,7 @@ def _predict_single_draw( **kwargs: Any, ) -> pd.DataFrame: out = pd.DataFrame(index=X_test.index) + X_test_augmented = X_test.copy() for variable in self.imputed_variables: regime = self._regimes.get(variable) @@ -472,9 +494,11 @@ def _predict_single_draw( # Non-numeric target; handled by the auxiliary bundle. continue bundle = self._per_variable[variable] - out[variable] = self._predict_single_variable( - X_test, variable, bundle, quantile=quantile, **kwargs + values = self._predict_single_variable( + X_test_augmented, variable, bundle, quantile=quantile, **kwargs ) + out[variable] = values + X_test_augmented[variable] = values # Merge in non-numeric target predictions from the auxiliary # single base imputer. @@ -501,6 +525,7 @@ def _predict_single_variable( ) -> np.ndarray: n = len(X_test) kind = bundle["kind"] + predictors = bundle.get("predictors", self.predictors) if kind == "constant": return np.full(n, bundle["value"], dtype=float) @@ -511,7 +536,7 @@ def _predict_single_variable( ) return preds[variable].to_numpy(dtype=float) - X_pred = X_test[self.predictors].to_numpy(dtype=float, copy=False) + X_pred = X_test[predictors].to_numpy(dtype=float, copy=False) if kind == "zi_positive": clf = bundle["classifier"] diff --git a/tests/test_models/test_zero_inflated.py b/tests/test_models/test_zero_inflated.py index c75195e..1605a1e 100644 --- a/tests/test_models/test_zero_inflated.py +++ b/tests/test_models/test_zero_inflated.py @@ -303,3 +303,55 @@ def test_positive_only_matches_bare_qrf(self) -> None: abs(bare_preds.mean() - wrapped_preds.mean()) / max(bare_preds.mean(), 1.0) < 0.25 ) + + +class TestSequentialNumericTargets: + """Multiple numeric targets should use earlier imputations as predictors.""" + + def test_later_numeric_targets_receive_previous_predictions(self) -> None: + from microimpute.models.zero_inflated import ZeroInflatedImputer + + class SequentialProbeResult: + def __init__(self, variable: str, predictors: list[str]) -> None: + self.variable = variable + self.predictors = predictors + + def predict(self, X_test: pd.DataFrame, **_kwargs) -> pd.DataFrame: + if self.variable == "y1": + return pd.DataFrame({"y1": np.full(len(X_test), 7.0)}) + return pd.DataFrame({"y2": X_test["y1"].to_numpy(dtype=float) * 3.0}) + + class SequentialProbeImputer: + def __init__(self, **_kwargs) -> None: + pass + + def fit( + self, + X_train: pd.DataFrame, + predictors: list[str], + imputed_variables: list[str], + ) -> SequentialProbeResult: + assert len(imputed_variables) == 1 + return SequentialProbeResult(imputed_variables[0], predictors) + + data = pd.DataFrame( + { + "x": np.arange(20, dtype=float), + "y1": np.linspace(10.0, 29.0, 20), + "y2": np.linspace(100.0, 119.0, 20), + } + ) + + result = ZeroInflatedImputer( + base_imputer_class=SequentialProbeImputer, + ).fit(data, predictors=["x"], imputed_variables=["y1", "y2"]) + + assert result._per_variable["y1"]["predictors"] == ["x"] + assert result._per_variable["y1"]["base"].predictors == ["x"] + assert result._per_variable["y2"]["predictors"] == ["x", "y1"] + assert result._per_variable["y2"]["base"].predictors == ["x", "y1"] + + predictions = result.predict(pd.DataFrame({"x": [1.0, 2.0, 3.0]})) + + np.testing.assert_array_equal(predictions["y1"].to_numpy(), 7.0) + np.testing.assert_array_equal(predictions["y2"].to_numpy(), 21.0)