Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/zeroinflated_sequential_chaining.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`ZeroInflatedImputer` now imputes a list of targets **sequentially** (chained-equations): each numeric target is conditioned on the original predictors plus the previously-imputed targets, so the imputed vector preserves cross-variable joint structure instead of imputing each variable in isolation. Controlled by the new `sequential` parameter (default `True`); single-target lists are unaffected.
42 changes: 31 additions & 11 deletions microimpute/models/zero_inflated.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,9 @@
from pydantic import validate_call

from microimpute.config import RANDOM_STATE, VALIDATE_CONFIG
from microimpute.models.imputer import (
Imputer,
ImputerResults,
)
from microimpute.models.imputer import Imputer, ImputerResults
from microimpute.models.qrf import QRF


# Regime labels. Kept as module-level constants so downstream code can
# match on them without magic strings.
REGIME_THREE_SIGN = "THREE_SIGN"
Expand Down Expand Up @@ -156,6 +152,7 @@ def __init__(
base_imputer_kwargs: Optional[Dict[str, Any]] = None,
zero_atol: float = 1e-6,
classifier_type: str = "hist_gb",
sequential: bool = True,
seed: Optional[int] = RANDOM_STATE,
log_level: Optional[str] = "WARNING",
) -> None:
Expand All @@ -164,6 +161,7 @@ def __init__(
self.base_imputer_kwargs = dict(base_imputer_kwargs or {})
self.zero_atol = float(zero_atol)
self.classifier_type = classifier_type
self.sequential = bool(sequential)

# Filled in during fit().
self._regimes: Dict[str, str] = {}
Expand Down Expand Up @@ -234,20 +232,32 @@ def fit(
for v in imputed_variables
if v in self.numeric_targets or v in constant_numeric_targets
]
for var in numeric_targets:
# Sequential (chained-equations) imputation: condition each numeric
# target on the original predictors plus the previously-fit numeric
# targets, so the imputed vector preserves cross-variable joint
# structure. ``imputed_variables`` order is the chain order; a
# single-target list is unaffected (no priors to chain on).
for position, var in enumerate(numeric_targets):
seq_predictors = (
list(predictors) + numeric_targets[:position]
if self.sequential
else list(predictors)
)
y = X_train[var].to_numpy(dtype=float, copy=False)
regime = _detect_regime(
y,
zero_atol=self.zero_atol,
)
self._regimes[var] = regime
self._per_variable[var] = self._fit_single_numeric(
bundle = self._fit_single_numeric(
X_train=X_train,
predictors=predictors,
predictors=seq_predictors,
variable=var,
regime=regime,
y=y,
)
bundle["predictors"] = list(seq_predictors)
self._per_variable[var] = bundle

# Non-numeric (categorical / boolean / constant) targets are
# handled by a single auxiliary base imputer over their union.
Expand Down Expand Up @@ -466,15 +476,23 @@ def _predict_single_draw(
) -> pd.DataFrame:
out = pd.DataFrame(index=X_test.index)

# Carry imputed numeric targets forward as predictors for later
# targets (chained-equations imputation), matching the sequential
# conditioning used at fit time. ``X_aug`` accumulates imputed
# columns in ``imputed_variables`` order so each target's gate/base
# can condition on the ones already drawn.
X_aug = X_test.copy()
for variable in self.imputed_variables:
regime = self._regimes.get(variable)
if regime is None:
# Non-numeric target; handled by the auxiliary bundle.
continue
bundle = self._per_variable[variable]
out[variable] = self._predict_single_variable(
X_test, variable, bundle, quantile=quantile, **kwargs
values = self._predict_single_variable(
X_aug, variable, bundle, quantile=quantile, **kwargs
)
out[variable] = values
X_aug[variable] = np.asarray(values, dtype=float)

# Merge in non-numeric target predictions from the auxiliary
# single base imputer.
Expand Down Expand Up @@ -511,7 +529,9 @@ def _predict_single_variable(
)
return preds[variable].to_numpy(dtype=float)

X_pred = X_test[self.predictors].to_numpy(dtype=float, copy=False)
X_pred = X_test[bundle.get("predictors", self.predictors)].to_numpy(
dtype=float, copy=False
)

if kind == "zi_positive":
clf = bundle["classifier"]
Expand Down
78 changes: 78 additions & 0 deletions tests/test_models/test_zero_inflated_chaining.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Sequential (chained-equations) imputation in ZeroInflatedImputer.

A correct chained imputer conditions each target on the previously
imputed ones, so it reproduces cross-variable correlation that is *not*
explained by the shared predictors. We construct two targets that are
correlated through an unobserved latent factor and confirm that the
sequential imputer recovers that correlation while the non-sequential
(per-variable independent) imputer does not.
"""

import numpy as np
import pandas as pd

from microimpute.models.zero_inflated import ZeroInflatedImputer


def _make_latent_correlated_frame(n: int, seed: int) -> pd.DataFrame:
"""Two positive targets A, B that share a latent factor L *with
opposite sign*, so they are strongly NEGATIVELY correlated.

The shared predictor X explains almost none of the A-B relationship;
it runs through L, which is never observed. An imputer that draws B
independently of A cannot reproduce this dependence. Only one that
conditions B on the already-imputed A recovers it, because A reveals
L (given X, A pins down L, which then pins down B).
"""
rng = np.random.default_rng(seed)
x = rng.normal(size=n)
latent = rng.normal(size=n) # unobserved
a = 10.0 + 0.2 * x + 1.5 * latent + 0.3 * rng.normal(size=n)
b = 20.0 + 0.2 * x - 1.5 * latent + 0.3 * rng.normal(size=n)
return pd.DataFrame({"x": x, "a": a, "b": b})


def _imputed_correlation(sequential: bool) -> tuple[float, float]:
train = _make_latent_correlated_frame(n=4000, seed=0)
test = _make_latent_correlated_frame(n=4000, seed=1)

imp = ZeroInflatedImputer(sequential=sequential, seed=0)
fitted = imp.fit(
X_train=train,
predictors=["x"],
imputed_variables=["a", "b"],
)
preds = fitted.predict(test[["x"]])
imputed_corr = float(np.corrcoef(preds["a"], preds["b"])[0, 1])
true_corr = float(np.corrcoef(test["a"], test["b"])[0, 1])
return imputed_corr, true_corr


def test_sequential_chaining_recovers_joint_correlation():
seq_corr, true_corr = _imputed_correlation(sequential=True)
indep_corr, _ = _imputed_correlation(sequential=False)

# The true A-B correlation is strongly negative (opposite latent loads).
assert true_corr < -0.85, true_corr
# Chained imputation conditions b on the already-imputed a (which
# reveals the latent factor), recovering the true negative dependence
# almost exactly.
assert seq_corr < -0.7, seq_corr
assert abs(seq_corr - true_corr) < 0.15, (seq_corr, true_corr)
# Non-sequential per-variable imputation never sees a when drawing b,
# so it misses most of the dependence.
assert seq_corr < indep_corr - 0.4, (seq_corr, indep_corr)


def test_single_target_is_unaffected_by_sequential_flag():
"""A one-variable list has no prior to chain on, so the flag is a no-op."""
train = _make_latent_correlated_frame(n=1500, seed=2)
test = _make_latent_correlated_frame(n=1500, seed=3)

out = {}
for sequential in (True, False):
imp = ZeroInflatedImputer(sequential=sequential, seed=7)
fitted = imp.fit(X_train=train, predictors=["x"], imputed_variables=["a"])
out[sequential] = fitted.predict(test[["x"]])["a"].to_numpy()

np.testing.assert_allclose(out[True], out[False])