diff --git a/CONTEXT.md b/CONTEXT.md
index fe347ac8..e2b1fdab 100644
--- a/CONTEXT.md
+++ b/CONTEXT.md
@@ -561,6 +561,10 @@ _Avoid_: feature importance (unsigned, group-level), SHAP value (the raw per-fea
The uniform boolean toggle on the `CPPPlot` family (`profile`, `heatmap`, `ranking`, `feature_map`) selecting **CPP analysis** (`False`, group-level **feature importance**, `feat_importance` / `mean_dif`) versus **CPP-SHAP analysis** (`True`, sample-level **feature impact**, `feat_impact_'name'` / `mean_dif_'name'`). `True` switches color encoding to signed red/blue and the colorbar to the diverging SHAP colormap. It selects the *interpretation level*; it does not itself run SHAP (that is `ShapModel`). In `feature_map(shap_plot=True)` the cumulative bars stack the per-feature impact in one direction colored by sign; a `mean_dif_'name'` `col_val` keeps the mean-difference heatmap with those bars, while a `feat_impact_'name'` `col_val` moves the impact into the heatmap cells and switches the bars off.
_Avoid_: shap_mode, use_shap, sample_plot.
+**fuzzy aggregation** (`fuzzy_aggregation`):
+The strategy `ShapModel.fit` selects to turn a soft label `p` ∈ (0, 1) into a SHAP estimate when **fuzzy labeling** is active. `"interpolate"` (default, new in v1.1) fits the model twice (fuzzy sample at 0 → `S0`, at 1 → `S1`) and blends `p·S1 + (1−p)·S0` — the **unbiased** exact-`p` estimate. `"threshold"` (the `Breimann25` sweep) hard-labels the fuzzy sample `1` across a non-uniform `n_rounds × n_selection` grid and averages — a **biased** approximation whose effective positive-fraction is the grid's `frac1`, not `p`; kept as a first-class option. Each fuzzy protein is explained independently against the fixed balanced 0/1 **core**, with the other fuzzy proteins excluded from that run's training data. `n_rounds` (default `5`) is interpolate's speed/stability dial: `1` = fast exact two-fit estimate, `5` = light averaging, `≈15–20` = converged Monte-Carlo mean (run-to-run spread <5% on `DOM_GSEC`).
+_Avoid_: fuzzy mode, blend mode, soft-label aggregation.
+
### Scale-set vocabulary
**explainable scale set** (`top_explain_n`):
diff --git a/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py b/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py
index 5b87a4d7..2fa4a6e6 100644
--- a/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py
+++ b/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py
@@ -47,6 +47,12 @@ def _get_shap_values(shap_output, class_index=1):
return shap_values
+def _class_index_from_labels(labels, label_target_class=1):
+ """Map the target class label to its index among the integer (non-fuzzy) classes."""
+ label_classes = sorted(list(dict.fromkeys([x for x in labels if x == int(x)])))
+ return label_classes.index(label_target_class)
+
+
def _compute_shap_values(X, labels, model_class=None, model_kwargs=None,
explainer_class=None, explainer_kwargs=None,
class_index=1, n_background_data=None):
@@ -105,6 +111,23 @@ def _aggregate_shap_values(X, labels=None, list_model_classes=None, list_model_k
return shap_values, exp_val
+def _seed_model_kwargs(list_model_kwargs, random_state=None, round_idx=0):
+ """Derive per-round-seeded copies of the model kwargs.
+
+ With a fixed ``random_state`` each round uses ``random_state + round_idx`` so the rounds
+ differ (Monte-Carlo averaging) yet stay reproducible. With ``random_state=None`` the kwargs
+ are returned unchanged (``random_state`` already ``None``), so every fit re-draws fresh entropy.
+ """
+ if random_state is None:
+ return [dict(model_kwargs) for model_kwargs in list_model_kwargs]
+ seeded = []
+ for model_kwargs in list_model_kwargs:
+ model_kwargs = dict(model_kwargs)
+ model_kwargs["random_state"] = random_state + round_idx
+ seeded.append(model_kwargs)
+ return seeded
+
+
# II Main Functions
@ut.catch_backend_processing_error()
def monte_carlo_shap_estimation(X, labels=None, list_model_classes=None, list_model_kwargs=None,
@@ -113,8 +136,7 @@ def monte_carlo_shap_estimation(X, labels=None, list_model_classes=None, list_mo
label_target_class=1, n_background_data=None):
"""Compute Monte Carlo estimates of SHAP values for multiple models and feature selections."""
# Get class index
- label_classes = sorted(list(dict.fromkeys([x for x in labels if x == int(x)])))
- class_index = label_classes.index(label_target_class)
+ class_index = _class_index_from_labels(labels, label_target_class)
# Create empty SHAP value matrix
n_samples, n_features = X.shape
n_selection_rounds = len(is_selected)
@@ -150,3 +172,79 @@ def monte_carlo_shap_estimation(X, labels=None, list_model_classes=None, list_mo
shap_values = np.mean(mc_shap_values, axis=(2, 3))
exp_val = np.mean(list_expected_value)
return shap_values, exp_val
+
+
+@ut.catch_backend_processing_error()
+def interpolate_fuzzy_shap_estimation(X, labels=None, list_model_classes=None, list_model_kwargs=None,
+ explainer_class=None, explainer_kwargs=None, n_rounds=5,
+ is_selected=None, verbose=False, label_target_class=1,
+ n_background_data=None, random_state=None):
+ """Compute unbiased exact-``p`` SHAP estimates for fuzzy labels by interpolating between 0/1 fits.
+
+ Each fuzzy sample with soft label ``p`` is weighted by exactly ``p``: the model is fit twice
+ (fuzzy sample at 0 -> ``S0``, at 1 -> ``S1``) and the per-feature attributions are blended as
+ ``p * S1 + (1 - p) * S0``. Each fuzzy protein is explained independently against the fixed
+ balanced 0/1 core, with the other fuzzy proteins excluded from that run's training data. With
+ ``n_rounds=1`` this is exactly two fits per fuzzy sample; ``n_rounds > 1`` averages per-round
+ re-seeded fits (reproducible for a fixed ``random_state``).
+ """
+ labels = list(labels)
+ # Get class index (fuzzy float labels are excluded; classes come from the 0/1 core)
+ class_index = _class_index_from_labels(labels, label_target_class)
+ n_samples, n_features = X.shape
+ n_selection_rounds = len(is_selected)
+ n_cells = n_rounds * n_selection_rounds
+ # Partition into the fixed 0/1 core and the fuzzy samples explained one at a time
+ fuzzy_idx = [i for i, label in enumerate(labels) if label not in (0, 1)]
+ core_idx = [i for i, label in enumerate(labels) if label in (0, 1)]
+ core_labels = [labels[i] for i in core_idx]
+ # A single fuzzy protein shares the full sample set, so the two blended fits already cover
+ # every row (no separate baseline needed) -> exactly two fits per round and selection.
+ single_fuzzy = len(fuzzy_idx) == 1
+ acc_shap_values = np.zeros(shape=(n_samples, n_features))
+ list_expected_value = []
+ if verbose:
+ ut.print_start_progress(start_message=f"ShapModel starts interpolation estimation of SHAP values over {n_rounds} rounds.")
+ for i in range(n_rounds):
+ _list_model_kwargs = _seed_model_kwargs(list_model_kwargs, random_state=random_state, round_idx=i)
+ for j, selected_features in enumerate(is_selected):
+ if verbose:
+ pct_progress = j / len(is_selected)
+ add_new_line = explainer_class in LIST_VERBOSE_shap_modelS
+ ut.print_progress(i=i+pct_progress, n_total=n_rounds, add_new_line=add_new_line)
+ X_selected = X[:, selected_features]
+ args = dict(list_model_classes=list_model_classes, list_model_kwargs=_list_model_kwargs,
+ explainer_class=explainer_class, explainer_kwargs=explainer_kwargs,
+ class_index=class_index, n_background_data=n_background_data)
+ if single_fuzzy:
+ f = fuzzy_idx[0]
+ p = labels[f]
+ labels_0 = [0 if k == f else labels[k] for k in range(n_samples)]
+ labels_1 = [1 if k == f else labels[k] for k in range(n_samples)]
+ shap_0, exp_0 = _aggregate_shap_values(X_selected, labels=labels_0, **args)
+ shap_1, exp_1 = _aggregate_shap_values(X_selected, labels=labels_1, **args)
+ cell = p * shap_1 + (1 - p) * shap_0
+ list_expected_value.append(p * exp_1 + (1 - p) * exp_0)
+ else:
+ cell = np.zeros(shape=(n_samples, X_selected.shape[1]))
+ # Non-fuzzy core rows come from a single baseline fit on the core
+ shap_core, exp_core = _aggregate_shap_values(X_selected[core_idx], labels=core_labels, **args)
+ cell[core_idx] = shap_core
+ list_expected_value.append(exp_core)
+ # Each fuzzy protein is explained against core + itself (others excluded)
+ for f in fuzzy_idx:
+ p = labels[f]
+ sub_idx = core_idx + [f]
+ X_sub = X_selected[sub_idx]
+ shap_0, exp_0 = _aggregate_shap_values(X_sub, labels=core_labels + [0], **args)
+ shap_1, exp_1 = _aggregate_shap_values(X_sub, labels=core_labels + [1], **args)
+ cell[f] = p * shap_1[-1] + (1 - p) * shap_0[-1]
+ list_expected_value.append(p * exp_1 + (1 - p) * exp_0)
+ full_cell = np.zeros(shape=(n_samples, n_features))
+ full_cell[:, selected_features] = cell
+ acc_shap_values += full_cell
+ if verbose:
+ ut.print_end_progress(end_message=f"ShapModel finished interpolation estimation and saved results.")
+ shap_values = acc_shap_values / n_cells
+ exp_val = np.mean(list_expected_value)
+ return shap_values, exp_val
diff --git a/aaanalysis/explainable_ai_pro/_shap_model.py b/aaanalysis/explainable_ai_pro/_shap_model.py
index d7cc7985..591ea679 100644
--- a/aaanalysis/explainable_ai_pro/_shap_model.py
+++ b/aaanalysis/explainable_ai_pro/_shap_model.py
@@ -13,7 +13,8 @@
from aaanalysis.template_classes import Wrapper
from ._backend.check_models import (check_match_labels_X,
check_match_X_is_selected)
-from ._backend.shap_model.shap_model_fit import monte_carlo_shap_estimation
+from ._backend.shap_model.shap_model_fit import (monte_carlo_shap_estimation,
+ interpolate_fuzzy_shap_estimation)
from ._backend.shap_model.sm_add_feat_impact import (comp_shap_feature_importance,
insert_shap_feature_importance,
comp_shap_feature_impact,
@@ -379,7 +380,9 @@ def __init__(self,
If ``True``, verbose outputs are enabled.
random_state : int, optional
The seed used by the random number generator. If a positive integer, results of stochastic processes are
- consistent, enabling reproducibility. If ``None``, stochastic processes will be truly random.
+ consistent, enabling reproducibility. If ``None``, stochastic processes will be truly random. For
+ ``fuzzy_aggregation='interpolate'`` it is the initial seed and each round re-seeds with
+ ``random_state + round`` (see :meth:`ShapModel.fit` Notes).
Notes
-----
@@ -472,6 +475,7 @@ def fit(self,
n_rounds: int = 5,
is_selected: Optional[ut.ArrayLike2D] = None,
fuzzy_labeling: bool = False,
+ fuzzy_aggregation: str = "interpolate",
n_background_data: Optional[int] = None,
df_seq: Optional[pd.DataFrame] = None,
fuzzy_labels: Optional[dict] = None,
@@ -499,11 +503,22 @@ def fit(self,
For binary classification, '0' represents the negative class and '1' the positive class.
n_rounds : int, default=5
The number of rounds (>=1) to fit the models and obtain the SHAP values by explainer.
+ For ``fuzzy_aggregation='interpolate'`` each round re-seeds the fit, so ``n_rounds`` is a
+ speed/stability dial (see Notes): ``1`` is the fast exact two-fit estimate, the default
+ ``5`` adds Monte-Carlo averaging, and a stable mean is reached around ``15-20``.
is_selected : array-like, shape (n_selection_round, n_features)
2D boolean arrays indicating different feature selections.
fuzzy_labeling : bool, default=False
If ``True``, fuzzy labeling is applied to approximate SHAP values for samples with uncertain/partial
memberships (e.g., between >0 and <1 for binary classification scenarios).
+ fuzzy_aggregation : str, default='interpolate'
+ Strategy to turn a soft label ``p`` into a SHAP estimate when fuzzy labeling is active (see Notes):
+
+ - ``'interpolate'`` (default, new in 1.1): blend ``p * S1 + (1 - p) * S0`` from a fit at
+ 0 and at 1 (unbiased, exact ``p``; with ``n_rounds=1`` only two fits per fuzzy sample).
+ - ``'threshold'``: hard-label the fuzzy sample over a threshold grid and average — the
+ biased sweep of [Breimann25]_; kept for backward-compatible results.
+
n_background_data : None or int, optional
The number samples (< 'n_samples') in the background dataset used for the `KernelExplainer`` to reduce
computation time. The dataset is obtained by k-means clustering. If ``None``, the full dataset 'X' is used.
@@ -530,6 +545,39 @@ def fit(self,
* Idea: Adjusts label thresholds dynamically in Monte Carlo estimation to better represent label uncertainties.
* Background: Inspired by fuzzy logic, replacing binary true/false with degrees of truth.
+ **Fuzzy aggregation strategies**
+
+ The ``fuzzy_aggregation`` parameter selects between two estimators:
+
+ * ``'interpolate'`` (default): The fuzzy sample is weighted by *exactly* ``p`` by fitting the model twice (fuzzy
+ sample at 0 -> ``S0``, at 1 -> ``S1``) and blending ``p * S1 + (1 - p) * S0`` (the ``exp_value`` is blended the
+ same way). This is *unbiased*. Each fuzzy protein is explained independently against the fixed balanced 0/1
+ core, with the other fuzzy proteins excluded from its training data.
+ * ``'threshold'``: Over an ``n_rounds`` x ``n_selection`` grid the fuzzy sample is hard-labeled ``1`` when a
+ per-cell threshold ``<= p`` and the per-cell SHAP matrices are averaged — the sweep of [Breimann25]_. Because
+ the grid is non-uniform on (0, 1], the effective positive-fraction is a *biased* approximation of ``p``.
+
+ **Per-round seeding (interpolate only)**
+
+ The constructor ``random_state`` is the initial seed, and ``'interpolate'`` re-seeds **each round** with
+ ``random_state + round`` (round 0 -> ``random_state``, round 1 -> ``random_state + 1``, ...). So every round
+ fits a *different* model and ``n_rounds`` averages a Monte-Carlo mean over model variance, yet a fixed
+ ``random_state`` gives the identical seed sequence and therefore an exactly reproducible result;
+ ``random_state=None`` draws fresh entropy each round (truly-random, non-reproducible). The ``'threshold'``
+ estimator and the non-fuzzy Monte-Carlo path do **not** re-seed per round — they bake ``random_state`` in once,
+ so their per-round variation comes from the threshold grid, not from the model seed.
+
+ **Choosing n_rounds for 'interpolate'**
+
+ Because each round re-seeds, ``n_rounds`` is a speed/stability dial:
+
+ * ``n_rounds=1`` -- the exact two-fit point estimate; fastest, but a single model draw (run-to-run spread ~20%
+ across seeds).
+ * ``n_rounds=5`` (default) -- adds light averaging (spread ~10%).
+ * ``n_rounds≈15-20`` -- the averaged estimate stabilizes (run-to-run spread and distance to the converged mean
+ fall below ~5% on the bundled ``DOM_GSEC`` gamma-secretase data, ~1/sqrt(n_rounds) decay). Use this for a
+ stable mean; with a fixed ``random_state`` any single run is exactly reproducible regardless.
+
**Setting soft labels**
There are two equivalent ways to provide soft labels, both enabling fuzzy labeling:
@@ -554,6 +602,8 @@ def fit(self,
n_samples, n_feat = X.shape
ut.check_X_unique_samples(X=X, min_n_unique_samples=2)
ut.check_bool(name="fuzzy_labeling", val=fuzzy_labeling)
+ ut.check_str_options(name="fuzzy_aggregation", val=fuzzy_aggregation,
+ list_str_options=["threshold", "interpolate"])
if fuzzy_labels is not None:
# Entry-keyed soft labels override 'labels' and enable fuzzy labeling
check_match_df_seq_X(df_seq=df_seq, X=X)
@@ -571,17 +621,25 @@ def fit(self,
ut.check_number_range(name="n_background_data", val=n_background_data, min_val=1, just_int=True, accept_none=True)
check_match_n_background_data_X(n_background_data=n_background_data, X=X)
# Compute SHAP values
- shap_values, exp_val = monte_carlo_shap_estimation(X, labels=labels,
- list_model_classes=self._list_model_classes,
- list_model_kwargs=self._list_model_kwargs,
- explainer_class=self._explainer_class,
- explainer_kwargs=self._explainer_kwargs,
- is_selected=is_selected,
- fuzzy_labeling=fuzzy_labeling,
- n_rounds=n_rounds,
- verbose=self._verbose,
- label_target_class=label_target_class,
- n_background_data=n_background_data)
+ backend_args = dict(list_model_classes=self._list_model_classes,
+ list_model_kwargs=self._list_model_kwargs,
+ explainer_class=self._explainer_class,
+ explainer_kwargs=self._explainer_kwargs,
+ is_selected=is_selected,
+ n_rounds=n_rounds,
+ verbose=self._verbose,
+ label_target_class=label_target_class,
+ n_background_data=n_background_data)
+ if fuzzy_labeling and fuzzy_aggregation == "interpolate":
+ # Only the interpolate path threads 'random_state' explicitly: it re-seeds per round
+ # (random_state + round). The threshold path keeps the seed baked into the model kwargs.
+ shap_values, exp_val = interpolate_fuzzy_shap_estimation(X, labels=labels,
+ random_state=self._random_state,
+ **backend_args)
+ else:
+ shap_values, exp_val = monte_carlo_shap_estimation(X, labels=labels,
+ fuzzy_labeling=fuzzy_labeling,
+ **backend_args)
self.shap_values = shap_values
self.exp_value = exp_val
return self
diff --git a/docs/source/index/release_notes.rst b/docs/source/index/release_notes.rst
index 8fafd242..89fac2d1 100644
--- a/docs/source/index/release_notes.rst
+++ b/docs/source/index/release_notes.rst
@@ -85,6 +85,15 @@ Added
``add_feat_impact`` / ``add_sample_mean_dif`` accept ``df_seq`` and a ``samples``
parameter taking row positions or entry names. The array-``labels`` path is unchanged;
``sample_positions`` is a deprecated alias for ``samples`` (removed in 1.2.0).
+- **ShapModel — unbiased fuzzy estimator, now the default** (``[pro]``): ``fit`` gains
+ ``fuzzy_aggregation``, defaulting to the new ``'interpolate'`` estimator. It weights a
+ soft label by *exactly* ``p`` — fitting at 0 (``S0``) and at 1 (``S1``) and blending
+ ``p * S1 + (1 - p) * S0`` — the unbiased alternative to the biased threshold sweep, which
+ stays available as a first-class option via ``fuzzy_aggregation='threshold'``. For
+ ``interpolate``, ``n_rounds`` (default ``5``) is a speed/stability dial: ``1`` is the fast
+ exact two-fit estimate (~2x faster than the threshold default on the same cell), ``5`` adds
+ light Monte-Carlo averaging, and the mean converges (run-to-run spread below ~5%) around
+ ``n_rounds ≈ 15–20``; a fixed ``random_state`` keeps every run reproducible.
**Sequence Analysis**
diff --git a/examples/explainable_ai/sm_fit.ipynb b/examples/explainable_ai/sm_fit.ipynb
index 25849ec8..9ead2c67 100644
--- a/examples/explainable_ai/sm_fit.ipynb
+++ b/examples/explainable_ai/sm_fit.ipynb
@@ -21,10 +21,10 @@
},
"collapsed": false,
"execution": {
- "iopub.execute_input": "2026-06-13T19:51:33.503612Z",
- "iopub.status.busy": "2026-06-13T19:51:33.503358Z",
- "iopub.status.idle": "2026-06-13T19:51:35.343709Z",
- "shell.execute_reply": "2026-06-13T19:51:35.343497Z"
+ "iopub.execute_input": "2026-06-25T14:19:46.705083Z",
+ "iopub.status.busy": "2026-06-25T14:19:46.704759Z",
+ "iopub.status.idle": "2026-06-25T14:19:48.092499Z",
+ "shell.execute_reply": "2026-06-25T14:19:48.092260Z"
}
},
"outputs": [
@@ -32,7 +32,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "/Users/stephanbreimann/Programming/1Packages/aaanalysis-shap-acc/aaanalysis/feature_engineering/_backend/cpp_run.py:143: UserWarning: CPP is using the Python kernel fallback — the compiled Cython extension is not available in this install. Output is bit-exact with the Cython path but ~2x slower. Reinstall via `pip install --force-reinstall aaanalysis` to fetch a prebuilt wheel.\n",
+ "/Users/stephanbreimann/Programming/1Packages/wt-shap-interpolate/aaanalysis/feature_engineering/_backend/cpp_run.py:163: UserWarning: CPP is using the Python kernel fallback — the compiled Cython extension is not available in this install. Output is bit-exact with the Cython path but ~2x slower. Reinstall via `pip install --force-reinstall aaanalysis` to fetch a prebuilt wheel.\n",
" warnings.warn(\n"
]
},
@@ -40,105 +40,105 @@
"data": {
"text/html": [
"\n",
- "
\n",
+ "\n",
" \n",
" \n",
" | | \n",
- " entry | \n",
- " sequence | \n",
- " label | \n",
- " tmd_start | \n",
- " tmd_stop | \n",
- " jmd_n | \n",
- " tmd | \n",
- " jmd_c | \n",
+ " entry | \n",
+ " sequence | \n",
+ " label | \n",
+ " tmd_start | \n",
+ " tmd_stop | \n",
+ " jmd_n | \n",
+ " tmd | \n",
+ " jmd_c | \n",
"
\n",
" \n",
" \n",
" \n",
- " | 1 | \n",
- " Q14802 | \n",
- " MQKVTLGLLVFLAGF...PGETPPLITPGSAQS | \n",
- " 0 | \n",
- " 37 | \n",
- " 59 | \n",
- " NSPFYYDWHS | \n",
- " LQVGGLICAGVLCAMGIIIVMSA | \n",
- " KCKCKFGQKS | \n",
+ " 1 | \n",
+ " Q14802 | \n",
+ " MQKVTLGLLVFLAGF...PGETPPLITPGSAQS | \n",
+ " 0 | \n",
+ " 37 | \n",
+ " 59 | \n",
+ " NSPFYYDWHS | \n",
+ " LQVGGLICAGVLCAMGIIIVMSA | \n",
+ " KCKCKFGQKS | \n",
"
\n",
" \n",
- " | 2 | \n",
- " Q86UE4 | \n",
- " MAARSWQDELAQQAE...SPKQIKKKKKARRET | \n",
- " 0 | \n",
- " 50 | \n",
- " 72 | \n",
- " LGLEPKRYPG | \n",
- " WVILVGTGALGLLLLFLLGYGWA | \n",
- " AACAGARKKR | \n",
+ " 2 | \n",
+ " Q86UE4 | \n",
+ " MAARSWQDELAQQAE...SPKQIKKKKKARRET | \n",
+ " 0 | \n",
+ " 50 | \n",
+ " 72 | \n",
+ " LGLEPKRYPG | \n",
+ " WVILVGTGALGLLLLFLLGYGWA | \n",
+ " AACAGARKKR | \n",
"
\n",
" \n",
- " | 3 | \n",
- " Q969W9 | \n",
- " MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL | \n",
- " 0 | \n",
- " 41 | \n",
- " 63 | \n",
- " FQSMEITELE | \n",
- " FVQIIIIVVVMMVMVVVITCLLS | \n",
- " HYKLSARSFI | \n",
+ " 3 | \n",
+ " Q969W9 | \n",
+ " MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL | \n",
+ " 0 | \n",
+ " 41 | \n",
+ " 63 | \n",
+ " FQSMEITELE | \n",
+ " FVQIIIIVVVMMVMVVVITCLLS | \n",
+ " HYKLSARSFI | \n",
"
\n",
" \n",
- " | 4 | \n",
- " P05067 | \n",
- " MLPGLALLLLAAWTA...GYENPTYKFFEQMQN | \n",
- " 1 | \n",
- " 701 | \n",
- " 723 | \n",
- " FAEDVGSNKG | \n",
- " AIIGLMVGGVVIATVIVITLVML | \n",
- " KKKQYTSIHH | \n",
+ " 4 | \n",
+ " P05067 | \n",
+ " MLPGLALLLLAAWTA...GYENPTYKFFEQMQN | \n",
+ " 1 | \n",
+ " 701 | \n",
+ " 723 | \n",
+ " FAEDVGSNKG | \n",
+ " AIIGLMVGGVVIATVIVITLVML | \n",
+ " KKKQYTSIHH | \n",
"
\n",
" \n",
- " | 5 | \n",
- " P14925 | \n",
- " MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS | \n",
- " 1 | \n",
- " 868 | \n",
- " 890 | \n",
- " KLSTEPGSGV | \n",
- " SVVLITTLLVIPVLVLLAIVMFI | \n",
- " RWKKSRAFGD | \n",
+ " 5 | \n",
+ " P14925 | \n",
+ " MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS | \n",
+ " 1 | \n",
+ " 868 | \n",
+ " 890 | \n",
+ " KLSTEPGSGV | \n",
+ " SVVLITTLLVIPVLVLLAIVMFI | \n",
+ " RWKKSRAFGD | \n",
"
\n",
" \n",
- " | 6 | \n",
- " P70180 | \n",
- " MRSLLLFTFSACVLL...RELREDSIRSHFSVA | \n",
- " 1 | \n",
- " 477 | \n",
- " 499 | \n",
- " PCKSSGGLEE | \n",
- " SAVTGIVVGALLGAGLLMAFYFF | \n",
- " RKKYRITIER | \n",
+ " 6 | \n",
+ " P70180 | \n",
+ " MRSLLLFTFSACVLL...RELREDSIRSHFSVA | \n",
+ " 1 | \n",
+ " 477 | \n",
+ " 499 | \n",
+ " PCKSSGGLEE | \n",
+ " SAVTGIVVGALLGAGLLMAFYFF | \n",
+ " RKKYRITIER | \n",
"
\n",
" \n",
"
\n"
@@ -188,10 +188,10 @@
},
"collapsed": false,
"execution": {
- "iopub.execute_input": "2026-06-13T19:51:35.344949Z",
- "iopub.status.busy": "2026-06-13T19:51:35.344873Z",
- "iopub.status.idle": "2026-06-13T19:51:35.655752Z",
- "shell.execute_reply": "2026-06-13T19:51:35.655480Z"
+ "iopub.execute_input": "2026-06-25T14:19:48.093620Z",
+ "iopub.status.busy": "2026-06-25T14:19:48.093555Z",
+ "iopub.status.idle": "2026-06-25T14:19:48.447226Z",
+ "shell.execute_reply": "2026-06-25T14:19:48.446995Z"
}
},
"outputs": [
@@ -200,16 +200,16 @@
"output_type": "stream",
"text": [
"SHAP values explain the feature impact for 3 negative and 3 positive samples\n",
- "[[-0.1 -0.1 -0.08 -0.09 -0.06]\n",
- " [-0.12 -0.12 -0.09 -0.1 -0.07]\n",
- " [-0.13 -0.14 -0.04 -0.09 -0.01]\n",
- " [ 0.13 0.13 0.05 0.09 0.04]\n",
- " [ 0.13 0.12 0.08 0.09 0.07]\n",
- " [ 0.13 0.12 0.08 0.09 0.06]]\n",
+ "[[-0.1 -0.11 -0.1 -0.09 -0.06]\n",
+ " [-0.12 -0.12 -0.09 -0.09 -0.07]\n",
+ " [-0.14 -0.14 -0.04 -0.09 -0.01]\n",
+ " [ 0.12 0.13 0.06 0.1 0.03]\n",
+ " [ 0.12 0.12 0.09 0.1 0.07]\n",
+ " [ 0.12 0.13 0.09 0.1 0.06]]\n",
"\n",
"The expected value approximates the expected model output (average prediction score).\n",
"For a binary classification with balanced datasets, it is around 0.5:\n",
- "0.49566666666666687\n"
+ "0.4981666666666669\n"
]
}
],
@@ -250,10 +250,10 @@
},
"collapsed": false,
"execution": {
- "iopub.execute_input": "2026-06-13T19:51:35.656884Z",
- "iopub.status.busy": "2026-06-13T19:51:35.656813Z",
- "iopub.status.idle": "2026-06-13T19:51:35.967928Z",
- "shell.execute_reply": "2026-06-13T19:51:35.967703Z"
+ "iopub.execute_input": "2026-06-25T14:19:48.448312Z",
+ "iopub.status.busy": "2026-06-25T14:19:48.448242Z",
+ "iopub.status.idle": "2026-06-25T14:19:48.801609Z",
+ "shell.execute_reply": "2026-06-25T14:19:48.801379Z"
}
},
"outputs": [
@@ -262,15 +262,15 @@
"output_type": "stream",
"text": [
"Reverse sign of SHAP values by changing reference class from 1 to 0\n",
- "[[ 0.11 0.1 0.1 0.09 0.06]\n",
- " [ 0.13 0.12 0.1 0.08 0.07]\n",
- " [ 0.15 0.14 0.04 0.08 0.01]\n",
- " [-0.14 -0.13 -0.07 -0.08 -0.04]\n",
- " [-0.13 -0.12 -0.09 -0.09 -0.07]\n",
- " [-0.13 -0.12 -0.09 -0.08 -0.06]]\n",
+ "[[ 0.1 0.09 0.09 0.11 0.07]\n",
+ " [ 0.12 0.11 0.09 0.1 0.08]\n",
+ " [ 0.14 0.14 0.04 0.1 0.02]\n",
+ " [-0.13 -0.12 -0.05 -0.09 -0.04]\n",
+ " [-0.12 -0.11 -0.08 -0.09 -0.07]\n",
+ " [-0.12 -0.12 -0.08 -0.09 -0.07]]\n",
"\n",
"Base value stays around 0.5:\n",
- "0.5036666666666669\n"
+ "0.49633333333333357\n"
]
}
],
@@ -309,10 +309,10 @@
},
"collapsed": false,
"execution": {
- "iopub.execute_input": "2026-06-13T19:51:35.969048Z",
- "iopub.status.busy": "2026-06-13T19:51:35.968988Z",
- "iopub.status.idle": "2026-06-13T19:51:36.534019Z",
- "shell.execute_reply": "2026-06-13T19:51:36.533756Z"
+ "iopub.execute_input": "2026-06-25T14:19:48.802738Z",
+ "iopub.status.busy": "2026-06-25T14:19:48.802667Z",
+ "iopub.status.idle": "2026-06-25T14:19:49.441570Z",
+ "shell.execute_reply": "2026-06-25T14:19:49.441305Z"
}
},
"outputs": [],
@@ -342,10 +342,10 @@
},
"collapsed": false,
"execution": {
- "iopub.execute_input": "2026-06-13T19:51:36.535531Z",
- "iopub.status.busy": "2026-06-13T19:51:36.535441Z",
- "iopub.status.idle": "2026-06-13T19:51:37.101052Z",
- "shell.execute_reply": "2026-06-13T19:51:37.100823Z"
+ "iopub.execute_input": "2026-06-25T14:19:49.442802Z",
+ "iopub.status.busy": "2026-06-25T14:19:49.442739Z",
+ "iopub.status.idle": "2026-06-25T14:19:50.085482Z",
+ "shell.execute_reply": "2026-06-25T14:19:50.085260Z"
}
},
"outputs": [
@@ -354,12 +354,12 @@
"output_type": "stream",
"text": [
"Impact of feature pre-selection\n",
- "[[-0.16 -0.19 -0.05 -0.05 0. ]\n",
- " [-0.18 -0.21 -0.05 -0.05 0. ]\n",
- " [-0.18 -0.21 -0.02 -0.05 0. ]\n",
- " [ 0.18 0.21 0.03 0.05 0. ]\n",
- " [ 0.18 0.2 0.05 0.05 0. ]\n",
- " [ 0.18 0.21 0.05 0.05 0. ]]\n"
+ "[[-0.18 -0.18 -0.05 -0.06 0. ]\n",
+ " [-0.2 -0.19 -0.05 -0.06 0. ]\n",
+ " [-0.2 -0.2 -0.02 -0.05 0. ]\n",
+ " [ 0.2 0.19 0.03 0.05 0. ]\n",
+ " [ 0.2 0.19 0.04 0.06 0. ]\n",
+ " [ 0.2 0.19 0.04 0.06 0. ]]\n"
]
}
],
@@ -395,10 +395,10 @@
},
"collapsed": false,
"execution": {
- "iopub.execute_input": "2026-06-13T19:51:37.102222Z",
- "iopub.status.busy": "2026-06-13T19:51:37.102160Z",
- "iopub.status.idle": "2026-06-13T19:51:37.672989Z",
- "shell.execute_reply": "2026-06-13T19:51:37.672757Z"
+ "iopub.execute_input": "2026-06-25T14:19:50.086683Z",
+ "iopub.status.busy": "2026-06-25T14:19:50.086627Z",
+ "iopub.status.idle": "2026-06-25T14:19:51.315443Z",
+ "shell.execute_reply": "2026-06-25T14:19:51.315230Z"
}
},
"outputs": [
@@ -407,12 +407,12 @@
"output_type": "stream",
"text": [
"First sample is labeled as 0.5 between negative (0) and positive (1)\n",
- "[[ 0.04 0.04 -0.03 -0.04 0. ]\n",
- " [-0.24 -0.25 -0.05 -0.05 0. ]\n",
+ "[[-0.04 -0.04 -0.02 -0.03 0. ]\n",
+ " [-0.23 -0.23 -0.04 -0.05 0. ]\n",
" [-0.22 -0.22 -0.02 -0.04 0. ]\n",
- " [ 0.15 0.16 0.03 0.04 0. ]\n",
- " [ 0.14 0.16 0.04 0.04 0. ]\n",
- " [ 0.14 0.16 0.04 0.04 0. ]]\n"
+ " [ 0.17 0.18 0.02 0.04 0. ]\n",
+ " [ 0.17 0.17 0.03 0.04 0. ]\n",
+ " [ 0.17 0.17 0.03 0.04 0. ]]\n"
]
}
],
@@ -426,6 +426,100 @@
"print(sm.shap_values.round(2))"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "445a9b21",
+ "metadata": {},
+ "source": [
+ "By default ``fuzzy_aggregation='interpolate'`` weights the fuzzy sample by *exactly* ``p`` (the cell above used it: two fits blended as ``p*S1 + (1-p)*S0``). The published threshold sweep stays available via ``fuzzy_aggregation='threshold'``. For interpolate, ``n_rounds`` is a speed/stability dial: ``n_rounds=1`` is the fast exact estimate, the default ``5`` adds light Monte-Carlo averaging, and a stable mean is reached around ``n_rounds ~ 15-20``:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "7312979a",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-06-25T14:19:51.316608Z",
+ "iopub.status.busy": "2026-06-25T14:19:51.316545Z",
+ "iopub.status.idle": "2026-06-25T14:19:55.550029Z",
+ "shell.execute_reply": "2026-06-25T14:19:55.549803Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Threshold-sweep estimate (first, fuzzy sample):\n",
+ "[ 0.03 0.04 -0.03 -0.03 0. ]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Converged interpolate mean over 15 rounds (first, fuzzy sample):\n",
+ "[-0.03 -0.04 -0.02 -0.02 0. ]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# The published threshold-sweep estimator, available via fuzzy_aggregation=\"threshold\"\n",
+ "sm = aa.ShapModel(random_state=42)\n",
+ "sm = sm.fit(X, labels=labels, is_selected=is_selected,\n",
+ " fuzzy_labeling=True, fuzzy_aggregation=\"threshold\")\n",
+ "print(\"Threshold-sweep estimate (first, fuzzy sample):\")\n",
+ "print(sm.shap_values[0].round(2))\n",
+ "\n",
+ "# Stable interpolate mean: n_rounds=1 is the fast exact blend, ~15-20 the converged mean\n",
+ "sm = aa.ShapModel(random_state=42)\n",
+ "sm = sm.fit(X, labels=labels, is_selected=is_selected,\n",
+ " fuzzy_labeling=True, n_rounds=15)\n",
+ "print(\"\\nConverged interpolate mean over 15 rounds (first, fuzzy sample):\")\n",
+ "print(sm.shap_values[0].round(2))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "94a8f9c9",
+ "metadata": {},
+ "source": [
+ "The constructor ``random_state`` seeds the estimator. For ``'interpolate'`` it is the *initial* seed and each round re-seeds with ``random_state + round``, so ``n_rounds>1`` averages genuinely different model fits yet stays exactly reproducible for a fixed seed (``random_state=None`` instead draws fresh entropy each round). The ``'threshold'`` and non-fuzzy paths do not re-seed per round:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "8556e656",
+ "metadata": {
+ "execution": {
+ "iopub.execute_input": "2026-06-25T14:19:55.551133Z",
+ "iopub.status.busy": "2026-06-25T14:19:55.551072Z",
+ "iopub.status.idle": "2026-06-25T14:19:58.018771Z",
+ "shell.execute_reply": "2026-06-25T14:19:58.018569Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reproducible for a fixed random_state (per-round seed = random_state + round): True\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Same random_state -> identical result (reproducible), even with averaging over rounds\n",
+ "a = aa.ShapModel(random_state=42).fit(X, labels=labels, is_selected=is_selected,\n",
+ " fuzzy_labeling=True, n_rounds=5).shap_values\n",
+ "b = aa.ShapModel(random_state=42).fit(X, labels=labels, is_selected=is_selected,\n",
+ " fuzzy_labeling=True, n_rounds=5).shap_values\n",
+ "print(\"Reproducible for a fixed random_state (per-round seed = random_state + round):\", \n",
+ " bool((a == b).all()))"
+ ]
+ },
{
"cell_type": "markdown",
"id": "9d3e32ab",
@@ -436,14 +530,14 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 9,
"id": "0b3fd3f3",
"metadata": {
"execution": {
- "iopub.execute_input": "2026-06-13T19:51:37.674230Z",
- "iopub.status.busy": "2026-06-13T19:51:37.674159Z",
- "iopub.status.idle": "2026-06-13T19:51:38.241890Z",
- "shell.execute_reply": "2026-06-13T19:51:38.241638Z"
+ "iopub.execute_input": "2026-06-25T14:19:58.019876Z",
+ "iopub.status.busy": "2026-06-25T14:19:58.019812Z",
+ "iopub.status.idle": "2026-06-25T14:19:59.257226Z",
+ "shell.execute_reply": "2026-06-25T14:19:59.257008Z"
}
},
"outputs": [
@@ -452,12 +546,12 @@
"output_type": "stream",
"text": [
"Sample 'Q14802' is labeled as 0.5 between negative (0) and positive (1)\n",
- "[[ 0.03 0.04 -0.02 -0.04 0. ]\n",
- " [-0.25 -0.24 -0.04 -0.05 0. ]\n",
- " [-0.23 -0.22 -0.01 -0.04 0. ]\n",
- " [ 0.15 0.15 0.02 0.04 0. ]\n",
- " [ 0.15 0.15 0.04 0.05 0. ]\n",
- " [ 0.15 0.15 0.04 0.05 0. ]]\n"
+ "[[-0.04 -0.03 -0.02 -0.02 0. ]\n",
+ " [-0.23 -0.23 -0.04 -0.05 0. ]\n",
+ " [-0.23 -0.22 -0.02 -0.04 0. ]\n",
+ " [ 0.17 0.17 0.02 0.03 0. ]\n",
+ " [ 0.17 0.17 0.03 0.04 0. ]\n",
+ " [ 0.17 0.17 0.03 0.04 0. ]]\n"
]
}
],
@@ -484,7 +578,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 10,
"id": "45c1efa0c82d7255",
"metadata": {
"ExecuteTime": {
@@ -493,10 +587,10 @@
},
"collapsed": false,
"execution": {
- "iopub.execute_input": "2026-06-13T19:51:38.243137Z",
- "iopub.status.busy": "2026-06-13T19:51:38.243065Z",
- "iopub.status.idle": "2026-06-13T19:51:38.245397Z",
- "shell.execute_reply": "2026-06-13T19:51:38.245176Z"
+ "iopub.execute_input": "2026-06-25T14:19:59.258321Z",
+ "iopub.status.busy": "2026-06-25T14:19:59.258258Z",
+ "iopub.status.idle": "2026-06-25T14:19:59.260533Z",
+ "shell.execute_reply": "2026-06-25T14:19:59.260367Z"
}
},
"outputs": [],
@@ -524,7 +618,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.14.0"
+ "version": "3.13.11"
}
},
"nbformat": 4,
diff --git a/tests/unit/shap_model_tests/test_sm_branch.py b/tests/unit/shap_model_tests/test_sm_branch.py
index 88a8055b..f8cba032 100644
--- a/tests/unit/shap_model_tests/test_sm_branch.py
+++ b/tests/unit/shap_model_tests/test_sm_branch.py
@@ -103,7 +103,7 @@ def test_fuzzy_valid_threshold_path(self):
X, _ = _small_data(n_samples=12)
labels = np.array([1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0.5], dtype=float)
sm = aa.ShapModel(**MODEL_KWARGS, verbose=False)
- sm.fit(X, labels=labels, fuzzy_labeling=True, **ARGS)
+ sm.fit(X, labels=labels, fuzzy_labeling=True, fuzzy_aggregation="threshold", **ARGS)
assert sm.shap_values.shape == X.shape
def test_fuzzy_label_out_of_range_raises(self):
diff --git a/tests/unit/shap_model_tests/test_sm_fit.py b/tests/unit/shap_model_tests/test_sm_fit.py
index 1aa18c34..dd72c1e0 100644
--- a/tests/unit/shap_model_tests/test_sm_fit.py
+++ b/tests/unit/shap_model_tests/test_sm_fit.py
@@ -393,3 +393,168 @@ def test_fuzzy_labels_df_seq_missing_entry_column(self):
sm = aa.ShapModel(**MODEL_KWARGS, verbose=False)
with pytest.raises(ValueError):
sm.fit(valid_X, labels=valid_labels, df_seq=df_bad, fuzzy_labels={entry: 0.6}, **ARGS)
+
+
+# Small deterministic fixture for the exact-p / fit-count assertions (single model -> exact fit counts)
+_RNG = np.random.default_rng(0)
+SMALL_X = _RNG.normal(size=(7, 5))
+SMALL_LABELS_1FUZZY = [1, 1, 1, 0, 0, 0, 0.6] # one fuzzy protein (last row), balanced 3/3 core
+SMALL_X_2FUZZY = _RNG.normal(size=(8, 5))
+SMALL_LABELS_2FUZZY = [1, 1, 1, 0, 0, 0, 0.6, 0.3] # two fuzzy proteins, balanced 3/3 core
+ONE_MODEL = dict(list_model_classes=[RandomForestClassifier])
+
+
+class TestShapModelFitFuzzyAggregation:
+ """The ``fuzzy_aggregation`` estimator: 'threshold' (default) vs unbiased 'interpolate'."""
+
+ # Positive: both options are accepted and produce SHAP values
+ def test_fuzzy_aggregation_options_valid(self):
+ for fuzzy_aggregation in ["threshold", "interpolate"]:
+ sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=0)
+ sm.fit(SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True,
+ fuzzy_aggregation=fuzzy_aggregation, n_rounds=1)
+ assert sm.shap_values.shape == SMALL_X.shape
+ assert sm.exp_value is not None
+
+ # Negative: an unknown value is rejected
+ def test_invalid_fuzzy_aggregation(self):
+ sm = aa.ShapModel(**ONE_MODEL, verbose=False)
+ with pytest.raises(ValueError):
+ sm.fit(SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, fuzzy_aggregation="bogus")
+
+ # KPI: unbiased exact-p weighting (single fuzzy, n_rounds=1) == p*S1 + (1-p)*S0
+ def test_interpolate_exact_p_blend(self):
+ from aaanalysis.explainable_ai_pro._backend.shap_model import shap_model_fit as B
+ p = SMALL_LABELS_1FUZZY[-1]
+ sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=42)
+ sm.fit(SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True,
+ fuzzy_aggregation="interpolate", n_rounds=1)
+ model_kwargs = dict(sm._list_model_kwargs[0])
+ model_kwargs["random_state"] = 42 # round 0 -> random_state + 0
+ args = dict(list_model_classes=[RandomForestClassifier], list_model_kwargs=[model_kwargs],
+ explainer_class=sm._explainer_class, explainer_kwargs=sm._explainer_kwargs,
+ class_index=1, n_background_data=None)
+ labels_0 = [1, 1, 1, 0, 0, 0, 0]
+ labels_1 = [1, 1, 1, 0, 0, 0, 1]
+ shap_0, _ = B._aggregate_shap_values(SMALL_X, labels=labels_0, **args)
+ shap_1, _ = B._aggregate_shap_values(SMALL_X, labels=labels_1, **args)
+ ref = p * shap_1 + (1 - p) * shap_0
+ assert np.allclose(sm.shap_values, ref, atol=1e-10, rtol=0)
+
+ # KPI: 2-fit fast path — exactly two model fits per fuzzy sample, scaling with n_rounds
+ def test_interpolate_two_fits_single_fuzzy(self):
+ from aaanalysis.explainable_ai_pro._backend.shap_model import shap_model_fit as B
+ orig = B._compute_shap_values
+ counter = {"n": 0}
+
+ def spy(*a, **k):
+ counter["n"] += 1
+ return orig(*a, **k)
+
+ B._compute_shap_values = spy
+ try:
+ sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=42)
+ sm.fit(SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True,
+ fuzzy_aggregation="interpolate", n_rounds=1)
+ assert counter["n"] == 2
+ counter["n"] = 0
+ sm.fit(SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True,
+ fuzzy_aggregation="interpolate", n_rounds=10)
+ assert counter["n"] == 20
+ finally:
+ B._compute_shap_values = orig
+
+ # KPI: n_rounds is meaningful (rounds differ) yet reproducible for a fixed random_state
+ def test_interpolate_reproducible_and_rounds_matter(self):
+ def run(n_rounds):
+ return aa.ShapModel(**ONE_MODEL, verbose=False, random_state=7).fit(
+ SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True,
+ fuzzy_aggregation="interpolate", n_rounds=n_rounds).shap_values
+ # Reproducible across runs at a fixed seed
+ assert np.array_equal(run(3), run(3))
+ # n_rounds genuinely changes the estimate (per-round re-seeding)
+ assert not np.allclose(run(1), run(10))
+
+ # KPI: Monte-Carlo averaging — variance shrinks with n_rounds when random_state=None
+ def test_interpolate_mc_variance_decreases(self):
+ def variance(reps, n_rounds):
+ runs = [aa.ShapModel(**ONE_MODEL, verbose=False, random_state=None).fit(
+ SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True,
+ fuzzy_aggregation="interpolate", n_rounds=n_rounds).shap_values for _ in range(reps)]
+ return np.mean(np.var(np.stack(runs), axis=0))
+ assert variance(6, 12) < variance(6, 1)
+
+ # The per-round average is a bootstrap mean: it converges to a stable value as n_rounds grows,
+ # so a single round (n_rounds=1) sits well off the mean while late rounds barely move it.
+ def test_interpolate_converges_with_n_rounds(self):
+ # n_rounds=R (fixed base seed) == cumulative mean of per-round blends S_0..S_{R-1},
+ # where S_i = a single-round fit at seed base+i. Compute the 25 blends once.
+ R_MAX = 25
+ base = 42
+ blends = np.stack([
+ aa.ShapModel(**ONE_MODEL, verbose=False, random_state=base + i).fit(
+ SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True,
+ fuzzy_aggregation="interpolate", n_rounds=1).shap_values
+ for i in range(R_MAX)])
+ means = np.stack([blends[:R].mean(axis=0).ravel() for R in range(1, R_MAX + 1)]) # M_1..M_25
+ final = means[-1]
+ norm = np.linalg.norm(final)
+ d_final = np.linalg.norm(means - final, axis=1) / norm # distance to the stable mean
+ d_incr = np.linalg.norm(np.diff(means, axis=0), axis=1) / norm # per-round change (len R_MAX-1)
+ # 1) a single round is clearly off the converged mean (this is why averaging exists)
+ assert d_final[0] > 0.05
+ # 2) the average converges: late rounds move it far less than early rounds
+ assert np.mean(d_incr[-5:]) < 0.5 * np.mean(d_incr[:5])
+ # 3) it is stable by the tail: the last few rounds barely change the estimate
+ assert d_final[-5] < 0.1
+
+ # Multi-fuzzy: each fuzzy protein explained independently against the core (baseline + 2 per fuzzy)
+ def test_interpolate_multi_fuzzy(self):
+ from aaanalysis.explainable_ai_pro._backend.shap_model import shap_model_fit as B
+ orig = B._compute_shap_values
+ counter = {"n": 0}
+
+ def spy(*a, **k):
+ counter["n"] += 1
+ return orig(*a, **k)
+
+ B._compute_shap_values = spy
+ try:
+ sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=5)
+ sm.fit(SMALL_X_2FUZZY, labels=SMALL_LABELS_2FUZZY, fuzzy_labeling=True,
+ fuzzy_aggregation="interpolate", n_rounds=1)
+ assert sm.shap_values.shape == SMALL_X_2FUZZY.shape
+ assert counter["n"] == 1 + 2 * 2 # one baseline core fit + two fits per fuzzy protein
+ finally:
+ B._compute_shap_values = orig
+
+ # 'interpolate' is the default estimator; explicit selection matches the default
+ def test_default_is_interpolate(self):
+ default = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit(
+ SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, n_rounds=3).shap_values
+ explicit = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit(
+ SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True,
+ fuzzy_aggregation="interpolate", n_rounds=3).shap_values
+ assert np.array_equal(default, explicit)
+
+ # n_rounds defaults to a plain 5 for every estimator (no per-estimator magic)
+ def test_default_n_rounds_is_five(self):
+ auto = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit(
+ SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True).shap_values
+ explicit5 = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit(
+ SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, n_rounds=5).shap_values
+ assert np.array_equal(auto, explicit5)
+ # and the default genuinely averages (differs from the single-round fast path)
+ one_round = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit(
+ SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, n_rounds=1).shap_values
+ assert not np.array_equal(auto, one_round)
+
+ # fuzzy_aggregation is inert when fuzzy labeling is off (binary path untouched)
+ def test_fuzzy_aggregation_inert_without_fuzzy(self):
+ labels = [1, 1, 1, 0, 0, 0, 0] # all binary -> no fuzzy sample
+ base = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=3).fit(
+ SMALL_X, labels=labels, fuzzy_labeling=False, n_rounds=2).shap_values
+ with_interp = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=3).fit(
+ SMALL_X, labels=labels, fuzzy_labeling=False,
+ fuzzy_aggregation="interpolate", n_rounds=2).shap_values
+ assert np.array_equal(base, with_interp)
diff --git a/tests/unit/shap_model_tests/test_sm_fuzzy_interpolate_regression.py b/tests/unit/shap_model_tests/test_sm_fuzzy_interpolate_regression.py
new file mode 100644
index 00000000..f90fde58
--- /dev/null
+++ b/tests/unit/shap_model_tests/test_sm_fuzzy_interpolate_regression.py
@@ -0,0 +1,150 @@
+"""Regression anchor for the ShapModel ``fuzzy_aggregation="interpolate"`` estimator.
+
+Pins the unbiased interpolation estimator on a real DOM_GSEC fuzzy-labeling cell:
+three proteins explained one at a time as a single fuzzy sample with invented
+prediction scores — APP (``P05067``, p=0.85), CD44 (``P16070``, p=0.65), and a
+non-substrate (``Q14802``, p=0.15).
+
+Two guards, strongest first:
+
+1. **Exact-p identity (platform-robust).** ``interpolate`` with ``n_rounds=1`` must
+ equal ``p*S1 + (1-p)*S0`` to ``atol=1e-10`` — recomputed on the same machine, so
+ no frozen value is involved and it cannot drift across platforms. This is the
+ unbiased-by-construction guarantee; any regression in the estimator breaks it.
+2. **Fit-count speed advantage (deterministic).** ``interpolate n_rounds=1`` does
+ strictly fewer model fits than the ``threshold n_rounds=5`` default (2 vs 5 per
+ fuzzy sample), captured via a fit-count spy — the wall-clock win measured against
+ aaanalysis 1.0.3 (~2.15x faster on this cell) reduced to a noise-free invariant.
+3. **Frozen signatures.** Compact per-protein signatures (row sum + max |value|) for
+ both estimators. The threshold signatures were verified **byte-identical** to
+ aaanalysis 1.0.3 on this cell (the no-regression guarantee for the default path);
+ interpolate differs by design (exact-p vs the biased threshold grid).
+
+Frozen values were captured on a dev machine (darwin/py3.13); the first canonical-cell
+CI run (Linux/py3.11, nightly) re-verifies them and they are re-frozen only on an
+intentional, reviewed change. Runs in the non-gating nightly only.
+
+Run locally (off the canonical cell) with: AAA_RUN_REGRESSION=1 pytest
+"""
+import os
+import sys
+
+import numpy as np
+import pytest
+from sklearn.ensemble import RandomForestClassifier
+
+import aaanalysis as aa
+from aaanalysis.explainable_ai_pro._backend.shap_model import shap_model_fit as B
+
+aa.options["verbose"] = False
+
+# Pin to the canonical cell; AAA_RUN_REGRESSION=1 forces it on any env (local check).
+_CANONICAL_ENV = (
+ os.environ.get("AAA_RUN_REGRESSION") == "1"
+ or (sys.platform.startswith("linux") and sys.version_info[:2] == (3, 11))
+)
+
+pytestmark = [
+ pytest.mark.regression,
+ pytest.mark.skipif(
+ not _CANONICAL_ENV,
+ reason="exact-value regression pinned to Linux/py3.11; "
+ "set AAA_RUN_REGRESSION=1 to force locally",
+ ),
+]
+
+# Three proteins + invented prediction scores (substrate, substrate, non-substrate)
+PROTEINS = {"P05067": 0.85, "P16070": 0.65, "Q14802": 0.15}
+N_FEAT = 25
+SEED = 42
+ONE_MODEL = dict(list_model_classes=[RandomForestClassifier])
+
+# Frozen (sum, max|value|) per protein; tolerant of cross-platform RF-SHAP drift.
+# THRESHOLD == aaanalysis 1.0.3 (verified byte-identical on this cell).
+SIG_THRESHOLD = {"P05067": (0.3335, 0.0533), "P16070": (0.1911, 0.0934), "Q14802": (-0.3002, 0.0701)}
+SIG_INTERPOLATE_N1 = {"P05067": (0.3676, 0.0550), "P16070": (0.2222, 0.0876), "Q14802": (-0.2054, 0.0628)}
+SIG_ATOL = 5e-3
+
+
+def _build_cell():
+ df_seq = aa.load_dataset(name="DOM_GSEC")
+ df_feat = aa.load_features(name="DOM_GSEC").head(N_FEAT)
+ sf = aa.SequenceFeature()
+ df_parts = sf.get_df_parts(df_seq=df_seq)
+ X = sf.feature_matrix(features=df_feat["feature"], df_parts=df_parts)
+ return X, df_seq["entry"].to_list(), df_seq["label"].to_list()
+
+
+def _fuzzy_labels(base_labels, idx, score):
+ labels = [float(v) for v in base_labels]
+ labels[idx] = score
+ return labels
+
+
+def test_interpolate_exact_p_identity_on_dom_gsec():
+ """interpolate (n_rounds=1) == p*S1 + (1-p)*S0 on the real 3-protein cell."""
+ X, entries, base_labels = _build_cell()
+ for entry, score in PROTEINS.items():
+ i = entries.index(entry)
+ labels = _fuzzy_labels(base_labels, i, score)
+ sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=SEED)
+ sm.fit(X, labels=labels, fuzzy_labeling=True, fuzzy_aggregation="interpolate", n_rounds=1)
+ # Reference: two seeded fits with the fuzzy sample pinned at 0 and at 1
+ mk = dict(sm._list_model_kwargs[0])
+ mk["random_state"] = SEED # round 0 -> random_state + 0
+ args = dict(list_model_classes=[RandomForestClassifier], list_model_kwargs=[mk],
+ explainer_class=sm._explainer_class, explainer_kwargs=sm._explainer_kwargs,
+ class_index=1, n_background_data=None)
+ labels_0 = [0 if k == i else labels[k] for k in range(len(labels))]
+ labels_1 = [1 if k == i else labels[k] for k in range(len(labels))]
+ shap_0, _ = B._aggregate_shap_values(X, labels=labels_0, **args)
+ shap_1, _ = B._aggregate_shap_values(X, labels=labels_1, **args)
+ ref = score * shap_1 + (1 - score) * shap_0
+ assert np.allclose(sm.shap_values, ref, atol=1e-10, rtol=0), entry
+
+
+def test_interpolate_n1_fewer_fits_than_threshold_n5():
+ """The wall-clock win vs 1.0.3 as a noise-free invariant: 2 fits vs 5 per fuzzy sample."""
+ X, entries, base_labels = _build_cell()
+ i = entries.index("P05067")
+ labels = _fuzzy_labels(base_labels, i, PROTEINS["P05067"])
+ orig = B._compute_shap_values
+ counter = {"n": 0}
+
+ def spy(*a, **k):
+ counter["n"] += 1
+ return orig(*a, **k)
+
+ B._compute_shap_values = spy
+ try:
+ counter["n"] = 0
+ aa.ShapModel(**ONE_MODEL, verbose=False, random_state=SEED).fit(
+ X, labels=labels, fuzzy_labeling=True, fuzzy_aggregation="interpolate", n_rounds=1)
+ n_interpolate = counter["n"]
+ counter["n"] = 0
+ aa.ShapModel(**ONE_MODEL, verbose=False, random_state=SEED).fit(
+ X, labels=labels, fuzzy_labeling=True, fuzzy_aggregation="threshold", n_rounds=5)
+ n_threshold = counter["n"]
+ finally:
+ B._compute_shap_values = orig
+ assert n_interpolate == 2
+ assert n_threshold == 5
+ assert n_interpolate < n_threshold
+
+
+@pytest.mark.parametrize("aggregation,n_rounds,frozen", [
+ ("threshold", 5, SIG_THRESHOLD),
+ ("interpolate", 1, SIG_INTERPOLATE_N1),
+])
+def test_frozen_signatures(aggregation, n_rounds, frozen):
+ """Coarse per-protein anchors; threshold signatures == aaanalysis 1.0.3 on this cell."""
+ X, entries, base_labels = _build_cell()
+ for entry, score in PROTEINS.items():
+ i = entries.index(entry)
+ labels = _fuzzy_labels(base_labels, i, score)
+ sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=SEED)
+ sm.fit(X, labels=labels, fuzzy_labeling=True, fuzzy_aggregation=aggregation, n_rounds=n_rounds)
+ row = sm.shap_values[i]
+ exp_sum, exp_maxabs = frozen[entry]
+ assert np.isclose(row.sum(), exp_sum, atol=SIG_ATOL), f"{entry} sum"
+ assert np.isclose(np.abs(row).max(), exp_maxabs, atol=SIG_ATOL), f"{entry} maxabs"