Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/zero-inflated-numeric-chain.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preserve sequential numeric-target chaining inside `ZeroInflatedImputer`.
105 changes: 65 additions & 40 deletions microimpute/models/zero_inflated.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,20 +234,26 @@ def fit(
for v in imputed_variables
if v in self.numeric_targets or v in constant_numeric_targets
]
for var in numeric_targets:
numeric_target_set = set(numeric_targets)
previous_numeric_targets: List[str] = []
for var in imputed_variables:
if var not in numeric_target_set:
continue
y = X_train[var].to_numpy(dtype=float, copy=False)
regime = _detect_regime(
y,
zero_atol=self.zero_atol,
)
sequential_predictors = list(predictors) + previous_numeric_targets
self._regimes[var] = regime
self._per_variable[var] = self._fit_single_numeric(
X_train=X_train,
predictors=predictors,
predictors=sequential_predictors,
variable=var,
regime=regime,
y=y,
)
previous_numeric_targets.append(var)

# Non-numeric (categorical / boolean / constant) targets are
# handled by a single auxiliary base imputer over their union.
Expand Down Expand Up @@ -298,17 +304,24 @@ def _fit_single_numeric(
Returns a bundle dict with the regime, the gate classifier
(or None), and the base imputer(s) keyed by their role.
"""

def with_predictors(bundle: Dict[str, Any]) -> Dict[str, Any]:
bundle["predictors"] = list(predictors)
return bundle

X_pred = X_train[predictors].to_numpy(dtype=float, copy=False)

if regime == REGIME_DEGENERATE_ZERO:
return {"kind": "constant", "value": 0.0}
return with_predictors({"kind": "constant", "value": 0.0})

if regime in (REGIME_POSITIVE_ONLY, REGIME_NEGATIVE_ONLY):
# No gate; single base imputer on the full training set.
return {
"kind": "single",
"base": self._fit_base_single(X_train, predictors, variable),
}
return with_predictors(
{
"kind": "single",
"base": self._fit_base_single(X_train, predictors, variable),
}
)

if regime == REGIME_ZI_POSITIVE:
labels = (y > self.zero_atol).astype(int)
Expand All @@ -318,11 +331,13 @@ def _fit_single_numeric(
pos_base = self._fit_base_single(
X_train.loc[pos_mask], predictors, variable
)
return {
"kind": "zi_positive",
"classifier": clf,
"positive_base": pos_base,
}
return with_predictors(
{
"kind": "zi_positive",
"classifier": clf,
"positive_base": pos_base,
}
)

if regime == REGIME_ZI_NEGATIVE:
labels = (y < -self.zero_atol).astype(int)
Expand All @@ -332,11 +347,13 @@ def _fit_single_numeric(
neg_base = self._fit_base_single(
X_train.loc[neg_mask], predictors, variable
)
return {
"kind": "zi_negative",
"classifier": clf,
"negative_base": neg_base,
}
return with_predictors(
{
"kind": "zi_negative",
"classifier": clf,
"negative_base": neg_base,
}
)

if regime == REGIME_SIGN_ONLY:
# No zero class, but both signs present. Binary sign gate
Expand All @@ -346,16 +363,18 @@ def _fit_single_numeric(
clf.fit(X_pred, labels)
pos_mask = y > 0
neg_mask = ~pos_mask
return {
"kind": "sign_only",
"classifier": clf,
"positive_base": self._fit_base_single(
X_train.loc[pos_mask], predictors, variable
),
"negative_base": self._fit_base_single(
X_train.loc[neg_mask], predictors, variable
),
}
return with_predictors(
{
"kind": "sign_only",
"classifier": clf,
"positive_base": self._fit_base_single(
X_train.loc[pos_mask], predictors, variable
),
"negative_base": self._fit_base_single(
X_train.loc[neg_mask], predictors, variable
),
}
)

if regime == REGIME_THREE_SIGN:
# 0 / neg / pos three-way gate + two base imputers.
Expand All @@ -368,16 +387,18 @@ def _fit_single_numeric(
clf.fit(X_pred, labels)
pos_mask = y > self.zero_atol
neg_mask = y < -self.zero_atol
return {
"kind": "three_sign",
"classifier": clf,
"positive_base": self._fit_base_single(
X_train.loc[pos_mask], predictors, variable
),
"negative_base": self._fit_base_single(
X_train.loc[neg_mask], predictors, variable
),
}
return with_predictors(
{
"kind": "three_sign",
"classifier": clf,
"positive_base": self._fit_base_single(
X_train.loc[pos_mask], predictors, variable
),
"negative_base": self._fit_base_single(
X_train.loc[neg_mask], predictors, variable
),
}
)

raise ValueError(f"Unhandled regime {regime!r}")

Expand Down Expand Up @@ -465,16 +486,19 @@ def _predict_single_draw(
**kwargs: Any,
) -> pd.DataFrame:
out = pd.DataFrame(index=X_test.index)
X_test_augmented = X_test.copy()

for variable in self.imputed_variables:
regime = self._regimes.get(variable)
if regime is None:
# Non-numeric target; handled by the auxiliary bundle.
continue
bundle = self._per_variable[variable]
out[variable] = self._predict_single_variable(
X_test, variable, bundle, quantile=quantile, **kwargs
values = self._predict_single_variable(
X_test_augmented, variable, bundle, quantile=quantile, **kwargs
)
out[variable] = values
X_test_augmented[variable] = values

# Merge in non-numeric target predictions from the auxiliary
# single base imputer.
Expand All @@ -501,6 +525,7 @@ def _predict_single_variable(
) -> np.ndarray:
n = len(X_test)
kind = bundle["kind"]
predictors = bundle.get("predictors", self.predictors)

if kind == "constant":
return np.full(n, bundle["value"], dtype=float)
Expand All @@ -511,7 +536,7 @@ def _predict_single_variable(
)
return preds[variable].to_numpy(dtype=float)

X_pred = X_test[self.predictors].to_numpy(dtype=float, copy=False)
X_pred = X_test[predictors].to_numpy(dtype=float, copy=False)

if kind == "zi_positive":
clf = bundle["classifier"]
Expand Down
52 changes: 52 additions & 0 deletions tests/test_models/test_zero_inflated.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,55 @@ def test_positive_only_matches_bare_qrf(self) -> None:
abs(bare_preds.mean() - wrapped_preds.mean()) / max(bare_preds.mean(), 1.0)
< 0.25
)


class TestSequentialNumericTargets:
"""Multiple numeric targets should use earlier imputations as predictors."""

def test_later_numeric_targets_receive_previous_predictions(self) -> None:
from microimpute.models.zero_inflated import ZeroInflatedImputer

class SequentialProbeResult:
def __init__(self, variable: str, predictors: list[str]) -> None:
self.variable = variable
self.predictors = predictors

def predict(self, X_test: pd.DataFrame, **_kwargs) -> pd.DataFrame:
if self.variable == "y1":
return pd.DataFrame({"y1": np.full(len(X_test), 7.0)})
return pd.DataFrame({"y2": X_test["y1"].to_numpy(dtype=float) * 3.0})

class SequentialProbeImputer:
def __init__(self, **_kwargs) -> None:
pass

def fit(
self,
X_train: pd.DataFrame,
predictors: list[str],
imputed_variables: list[str],
) -> SequentialProbeResult:
assert len(imputed_variables) == 1
return SequentialProbeResult(imputed_variables[0], predictors)

data = pd.DataFrame(
{
"x": np.arange(20, dtype=float),
"y1": np.linspace(10.0, 29.0, 20),
"y2": np.linspace(100.0, 119.0, 20),
}
)

result = ZeroInflatedImputer(
base_imputer_class=SequentialProbeImputer,
).fit(data, predictors=["x"], imputed_variables=["y1", "y2"])

assert result._per_variable["y1"]["predictors"] == ["x"]
assert result._per_variable["y1"]["base"].predictors == ["x"]
assert result._per_variable["y2"]["predictors"] == ["x", "y1"]
assert result._per_variable["y2"]["base"].predictors == ["x", "y1"]

predictions = result.predict(pd.DataFrame({"x": [1.0, 2.0, 3.0]}))

np.testing.assert_array_equal(predictions["y1"].to_numpy(), 7.0)
np.testing.assert_array_equal(predictions["y2"].to_numpy(), 21.0)