diff --git a/changelog.d/zeroinflated_sequential_chaining.added.md b/changelog.d/zeroinflated_sequential_chaining.added.md new file mode 100644 index 0000000..929c16f --- /dev/null +++ b/changelog.d/zeroinflated_sequential_chaining.added.md @@ -0,0 +1 @@ +`ZeroInflatedImputer` now imputes a list of targets **sequentially** (chained-equations): each numeric target is conditioned on the original predictors plus the previously-imputed targets, so the imputed vector preserves cross-variable joint structure instead of imputing each variable in isolation. Controlled by the new `sequential` parameter (default `True`); single-target lists are unaffected. diff --git a/microimpute/models/zero_inflated.py b/microimpute/models/zero_inflated.py index 8858446..5286596 100644 --- a/microimpute/models/zero_inflated.py +++ b/microimpute/models/zero_inflated.py @@ -49,13 +49,9 @@ from pydantic import validate_call from microimpute.config import RANDOM_STATE, VALIDATE_CONFIG -from microimpute.models.imputer import ( - Imputer, - ImputerResults, -) +from microimpute.models.imputer import Imputer, ImputerResults from microimpute.models.qrf import QRF - # Regime labels. Kept as module-level constants so downstream code can # match on them without magic strings. REGIME_THREE_SIGN = "THREE_SIGN" @@ -156,6 +152,7 @@ def __init__( base_imputer_kwargs: Optional[Dict[str, Any]] = None, zero_atol: float = 1e-6, classifier_type: str = "hist_gb", + sequential: bool = True, seed: Optional[int] = RANDOM_STATE, log_level: Optional[str] = "WARNING", ) -> None: @@ -164,6 +161,7 @@ def __init__( self.base_imputer_kwargs = dict(base_imputer_kwargs or {}) self.zero_atol = float(zero_atol) self.classifier_type = classifier_type + self.sequential = bool(sequential) # Filled in during fit(). self._regimes: Dict[str, str] = {} @@ -234,20 +232,32 @@ def fit( for v in imputed_variables if v in self.numeric_targets or v in constant_numeric_targets ] - for var in numeric_targets: + # Sequential (chained-equations) imputation: condition each numeric + # target on the original predictors plus the previously-fit numeric + # targets, so the imputed vector preserves cross-variable joint + # structure. ``imputed_variables`` order is the chain order; a + # single-target list is unaffected (no priors to chain on). + for position, var in enumerate(numeric_targets): + seq_predictors = ( + list(predictors) + numeric_targets[:position] + if self.sequential + else list(predictors) + ) y = X_train[var].to_numpy(dtype=float, copy=False) regime = _detect_regime( y, zero_atol=self.zero_atol, ) self._regimes[var] = regime - self._per_variable[var] = self._fit_single_numeric( + bundle = self._fit_single_numeric( X_train=X_train, - predictors=predictors, + predictors=seq_predictors, variable=var, regime=regime, y=y, ) + bundle["predictors"] = list(seq_predictors) + self._per_variable[var] = bundle # Non-numeric (categorical / boolean / constant) targets are # handled by a single auxiliary base imputer over their union. @@ -466,15 +476,23 @@ def _predict_single_draw( ) -> pd.DataFrame: out = pd.DataFrame(index=X_test.index) + # Carry imputed numeric targets forward as predictors for later + # targets (chained-equations imputation), matching the sequential + # conditioning used at fit time. ``X_aug`` accumulates imputed + # columns in ``imputed_variables`` order so each target's gate/base + # can condition on the ones already drawn. + X_aug = X_test.copy() for variable in self.imputed_variables: regime = self._regimes.get(variable) if regime is None: # Non-numeric target; handled by the auxiliary bundle. continue bundle = self._per_variable[variable] - out[variable] = self._predict_single_variable( - X_test, variable, bundle, quantile=quantile, **kwargs + values = self._predict_single_variable( + X_aug, variable, bundle, quantile=quantile, **kwargs ) + out[variable] = values + X_aug[variable] = np.asarray(values, dtype=float) # Merge in non-numeric target predictions from the auxiliary # single base imputer. @@ -511,7 +529,9 @@ def _predict_single_variable( ) return preds[variable].to_numpy(dtype=float) - X_pred = X_test[self.predictors].to_numpy(dtype=float, copy=False) + X_pred = X_test[bundle.get("predictors", self.predictors)].to_numpy( + dtype=float, copy=False + ) if kind == "zi_positive": clf = bundle["classifier"] diff --git a/tests/test_models/test_zero_inflated_chaining.py b/tests/test_models/test_zero_inflated_chaining.py new file mode 100644 index 0000000..2a8b885 --- /dev/null +++ b/tests/test_models/test_zero_inflated_chaining.py @@ -0,0 +1,78 @@ +"""Sequential (chained-equations) imputation in ZeroInflatedImputer. + +A correct chained imputer conditions each target on the previously +imputed ones, so it reproduces cross-variable correlation that is *not* +explained by the shared predictors. We construct two targets that are +correlated through an unobserved latent factor and confirm that the +sequential imputer recovers that correlation while the non-sequential +(per-variable independent) imputer does not. +""" + +import numpy as np +import pandas as pd + +from microimpute.models.zero_inflated import ZeroInflatedImputer + + +def _make_latent_correlated_frame(n: int, seed: int) -> pd.DataFrame: + """Two positive targets A, B that share a latent factor L *with + opposite sign*, so they are strongly NEGATIVELY correlated. + + The shared predictor X explains almost none of the A-B relationship; + it runs through L, which is never observed. An imputer that draws B + independently of A cannot reproduce this dependence. Only one that + conditions B on the already-imputed A recovers it, because A reveals + L (given X, A pins down L, which then pins down B). + """ + rng = np.random.default_rng(seed) + x = rng.normal(size=n) + latent = rng.normal(size=n) # unobserved + a = 10.0 + 0.2 * x + 1.5 * latent + 0.3 * rng.normal(size=n) + b = 20.0 + 0.2 * x - 1.5 * latent + 0.3 * rng.normal(size=n) + return pd.DataFrame({"x": x, "a": a, "b": b}) + + +def _imputed_correlation(sequential: bool) -> tuple[float, float]: + train = _make_latent_correlated_frame(n=4000, seed=0) + test = _make_latent_correlated_frame(n=4000, seed=1) + + imp = ZeroInflatedImputer(sequential=sequential, seed=0) + fitted = imp.fit( + X_train=train, + predictors=["x"], + imputed_variables=["a", "b"], + ) + preds = fitted.predict(test[["x"]]) + imputed_corr = float(np.corrcoef(preds["a"], preds["b"])[0, 1]) + true_corr = float(np.corrcoef(test["a"], test["b"])[0, 1]) + return imputed_corr, true_corr + + +def test_sequential_chaining_recovers_joint_correlation(): + seq_corr, true_corr = _imputed_correlation(sequential=True) + indep_corr, _ = _imputed_correlation(sequential=False) + + # The true A-B correlation is strongly negative (opposite latent loads). + assert true_corr < -0.85, true_corr + # Chained imputation conditions b on the already-imputed a (which + # reveals the latent factor), recovering the true negative dependence + # almost exactly. + assert seq_corr < -0.7, seq_corr + assert abs(seq_corr - true_corr) < 0.15, (seq_corr, true_corr) + # Non-sequential per-variable imputation never sees a when drawing b, + # so it misses most of the dependence. + assert seq_corr < indep_corr - 0.4, (seq_corr, indep_corr) + + +def test_single_target_is_unaffected_by_sequential_flag(): + """A one-variable list has no prior to chain on, so the flag is a no-op.""" + train = _make_latent_correlated_frame(n=1500, seed=2) + test = _make_latent_correlated_frame(n=1500, seed=3) + + out = {} + for sequential in (True, False): + imp = ZeroInflatedImputer(sequential=sequential, seed=7) + fitted = imp.fit(X_train=train, predictors=["x"], imputed_variables=["a"]) + out[sequential] = fitted.predict(test[["x"]])["a"].to_numpy() + + np.testing.assert_allclose(out[True], out[False])