From 90ed3dff7dd466b8959c9ae0aa95834fbc8703af Mon Sep 17 00:00:00 2001 From: Max Ghenis Date: Sat, 6 Jun 2026 10:36:50 +0100 Subject: [PATCH] Preserve QRF feature order during prediction --- changelog.d/195.fixed.md | 1 + microimpute/models/qrf.py | 31 +++++++++++ microimpute/models/zero_inflated.py | 44 +++++++++++++--- tests/test_models/test_qrf.py | 25 +++++++++ tests/test_models/test_zero_inflated.py | 69 +++++++++++++++++++++++++ 5 files changed, 163 insertions(+), 7 deletions(-) create mode 100644 changelog.d/195.fixed.md diff --git a/changelog.d/195.fixed.md b/changelog.d/195.fixed.md new file mode 100644 index 0000000..e2bfd05 --- /dev/null +++ b/changelog.d/195.fixed.md @@ -0,0 +1 @@ +Preserved fitted feature order for QRF/RFC prediction and propagated numeric-categorical overrides into zero-inflated nested base imputers. diff --git a/microimpute/models/qrf.py b/microimpute/models/qrf.py index c190d5b..8edc5d0 100644 --- a/microimpute/models/qrf.py +++ b/microimpute/models/qrf.py @@ -50,6 +50,7 @@ def __init__(self, seed: int, logger): self.var_type = None self.categories = None self.label_map = None + self.feature_columns: List[str] = [] def fit( self, @@ -99,10 +100,24 @@ def fit( fit_kwargs = {} if sample_weight is not None: fit_kwargs["sample_weight"] = np.asarray(sample_weight, dtype=float) + self.feature_columns = list(X.columns) self.classifier.fit(X, y_encoded, **fit_kwargs) + def _align_features(self, X: pd.DataFrame) -> pd.DataFrame: + """Reorder prediction features to the fitted sklearn column contract.""" + if not self.feature_columns: + return X + missing = [column for column in self.feature_columns if column not in X.columns] + if missing: + raise ValueError( + "Prediction data is missing fitted feature columns for " + f"{self.output_column}: {missing}" + ) + return X.loc[:, self.feature_columns] + def predict(self, X: pd.DataFrame, return_probs: bool = False) -> pd.Series: """Predict classes or probabilities.""" + X = self._align_features(X) if return_probs: probs = self.classifier.predict_proba(X) # Return both probabilities and the original category labels @@ -151,6 +166,7 @@ def __init__(self, seed: int, logger): self.logger = logger self.qrf = None self.output_column = None + self.feature_columns: List[str] = [] # Create the RNG once at construction so that repeated predict() # calls consume state progressively and return different draws. self._rng = np.random.default_rng(self.seed) @@ -192,8 +208,21 @@ def fit( fit_kwargs = {} if sample_weight is not None: fit_kwargs["sample_weight"] = np.asarray(sample_weight, dtype=float) + self.feature_columns = list(X.columns) self.qrf.fit(X, y.values.ravel(), **fit_kwargs) + def _align_features(self, X: pd.DataFrame) -> pd.DataFrame: + """Reorder prediction features to the fitted QRF column contract.""" + if not self.feature_columns: + return X + missing = [column for column in self.feature_columns if column not in X.columns] + if missing: + raise ValueError( + "Prediction data is missing fitted feature columns for " + f"{self.output_column}: {missing}" + ) + return X.loc[:, self.feature_columns] + def predict( self, X: pd.DataFrame, @@ -218,6 +247,8 @@ def predict( across quantiles for a given row and is used when the caller supplies an explicit ``quantiles`` list. """ + X = self._align_features(X) + # Deterministic path: user asked for a specific quantile — query the # QRF directly so that for any row i, # prediction(q_low) <= prediction(q_mid) <= prediction(q_high). diff --git a/microimpute/models/zero_inflated.py b/microimpute/models/zero_inflated.py index 5286596..fc9b517 100644 --- a/microimpute/models/zero_inflated.py +++ b/microimpute/models/zero_inflated.py @@ -232,6 +232,9 @@ def fit( for v in imputed_variables if v in self.numeric_targets or v in constant_numeric_targets ] + nested_not_numeric_categorical = list( + dict.fromkeys([*(not_numeric_categorical or []), *numeric_targets]) + ) # Sequential (chained-equations) imputation: condition each numeric # target on the original predictors plus the previously-fit numeric # targets, so the imputed vector preserves cross-variable joint @@ -255,6 +258,7 @@ def fit( variable=var, regime=regime, y=y, + not_numeric_categorical=nested_not_numeric_categorical, ) bundle["predictors"] = list(seq_predictors) self._per_variable[var] = bundle @@ -302,6 +306,7 @@ def _fit_single_numeric( variable: str, regime: str, y: np.ndarray, + not_numeric_categorical: Optional[List[str]] = None, ) -> Dict[str, Any]: """Fit the gate and base imputer(s) for one numeric target. @@ -317,7 +322,12 @@ def _fit_single_numeric( # No gate; single base imputer on the full training set. return { "kind": "single", - "base": self._fit_base_single(X_train, predictors, variable), + "base": self._fit_base_single( + X_train, + predictors, + variable, + not_numeric_categorical=not_numeric_categorical, + ), } if regime == REGIME_ZI_POSITIVE: @@ -326,7 +336,10 @@ def _fit_single_numeric( clf.fit(X_pred, labels) pos_mask = y > self.zero_atol pos_base = self._fit_base_single( - X_train.loc[pos_mask], predictors, variable + X_train.loc[pos_mask], + predictors, + variable, + not_numeric_categorical=not_numeric_categorical, ) return { "kind": "zi_positive", @@ -340,7 +353,10 @@ def _fit_single_numeric( clf.fit(X_pred, labels) neg_mask = y < -self.zero_atol neg_base = self._fit_base_single( - X_train.loc[neg_mask], predictors, variable + X_train.loc[neg_mask], + predictors, + variable, + not_numeric_categorical=not_numeric_categorical, ) return { "kind": "zi_negative", @@ -360,10 +376,16 @@ def _fit_single_numeric( "kind": "sign_only", "classifier": clf, "positive_base": self._fit_base_single( - X_train.loc[pos_mask], predictors, variable + X_train.loc[pos_mask], + predictors, + variable, + not_numeric_categorical=not_numeric_categorical, ), "negative_base": self._fit_base_single( - X_train.loc[neg_mask], predictors, variable + X_train.loc[neg_mask], + predictors, + variable, + not_numeric_categorical=not_numeric_categorical, ), } @@ -382,10 +404,16 @@ def _fit_single_numeric( "kind": "three_sign", "classifier": clf, "positive_base": self._fit_base_single( - X_train.loc[pos_mask], predictors, variable + X_train.loc[pos_mask], + predictors, + variable, + not_numeric_categorical=not_numeric_categorical, ), "negative_base": self._fit_base_single( - X_train.loc[neg_mask], predictors, variable + X_train.loc[neg_mask], + predictors, + variable, + not_numeric_categorical=not_numeric_categorical, ), } @@ -396,6 +424,7 @@ def _fit_base_single( X_train: pd.DataFrame, predictors: List[str], variable: str, + not_numeric_categorical: Optional[List[str]] = None, ) -> ImputerResults: """Fit a single base Imputer on a (possibly filtered) slice.""" imputer = self.base_imputer_class( @@ -406,6 +435,7 @@ def _fit_base_single( X_train=X_train, predictors=predictors, imputed_variables=[variable], + not_numeric_categorical=not_numeric_categorical, ) diff --git a/tests/test_models/test_qrf.py b/tests/test_models/test_qrf.py index 4da9388..e3099f7 100644 --- a/tests/test_models/test_qrf.py +++ b/tests/test_models/test_qrf.py @@ -145,6 +145,31 @@ def test_qrf_sequential_imputation(diabetes_data: pd.DataFrame) -> None: ), "Imputation order should affect results" +def test_qrf_model_prediction_reorders_to_fitted_feature_order() -> None: + """Prediction must honor the fitted sklearn feature order contract.""" + rng = np.random.default_rng(42) + X_fit = pd.DataFrame( + { + "b": rng.normal(size=120), + "a": rng.normal(size=120), + } + ) + y = pd.Series( + 2.0 * X_fit["a"] - X_fit["b"] + rng.normal(scale=0.1, size=120), + name="target", + ) + + model = _QRFModel(seed=42, logger=logging.getLogger(__name__)) + model.fit(X_fit[["b", "a"]], y, n_estimators=10) + + X_reordered = X_fit[["a", "b"]].head(12) + predictions = model.predict(X_reordered, exact_quantile=0.5) + + assert isinstance(predictions, pd.Series) + assert predictions.shape == (12,) + assert not predictions.isna().any() + + def test_qrf_beta_distribution_sampling(): """Test different mean_quantile values for beta distribution sampling.""" np.random.seed(42) diff --git a/tests/test_models/test_zero_inflated.py b/tests/test_models/test_zero_inflated.py index c75195e..7621c14 100644 --- a/tests/test_models/test_zero_inflated.py +++ b/tests/test_models/test_zero_inflated.py @@ -303,3 +303,72 @@ def test_positive_only_matches_bare_qrf(self) -> None: abs(bare_preds.mean() - wrapped_preds.mean()) / max(bare_preds.mean(), 1.0) < 0.25 ) + + +class TestSequentialPredictorTyping: + """Numeric overrides must reach nested per-regime base imputers.""" + + def test_chained_numeric_targets_stay_numeric_in_nested_base_fits(self) -> None: + from microimpute.models.zero_inflated import ZeroInflatedImputer + + rng = np.random.default_rng(42) + y1 = np.tile(np.arange(20, dtype=float), 10) + positive = y1 < 3 + data = pd.DataFrame( + { + "x": rng.normal(size=len(y1)), + # Numeric in the full target block, but low-cardinality inside + # the y2 positive-regime slice where it becomes a predictor. + "y1": y1, + "y2": np.where( + positive, + rng.exponential(scale=100.0, size=len(y1)) + 1.0, + 0.0, + ), + } + ) + + imputer = ZeroInflatedImputer(base_imputer_class=QRF) + fitted = imputer.fit( + data, + predictors=["x"], + imputed_variables=["y1", "y2"], + ) + + y2_base = fitted._per_variable["y2"]["positive_base"] + assert "y1" not in y2_base.dummy_processor.dummy_mapping + + predictions = fitted.predict(data[["x"]].head(20)) + assert set(predictions.columns) == {"y1", "y2"} + assert not predictions.isna().any().any() + + def test_not_numeric_categorical_applies_to_chained_base_predictors(self) -> None: + from microimpute.models.zero_inflated import ZeroInflatedImputer + + rng = np.random.default_rng(42) + n = 180 + data = pd.DataFrame( + { + "x": rng.normal(size=n), + # Low-cardinality, equally spaced numeric target. Without + # the override this becomes a categorical predictor inside + # the nested base QRF for y2. + "y1": rng.choice([0.0, 1.0, 2.0], size=n), + "y2": rng.exponential(scale=100.0, size=n) + 1.0, + } + ) + + imputer = ZeroInflatedImputer(base_imputer_class=QRF) + fitted = imputer.fit( + data, + predictors=["x"], + imputed_variables=["y1", "y2"], + not_numeric_categorical=["y1"], + ) + + y2_base = fitted._per_variable["y2"]["base"] + assert "y1" not in y2_base.dummy_processor.dummy_mapping + + predictions = fitted.predict(data[["x"]].head(20)) + assert set(predictions.columns) == {"y1", "y2"} + assert not predictions.isna().any().any()