From 37f25e1f53bde069c1ca6ecf8b32c8ec312440f3 Mon Sep 17 00:00:00 2001 From: Henry Tullis Date: Fri, 20 Jun 2025 22:50:46 -0700 Subject: [PATCH] compat(sklearn): address _validate_data depr. make Stabl more compatible with more recent versions of scikit-learn add backwards-compatible fix for warnings/errors related to the deprication (1.6.0) and removal (1.7.0) of the scikit-learn BaseEstimator._validate_data function tested this fix using the 'Tutorial Notebook.ipynb' using with various versions of Scikit-learn: - 1.7.0: runs better than without the fix, however there are still errors and warnings unrelated to this fix - 1.6.0: runs as expected and reduces warnings - 1.3.2 (recommended for Stabl): runs as expected code written with assistance from Claude 3.7 --- stabl/preprocessing.py | 21 +++++++++++++++++++++ stabl/stabl.py | 21 +++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/stabl/preprocessing.py b/stabl/preprocessing.py index e536b3de..c74c1eb2 100644 --- a/stabl/preprocessing.py +++ b/stabl/preprocessing.py @@ -1,7 +1,10 @@ +from packaging import version + import numpy as np from sklearn.base import BaseEstimator from sklearn.feature_selection import SelectorMixin from sklearn.utils.validation import check_is_fitted +from sklearn import __version__ as sklearn_version def remove_low_info_samples(X, threshold=1.0): @@ -131,3 +134,21 @@ def _more_tags(self): # Useful to allow the use of nan values # For the transform function ;) return {"allow_nan": True} + + def _validate_data(self, *args, **kwargs): + """Data validation method compatible with multiple versions of sklearn. + + For sklearn < 1.6.0, overloads BaseEstimator._validate_data, which is depricated in 1.6.0 + and removed in 1.7.0. For sklearn >= 1.6.0, this method calls sklearn.utils.validation.validate_data, + which performs the equivalent functionality.""" + + skl_ver = version.parse(sklearn_version) + + if skl_ver >= version.parse("1.6.0"): + # for scikit-learn >= 1.6.0 + from sklearn.utils.validation import validate_data + return validate_data(self, *args, **kwargs) + + else: + # for scikit-learn < 1.6.0 + return BaseEstimator._validate_data(self, *args, **kwargs) diff --git a/stabl/stabl.py b/stabl/stabl.py index ac111ecd..80ea9925 100644 --- a/stabl/stabl.py +++ b/stabl/stabl.py @@ -2,12 +2,14 @@ from pathlib import Path from warnings import warn import sys +from packaging import version import matplotlib.pyplot as plt import numpy as np import pandas as pd from joblib import Parallel, delayed from knockpy.knockoffs import GaussianSampler +from sklearn import __version__ as sklearn_version from sklearn.base import BaseEstimator, clone from sklearn.feature_selection import SelectorMixin, SelectFromModel from sklearn.linear_model import LogisticRegression, Lasso, ElasticNet @@ -1064,6 +1066,25 @@ def _validate_input(self): f"When injecting noise, the noise proportion must be between 0 and 1, " f"got {self.artificial_proportion}" ) + + def _validate_data(self, *args, **kwargs): + """Data validation method compatible with multiple versions of sklearn. + + For sklearn < 1.6.0, overloads BaseEstimator._validate_data, which is depricated in 1.6.0 + and removed in 1.7.0. For sklearn >= 1.6.0, this method calls sklearn.utils.validation.validate_data, + which performs the equivalent functionality.""" + + skl_ver = version.parse(sklearn_version) + + if skl_ver >= version.parse("1.6.0"): + # for scikit-learn >= 1.6.0 + from sklearn.utils.validation import validate_data + return validate_data(self, *args, **kwargs) + + else: + # for scikit-learn < 1.6.0 + return BaseEstimator._validate_data(self, *args, **kwargs) + def _make_groups(self, X): """Make groups for self configuration.