Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions stabl/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from packaging import version

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.feature_selection import SelectorMixin
from sklearn.utils.validation import check_is_fitted
from sklearn import __version__ as sklearn_version


def remove_low_info_samples(X, threshold=1.0):
Expand Down Expand Up @@ -131,3 +134,21 @@ def _more_tags(self):
# Useful to allow the use of nan values
# For the transform function ;)
return {"allow_nan": True}

def _validate_data(self, *args, **kwargs):
"""Data validation method compatible with multiple versions of sklearn.

For sklearn < 1.6.0, overloads BaseEstimator._validate_data, which is depricated in 1.6.0
and removed in 1.7.0. For sklearn >= 1.6.0, this method calls sklearn.utils.validation.validate_data,
which performs the equivalent functionality."""

skl_ver = version.parse(sklearn_version)

if skl_ver >= version.parse("1.6.0"):
# for scikit-learn >= 1.6.0
from sklearn.utils.validation import validate_data
return validate_data(self, *args, **kwargs)

else:
# for scikit-learn < 1.6.0
return BaseEstimator._validate_data(self, *args, **kwargs)
21 changes: 21 additions & 0 deletions stabl/stabl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
from pathlib import Path
from warnings import warn
import sys
from packaging import version

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from knockpy.knockoffs import GaussianSampler
from sklearn import __version__ as sklearn_version
from sklearn.base import BaseEstimator, clone
from sklearn.feature_selection import SelectorMixin, SelectFromModel
from sklearn.linear_model import LogisticRegression, Lasso, ElasticNet
Expand Down Expand Up @@ -1064,6 +1066,25 @@ def _validate_input(self):
f"When injecting noise, the noise proportion must be between 0 and 1, "
f"got {self.artificial_proportion}"
)

def _validate_data(self, *args, **kwargs):
"""Data validation method compatible with multiple versions of sklearn.

For sklearn < 1.6.0, overloads BaseEstimator._validate_data, which is depricated in 1.6.0
and removed in 1.7.0. For sklearn >= 1.6.0, this method calls sklearn.utils.validation.validate_data,
which performs the equivalent functionality."""

skl_ver = version.parse(sklearn_version)

if skl_ver >= version.parse("1.6.0"):
# for scikit-learn >= 1.6.0
from sklearn.utils.validation import validate_data
return validate_data(self, *args, **kwargs)

else:
# for scikit-learn < 1.6.0
return BaseEstimator._validate_data(self, *args, **kwargs)


def _make_groups(self, X):
"""Make groups for self configuration.
Expand Down