From a21e0720a75906eb985c76f4378ad7286cb435bf Mon Sep 17 00:00:00 2001 From: Stephan Breimann Date: Thu, 25 Jun 2026 06:18:05 +0200 Subject: [PATCH 1/8] feat(shap): add opt-in fuzzy_aggregation="interpolate" estimator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ShapModel.fit gains fuzzy_aggregation (default "threshold", byte-identical to before). "interpolate" weights a soft label p by exactly p — fitting at 0 (S0) and at 1 (S1) and blending p*S1 + (1-p)*S0 — instead of the biased threshold sweep. With n_rounds=1 it is exactly two fits per fuzzy sample (the fastest fuzzy estimator); n_rounds>1 averages per-round re-seeded fits (random_state + round) so n_rounds is always meaningful and reproducible. Each fuzzy protein is explained independently against the fixed balanced 0/1 core, with the other fuzzy proteins excluded; a single fuzzy protein shares the full set, so its two blended fits cover every row (no baseline needed). Backend: interpolate_fuzzy_shap_estimation + _seed_model_kwargs. Frontend: validation (check_str_options), routing, numpydoc Notes block. Adds 9 unit tests (exact-p golden, fit-count spy, reproducibility, MC-variance-vs-rounds, multi-fuzzy, threshold no-regression), an example notebook cell, a release-notes entry, and a CONTEXT.md glossary term. Co-Authored-By: Claude Opus 4.8 (1M context) --- CONTEXT.md | 4 + .../_backend/shap_model/shap_model_fit.py | 94 ++++++ aaanalysis/explainable_ai_pro/_shap_model.py | 55 +++- docs/source/index/release_notes.rst | 7 + examples/explainable_ai/sm_fit.ipynb | 311 ++++++++++-------- tests/unit/shap_model_tests/test_sm_fit.py | 129 ++++++++ 6 files changed, 455 insertions(+), 145 deletions(-) diff --git a/CONTEXT.md b/CONTEXT.md index fe347ac8..c2287272 100644 --- a/CONTEXT.md +++ b/CONTEXT.md @@ -561,6 +561,10 @@ _Avoid_: feature importance (unsigned, group-level), SHAP value (the raw per-fea The uniform boolean toggle on the `CPPPlot` family (`profile`, `heatmap`, `ranking`, `feature_map`) selecting **CPP analysis** (`False`, group-level **feature importance**, `feat_importance` / `mean_dif`) versus **CPP-SHAP analysis** (`True`, sample-level **feature impact**, `feat_impact_'name'` / `mean_dif_'name'`). `True` switches color encoding to signed red/blue and the colorbar to the diverging SHAP colormap. It selects the *interpretation level*; it does not itself run SHAP (that is `ShapModel`). In `feature_map(shap_plot=True)` the cumulative bars stack the per-feature impact in one direction colored by sign; a `mean_dif_'name'` `col_val` keeps the mean-difference heatmap with those bars, while a `feat_impact_'name'` `col_val` moves the impact into the heatmap cells and switches the bars off. _Avoid_: shap_mode, use_shap, sample_plot. +**fuzzy aggregation** (`fuzzy_aggregation`): +The strategy `ShapModel.fit` uses to turn a soft label `p` ∈ (0, 1) into a SHAP estimate when **fuzzy labeling** is active. `"threshold"` (default) hard-labels the fuzzy sample `1` across a non-uniform `n_rounds × n_selection` threshold grid and averages — a **biased** approximation whose effective positive-fraction is the grid's `frac1`, not `p`. `"interpolate"` fits the model twice (fuzzy sample at 0 → `S0`, at 1 → `S1`) and blends `p·S1 + (1−p)·S0` — the **unbiased** exact-`p` estimate, and at `n_rounds=1` the fastest one (exactly 2 fits per fuzzy sample). Each fuzzy protein is explained independently against the fixed balanced 0/1 **core**, with the other fuzzy proteins excluded from that run's training data. +_Avoid_: fuzzy mode, blend mode, soft-label aggregation. + ### Scale-set vocabulary **explainable scale set** (`top_explain_n`): diff --git a/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py b/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py index 5b87a4d7..4d5280d3 100644 --- a/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py +++ b/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py @@ -105,6 +105,23 @@ def _aggregate_shap_values(X, labels=None, list_model_classes=None, list_model_k return shap_values, exp_val +def _seed_model_kwargs(list_model_kwargs, random_state=None, round_idx=0): + """Derive per-round-seeded copies of the model kwargs. + + With a fixed ``random_state`` each round uses ``random_state + round_idx`` so the rounds + differ (Monte-Carlo averaging) yet stay reproducible. With ``random_state=None`` the kwargs + are returned unchanged (``random_state`` already ``None``), so every fit re-draws fresh entropy. + """ + if random_state is None: + return [dict(model_kwargs) for model_kwargs in list_model_kwargs] + seeded = [] + for model_kwargs in list_model_kwargs: + model_kwargs = dict(model_kwargs) + model_kwargs["random_state"] = random_state + round_idx + seeded.append(model_kwargs) + return seeded + + # II Main Functions @ut.catch_backend_processing_error() def monte_carlo_shap_estimation(X, labels=None, list_model_classes=None, list_model_kwargs=None, @@ -150,3 +167,80 @@ def monte_carlo_shap_estimation(X, labels=None, list_model_classes=None, list_mo shap_values = np.mean(mc_shap_values, axis=(2, 3)) exp_val = np.mean(list_expected_value) return shap_values, exp_val + + +@ut.catch_backend_processing_error() +def interpolate_fuzzy_shap_estimation(X, labels=None, list_model_classes=None, list_model_kwargs=None, + explainer_class=None, explainer_kwargs=None, n_rounds=5, + is_selected=None, verbose=False, label_target_class=1, + n_background_data=None, random_state=None): + """Compute unbiased exact-``p`` SHAP estimates for fuzzy labels by interpolating between 0/1 fits. + + Each fuzzy sample with soft label ``p`` is weighted by exactly ``p``: the model is fit twice + (fuzzy sample at 0 -> ``S0``, at 1 -> ``S1``) and the per-feature attributions are blended as + ``p * S1 + (1 - p) * S0``. Each fuzzy protein is explained independently against the fixed + balanced 0/1 core, with the other fuzzy proteins excluded from that run's training data. With + ``n_rounds=1`` this is exactly two fits per fuzzy sample; ``n_rounds > 1`` averages per-round + re-seeded fits (reproducible for a fixed ``random_state``). + """ + labels = list(labels) + # Get class index (fuzzy float labels are excluded; classes come from the 0/1 core) + label_classes = sorted(list(dict.fromkeys([x for x in labels if x == int(x)]))) + class_index = label_classes.index(label_target_class) + n_samples, n_features = X.shape + n_selection_rounds = len(is_selected) + n_cells = n_rounds * n_selection_rounds + # Partition into the fixed 0/1 core and the fuzzy samples explained one at a time + fuzzy_idx = [i for i, label in enumerate(labels) if label not in (0, 1)] + core_idx = [i for i, label in enumerate(labels) if label in (0, 1)] + core_labels = [labels[i] for i in core_idx] + # A single fuzzy protein shares the full sample set, so the two blended fits already cover + # every row (no separate baseline needed) -> exactly two fits per round and selection. + single_fuzzy = len(fuzzy_idx) == 1 + acc_shap_values = np.zeros(shape=(n_samples, n_features)) + list_expected_value = [] + if verbose: + ut.print_start_progress(start_message=f"ShapModel starts interpolation estimation of SHAP values over {n_rounds} rounds.") + for i in range(n_rounds): + _list_model_kwargs = _seed_model_kwargs(list_model_kwargs, random_state=random_state, round_idx=i) + for j, selected_features in enumerate(is_selected): + if verbose: + pct_progress = j / len(is_selected) + add_new_line = explainer_class in LIST_VERBOSE_shap_modelS + ut.print_progress(i=i+pct_progress, n_total=n_rounds, add_new_line=add_new_line) + X_selected = X[:, selected_features] + args = dict(list_model_classes=list_model_classes, list_model_kwargs=_list_model_kwargs, + explainer_class=explainer_class, explainer_kwargs=explainer_kwargs, + class_index=class_index, n_background_data=n_background_data) + cell = np.zeros(shape=(n_samples, X_selected.shape[1])) + if single_fuzzy: + f = fuzzy_idx[0] + p = labels[f] + labels_0 = [0 if k == f else labels[k] for k in range(n_samples)] + labels_1 = [1 if k == f else labels[k] for k in range(n_samples)] + shap_0, exp_0 = _aggregate_shap_values(X_selected, labels=labels_0, **args) + shap_1, exp_1 = _aggregate_shap_values(X_selected, labels=labels_1, **args) + cell = p * shap_1 + (1 - p) * shap_0 + list_expected_value.append(p * exp_1 + (1 - p) * exp_0) + else: + # Non-fuzzy core rows come from a single baseline fit on the core + shap_core, exp_core = _aggregate_shap_values(X_selected[core_idx], labels=core_labels, **args) + cell[core_idx] = shap_core + list_expected_value.append(exp_core) + # Each fuzzy protein is explained against core + itself (others excluded) + for f in fuzzy_idx: + p = labels[f] + sub_idx = core_idx + [f] + X_sub = X_selected[sub_idx] + shap_0, exp_0 = _aggregate_shap_values(X_sub, labels=core_labels + [0], **args) + shap_1, exp_1 = _aggregate_shap_values(X_sub, labels=core_labels + [1], **args) + cell[f] = p * shap_1[-1] + (1 - p) * shap_0[-1] + list_expected_value.append(p * exp_1 + (1 - p) * exp_0) + full_cell = np.zeros(shape=(n_samples, n_features)) + full_cell[:, selected_features] = cell + acc_shap_values += full_cell + if verbose: + ut.print_end_progress(end_message=f"ShapModel finished interpolation estimation and saved results.") + shap_values = acc_shap_values / n_cells + exp_val = np.mean(list_expected_value) + return shap_values, exp_val diff --git a/aaanalysis/explainable_ai_pro/_shap_model.py b/aaanalysis/explainable_ai_pro/_shap_model.py index d7cc7985..9f5f22ce 100644 --- a/aaanalysis/explainable_ai_pro/_shap_model.py +++ b/aaanalysis/explainable_ai_pro/_shap_model.py @@ -13,7 +13,8 @@ from aaanalysis.template_classes import Wrapper from ._backend.check_models import (check_match_labels_X, check_match_X_is_selected) -from ._backend.shap_model.shap_model_fit import monte_carlo_shap_estimation +from ._backend.shap_model.shap_model_fit import (monte_carlo_shap_estimation, + interpolate_fuzzy_shap_estimation) from ._backend.shap_model.sm_add_feat_impact import (comp_shap_feature_importance, insert_shap_feature_importance, comp_shap_feature_impact, @@ -472,6 +473,7 @@ def fit(self, n_rounds: int = 5, is_selected: Optional[ut.ArrayLike2D] = None, fuzzy_labeling: bool = False, + fuzzy_aggregation: str = "threshold", n_background_data: Optional[int] = None, df_seq: Optional[pd.DataFrame] = None, fuzzy_labels: Optional[dict] = None, @@ -504,6 +506,12 @@ def fit(self, fuzzy_labeling : bool, default=False If ``True``, fuzzy labeling is applied to approximate SHAP values for samples with uncertain/partial memberships (e.g., between >0 and <1 for binary classification scenarios). + fuzzy_aggregation : str, default='threshold' + Strategy to turn a soft label ``p`` into a SHAP estimate when fuzzy labeling is active (see Notes): + + - ``'threshold'``: hard-label the fuzzy sample over a threshold grid and average (biased; the default). + - ``'interpolate'``: blend ``p * S1 + (1 - p) * S0`` from a fit at 0 and at 1 (unbiased, exact ``p``). + n_background_data : None or int, optional The number samples (< 'n_samples') in the background dataset used for the `KernelExplainer`` to reduce computation time. The dataset is obtained by k-means clustering. If ``None``, the full dataset 'X' is used. @@ -530,6 +538,21 @@ def fit(self, * Idea: Adjusts label thresholds dynamically in Monte Carlo estimation to better represent label uncertainties. * Background: Inspired by fuzzy logic, replacing binary true/false with degrees of truth. + **Fuzzy aggregation strategies** + + The ``fuzzy_aggregation`` parameter selects how a soft label ``p`` (in [0, 1]) is turned into a SHAP estimate: + + * ``'threshold'`` (default): Over an ``n_rounds`` x ``n_selection`` grid the fuzzy sample is hard-labeled ``1`` + when a per-cell threshold ``<= p`` and the per-cell SHAP matrices are averaged. Because the threshold grid is + non-uniform on (0, 1], the effective positive-fraction is a *biased* approximation of ``p``. + * ``'interpolate'``: The fuzzy sample is weighted by *exactly* ``p`` by fitting the model twice (fuzzy sample at + 0 -> ``S0``, at 1 -> ``S1``) and blending ``p * S1 + (1 - p) * S0`` (the ``exp_value`` is blended the same way). + This is *unbiased* and, with ``n_rounds=1``, the fastest fuzzy estimator (exactly two fits per fuzzy sample). + ``n_rounds > 1`` averages re-seeded fits per round (``random_state + round``), capturing model variance while + staying reproducible for a fixed ``random_state``. Each fuzzy protein is explained independently against the + fixed balanced 0/1 core, with the other fuzzy proteins excluded from its training data. Recommended for the + "explain newly-predicted proteins" path. + **Setting soft labels** There are two equivalent ways to provide soft labels, both enabling fuzzy labeling: @@ -554,6 +577,8 @@ def fit(self, n_samples, n_feat = X.shape ut.check_X_unique_samples(X=X, min_n_unique_samples=2) ut.check_bool(name="fuzzy_labeling", val=fuzzy_labeling) + ut.check_str_options(name="fuzzy_aggregation", val=fuzzy_aggregation, + list_str_options=["threshold", "interpolate"]) if fuzzy_labels is not None: # Entry-keyed soft labels override 'labels' and enable fuzzy labeling check_match_df_seq_X(df_seq=df_seq, X=X) @@ -571,17 +596,23 @@ def fit(self, ut.check_number_range(name="n_background_data", val=n_background_data, min_val=1, just_int=True, accept_none=True) check_match_n_background_data_X(n_background_data=n_background_data, X=X) # Compute SHAP values - shap_values, exp_val = monte_carlo_shap_estimation(X, labels=labels, - list_model_classes=self._list_model_classes, - list_model_kwargs=self._list_model_kwargs, - explainer_class=self._explainer_class, - explainer_kwargs=self._explainer_kwargs, - is_selected=is_selected, - fuzzy_labeling=fuzzy_labeling, - n_rounds=n_rounds, - verbose=self._verbose, - label_target_class=label_target_class, - n_background_data=n_background_data) + backend_args = dict(list_model_classes=self._list_model_classes, + list_model_kwargs=self._list_model_kwargs, + explainer_class=self._explainer_class, + explainer_kwargs=self._explainer_kwargs, + is_selected=is_selected, + n_rounds=n_rounds, + verbose=self._verbose, + label_target_class=label_target_class, + n_background_data=n_background_data) + if fuzzy_labeling and fuzzy_aggregation == "interpolate": + shap_values, exp_val = interpolate_fuzzy_shap_estimation(X, labels=labels, + random_state=self._random_state, + **backend_args) + else: + shap_values, exp_val = monte_carlo_shap_estimation(X, labels=labels, + fuzzy_labeling=fuzzy_labeling, + **backend_args) self.shap_values = shap_values self.exp_value = exp_val return self diff --git a/docs/source/index/release_notes.rst b/docs/source/index/release_notes.rst index 8fafd242..31015939 100644 --- a/docs/source/index/release_notes.rst +++ b/docs/source/index/release_notes.rst @@ -85,6 +85,13 @@ Added ``add_feat_impact`` / ``add_sample_mean_dif`` accept ``df_seq`` and a ``samples`` parameter taking row positions or entry names. The array-``labels`` path is unchanged; ``sample_positions`` is a deprecated alias for ``samples`` (removed in 1.2.0). +- **ShapModel — unbiased fuzzy estimator** (``[pro]``): ``fit`` gains + ``fuzzy_aggregation`` (default ``'threshold'``, unchanged). ``'interpolate'`` weights a + soft label by *exactly* ``p`` — fitting at 0 (``S0``) and at 1 (``S1``) and blending + ``p * S1 + (1 - p) * S0`` — instead of the biased threshold sweep. With ``n_rounds=1`` + it needs only two fits per fuzzy sample; ``n_rounds > 1`` averages per-round re-seeded + fits (reproducible for a fixed ``random_state``). Recommended for explaining + newly-predicted proteins. **Sequence Analysis** diff --git a/examples/explainable_ai/sm_fit.ipynb b/examples/explainable_ai/sm_fit.ipynb index 25849ec8..55edce3b 100644 --- a/examples/explainable_ai/sm_fit.ipynb +++ b/examples/explainable_ai/sm_fit.ipynb @@ -21,10 +21,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T19:51:33.503612Z", - "iopub.status.busy": "2026-06-13T19:51:33.503358Z", - "iopub.status.idle": "2026-06-13T19:51:35.343709Z", - "shell.execute_reply": "2026-06-13T19:51:35.343497Z" + "iopub.execute_input": "2026-06-25T03:49:04.837887Z", + "iopub.status.busy": "2026-06-25T03:49:04.837819Z", + "iopub.status.idle": "2026-06-25T03:49:06.341129Z", + "shell.execute_reply": "2026-06-25T03:49:06.340905Z" } }, "outputs": [ @@ -32,7 +32,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/stephanbreimann/Programming/1Packages/aaanalysis-shap-acc/aaanalysis/feature_engineering/_backend/cpp_run.py:143: UserWarning: CPP is using the Python kernel fallback — the compiled Cython extension is not available in this install. Output is bit-exact with the Cython path but ~2x slower. Reinstall via `pip install --force-reinstall aaanalysis` to fetch a prebuilt wheel.\n", + "/Users/stephanbreimann/Programming/1Packages/wt-shap-interpolate/aaanalysis/feature_engineering/_backend/cpp_run.py:163: UserWarning: CPP is using the Python kernel fallback — the compiled Cython extension is not available in this install. Output is bit-exact with the Cython path but ~2x slower. Reinstall via `pip install --force-reinstall aaanalysis` to fetch a prebuilt wheel.\n", " warnings.warn(\n" ] }, @@ -40,105 +40,105 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 entrysequencelabeltmd_starttmd_stopjmd_ntmdjmd_centrysequencelabeltmd_starttmd_stopjmd_ntmdjmd_c
1Q14802MQKVTLGLLVFLAGF...PGETPPLITPGSAQS03759NSPFYYDWHSLQVGGLICAGVLCAMGIIIVMSAKCKCKFGQKS1Q14802MQKVTLGLLVFLAGF...PGETPPLITPGSAQS03759NSPFYYDWHSLQVGGLICAGVLCAMGIIIVMSAKCKCKFGQKS
2Q86UE4MAARSWQDELAQQAE...SPKQIKKKKKARRET05072LGLEPKRYPGWVILVGTGALGLLLLFLLGYGWAAACAGARKKR2Q86UE4MAARSWQDELAQQAE...SPKQIKKKKKARRET05072LGLEPKRYPGWVILVGTGALGLLLLFLLGYGWAAACAGARKKR
3Q969W9MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL04163FQSMEITELEFVQIIIIVVVMMVMVVVITCLLSHYKLSARSFI3Q969W9MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL04163FQSMEITELEFVQIIIIVVVMMVMVVVITCLLSHYKLSARSFI
4P05067MLPGLALLLLAAWTA...GYENPTYKFFEQMQN1701723FAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHH4P05067MLPGLALLLLAAWTA...GYENPTYKFFEQMQN1701723FAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHH
5P14925MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS1868890KLSTEPGSGVSVVLITTLLVIPVLVLLAIVMFIRWKKSRAFGD5P14925MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS1868890KLSTEPGSGVSVVLITTLLVIPVLVLLAIVMFIRWKKSRAFGD
6P70180MRSLLLFTFSACVLL...RELREDSIRSHFSVA1477499PCKSSGGLEESAVTGIVVGALLGAGLLMAFYFFRKKYRITIER6P70180MRSLLLFTFSACVLL...RELREDSIRSHFSVA1477499PCKSSGGLEESAVTGIVVGALLGAGLLMAFYFFRKKYRITIER
\n" @@ -188,10 +188,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T19:51:35.344949Z", - "iopub.status.busy": "2026-06-13T19:51:35.344873Z", - "iopub.status.idle": "2026-06-13T19:51:35.655752Z", - "shell.execute_reply": "2026-06-13T19:51:35.655480Z" + "iopub.execute_input": "2026-06-25T03:49:06.342198Z", + "iopub.status.busy": "2026-06-25T03:49:06.342114Z", + "iopub.status.idle": "2026-06-25T03:49:06.710977Z", + "shell.execute_reply": "2026-06-25T03:49:06.710741Z" } }, "outputs": [ @@ -200,16 +200,16 @@ "output_type": "stream", "text": [ "SHAP values explain the feature impact for 3 negative and 3 positive samples\n", - "[[-0.1 -0.1 -0.08 -0.09 -0.06]\n", - " [-0.12 -0.12 -0.09 -0.1 -0.07]\n", - " [-0.13 -0.14 -0.04 -0.09 -0.01]\n", - " [ 0.13 0.13 0.05 0.09 0.04]\n", - " [ 0.13 0.12 0.08 0.09 0.07]\n", + "[[-0.11 -0.11 -0.09 -0.09 -0.07]\n", + " [-0.12 -0.12 -0.09 -0.09 -0.07]\n", + " [-0.14 -0.13 -0.04 -0.08 -0.01]\n", + " [ 0.13 0.12 0.05 0.09 0.04]\n", + " [ 0.13 0.11 0.08 0.09 0.07]\n", " [ 0.13 0.12 0.08 0.09 0.06]]\n", "\n", "The expected value approximates the expected model output (average prediction score).\n", "For a binary classification with balanced datasets, it is around 0.5:\n", - "0.49566666666666687\n" + "0.5061666666666669\n" ] } ], @@ -250,10 +250,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T19:51:35.656884Z", - "iopub.status.busy": "2026-06-13T19:51:35.656813Z", - "iopub.status.idle": "2026-06-13T19:51:35.967928Z", - "shell.execute_reply": "2026-06-13T19:51:35.967703Z" + "iopub.execute_input": "2026-06-25T03:49:06.712150Z", + "iopub.status.busy": "2026-06-25T03:49:06.712084Z", + "iopub.status.idle": "2026-06-25T03:49:07.084512Z", + "shell.execute_reply": "2026-06-25T03:49:07.084272Z" } }, "outputs": [ @@ -262,15 +262,15 @@ "output_type": "stream", "text": [ "Reverse sign of SHAP values by changing reference class from 1 to 0\n", - "[[ 0.11 0.1 0.1 0.09 0.06]\n", - " [ 0.13 0.12 0.1 0.08 0.07]\n", - " [ 0.15 0.14 0.04 0.08 0.01]\n", - " [-0.14 -0.13 -0.07 -0.08 -0.04]\n", - " [-0.13 -0.12 -0.09 -0.09 -0.07]\n", - " [-0.13 -0.12 -0.09 -0.08 -0.06]]\n", + "[[ 0.1 0.11 0.08 0.1 0.07]\n", + " [ 0.12 0.12 0.08 0.09 0.08]\n", + " [ 0.14 0.14 0.03 0.09 0.03]\n", + " [-0.12 -0.14 -0.05 -0.09 -0.04]\n", + " [-0.12 -0.13 -0.08 -0.1 -0.07]\n", + " [-0.12 -0.13 -0.08 -0.1 -0.07]]\n", "\n", "Base value stays around 0.5:\n", - "0.5036666666666669\n" + "0.5048333333333336\n" ] } ], @@ -309,10 +309,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T19:51:35.969048Z", - "iopub.status.busy": "2026-06-13T19:51:35.968988Z", - "iopub.status.idle": "2026-06-13T19:51:36.534019Z", - "shell.execute_reply": "2026-06-13T19:51:36.533756Z" + "iopub.execute_input": "2026-06-25T03:49:07.085712Z", + "iopub.status.busy": "2026-06-25T03:49:07.085645Z", + "iopub.status.idle": "2026-06-25T03:49:07.761961Z", + "shell.execute_reply": "2026-06-25T03:49:07.761631Z" } }, "outputs": [], @@ -342,10 +342,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T19:51:36.535531Z", - "iopub.status.busy": "2026-06-13T19:51:36.535441Z", - "iopub.status.idle": "2026-06-13T19:51:37.101052Z", - "shell.execute_reply": "2026-06-13T19:51:37.100823Z" + "iopub.execute_input": "2026-06-25T03:49:07.763450Z", + "iopub.status.busy": "2026-06-25T03:49:07.763356Z", + "iopub.status.idle": "2026-06-25T03:49:08.445470Z", + "shell.execute_reply": "2026-06-25T03:49:08.445237Z" } }, "outputs": [ @@ -354,12 +354,12 @@ "output_type": "stream", "text": [ "Impact of feature pre-selection\n", - "[[-0.16 -0.19 -0.05 -0.05 0. ]\n", - " [-0.18 -0.21 -0.05 -0.05 0. ]\n", - " [-0.18 -0.21 -0.02 -0.05 0. ]\n", - " [ 0.18 0.21 0.03 0.05 0. ]\n", - " [ 0.18 0.2 0.05 0.05 0. ]\n", - " [ 0.18 0.21 0.05 0.05 0. ]]\n" + "[[-0.19 -0.17 -0.05 -0.05 0. ]\n", + " [-0.2 -0.19 -0.05 -0.05 0. ]\n", + " [-0.21 -0.2 -0.02 -0.05 0. ]\n", + " [ 0.2 0.2 0.03 0.05 0. ]\n", + " [ 0.2 0.19 0.05 0.05 0. ]\n", + " [ 0.2 0.19 0.05 0.05 0. ]]\n" ] } ], @@ -395,10 +395,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T19:51:37.102222Z", - "iopub.status.busy": "2026-06-13T19:51:37.102160Z", - "iopub.status.idle": "2026-06-13T19:51:37.672989Z", - "shell.execute_reply": "2026-06-13T19:51:37.672757Z" + "iopub.execute_input": "2026-06-25T03:49:08.446668Z", + "iopub.status.busy": "2026-06-25T03:49:08.446593Z", + "iopub.status.idle": "2026-06-25T03:49:09.130957Z", + "shell.execute_reply": "2026-06-25T03:49:09.130741Z" } }, "outputs": [ @@ -407,12 +407,12 @@ "output_type": "stream", "text": [ "First sample is labeled as 0.5 between negative (0) and positive (1)\n", - "[[ 0.04 0.04 -0.03 -0.04 0. ]\n", + "[[ 0.03 0.04 -0.03 -0.03 0. ]\n", " [-0.24 -0.25 -0.05 -0.05 0. ]\n", " [-0.22 -0.22 -0.02 -0.04 0. ]\n", - " [ 0.15 0.16 0.03 0.04 0. ]\n", - " [ 0.14 0.16 0.04 0.04 0. ]\n", - " [ 0.14 0.16 0.04 0.04 0. ]]\n" + " [ 0.15 0.15 0.03 0.04 0. ]\n", + " [ 0.14 0.15 0.04 0.04 0. ]\n", + " [ 0.14 0.15 0.04 0.04 0. ]]\n" ] } ], @@ -426,6 +426,51 @@ "print(sm.shap_values.round(2))" ] }, + { + "cell_type": "markdown", + "id": "445a9b21", + "metadata": {}, + "source": [ + "The ``fuzzy_aggregation`` parameter selects how a soft label ``p`` is turned into a SHAP estimate. The default ``'threshold'`` hard-labels the sample across a threshold grid (a biased approximation of ``p``), while ``'interpolate'`` blends ``p * S1 + (1 - p) * S0`` from a fit at 0 and at 1 — an unbiased, exact-``p`` estimate that needs only two fits per fuzzy sample at ``n_rounds=1``:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7312979a", + "metadata": { + "execution": { + "iopub.execute_input": "2026-06-25T03:49:09.132049Z", + "iopub.status.busy": "2026-06-25T03:49:09.131987Z", + "iopub.status.idle": "2026-06-25T03:49:09.448718Z", + "shell.execute_reply": "2026-06-25T03:49:09.448466Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exact-p interpolation for the fuzzy first sample (label=0.5)\n", + "[[-0.04 -0.03 -0.02 -0.02 0. ]\n", + " [-0.24 -0.24 -0.04 -0.05 0. ]\n", + " [-0.22 -0.23 -0.03 -0.04 0. ]\n", + " [ 0.17 0.18 0.01 0.03 0. ]\n", + " [ 0.17 0.18 0.03 0.03 0. ]\n", + " [ 0.17 0.18 0.03 0.03 0. ]]\n" + ] + } + ], + "source": [ + "# Unbiased exact-p estimate: blend p*S1 + (1-p)*S0 (the fuzzy first sample has label 0.5)\n", + "sm = aa.ShapModel(random_state=42)\n", + "sm = sm.fit(X, labels=labels, is_selected=is_selected,\n", + " fuzzy_labeling=True, fuzzy_aggregation=\"interpolate\", n_rounds=1)\n", + "\n", + "print(\"Exact-p interpolation for the fuzzy first sample (label=0.5)\")\n", + "print(sm.shap_values.round(2))" + ] + }, { "cell_type": "markdown", "id": "9d3e32ab", @@ -436,14 +481,14 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "0b3fd3f3", "metadata": { "execution": { - "iopub.execute_input": "2026-06-13T19:51:37.674230Z", - "iopub.status.busy": "2026-06-13T19:51:37.674159Z", - "iopub.status.idle": "2026-06-13T19:51:38.241890Z", - "shell.execute_reply": "2026-06-13T19:51:38.241638Z" + "iopub.execute_input": "2026-06-25T03:49:09.449821Z", + "iopub.status.busy": "2026-06-25T03:49:09.449755Z", + "iopub.status.idle": "2026-06-25T03:49:10.138026Z", + "shell.execute_reply": "2026-06-25T03:49:10.137745Z" } }, "outputs": [ @@ -452,12 +497,12 @@ "output_type": "stream", "text": [ "Sample 'Q14802' is labeled as 0.5 between negative (0) and positive (1)\n", - "[[ 0.03 0.04 -0.02 -0.04 0. ]\n", - " [-0.25 -0.24 -0.04 -0.05 0. ]\n", - " [-0.23 -0.22 -0.01 -0.04 0. ]\n", - " [ 0.15 0.15 0.02 0.04 0. ]\n", - " [ 0.15 0.15 0.04 0.05 0. ]\n", - " [ 0.15 0.15 0.04 0.05 0. ]]\n" + "[[ 0.03 0.03 -0.03 -0.03 0. ]\n", + " [-0.25 -0.24 -0.05 -0.05 0. ]\n", + " [-0.23 -0.21 -0.02 -0.04 0. ]\n", + " [ 0.16 0.15 0.02 0.04 0. ]\n", + " [ 0.16 0.14 0.04 0.04 0. ]\n", + " [ 0.16 0.14 0.04 0.04 0. ]]\n" ] } ], @@ -484,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "45c1efa0c82d7255", "metadata": { "ExecuteTime": { @@ -493,10 +538,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-13T19:51:38.243137Z", - "iopub.status.busy": "2026-06-13T19:51:38.243065Z", - "iopub.status.idle": "2026-06-13T19:51:38.245397Z", - "shell.execute_reply": "2026-06-13T19:51:38.245176Z" + "iopub.execute_input": "2026-06-25T03:49:10.139326Z", + "iopub.status.busy": "2026-06-25T03:49:10.139243Z", + "iopub.status.idle": "2026-06-25T03:49:10.142046Z", + "shell.execute_reply": "2026-06-25T03:49:10.141841Z" } }, "outputs": [], @@ -524,7 +569,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.14.0" + "version": "3.13.11" } }, "nbformat": 4, diff --git a/tests/unit/shap_model_tests/test_sm_fit.py b/tests/unit/shap_model_tests/test_sm_fit.py index 1aa18c34..6f120dc7 100644 --- a/tests/unit/shap_model_tests/test_sm_fit.py +++ b/tests/unit/shap_model_tests/test_sm_fit.py @@ -393,3 +393,132 @@ def test_fuzzy_labels_df_seq_missing_entry_column(self): sm = aa.ShapModel(**MODEL_KWARGS, verbose=False) with pytest.raises(ValueError): sm.fit(valid_X, labels=valid_labels, df_seq=df_bad, fuzzy_labels={entry: 0.6}, **ARGS) + + +# Small deterministic fixture for the exact-p / fit-count assertions (single model -> exact fit counts) +_RNG = np.random.default_rng(0) +SMALL_X = _RNG.normal(size=(7, 5)) +SMALL_LABELS_1FUZZY = [1, 1, 1, 0, 0, 0, 0.6] # one fuzzy protein (last row), balanced 3/3 core +SMALL_X_2FUZZY = _RNG.normal(size=(8, 5)) +SMALL_LABELS_2FUZZY = [1, 1, 1, 0, 0, 0, 0.6, 0.3] # two fuzzy proteins, balanced 3/3 core +ONE_MODEL = dict(list_model_classes=[RandomForestClassifier]) + + +class TestShapModelFitFuzzyAggregation: + """The ``fuzzy_aggregation`` estimator: 'threshold' (default) vs unbiased 'interpolate'.""" + + # Positive: both options are accepted and produce SHAP values + def test_fuzzy_aggregation_options_valid(self): + for fuzzy_aggregation in ["threshold", "interpolate"]: + sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=0) + sm.fit(SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, + fuzzy_aggregation=fuzzy_aggregation, n_rounds=1) + assert sm.shap_values.shape == SMALL_X.shape + assert sm.exp_value is not None + + # Negative: an unknown value is rejected + def test_invalid_fuzzy_aggregation(self): + sm = aa.ShapModel(**ONE_MODEL, verbose=False) + with pytest.raises(ValueError): + sm.fit(SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, fuzzy_aggregation="bogus") + + # KPI: unbiased exact-p weighting (single fuzzy, n_rounds=1) == p*S1 + (1-p)*S0 + def test_interpolate_exact_p_blend(self): + from aaanalysis.explainable_ai_pro._backend.shap_model import shap_model_fit as B + p = SMALL_LABELS_1FUZZY[-1] + sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=42) + sm.fit(SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, + fuzzy_aggregation="interpolate", n_rounds=1) + model_kwargs = dict(sm._list_model_kwargs[0]) + model_kwargs["random_state"] = 42 # round 0 -> random_state + 0 + args = dict(list_model_classes=[RandomForestClassifier], list_model_kwargs=[model_kwargs], + explainer_class=sm._explainer_class, explainer_kwargs=sm._explainer_kwargs, + class_index=1, n_background_data=None) + labels_0 = [1, 1, 1, 0, 0, 0, 0] + labels_1 = [1, 1, 1, 0, 0, 0, 1] + shap_0, _ = B._aggregate_shap_values(SMALL_X, labels=labels_0, **args) + shap_1, _ = B._aggregate_shap_values(SMALL_X, labels=labels_1, **args) + ref = p * shap_1 + (1 - p) * shap_0 + assert np.allclose(sm.shap_values, ref, atol=1e-10, rtol=0) + + # KPI: 2-fit fast path — exactly two model fits per fuzzy sample, scaling with n_rounds + def test_interpolate_two_fits_single_fuzzy(self): + from aaanalysis.explainable_ai_pro._backend.shap_model import shap_model_fit as B + orig = B._compute_shap_values + counter = {"n": 0} + + def spy(*a, **k): + counter["n"] += 1 + return orig(*a, **k) + + B._compute_shap_values = spy + try: + sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=42) + sm.fit(SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, + fuzzy_aggregation="interpolate", n_rounds=1) + assert counter["n"] == 2 + counter["n"] = 0 + sm.fit(SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, + fuzzy_aggregation="interpolate", n_rounds=10) + assert counter["n"] == 20 + finally: + B._compute_shap_values = orig + + # KPI: n_rounds is meaningful (rounds differ) yet reproducible for a fixed random_state + def test_interpolate_reproducible_and_rounds_matter(self): + def run(n_rounds): + return aa.ShapModel(**ONE_MODEL, verbose=False, random_state=7).fit( + SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, + fuzzy_aggregation="interpolate", n_rounds=n_rounds).shap_values + # Reproducible across runs at a fixed seed + assert np.array_equal(run(3), run(3)) + # n_rounds genuinely changes the estimate (per-round re-seeding) + assert not np.allclose(run(1), run(10)) + + # KPI: Monte-Carlo averaging — variance shrinks with n_rounds when random_state=None + def test_interpolate_mc_variance_decreases(self): + def variance(reps, n_rounds): + runs = [aa.ShapModel(**ONE_MODEL, verbose=False, random_state=None).fit( + SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, + fuzzy_aggregation="interpolate", n_rounds=n_rounds).shap_values for _ in range(reps)] + return np.mean(np.var(np.stack(runs), axis=0)) + assert variance(6, 12) < variance(6, 1) + + # Multi-fuzzy: each fuzzy protein explained independently against the core (baseline + 2 per fuzzy) + def test_interpolate_multi_fuzzy(self): + from aaanalysis.explainable_ai_pro._backend.shap_model import shap_model_fit as B + orig = B._compute_shap_values + counter = {"n": 0} + + def spy(*a, **k): + counter["n"] += 1 + return orig(*a, **k) + + B._compute_shap_values = spy + try: + sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=5) + sm.fit(SMALL_X_2FUZZY, labels=SMALL_LABELS_2FUZZY, fuzzy_labeling=True, + fuzzy_aggregation="interpolate", n_rounds=1) + assert sm.shap_values.shape == SMALL_X_2FUZZY.shape + assert counter["n"] == 1 + 2 * 2 # one baseline core fit + two fits per fuzzy protein + finally: + B._compute_shap_values = orig + + # No regression: the default 'threshold' path is unchanged by the routing refactor + def test_threshold_default_unchanged(self): + default = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( + SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, n_rounds=3).shap_values + explicit = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( + SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, + fuzzy_aggregation="threshold", n_rounds=3).shap_values + assert np.array_equal(default, explicit) + + # fuzzy_aggregation is inert when fuzzy labeling is off (binary path untouched) + def test_fuzzy_aggregation_inert_without_fuzzy(self): + labels = [1, 1, 1, 0, 0, 0, 0] # all binary -> no fuzzy sample + base = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=3).fit( + SMALL_X, labels=labels, fuzzy_labeling=False, n_rounds=2).shap_values + with_interp = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=3).fit( + SMALL_X, labels=labels, fuzzy_labeling=False, + fuzzy_aggregation="interpolate", n_rounds=2).shap_values + assert np.array_equal(base, with_interp) From adad4a72f0d214da8129e9b137c4267dfe4cbf68 Mon Sep 17 00:00:00 2001 From: Stephan Breimann Date: Thu, 25 Jun 2026 09:29:24 +0200 Subject: [PATCH 2/8] refactor(shap): apply review cleanups to interpolate estimator - Extract _class_index_from_labels helper (was duplicated across both estimators). - Move the cell zero-init into the multi-fuzzy branch (the single-fuzzy branch reassigns it; the init was dead there). - Comment why only the interpolate path threads random_state (per-round re-seeding) while the threshold path keeps it baked into model kwargs. No behavior change; output identical (41 ShapModel fit tests green). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../_backend/shap_model/shap_model_fit.py | 14 +++++++++----- aaanalysis/explainable_ai_pro/_shap_model.py | 2 ++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py b/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py index 4d5280d3..2fa4a6e6 100644 --- a/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py +++ b/aaanalysis/explainable_ai_pro/_backend/shap_model/shap_model_fit.py @@ -47,6 +47,12 @@ def _get_shap_values(shap_output, class_index=1): return shap_values +def _class_index_from_labels(labels, label_target_class=1): + """Map the target class label to its index among the integer (non-fuzzy) classes.""" + label_classes = sorted(list(dict.fromkeys([x for x in labels if x == int(x)]))) + return label_classes.index(label_target_class) + + def _compute_shap_values(X, labels, model_class=None, model_kwargs=None, explainer_class=None, explainer_kwargs=None, class_index=1, n_background_data=None): @@ -130,8 +136,7 @@ def monte_carlo_shap_estimation(X, labels=None, list_model_classes=None, list_mo label_target_class=1, n_background_data=None): """Compute Monte Carlo estimates of SHAP values for multiple models and feature selections.""" # Get class index - label_classes = sorted(list(dict.fromkeys([x for x in labels if x == int(x)]))) - class_index = label_classes.index(label_target_class) + class_index = _class_index_from_labels(labels, label_target_class) # Create empty SHAP value matrix n_samples, n_features = X.shape n_selection_rounds = len(is_selected) @@ -185,8 +190,7 @@ def interpolate_fuzzy_shap_estimation(X, labels=None, list_model_classes=None, l """ labels = list(labels) # Get class index (fuzzy float labels are excluded; classes come from the 0/1 core) - label_classes = sorted(list(dict.fromkeys([x for x in labels if x == int(x)]))) - class_index = label_classes.index(label_target_class) + class_index = _class_index_from_labels(labels, label_target_class) n_samples, n_features = X.shape n_selection_rounds = len(is_selected) n_cells = n_rounds * n_selection_rounds @@ -212,7 +216,6 @@ def interpolate_fuzzy_shap_estimation(X, labels=None, list_model_classes=None, l args = dict(list_model_classes=list_model_classes, list_model_kwargs=_list_model_kwargs, explainer_class=explainer_class, explainer_kwargs=explainer_kwargs, class_index=class_index, n_background_data=n_background_data) - cell = np.zeros(shape=(n_samples, X_selected.shape[1])) if single_fuzzy: f = fuzzy_idx[0] p = labels[f] @@ -223,6 +226,7 @@ def interpolate_fuzzy_shap_estimation(X, labels=None, list_model_classes=None, l cell = p * shap_1 + (1 - p) * shap_0 list_expected_value.append(p * exp_1 + (1 - p) * exp_0) else: + cell = np.zeros(shape=(n_samples, X_selected.shape[1])) # Non-fuzzy core rows come from a single baseline fit on the core shap_core, exp_core = _aggregate_shap_values(X_selected[core_idx], labels=core_labels, **args) cell[core_idx] = shap_core diff --git a/aaanalysis/explainable_ai_pro/_shap_model.py b/aaanalysis/explainable_ai_pro/_shap_model.py index 9f5f22ce..2d06fad1 100644 --- a/aaanalysis/explainable_ai_pro/_shap_model.py +++ b/aaanalysis/explainable_ai_pro/_shap_model.py @@ -606,6 +606,8 @@ def fit(self, label_target_class=label_target_class, n_background_data=n_background_data) if fuzzy_labeling and fuzzy_aggregation == "interpolate": + # Only the interpolate path threads 'random_state' explicitly: it re-seeds per round + # (random_state + round). The threshold path keeps the seed baked into the model kwargs. shap_values, exp_val = interpolate_fuzzy_shap_estimation(X, labels=labels, random_state=self._random_state, **backend_args) From 5a07d8acc8321d30d1f12f3813fe507d52c70d8a Mon Sep 17 00:00:00 2001 From: Stephan Breimann Date: Thu, 25 Jun 2026 12:28:18 +0200 Subject: [PATCH 3/8] test(shap): add fuzzy-interpolate regression anchor (vs v1.0.x) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pins the interpolate estimator on a real DOM_GSEC fuzzy cell — APP (P05067), CD44 (P16070), and a non-substrate (Q14802) with invented prediction scores, each explained as a single fuzzy sample. Guards: - exact-p identity: interpolate(n_rounds=1) == p*S1 + (1-p)*S0 (atol=1e-10), recomputed same-machine so it is platform-robust; - fit-count advantage: interpolate(n_rounds=1) does 2 fits vs threshold(n5)'s 5 — the ~2.15x wall-clock win measured against aaanalysis 1.0.3 as a noise-free invariant; - frozen per-protein signatures; the threshold signatures were verified byte-identical to aaanalysis 1.0.3 on this cell (no-regression for the default path), while interpolate differs by design (unbiased exact-p). @pytest.mark.regression, pinned to Linux/py3.11 (AAA_RUN_REGRESSION=1 forces it locally); runs in the non-gating nightly only. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../test_sm_fuzzy_interpolate_regression.py | 150 ++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 tests/unit/shap_model_tests/test_sm_fuzzy_interpolate_regression.py diff --git a/tests/unit/shap_model_tests/test_sm_fuzzy_interpolate_regression.py b/tests/unit/shap_model_tests/test_sm_fuzzy_interpolate_regression.py new file mode 100644 index 00000000..f90fde58 --- /dev/null +++ b/tests/unit/shap_model_tests/test_sm_fuzzy_interpolate_regression.py @@ -0,0 +1,150 @@ +"""Regression anchor for the ShapModel ``fuzzy_aggregation="interpolate"`` estimator. + +Pins the unbiased interpolation estimator on a real DOM_GSEC fuzzy-labeling cell: +three proteins explained one at a time as a single fuzzy sample with invented +prediction scores — APP (``P05067``, p=0.85), CD44 (``P16070``, p=0.65), and a +non-substrate (``Q14802``, p=0.15). + +Two guards, strongest first: + +1. **Exact-p identity (platform-robust).** ``interpolate`` with ``n_rounds=1`` must + equal ``p*S1 + (1-p)*S0`` to ``atol=1e-10`` — recomputed on the same machine, so + no frozen value is involved and it cannot drift across platforms. This is the + unbiased-by-construction guarantee; any regression in the estimator breaks it. +2. **Fit-count speed advantage (deterministic).** ``interpolate n_rounds=1`` does + strictly fewer model fits than the ``threshold n_rounds=5`` default (2 vs 5 per + fuzzy sample), captured via a fit-count spy — the wall-clock win measured against + aaanalysis 1.0.3 (~2.15x faster on this cell) reduced to a noise-free invariant. +3. **Frozen signatures.** Compact per-protein signatures (row sum + max |value|) for + both estimators. The threshold signatures were verified **byte-identical** to + aaanalysis 1.0.3 on this cell (the no-regression guarantee for the default path); + interpolate differs by design (exact-p vs the biased threshold grid). + +Frozen values were captured on a dev machine (darwin/py3.13); the first canonical-cell +CI run (Linux/py3.11, nightly) re-verifies them and they are re-frozen only on an +intentional, reviewed change. Runs in the non-gating nightly only. + +Run locally (off the canonical cell) with: AAA_RUN_REGRESSION=1 pytest +""" +import os +import sys + +import numpy as np +import pytest +from sklearn.ensemble import RandomForestClassifier + +import aaanalysis as aa +from aaanalysis.explainable_ai_pro._backend.shap_model import shap_model_fit as B + +aa.options["verbose"] = False + +# Pin to the canonical cell; AAA_RUN_REGRESSION=1 forces it on any env (local check). +_CANONICAL_ENV = ( + os.environ.get("AAA_RUN_REGRESSION") == "1" + or (sys.platform.startswith("linux") and sys.version_info[:2] == (3, 11)) +) + +pytestmark = [ + pytest.mark.regression, + pytest.mark.skipif( + not _CANONICAL_ENV, + reason="exact-value regression pinned to Linux/py3.11; " + "set AAA_RUN_REGRESSION=1 to force locally", + ), +] + +# Three proteins + invented prediction scores (substrate, substrate, non-substrate) +PROTEINS = {"P05067": 0.85, "P16070": 0.65, "Q14802": 0.15} +N_FEAT = 25 +SEED = 42 +ONE_MODEL = dict(list_model_classes=[RandomForestClassifier]) + +# Frozen (sum, max|value|) per protein; tolerant of cross-platform RF-SHAP drift. +# THRESHOLD == aaanalysis 1.0.3 (verified byte-identical on this cell). +SIG_THRESHOLD = {"P05067": (0.3335, 0.0533), "P16070": (0.1911, 0.0934), "Q14802": (-0.3002, 0.0701)} +SIG_INTERPOLATE_N1 = {"P05067": (0.3676, 0.0550), "P16070": (0.2222, 0.0876), "Q14802": (-0.2054, 0.0628)} +SIG_ATOL = 5e-3 + + +def _build_cell(): + df_seq = aa.load_dataset(name="DOM_GSEC") + df_feat = aa.load_features(name="DOM_GSEC").head(N_FEAT) + sf = aa.SequenceFeature() + df_parts = sf.get_df_parts(df_seq=df_seq) + X = sf.feature_matrix(features=df_feat["feature"], df_parts=df_parts) + return X, df_seq["entry"].to_list(), df_seq["label"].to_list() + + +def _fuzzy_labels(base_labels, idx, score): + labels = [float(v) for v in base_labels] + labels[idx] = score + return labels + + +def test_interpolate_exact_p_identity_on_dom_gsec(): + """interpolate (n_rounds=1) == p*S1 + (1-p)*S0 on the real 3-protein cell.""" + X, entries, base_labels = _build_cell() + for entry, score in PROTEINS.items(): + i = entries.index(entry) + labels = _fuzzy_labels(base_labels, i, score) + sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=SEED) + sm.fit(X, labels=labels, fuzzy_labeling=True, fuzzy_aggregation="interpolate", n_rounds=1) + # Reference: two seeded fits with the fuzzy sample pinned at 0 and at 1 + mk = dict(sm._list_model_kwargs[0]) + mk["random_state"] = SEED # round 0 -> random_state + 0 + args = dict(list_model_classes=[RandomForestClassifier], list_model_kwargs=[mk], + explainer_class=sm._explainer_class, explainer_kwargs=sm._explainer_kwargs, + class_index=1, n_background_data=None) + labels_0 = [0 if k == i else labels[k] for k in range(len(labels))] + labels_1 = [1 if k == i else labels[k] for k in range(len(labels))] + shap_0, _ = B._aggregate_shap_values(X, labels=labels_0, **args) + shap_1, _ = B._aggregate_shap_values(X, labels=labels_1, **args) + ref = score * shap_1 + (1 - score) * shap_0 + assert np.allclose(sm.shap_values, ref, atol=1e-10, rtol=0), entry + + +def test_interpolate_n1_fewer_fits_than_threshold_n5(): + """The wall-clock win vs 1.0.3 as a noise-free invariant: 2 fits vs 5 per fuzzy sample.""" + X, entries, base_labels = _build_cell() + i = entries.index("P05067") + labels = _fuzzy_labels(base_labels, i, PROTEINS["P05067"]) + orig = B._compute_shap_values + counter = {"n": 0} + + def spy(*a, **k): + counter["n"] += 1 + return orig(*a, **k) + + B._compute_shap_values = spy + try: + counter["n"] = 0 + aa.ShapModel(**ONE_MODEL, verbose=False, random_state=SEED).fit( + X, labels=labels, fuzzy_labeling=True, fuzzy_aggregation="interpolate", n_rounds=1) + n_interpolate = counter["n"] + counter["n"] = 0 + aa.ShapModel(**ONE_MODEL, verbose=False, random_state=SEED).fit( + X, labels=labels, fuzzy_labeling=True, fuzzy_aggregation="threshold", n_rounds=5) + n_threshold = counter["n"] + finally: + B._compute_shap_values = orig + assert n_interpolate == 2 + assert n_threshold == 5 + assert n_interpolate < n_threshold + + +@pytest.mark.parametrize("aggregation,n_rounds,frozen", [ + ("threshold", 5, SIG_THRESHOLD), + ("interpolate", 1, SIG_INTERPOLATE_N1), +]) +def test_frozen_signatures(aggregation, n_rounds, frozen): + """Coarse per-protein anchors; threshold signatures == aaanalysis 1.0.3 on this cell.""" + X, entries, base_labels = _build_cell() + for entry, score in PROTEINS.items(): + i = entries.index(entry) + labels = _fuzzy_labels(base_labels, i, score) + sm = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=SEED) + sm.fit(X, labels=labels, fuzzy_labeling=True, fuzzy_aggregation=aggregation, n_rounds=n_rounds) + row = sm.shap_values[i] + exp_sum, exp_maxabs = frozen[entry] + assert np.isclose(row.sum(), exp_sum, atol=SIG_ATOL), f"{entry} sum" + assert np.isclose(np.abs(row).max(), exp_maxabs, atol=SIG_ATOL), f"{entry} maxabs" From de40ac7329ef979a1a49e017597ed6ab20863cbf Mon Sep 17 00:00:00 2001 From: Stephan Breimann Date: Thu, 25 Jun 2026 12:41:33 +0200 Subject: [PATCH 4/8] test(shap): assert interpolate converges with n_rounds (bootstrap mean) interpolate's per-round average is a Monte-Carlo/bootstrap mean over re-seeded model fits, so it converges as n_rounds grows. Adds a reproducible (fixed base seed) convergence test on the canonical fuzzy cell: n_rounds=R equals the cumulative mean of per-round blends, so the 25 blends are computed once and all cumulative means derived from them. Asserts the convergence structure (platform-robust, no frozen values): - a single round (n_rounds=1) sits clearly off the converged mean; - late rounds move the estimate far less than early rounds (it converges ~1/sqrt(R)); - the tail is stable (the last rounds barely change the estimate). On DOM_GSEC the estimate stabilizes (incremental change < 2%) around n_rounds 15-20; n_rounds=1 stays the fast unbiased point estimate, higher n_rounds buys a stable mean. Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/unit/shap_model_tests/test_sm_fit.py | 24 ++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/unit/shap_model_tests/test_sm_fit.py b/tests/unit/shap_model_tests/test_sm_fit.py index 6f120dc7..47b6b0aa 100644 --- a/tests/unit/shap_model_tests/test_sm_fit.py +++ b/tests/unit/shap_model_tests/test_sm_fit.py @@ -484,6 +484,30 @@ def variance(reps, n_rounds): return np.mean(np.var(np.stack(runs), axis=0)) assert variance(6, 12) < variance(6, 1) + # The per-round average is a bootstrap mean: it converges to a stable value as n_rounds grows, + # so a single round (n_rounds=1) sits well off the mean while late rounds barely move it. + def test_interpolate_converges_with_n_rounds(self): + # n_rounds=R (fixed base seed) == cumulative mean of per-round blends S_0..S_{R-1}, + # where S_i = a single-round fit at seed base+i. Compute the 25 blends once. + R_MAX = 25 + base = 42 + blends = np.stack([ + aa.ShapModel(**ONE_MODEL, verbose=False, random_state=base + i).fit( + SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, + fuzzy_aggregation="interpolate", n_rounds=1).shap_values + for i in range(R_MAX)]) + means = np.stack([blends[:R].mean(axis=0).ravel() for R in range(1, R_MAX + 1)]) # M_1..M_25 + final = means[-1] + norm = np.linalg.norm(final) + d_final = np.linalg.norm(means - final, axis=1) / norm # distance to the stable mean + d_incr = np.linalg.norm(np.diff(means, axis=0), axis=1) / norm # per-round change (len R_MAX-1) + # 1) a single round is clearly off the converged mean (this is why averaging exists) + assert d_final[0] > 0.05 + # 2) the average converges: late rounds move it far less than early rounds + assert np.mean(d_incr[-5:]) < 0.5 * np.mean(d_incr[:5]) + # 3) it is stable by the tail: the last few rounds barely change the estimate + assert d_final[-5] < 0.1 + # Multi-fuzzy: each fuzzy protein explained independently against the core (baseline + 2 per fuzzy) def test_interpolate_multi_fuzzy(self): from aaanalysis.explainable_ai_pro._backend.shap_model import shap_model_fit as B From e90cb3a140de5c95939b0530282fb8644fbaae9c Mon Sep 17 00:00:00 2001 From: Stephan Breimann Date: Thu, 25 Jun 2026 13:50:46 +0200 Subject: [PATCH 5/8] feat(shap): make interpolate the default fuzzy estimator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fuzzy_aggregation now defaults to "interpolate" (was "threshold"); the legacy biased threshold sweep stays available via fuzzy_aggregation="threshold". n_rounds is now Optional[int]=None, resolving to a per-estimator natural default: 1 for interpolate (exact in a single round) and 5 otherwise (the threshold sweep and the non-fuzzy Monte-Carlo need several rounds). So the default fuzzy estimate is the exact two-fit blend — ~2x faster than the v1.0 default on the same cell — while n_rounds>1 averages re-seeded fits into a reproducible Monte-Carlo mean that converges around n_rounds~15-20. Updates the Notes/param docstrings, the CONTEXT.md glossary, the release notes, and the example notebook (now demos the threshold opt-in + n_rounds averaging). Tests: default-is-interpolate, n_rounds=None natural-default resolution; the threshold branch-coverage test pins fuzzy_aggregation="threshold" explicitly. Co-Authored-By: Claude Opus 4.8 (1M context) --- CONTEXT.md | 2 +- aaanalysis/explainable_ai_pro/_shap_model.py | 41 ++- docs/source/index/release_notes.rst | 14 +- examples/explainable_ai/sm_fit.ipynb | 305 +++++++++--------- tests/unit/shap_model_tests/test_sm_branch.py | 2 +- tests/unit/shap_model_tests/test_sm_fit.py | 24 +- 6 files changed, 211 insertions(+), 177 deletions(-) diff --git a/CONTEXT.md b/CONTEXT.md index c2287272..de7783d2 100644 --- a/CONTEXT.md +++ b/CONTEXT.md @@ -562,7 +562,7 @@ The uniform boolean toggle on the `CPPPlot` family (`profile`, `heatmap`, `ranki _Avoid_: shap_mode, use_shap, sample_plot. **fuzzy aggregation** (`fuzzy_aggregation`): -The strategy `ShapModel.fit` uses to turn a soft label `p` ∈ (0, 1) into a SHAP estimate when **fuzzy labeling** is active. `"threshold"` (default) hard-labels the fuzzy sample `1` across a non-uniform `n_rounds × n_selection` threshold grid and averages — a **biased** approximation whose effective positive-fraction is the grid's `frac1`, not `p`. `"interpolate"` fits the model twice (fuzzy sample at 0 → `S0`, at 1 → `S1`) and blends `p·S1 + (1−p)·S0` — the **unbiased** exact-`p` estimate, and at `n_rounds=1` the fastest one (exactly 2 fits per fuzzy sample). Each fuzzy protein is explained independently against the fixed balanced 0/1 **core**, with the other fuzzy proteins excluded from that run's training data. +The strategy `ShapModel.fit` uses to turn a soft label `p` ∈ (0, 1) into a SHAP estimate when **fuzzy labeling** is active. `"interpolate"` (default) fits the model twice (fuzzy sample at 0 → `S0`, at 1 → `S1`) and blends `p·S1 + (1−p)·S0` — the **unbiased** exact-`p` estimate, and with its default `n_rounds=1` the fastest one (exactly 2 fits per fuzzy sample); `n_rounds>1` is a Monte-Carlo mean over re-seeded fits that converges around `n_rounds≈15–20`. `"threshold"` (the legacy v1.0 default) hard-labels the fuzzy sample `1` across a non-uniform `n_rounds × n_selection` grid (default `n_rounds=5`) and averages — a **biased** approximation whose effective positive-fraction is the grid's `frac1`, not `p`. Each fuzzy protein is explained independently against the fixed balanced 0/1 **core**, with the other fuzzy proteins excluded from that run's training data. `n_rounds` defaults per estimator (`None` → 1 for interpolate, 5 otherwise). _Avoid_: fuzzy mode, blend mode, soft-label aggregation. ### Scale-set vocabulary diff --git a/aaanalysis/explainable_ai_pro/_shap_model.py b/aaanalysis/explainable_ai_pro/_shap_model.py index 2d06fad1..3ad44308 100644 --- a/aaanalysis/explainable_ai_pro/_shap_model.py +++ b/aaanalysis/explainable_ai_pro/_shap_model.py @@ -470,10 +470,10 @@ def fit(self, X: ut.ArrayLike2D, labels: ut.ArrayLike1D, label_target_class: int = 1, - n_rounds: int = 5, + n_rounds: Optional[int] = None, is_selected: Optional[ut.ArrayLike2D] = None, fuzzy_labeling: bool = False, - fuzzy_aggregation: str = "threshold", + fuzzy_aggregation: str = "interpolate", n_background_data: Optional[int] = None, df_seq: Optional[pd.DataFrame] = None, fuzzy_labels: Optional[dict] = None, @@ -499,18 +499,22 @@ def fit(self, label_target_class : int, default=1 The label of the class for which SHAP values are computed in a classification tasks. For binary classification, '0' represents the negative class and '1' the positive class. - n_rounds : int, default=5 + n_rounds : int, optional The number of rounds (>=1) to fit the models and obtain the SHAP values by explainer. + If ``None``, a per-estimator natural default is used: ``1`` for + ``fuzzy_aggregation='interpolate'`` (exact in a single round) and ``5`` otherwise + (the threshold sweep and the non-fuzzy Monte-Carlo estimate need several rounds). is_selected : array-like, shape (n_selection_round, n_features) 2D boolean arrays indicating different feature selections. fuzzy_labeling : bool, default=False If ``True``, fuzzy labeling is applied to approximate SHAP values for samples with uncertain/partial memberships (e.g., between >0 and <1 for binary classification scenarios). - fuzzy_aggregation : str, default='threshold' + fuzzy_aggregation : str, default='interpolate' Strategy to turn a soft label ``p`` into a SHAP estimate when fuzzy labeling is active (see Notes): - - ``'threshold'``: hard-label the fuzzy sample over a threshold grid and average (biased; the default). - - ``'interpolate'``: blend ``p * S1 + (1 - p) * S0`` from a fit at 0 and at 1 (unbiased, exact ``p``). + - ``'interpolate'`` (default): blend ``p * S1 + (1 - p) * S0`` from a fit at 0 and at 1 + (unbiased, exact ``p``; with ``n_rounds=1`` only two fits per fuzzy sample). + - ``'threshold'``: hard-label the fuzzy sample over a threshold grid and average (biased; legacy default). n_background_data : None or int, optional The number samples (< 'n_samples') in the background dataset used for the `KernelExplainer`` to reduce @@ -542,16 +546,17 @@ def fit(self, The ``fuzzy_aggregation`` parameter selects how a soft label ``p`` (in [0, 1]) is turned into a SHAP estimate: - * ``'threshold'`` (default): Over an ``n_rounds`` x ``n_selection`` grid the fuzzy sample is hard-labeled ``1`` - when a per-cell threshold ``<= p`` and the per-cell SHAP matrices are averaged. Because the threshold grid is - non-uniform on (0, 1], the effective positive-fraction is a *biased* approximation of ``p``. - * ``'interpolate'``: The fuzzy sample is weighted by *exactly* ``p`` by fitting the model twice (fuzzy sample at - 0 -> ``S0``, at 1 -> ``S1``) and blending ``p * S1 + (1 - p) * S0`` (the ``exp_value`` is blended the same way). - This is *unbiased* and, with ``n_rounds=1``, the fastest fuzzy estimator (exactly two fits per fuzzy sample). - ``n_rounds > 1`` averages re-seeded fits per round (``random_state + round``), capturing model variance while - staying reproducible for a fixed ``random_state``. Each fuzzy protein is explained independently against the - fixed balanced 0/1 core, with the other fuzzy proteins excluded from its training data. Recommended for the - "explain newly-predicted proteins" path. + * ``'interpolate'`` (default): The fuzzy sample is weighted by *exactly* ``p`` by fitting the model twice (fuzzy + sample at 0 -> ``S0``, at 1 -> ``S1``) and blending ``p * S1 + (1 - p) * S0`` (the ``exp_value`` is blended the + same way). This is *unbiased* and, with the default ``n_rounds=1``, the fastest fuzzy estimator (exactly two + fits per fuzzy sample). ``n_rounds > 1`` averages re-seeded fits per round (``random_state + round``), a + Monte-Carlo mean over model variance that converges as ``n_rounds`` grows while staying reproducible for a + fixed ``random_state``. Each fuzzy protein is explained independently against the fixed balanced 0/1 core, with + the other fuzzy proteins excluded from its training data. + * ``'threshold'``: Over an ``n_rounds`` x ``n_selection`` grid (default ``n_rounds=5``) the fuzzy sample is + hard-labeled ``1`` when a per-cell threshold ``<= p`` and the per-cell SHAP matrices are averaged. Because the + threshold grid is non-uniform on (0, 1], the effective positive-fraction is a *biased* approximation of ``p``. + This was the default in v1.0; kept for backward-compatible results. **Setting soft labels** @@ -592,6 +597,10 @@ def fit(self, check_match_labels_target_class_labels(label_target_class=label_target_class, labels=labels) is_selected = check_is_selected(is_selected=is_selected, n_feat=n_feat) check_match_X_is_selected(X=X, is_selected=is_selected) + # Resolve the per-estimator natural default: interpolate is exact in a single round, + # while the threshold sweep (and the non-fuzzy Monte-Carlo) need several rounds. + if n_rounds is None: + n_rounds = 1 if (fuzzy_labeling and fuzzy_aggregation == "interpolate") else 5 ut.check_number_range(name="n_rounds", val=n_rounds, min_val=1, just_int=True) ut.check_number_range(name="n_background_data", val=n_background_data, min_val=1, just_int=True, accept_none=True) check_match_n_background_data_X(n_background_data=n_background_data, X=X) diff --git a/docs/source/index/release_notes.rst b/docs/source/index/release_notes.rst index 31015939..2a23a264 100644 --- a/docs/source/index/release_notes.rst +++ b/docs/source/index/release_notes.rst @@ -85,13 +85,15 @@ Added ``add_feat_impact`` / ``add_sample_mean_dif`` accept ``df_seq`` and a ``samples`` parameter taking row positions or entry names. The array-``labels`` path is unchanged; ``sample_positions`` is a deprecated alias for ``samples`` (removed in 1.2.0). -- **ShapModel — unbiased fuzzy estimator** (``[pro]``): ``fit`` gains - ``fuzzy_aggregation`` (default ``'threshold'``, unchanged). ``'interpolate'`` weights a +- **ShapModel — unbiased fuzzy estimator, now the default** (``[pro]``): ``fit`` gains + ``fuzzy_aggregation``, defaulting to the new ``'interpolate'`` estimator. It weights a soft label by *exactly* ``p`` — fitting at 0 (``S0``) and at 1 (``S1``) and blending - ``p * S1 + (1 - p) * S0`` — instead of the biased threshold sweep. With ``n_rounds=1`` - it needs only two fits per fuzzy sample; ``n_rounds > 1`` averages per-round re-seeded - fits (reproducible for a fixed ``random_state``). Recommended for explaining - newly-predicted proteins. + ``p * S1 + (1 - p) * S0`` — instead of the biased threshold sweep used in v1.0 + (still available via ``fuzzy_aggregation='threshold'``). ``n_rounds`` now defaults per + estimator (``None`` → ``1`` for ``interpolate``, ``5`` otherwise): the default fuzzy + estimate is the exact two-fit blend, ~2x faster than the v1.0 default on the same cell, + while ``n_rounds > 1`` averages per-round re-seeded fits into a reproducible Monte-Carlo + mean that converges around ``n_rounds ≈ 15–20``. **Sequence Analysis** diff --git a/examples/explainable_ai/sm_fit.ipynb b/examples/explainable_ai/sm_fit.ipynb index 55edce3b..ad07e21e 100644 --- a/examples/explainable_ai/sm_fit.ipynb +++ b/examples/explainable_ai/sm_fit.ipynb @@ -21,10 +21,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T03:49:04.837887Z", - "iopub.status.busy": "2026-06-25T03:49:04.837819Z", - "iopub.status.idle": "2026-06-25T03:49:06.341129Z", - "shell.execute_reply": "2026-06-25T03:49:06.340905Z" + "iopub.execute_input": "2026-06-25T11:49:27.299223Z", + "iopub.status.busy": "2026-06-25T11:49:27.298923Z", + "iopub.status.idle": "2026-06-25T11:49:28.740060Z", + "shell.execute_reply": "2026-06-25T11:49:28.739722Z" } }, "outputs": [ @@ -40,105 +40,105 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 entrysequencelabeltmd_starttmd_stopjmd_ntmdjmd_centrysequencelabeltmd_starttmd_stopjmd_ntmdjmd_c
1Q14802MQKVTLGLLVFLAGF...PGETPPLITPGSAQS03759NSPFYYDWHSLQVGGLICAGVLCAMGIIIVMSAKCKCKFGQKS1Q14802MQKVTLGLLVFLAGF...PGETPPLITPGSAQS03759NSPFYYDWHSLQVGGLICAGVLCAMGIIIVMSAKCKCKFGQKS
2Q86UE4MAARSWQDELAQQAE...SPKQIKKKKKARRET05072LGLEPKRYPGWVILVGTGALGLLLLFLLGYGWAAACAGARKKR2Q86UE4MAARSWQDELAQQAE...SPKQIKKKKKARRET05072LGLEPKRYPGWVILVGTGALGLLLLFLLGYGWAAACAGARKKR
3Q969W9MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL04163FQSMEITELEFVQIIIIVVVMMVMVVVITCLLSHYKLSARSFI3Q969W9MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL04163FQSMEITELEFVQIIIIVVVMMVMVVVITCLLSHYKLSARSFI
4P05067MLPGLALLLLAAWTA...GYENPTYKFFEQMQN1701723FAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHH4P05067MLPGLALLLLAAWTA...GYENPTYKFFEQMQN1701723FAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHH
5P14925MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS1868890KLSTEPGSGVSVVLITTLLVIPVLVLLAIVMFIRWKKSRAFGD5P14925MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS1868890KLSTEPGSGVSVVLITTLLVIPVLVLLAIVMFIRWKKSRAFGD
6P70180MRSLLLFTFSACVLL...RELREDSIRSHFSVA1477499PCKSSGGLEESAVTGIVVGALLGAGLLMAFYFFRKKYRITIER6P70180MRSLLLFTFSACVLL...RELREDSIRSHFSVA1477499PCKSSGGLEESAVTGIVVGALLGAGLLMAFYFFRKKYRITIER
\n" @@ -188,10 +188,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T03:49:06.342198Z", - "iopub.status.busy": "2026-06-25T03:49:06.342114Z", - "iopub.status.idle": "2026-06-25T03:49:06.710977Z", - "shell.execute_reply": "2026-06-25T03:49:06.710741Z" + "iopub.execute_input": "2026-06-25T11:49:28.741284Z", + "iopub.status.busy": "2026-06-25T11:49:28.741204Z", + "iopub.status.idle": "2026-06-25T11:49:29.098530Z", + "shell.execute_reply": "2026-06-25T11:49:29.098307Z" } }, "outputs": [ @@ -200,16 +200,16 @@ "output_type": "stream", "text": [ "SHAP values explain the feature impact for 3 negative and 3 positive samples\n", - "[[-0.11 -0.11 -0.09 -0.09 -0.07]\n", - " [-0.12 -0.12 -0.09 -0.09 -0.07]\n", - " [-0.14 -0.13 -0.04 -0.08 -0.01]\n", - " [ 0.13 0.12 0.05 0.09 0.04]\n", - " [ 0.13 0.11 0.08 0.09 0.07]\n", - " [ 0.13 0.12 0.08 0.09 0.06]]\n", + "[[-0.09 -0.1 -0.09 -0.1 -0.08]\n", + " [-0.11 -0.13 -0.08 -0.1 -0.08]\n", + " [-0.13 -0.15 -0.04 -0.1 -0.02]\n", + " [ 0.12 0.13 0.05 0.1 0.04]\n", + " [ 0.11 0.12 0.07 0.1 0.08]\n", + " [ 0.12 0.12 0.07 0.1 0.06]]\n", "\n", "The expected value approximates the expected model output (average prediction score).\n", "For a binary classification with balanced datasets, it is around 0.5:\n", - "0.5061666666666669\n" + "0.5011666666666669\n" ] } ], @@ -250,10 +250,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T03:49:06.712150Z", - "iopub.status.busy": "2026-06-25T03:49:06.712084Z", - "iopub.status.idle": "2026-06-25T03:49:07.084512Z", - "shell.execute_reply": "2026-06-25T03:49:07.084272Z" + "iopub.execute_input": "2026-06-25T11:49:29.099686Z", + "iopub.status.busy": "2026-06-25T11:49:29.099623Z", + "iopub.status.idle": "2026-06-25T11:49:29.452810Z", + "shell.execute_reply": "2026-06-25T11:49:29.452597Z" } }, "outputs": [ @@ -262,15 +262,15 @@ "output_type": "stream", "text": [ "Reverse sign of SHAP values by changing reference class from 1 to 0\n", - "[[ 0.1 0.11 0.08 0.1 0.07]\n", - " [ 0.12 0.12 0.08 0.09 0.08]\n", - " [ 0.14 0.14 0.03 0.09 0.03]\n", - " [-0.12 -0.14 -0.05 -0.09 -0.04]\n", - " [-0.12 -0.13 -0.08 -0.1 -0.07]\n", - " [-0.12 -0.13 -0.08 -0.1 -0.07]]\n", + "[[ 0.1 0.11 0.09 0.09 0.08]\n", + " [ 0.11 0.12 0.09 0.08 0.08]\n", + " [ 0.13 0.14 0.04 0.08 0.02]\n", + " [-0.13 -0.13 -0.06 -0.08 -0.04]\n", + " [-0.12 -0.12 -0.09 -0.08 -0.08]\n", + " [-0.13 -0.13 -0.09 -0.08 -0.07]]\n", "\n", "Base value stays around 0.5:\n", - "0.5048333333333336\n" + "0.5083333333333335\n" ] } ], @@ -309,10 +309,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T03:49:07.085712Z", - "iopub.status.busy": "2026-06-25T03:49:07.085645Z", - "iopub.status.idle": "2026-06-25T03:49:07.761961Z", - "shell.execute_reply": "2026-06-25T03:49:07.761631Z" + "iopub.execute_input": "2026-06-25T11:49:29.453905Z", + "iopub.status.busy": "2026-06-25T11:49:29.453845Z", + "iopub.status.idle": "2026-06-25T11:49:30.101114Z", + "shell.execute_reply": "2026-06-25T11:49:30.100821Z" } }, "outputs": [], @@ -342,10 +342,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T03:49:07.763450Z", - "iopub.status.busy": "2026-06-25T03:49:07.763356Z", - "iopub.status.idle": "2026-06-25T03:49:08.445470Z", - "shell.execute_reply": "2026-06-25T03:49:08.445237Z" + "iopub.execute_input": "2026-06-25T11:49:30.102412Z", + "iopub.status.busy": "2026-06-25T11:49:30.102352Z", + "iopub.status.idle": "2026-06-25T11:49:30.748064Z", + "shell.execute_reply": "2026-06-25T11:49:30.747796Z" } }, "outputs": [ @@ -354,12 +354,12 @@ "output_type": "stream", "text": [ "Impact of feature pre-selection\n", - "[[-0.19 -0.17 -0.05 -0.05 0. ]\n", - " [-0.2 -0.19 -0.05 -0.05 0. ]\n", - " [-0.21 -0.2 -0.02 -0.05 0. ]\n", - " [ 0.2 0.2 0.03 0.05 0. ]\n", - " [ 0.2 0.19 0.05 0.05 0. ]\n", - " [ 0.2 0.19 0.05 0.05 0. ]]\n" + "[[-0.17 -0.19 -0.05 -0.05 0. ]\n", + " [-0.19 -0.2 -0.05 -0.05 0. ]\n", + " [-0.19 -0.2 -0.02 -0.05 0. ]\n", + " [ 0.2 0.2 0.03 0.04 0. ]\n", + " [ 0.19 0.2 0.05 0.05 0. ]\n", + " [ 0.19 0.2 0.05 0.05 0. ]]\n" ] } ], @@ -395,10 +395,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T03:49:08.446668Z", - "iopub.status.busy": "2026-06-25T03:49:08.446593Z", - "iopub.status.idle": "2026-06-25T03:49:09.130957Z", - "shell.execute_reply": "2026-06-25T03:49:09.130741Z" + "iopub.execute_input": "2026-06-25T11:49:30.749257Z", + "iopub.status.busy": "2026-06-25T11:49:30.749196Z", + "iopub.status.idle": "2026-06-25T11:49:31.047878Z", + "shell.execute_reply": "2026-06-25T11:49:31.047598Z" } }, "outputs": [ @@ -407,12 +407,12 @@ "output_type": "stream", "text": [ "First sample is labeled as 0.5 between negative (0) and positive (1)\n", - "[[ 0.03 0.04 -0.03 -0.03 0. ]\n", - " [-0.24 -0.25 -0.05 -0.05 0. ]\n", - " [-0.22 -0.22 -0.02 -0.04 0. ]\n", - " [ 0.15 0.15 0.03 0.04 0. ]\n", - " [ 0.14 0.15 0.04 0.04 0. ]\n", - " [ 0.14 0.15 0.04 0.04 0. ]]\n" + "[[-0.03 -0.05 -0.03 -0.02 0. ]\n", + " [-0.24 -0.24 -0.04 -0.04 0. ]\n", + " [-0.22 -0.22 -0.02 -0.03 0. ]\n", + " [ 0.17 0.18 0.02 0.03 0. ]\n", + " [ 0.17 0.18 0.03 0.03 0. ]\n", + " [ 0.17 0.18 0.03 0.03 0. ]]\n" ] } ], @@ -431,7 +431,7 @@ "id": "445a9b21", "metadata": {}, "source": [ - "The ``fuzzy_aggregation`` parameter selects how a soft label ``p`` is turned into a SHAP estimate. The default ``'threshold'`` hard-labels the sample across a threshold grid (a biased approximation of ``p``), while ``'interpolate'`` blends ``p * S1 + (1 - p) * S0`` from a fit at 0 and at 1 — an unbiased, exact-``p`` estimate that needs only two fits per fuzzy sample at ``n_rounds=1``:" + "By default ``fuzzy_aggregation='interpolate'`` weights the fuzzy sample by *exactly* ``p`` (the cell above already used it: two fits blended as ``p*S1 + (1-p)*S0``). The legacy v1.0 ``'threshold'`` sweep stays available for backward-compatible results, and ``n_rounds > 1`` averages re-seeded fits into a Monte-Carlo mean that converges around ``n_rounds ~ 15-20``:" ] }, { @@ -440,10 +440,10 @@ "id": "7312979a", "metadata": { "execution": { - "iopub.execute_input": "2026-06-25T03:49:09.132049Z", - "iopub.status.busy": "2026-06-25T03:49:09.131987Z", - "iopub.status.idle": "2026-06-25T03:49:09.448718Z", - "shell.execute_reply": "2026-06-25T03:49:09.448466Z" + "iopub.execute_input": "2026-06-25T11:49:31.049248Z", + "iopub.status.busy": "2026-06-25T11:49:31.049169Z", + "iopub.status.idle": "2026-06-25T11:49:35.259909Z", + "shell.execute_reply": "2026-06-25T11:49:35.259678Z" } }, "outputs": [ @@ -451,24 +451,29 @@ "name": "stdout", "output_type": "stream", "text": [ - "Exact-p interpolation for the fuzzy first sample (label=0.5)\n", - "[[-0.04 -0.03 -0.02 -0.02 0. ]\n", - " [-0.24 -0.24 -0.04 -0.05 0. ]\n", - " [-0.22 -0.23 -0.03 -0.04 0. ]\n", - " [ 0.17 0.18 0.01 0.03 0. ]\n", - " [ 0.17 0.18 0.03 0.03 0. ]\n", - " [ 0.17 0.18 0.03 0.03 0. ]]\n" + "Legacy threshold estimate (first, fuzzy sample):\n", + "[ 0.03 0.04 -0.03 -0.03 0. ]\n", + "\n", + "Converged interpolate mean over 15 rounds (first, fuzzy sample):\n", + "[-0.03 -0.04 -0.02 -0.02 0. ]\n" ] } ], "source": [ - "# Unbiased exact-p estimate: blend p*S1 + (1-p)*S0 (the fuzzy first sample has label 0.5)\n", - "sm = aa.ShapModel(random_state=42)\n", - "sm = sm.fit(X, labels=labels, is_selected=is_selected,\n", - " fuzzy_labeling=True, fuzzy_aggregation=\"interpolate\", n_rounds=1)\n", + "# Legacy biased threshold estimator (the v1.0 default), kept for backward-compatible results\n", + "sm_threshold = aa.ShapModel(random_state=42)\n", + "sm_threshold = sm_threshold.fit(X, labels=labels, is_selected=is_selected,\n", + " fuzzy_labeling=True, fuzzy_aggregation=\"threshold\")\n", "\n", - "print(\"Exact-p interpolation for the fuzzy first sample (label=0.5)\")\n", - "print(sm.shap_values.round(2))" + "# Converged interpolate mean: average more re-seeded rounds (default n_rounds=1 is the exact 2-fit blend)\n", + "sm_converged = aa.ShapModel(random_state=42)\n", + "sm_converged = sm_converged.fit(X, labels=labels, is_selected=is_selected,\n", + " fuzzy_labeling=True, n_rounds=15)\n", + "\n", + "print(\"Legacy threshold estimate (first, fuzzy sample):\")\n", + "print(sm_threshold.shap_values[0].round(2))\n", + "print(\"\\nConverged interpolate mean over 15 rounds (first, fuzzy sample):\")\n", + "print(sm_converged.shap_values[0].round(2))" ] }, { @@ -485,10 +490,10 @@ "id": "0b3fd3f3", "metadata": { "execution": { - "iopub.execute_input": "2026-06-25T03:49:09.449821Z", - "iopub.status.busy": "2026-06-25T03:49:09.449755Z", - "iopub.status.idle": "2026-06-25T03:49:10.138026Z", - "shell.execute_reply": "2026-06-25T03:49:10.137745Z" + "iopub.execute_input": "2026-06-25T11:49:35.261033Z", + "iopub.status.busy": "2026-06-25T11:49:35.260972Z", + "iopub.status.idle": "2026-06-25T11:49:35.557346Z", + "shell.execute_reply": "2026-06-25T11:49:35.557107Z" } }, "outputs": [ @@ -497,12 +502,12 @@ "output_type": "stream", "text": [ "Sample 'Q14802' is labeled as 0.5 between negative (0) and positive (1)\n", - "[[ 0.03 0.03 -0.03 -0.03 0. ]\n", - " [-0.25 -0.24 -0.05 -0.05 0. ]\n", - " [-0.23 -0.21 -0.02 -0.04 0. ]\n", - " [ 0.16 0.15 0.02 0.04 0. ]\n", - " [ 0.16 0.14 0.04 0.04 0. ]\n", - " [ 0.16 0.14 0.04 0.04 0. ]]\n" + "[[-0.03 -0.03 -0.03 -0.02 0. ]\n", + " [-0.23 -0.23 -0.05 -0.05 0. ]\n", + " [-0.21 -0.23 -0.02 -0.04 0. ]\n", + " [ 0.17 0.17 0.02 0.03 0. ]\n", + " [ 0.17 0.17 0.03 0.04 0. ]\n", + " [ 0.17 0.17 0.03 0.04 0. ]]\n" ] } ], @@ -538,10 +543,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T03:49:10.139326Z", - "iopub.status.busy": "2026-06-25T03:49:10.139243Z", - "iopub.status.idle": "2026-06-25T03:49:10.142046Z", - "shell.execute_reply": "2026-06-25T03:49:10.141841Z" + "iopub.execute_input": "2026-06-25T11:49:35.558535Z", + "iopub.status.busy": "2026-06-25T11:49:35.558456Z", + "iopub.status.idle": "2026-06-25T11:49:35.560958Z", + "shell.execute_reply": "2026-06-25T11:49:35.560751Z" } }, "outputs": [], diff --git a/tests/unit/shap_model_tests/test_sm_branch.py b/tests/unit/shap_model_tests/test_sm_branch.py index 88a8055b..f8cba032 100644 --- a/tests/unit/shap_model_tests/test_sm_branch.py +++ b/tests/unit/shap_model_tests/test_sm_branch.py @@ -103,7 +103,7 @@ def test_fuzzy_valid_threshold_path(self): X, _ = _small_data(n_samples=12) labels = np.array([1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0.5], dtype=float) sm = aa.ShapModel(**MODEL_KWARGS, verbose=False) - sm.fit(X, labels=labels, fuzzy_labeling=True, **ARGS) + sm.fit(X, labels=labels, fuzzy_labeling=True, fuzzy_aggregation="threshold", **ARGS) assert sm.shap_values.shape == X.shape def test_fuzzy_label_out_of_range_raises(self): diff --git a/tests/unit/shap_model_tests/test_sm_fit.py b/tests/unit/shap_model_tests/test_sm_fit.py index 47b6b0aa..b5ffbd6d 100644 --- a/tests/unit/shap_model_tests/test_sm_fit.py +++ b/tests/unit/shap_model_tests/test_sm_fit.py @@ -528,15 +528,33 @@ def spy(*a, **k): finally: B._compute_shap_values = orig - # No regression: the default 'threshold' path is unchanged by the routing refactor - def test_threshold_default_unchanged(self): + # 'interpolate' is the default estimator; explicit selection matches the default + def test_default_is_interpolate(self): default = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, n_rounds=3).shap_values explicit = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, - fuzzy_aggregation="threshold", n_rounds=3).shap_values + fuzzy_aggregation="interpolate", n_rounds=3).shap_values assert np.array_equal(default, explicit) + # n_rounds=None resolves to the per-estimator natural default: 1 for interpolate, 5 for threshold + def test_n_rounds_none_natural_default(self): + # interpolate default-rounds == explicit n_rounds=1 + auto_i = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( + SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True).shap_values + explicit_i1 = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( + SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, + fuzzy_aggregation="interpolate", n_rounds=1).shap_values + assert np.array_equal(auto_i, explicit_i1) + # threshold default-rounds == explicit n_rounds=5 + auto_t = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( + SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, + fuzzy_aggregation="threshold").shap_values + explicit_t5 = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( + SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, + fuzzy_aggregation="threshold", n_rounds=5).shap_values + assert np.array_equal(auto_t, explicit_t5) + # fuzzy_aggregation is inert when fuzzy labeling is off (binary path untouched) def test_fuzzy_aggregation_inert_without_fuzzy(self): labels = [1, 1, 1, 0, 0, 0, 0] # all binary -> no fuzzy sample From 4e8e47e49fb3b4cc5029c9e9ce90c017843ca4eb Mon Sep 17 00:00:00 2001 From: Stephan Breimann Date: Thu, 25 Jun 2026 15:59:37 +0200 Subject: [PATCH 6/8] refactor(shap): plain n_rounds=5; keep threshold + interpolate first-class MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reverts the per-estimator n_rounds=None magic for a simpler, single shared default. fuzzy_aggregation selects two first-class estimators: the cited threshold sweep ([Breimann25]) and the new unbiased interpolate (default, v1.1); threshold is kept (not deprecated) and stays faithful to its published n_rounds=5 grid. n_rounds is a plain int=5 (no None resolution): no regression to the threshold or non-fuzzy paths, and for interpolate it is a documented speed/stability dial — n_rounds=1 the fast exact two-fit estimate, 5 (default) light averaging, ~15-20 the converged Monte-Carlo mean (run-to-run spread <5% on DOM_GSEC). The n_rounds reasoning + g-secretase convergence are documented in the fit Notes, the CONTEXT.md glossary, and the release notes. Co-Authored-By: Claude Opus 4.8 (1M context) --- CONTEXT.md | 2 +- aaanalysis/explainable_ai_pro/_shap_model.py | 49 ++-- docs/source/index/release_notes.rst | 12 +- examples/explainable_ai/sm_fit.ipynb | 272 +++++++++---------- tests/unit/shap_model_tests/test_sm_fit.py | 26 +- 5 files changed, 180 insertions(+), 181 deletions(-) diff --git a/CONTEXT.md b/CONTEXT.md index de7783d2..e2b1fdab 100644 --- a/CONTEXT.md +++ b/CONTEXT.md @@ -562,7 +562,7 @@ The uniform boolean toggle on the `CPPPlot` family (`profile`, `heatmap`, `ranki _Avoid_: shap_mode, use_shap, sample_plot. **fuzzy aggregation** (`fuzzy_aggregation`): -The strategy `ShapModel.fit` uses to turn a soft label `p` ∈ (0, 1) into a SHAP estimate when **fuzzy labeling** is active. `"interpolate"` (default) fits the model twice (fuzzy sample at 0 → `S0`, at 1 → `S1`) and blends `p·S1 + (1−p)·S0` — the **unbiased** exact-`p` estimate, and with its default `n_rounds=1` the fastest one (exactly 2 fits per fuzzy sample); `n_rounds>1` is a Monte-Carlo mean over re-seeded fits that converges around `n_rounds≈15–20`. `"threshold"` (the legacy v1.0 default) hard-labels the fuzzy sample `1` across a non-uniform `n_rounds × n_selection` grid (default `n_rounds=5`) and averages — a **biased** approximation whose effective positive-fraction is the grid's `frac1`, not `p`. Each fuzzy protein is explained independently against the fixed balanced 0/1 **core**, with the other fuzzy proteins excluded from that run's training data. `n_rounds` defaults per estimator (`None` → 1 for interpolate, 5 otherwise). +The strategy `ShapModel.fit` selects to turn a soft label `p` ∈ (0, 1) into a SHAP estimate when **fuzzy labeling** is active. `"interpolate"` (default, new in v1.1) fits the model twice (fuzzy sample at 0 → `S0`, at 1 → `S1`) and blends `p·S1 + (1−p)·S0` — the **unbiased** exact-`p` estimate. `"threshold"` (the `Breimann25` sweep) hard-labels the fuzzy sample `1` across a non-uniform `n_rounds × n_selection` grid and averages — a **biased** approximation whose effective positive-fraction is the grid's `frac1`, not `p`; kept as a first-class option. Each fuzzy protein is explained independently against the fixed balanced 0/1 **core**, with the other fuzzy proteins excluded from that run's training data. `n_rounds` (default `5`) is interpolate's speed/stability dial: `1` = fast exact two-fit estimate, `5` = light averaging, `≈15–20` = converged Monte-Carlo mean (run-to-run spread <5% on `DOM_GSEC`). _Avoid_: fuzzy mode, blend mode, soft-label aggregation. ### Scale-set vocabulary diff --git a/aaanalysis/explainable_ai_pro/_shap_model.py b/aaanalysis/explainable_ai_pro/_shap_model.py index 3ad44308..b119a84c 100644 --- a/aaanalysis/explainable_ai_pro/_shap_model.py +++ b/aaanalysis/explainable_ai_pro/_shap_model.py @@ -470,7 +470,7 @@ def fit(self, X: ut.ArrayLike2D, labels: ut.ArrayLike1D, label_target_class: int = 1, - n_rounds: Optional[int] = None, + n_rounds: int = 5, is_selected: Optional[ut.ArrayLike2D] = None, fuzzy_labeling: bool = False, fuzzy_aggregation: str = "interpolate", @@ -499,11 +499,11 @@ def fit(self, label_target_class : int, default=1 The label of the class for which SHAP values are computed in a classification tasks. For binary classification, '0' represents the negative class and '1' the positive class. - n_rounds : int, optional + n_rounds : int, default=5 The number of rounds (>=1) to fit the models and obtain the SHAP values by explainer. - If ``None``, a per-estimator natural default is used: ``1`` for - ``fuzzy_aggregation='interpolate'`` (exact in a single round) and ``5`` otherwise - (the threshold sweep and the non-fuzzy Monte-Carlo estimate need several rounds). + For ``fuzzy_aggregation='interpolate'`` each round re-seeds the fit, so ``n_rounds`` is a + speed/stability dial (see Notes): ``1`` is the fast exact two-fit estimate, the default + ``5`` adds Monte-Carlo averaging, and a stable mean is reached around ``15-20``. is_selected : array-like, shape (n_selection_round, n_features) 2D boolean arrays indicating different feature selections. fuzzy_labeling : bool, default=False @@ -512,9 +512,10 @@ def fit(self, fuzzy_aggregation : str, default='interpolate' Strategy to turn a soft label ``p`` into a SHAP estimate when fuzzy labeling is active (see Notes): - - ``'interpolate'`` (default): blend ``p * S1 + (1 - p) * S0`` from a fit at 0 and at 1 - (unbiased, exact ``p``; with ``n_rounds=1`` only two fits per fuzzy sample). - - ``'threshold'``: hard-label the fuzzy sample over a threshold grid and average (biased; legacy default). + - ``'interpolate'`` (default, new in 1.1): blend ``p * S1 + (1 - p) * S0`` from a fit at + 0 and at 1 (unbiased, exact ``p``; with ``n_rounds=1`` only two fits per fuzzy sample). + - ``'threshold'``: hard-label the fuzzy sample over a threshold grid and average — the + biased sweep of [Breimann25]_; kept for backward-compatible results. n_background_data : None or int, optional The number samples (< 'n_samples') in the background dataset used for the `KernelExplainer`` to reduce @@ -544,19 +545,27 @@ def fit(self, **Fuzzy aggregation strategies** - The ``fuzzy_aggregation`` parameter selects how a soft label ``p`` (in [0, 1]) is turned into a SHAP estimate: + The ``fuzzy_aggregation`` parameter selects between two estimators: * ``'interpolate'`` (default): The fuzzy sample is weighted by *exactly* ``p`` by fitting the model twice (fuzzy sample at 0 -> ``S0``, at 1 -> ``S1``) and blending ``p * S1 + (1 - p) * S0`` (the ``exp_value`` is blended the - same way). This is *unbiased* and, with the default ``n_rounds=1``, the fastest fuzzy estimator (exactly two - fits per fuzzy sample). ``n_rounds > 1`` averages re-seeded fits per round (``random_state + round``), a - Monte-Carlo mean over model variance that converges as ``n_rounds`` grows while staying reproducible for a - fixed ``random_state``. Each fuzzy protein is explained independently against the fixed balanced 0/1 core, with - the other fuzzy proteins excluded from its training data. - * ``'threshold'``: Over an ``n_rounds`` x ``n_selection`` grid (default ``n_rounds=5``) the fuzzy sample is - hard-labeled ``1`` when a per-cell threshold ``<= p`` and the per-cell SHAP matrices are averaged. Because the - threshold grid is non-uniform on (0, 1], the effective positive-fraction is a *biased* approximation of ``p``. - This was the default in v1.0; kept for backward-compatible results. + same way). This is *unbiased*. Each fuzzy protein is explained independently against the fixed balanced 0/1 + core, with the other fuzzy proteins excluded from its training data. + * ``'threshold'``: Over an ``n_rounds`` x ``n_selection`` grid the fuzzy sample is hard-labeled ``1`` when a + per-cell threshold ``<= p`` and the per-cell SHAP matrices are averaged — the sweep of [Breimann25]_. Because + the grid is non-uniform on (0, 1], the effective positive-fraction is a *biased* approximation of ``p``. + + **Choosing n_rounds for 'interpolate'** + + Each round re-seeds the model fit (``random_state + round``), so ``n_rounds`` averages a Monte-Carlo mean over + model variance (reproducible for a fixed ``random_state``): + + * ``n_rounds=1`` -- the exact two-fit point estimate; fastest, but a single model draw (run-to-run spread ~20% + across seeds). + * ``n_rounds=5`` (default) -- adds light averaging (spread ~10%). + * ``n_rounds≈15-20`` -- the averaged estimate stabilizes (run-to-run spread and distance to the converged mean + fall below ~5% on the bundled ``DOM_GSEC`` gamma-secretase data, ~1/sqrt(n_rounds) decay). Use this for a + stable mean; with a fixed ``random_state`` any single run is exactly reproducible regardless. **Setting soft labels** @@ -597,10 +606,6 @@ def fit(self, check_match_labels_target_class_labels(label_target_class=label_target_class, labels=labels) is_selected = check_is_selected(is_selected=is_selected, n_feat=n_feat) check_match_X_is_selected(X=X, is_selected=is_selected) - # Resolve the per-estimator natural default: interpolate is exact in a single round, - # while the threshold sweep (and the non-fuzzy Monte-Carlo) need several rounds. - if n_rounds is None: - n_rounds = 1 if (fuzzy_labeling and fuzzy_aggregation == "interpolate") else 5 ut.check_number_range(name="n_rounds", val=n_rounds, min_val=1, just_int=True) ut.check_number_range(name="n_background_data", val=n_background_data, min_val=1, just_int=True, accept_none=True) check_match_n_background_data_X(n_background_data=n_background_data, X=X) diff --git a/docs/source/index/release_notes.rst b/docs/source/index/release_notes.rst index 2a23a264..89fac2d1 100644 --- a/docs/source/index/release_notes.rst +++ b/docs/source/index/release_notes.rst @@ -88,12 +88,12 @@ Added - **ShapModel — unbiased fuzzy estimator, now the default** (``[pro]``): ``fit`` gains ``fuzzy_aggregation``, defaulting to the new ``'interpolate'`` estimator. It weights a soft label by *exactly* ``p`` — fitting at 0 (``S0``) and at 1 (``S1``) and blending - ``p * S1 + (1 - p) * S0`` — instead of the biased threshold sweep used in v1.0 - (still available via ``fuzzy_aggregation='threshold'``). ``n_rounds`` now defaults per - estimator (``None`` → ``1`` for ``interpolate``, ``5`` otherwise): the default fuzzy - estimate is the exact two-fit blend, ~2x faster than the v1.0 default on the same cell, - while ``n_rounds > 1`` averages per-round re-seeded fits into a reproducible Monte-Carlo - mean that converges around ``n_rounds ≈ 15–20``. + ``p * S1 + (1 - p) * S0`` — the unbiased alternative to the biased threshold sweep, which + stays available as a first-class option via ``fuzzy_aggregation='threshold'``. For + ``interpolate``, ``n_rounds`` (default ``5``) is a speed/stability dial: ``1`` is the fast + exact two-fit estimate (~2x faster than the threshold default on the same cell), ``5`` adds + light Monte-Carlo averaging, and the mean converges (run-to-run spread below ~5%) around + ``n_rounds ≈ 15–20``; a fixed ``random_state`` keeps every run reproducible. **Sequence Analysis** diff --git a/examples/explainable_ai/sm_fit.ipynb b/examples/explainable_ai/sm_fit.ipynb index ad07e21e..502c2ba3 100644 --- a/examples/explainable_ai/sm_fit.ipynb +++ b/examples/explainable_ai/sm_fit.ipynb @@ -21,10 +21,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T11:49:27.299223Z", - "iopub.status.busy": "2026-06-25T11:49:27.298923Z", - "iopub.status.idle": "2026-06-25T11:49:28.740060Z", - "shell.execute_reply": "2026-06-25T11:49:28.739722Z" + "iopub.execute_input": "2026-06-25T13:58:01.221390Z", + "iopub.status.busy": "2026-06-25T13:58:01.221138Z", + "iopub.status.idle": "2026-06-25T13:58:02.634160Z", + "shell.execute_reply": "2026-06-25T13:58:02.633962Z" } }, "outputs": [ @@ -40,105 +40,105 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 entrysequencelabeltmd_starttmd_stopjmd_ntmdjmd_centrysequencelabeltmd_starttmd_stopjmd_ntmdjmd_c
1Q14802MQKVTLGLLVFLAGF...PGETPPLITPGSAQS03759NSPFYYDWHSLQVGGLICAGVLCAMGIIIVMSAKCKCKFGQKS1Q14802MQKVTLGLLVFLAGF...PGETPPLITPGSAQS03759NSPFYYDWHSLQVGGLICAGVLCAMGIIIVMSAKCKCKFGQKS
2Q86UE4MAARSWQDELAQQAE...SPKQIKKKKKARRET05072LGLEPKRYPGWVILVGTGALGLLLLFLLGYGWAAACAGARKKR2Q86UE4MAARSWQDELAQQAE...SPKQIKKKKKARRET05072LGLEPKRYPGWVILVGTGALGLLLLFLLGYGWAAACAGARKKR
3Q969W9MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL04163FQSMEITELEFVQIIIIVVVMMVMVVVITCLLSHYKLSARSFI3Q969W9MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL04163FQSMEITELEFVQIIIIVVVMMVMVVVITCLLSHYKLSARSFI
4P05067MLPGLALLLLAAWTA...GYENPTYKFFEQMQN1701723FAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHH4P05067MLPGLALLLLAAWTA...GYENPTYKFFEQMQN1701723FAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHH
5P14925MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS1868890KLSTEPGSGVSVVLITTLLVIPVLVLLAIVMFIRWKKSRAFGD5P14925MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS1868890KLSTEPGSGVSVVLITTLLVIPVLVLLAIVMFIRWKKSRAFGD
6P70180MRSLLLFTFSACVLL...RELREDSIRSHFSVA1477499PCKSSGGLEESAVTGIVVGALLGAGLLMAFYFFRKKYRITIER6P70180MRSLLLFTFSACVLL...RELREDSIRSHFSVA1477499PCKSSGGLEESAVTGIVVGALLGAGLLMAFYFFRKKYRITIER
\n" @@ -188,10 +188,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T11:49:28.741284Z", - "iopub.status.busy": "2026-06-25T11:49:28.741204Z", - "iopub.status.idle": "2026-06-25T11:49:29.098530Z", - "shell.execute_reply": "2026-06-25T11:49:29.098307Z" + "iopub.execute_input": "2026-06-25T13:58:02.635215Z", + "iopub.status.busy": "2026-06-25T13:58:02.635138Z", + "iopub.status.idle": "2026-06-25T13:58:02.991000Z", + "shell.execute_reply": "2026-06-25T13:58:02.990774Z" } }, "outputs": [ @@ -200,16 +200,16 @@ "output_type": "stream", "text": [ "SHAP values explain the feature impact for 3 negative and 3 positive samples\n", - "[[-0.09 -0.1 -0.09 -0.1 -0.08]\n", - " [-0.11 -0.13 -0.08 -0.1 -0.08]\n", - " [-0.13 -0.15 -0.04 -0.1 -0.02]\n", - " [ 0.12 0.13 0.05 0.1 0.04]\n", - " [ 0.11 0.12 0.07 0.1 0.08]\n", - " [ 0.12 0.12 0.07 0.1 0.06]]\n", + "[[-0.11 -0.09 -0.08 -0.09 -0.07]\n", + " [-0.12 -0.12 -0.09 -0.09 -0.08]\n", + " [-0.15 -0.13 -0.03 -0.09 -0.02]\n", + " [ 0.13 0.13 0.05 0.09 0.04]\n", + " [ 0.12 0.12 0.09 0.09 0.08]\n", + " [ 0.12 0.12 0.09 0.09 0.06]]\n", "\n", "The expected value approximates the expected model output (average prediction score).\n", "For a binary classification with balanced datasets, it is around 0.5:\n", - "0.5011666666666669\n" + "0.49833333333333363\n" ] } ], @@ -250,10 +250,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T11:49:29.099686Z", - "iopub.status.busy": "2026-06-25T11:49:29.099623Z", - "iopub.status.idle": "2026-06-25T11:49:29.452810Z", - "shell.execute_reply": "2026-06-25T11:49:29.452597Z" + "iopub.execute_input": "2026-06-25T13:58:02.992140Z", + "iopub.status.busy": "2026-06-25T13:58:02.992079Z", + "iopub.status.idle": "2026-06-25T13:58:03.349375Z", + "shell.execute_reply": "2026-06-25T13:58:03.349147Z" } }, "outputs": [ @@ -262,15 +262,15 @@ "output_type": "stream", "text": [ "Reverse sign of SHAP values by changing reference class from 1 to 0\n", - "[[ 0.1 0.11 0.09 0.09 0.08]\n", - " [ 0.11 0.12 0.09 0.08 0.08]\n", - " [ 0.13 0.14 0.04 0.08 0.02]\n", - " [-0.13 -0.13 -0.06 -0.08 -0.04]\n", - " [-0.12 -0.12 -0.09 -0.08 -0.08]\n", - " [-0.13 -0.13 -0.09 -0.08 -0.07]]\n", + "[[ 0.09 0.11 0.09 0.1 0.07]\n", + " [ 0.11 0.12 0.09 0.1 0.07]\n", + " [ 0.13 0.14 0.04 0.09 0.03]\n", + " [-0.12 -0.13 -0.06 -0.09 -0.04]\n", + " [-0.11 -0.13 -0.08 -0.09 -0.08]\n", + " [-0.11 -0.13 -0.08 -0.09 -0.06]]\n", "\n", "Base value stays around 0.5:\n", - "0.5083333333333335\n" + "0.5036666666666669\n" ] } ], @@ -309,10 +309,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T11:49:29.453905Z", - "iopub.status.busy": "2026-06-25T11:49:29.453845Z", - "iopub.status.idle": "2026-06-25T11:49:30.101114Z", - "shell.execute_reply": "2026-06-25T11:49:30.100821Z" + "iopub.execute_input": "2026-06-25T13:58:03.350474Z", + "iopub.status.busy": "2026-06-25T13:58:03.350413Z", + "iopub.status.idle": "2026-06-25T13:58:03.992200Z", + "shell.execute_reply": "2026-06-25T13:58:03.991916Z" } }, "outputs": [], @@ -342,10 +342,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T11:49:30.102412Z", - "iopub.status.busy": "2026-06-25T11:49:30.102352Z", - "iopub.status.idle": "2026-06-25T11:49:30.748064Z", - "shell.execute_reply": "2026-06-25T11:49:30.747796Z" + "iopub.execute_input": "2026-06-25T13:58:03.993404Z", + "iopub.status.busy": "2026-06-25T13:58:03.993338Z", + "iopub.status.idle": "2026-06-25T13:58:04.639719Z", + "shell.execute_reply": "2026-06-25T13:58:04.639490Z" } }, "outputs": [ @@ -354,10 +354,10 @@ "output_type": "stream", "text": [ "Impact of feature pre-selection\n", - "[[-0.17 -0.19 -0.05 -0.05 0. ]\n", - " [-0.19 -0.2 -0.05 -0.05 0. ]\n", + "[[-0.17 -0.18 -0.05 -0.06 0. ]\n", + " [-0.18 -0.2 -0.06 -0.05 0. ]\n", " [-0.19 -0.2 -0.02 -0.05 0. ]\n", - " [ 0.2 0.2 0.03 0.04 0. ]\n", + " [ 0.19 0.2 0.03 0.05 0. ]\n", " [ 0.19 0.2 0.05 0.05 0. ]\n", " [ 0.19 0.2 0.05 0.05 0. ]]\n" ] @@ -395,10 +395,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T11:49:30.749257Z", - "iopub.status.busy": "2026-06-25T11:49:30.749196Z", - "iopub.status.idle": "2026-06-25T11:49:31.047878Z", - "shell.execute_reply": "2026-06-25T11:49:31.047598Z" + "iopub.execute_input": "2026-06-25T13:58:04.640920Z", + "iopub.status.busy": "2026-06-25T13:58:04.640846Z", + "iopub.status.idle": "2026-06-25T13:58:05.875887Z", + "shell.execute_reply": "2026-06-25T13:58:05.875653Z" } }, "outputs": [ @@ -407,12 +407,12 @@ "output_type": "stream", "text": [ "First sample is labeled as 0.5 between negative (0) and positive (1)\n", - "[[-0.03 -0.05 -0.03 -0.02 0. ]\n", - " [-0.24 -0.24 -0.04 -0.04 0. ]\n", - " [-0.22 -0.22 -0.02 -0.03 0. ]\n", - " [ 0.17 0.18 0.02 0.03 0. ]\n", - " [ 0.17 0.18 0.03 0.03 0. ]\n", - " [ 0.17 0.18 0.03 0.03 0. ]]\n" + "[[-0.04 -0.04 -0.02 -0.02 0. ]\n", + " [-0.23 -0.23 -0.05 -0.05 0. ]\n", + " [-0.22 -0.22 -0.02 -0.04 0. ]\n", + " [ 0.17 0.17 0.02 0.04 0. ]\n", + " [ 0.17 0.17 0.03 0.04 0. ]\n", + " [ 0.17 0.17 0.03 0.04 0. ]]\n" ] } ], @@ -431,7 +431,7 @@ "id": "445a9b21", "metadata": {}, "source": [ - "By default ``fuzzy_aggregation='interpolate'`` weights the fuzzy sample by *exactly* ``p`` (the cell above already used it: two fits blended as ``p*S1 + (1-p)*S0``). The legacy v1.0 ``'threshold'`` sweep stays available for backward-compatible results, and ``n_rounds > 1`` averages re-seeded fits into a Monte-Carlo mean that converges around ``n_rounds ~ 15-20``:" + "By default ``fuzzy_aggregation='interpolate'`` weights the fuzzy sample by *exactly* ``p`` (the cell above used it: two fits blended as ``p*S1 + (1-p)*S0``). The published threshold sweep stays available via ``fuzzy_aggregation='threshold'``. For interpolate, ``n_rounds`` is a speed/stability dial: ``n_rounds=1`` is the fast exact estimate, the default ``5`` adds light Monte-Carlo averaging, and a stable mean is reached around ``n_rounds ~ 15-20``:" ] }, { @@ -440,10 +440,10 @@ "id": "7312979a", "metadata": { "execution": { - "iopub.execute_input": "2026-06-25T11:49:31.049248Z", - "iopub.status.busy": "2026-06-25T11:49:31.049169Z", - "iopub.status.idle": "2026-06-25T11:49:35.259909Z", - "shell.execute_reply": "2026-06-25T11:49:35.259678Z" + "iopub.execute_input": "2026-06-25T13:58:05.877059Z", + "iopub.status.busy": "2026-06-25T13:58:05.876996Z", + "iopub.status.idle": "2026-06-25T13:58:10.104659Z", + "shell.execute_reply": "2026-06-25T13:58:10.104400Z" } }, "outputs": [ @@ -451,7 +451,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Legacy threshold estimate (first, fuzzy sample):\n", + "Threshold-sweep estimate (first, fuzzy sample):\n", "[ 0.03 0.04 -0.03 -0.03 0. ]\n", "\n", "Converged interpolate mean over 15 rounds (first, fuzzy sample):\n", @@ -460,17 +460,17 @@ } ], "source": [ - "# Legacy biased threshold estimator (the v1.0 default), kept for backward-compatible results\n", + "# The published threshold-sweep estimator, available via fuzzy_aggregation=\"threshold\"\n", "sm_threshold = aa.ShapModel(random_state=42)\n", "sm_threshold = sm_threshold.fit(X, labels=labels, is_selected=is_selected,\n", " fuzzy_labeling=True, fuzzy_aggregation=\"threshold\")\n", "\n", - "# Converged interpolate mean: average more re-seeded rounds (default n_rounds=1 is the exact 2-fit blend)\n", + "# Stable interpolate mean: n_rounds=1 is the fast exact blend, ~15-20 the converged mean\n", "sm_converged = aa.ShapModel(random_state=42)\n", "sm_converged = sm_converged.fit(X, labels=labels, is_selected=is_selected,\n", " fuzzy_labeling=True, n_rounds=15)\n", "\n", - "print(\"Legacy threshold estimate (first, fuzzy sample):\")\n", + "print(\"Threshold-sweep estimate (first, fuzzy sample):\")\n", "print(sm_threshold.shap_values[0].round(2))\n", "print(\"\\nConverged interpolate mean over 15 rounds (first, fuzzy sample):\")\n", "print(sm_converged.shap_values[0].round(2))" @@ -490,10 +490,10 @@ "id": "0b3fd3f3", "metadata": { "execution": { - "iopub.execute_input": "2026-06-25T11:49:35.261033Z", - "iopub.status.busy": "2026-06-25T11:49:35.260972Z", - "iopub.status.idle": "2026-06-25T11:49:35.557346Z", - "shell.execute_reply": "2026-06-25T11:49:35.557107Z" + "iopub.execute_input": "2026-06-25T13:58:10.105963Z", + "iopub.status.busy": "2026-06-25T13:58:10.105886Z", + "iopub.status.idle": "2026-06-25T13:58:11.345893Z", + "shell.execute_reply": "2026-06-25T13:58:11.345656Z" } }, "outputs": [ @@ -502,10 +502,10 @@ "output_type": "stream", "text": [ "Sample 'Q14802' is labeled as 0.5 between negative (0) and positive (1)\n", - "[[-0.03 -0.03 -0.03 -0.02 0. ]\n", - " [-0.23 -0.23 -0.05 -0.05 0. ]\n", - " [-0.21 -0.23 -0.02 -0.04 0. ]\n", - " [ 0.17 0.17 0.02 0.03 0. ]\n", + "[[-0.03 -0.04 -0.02 -0.02 0. ]\n", + " [-0.23 -0.23 -0.04 -0.05 0. ]\n", + " [-0.22 -0.22 -0.02 -0.03 0. ]\n", + " [ 0.18 0.17 0.02 0.04 0. ]\n", " [ 0.17 0.17 0.03 0.04 0. ]\n", " [ 0.17 0.17 0.03 0.04 0. ]]\n" ] @@ -543,10 +543,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T11:49:35.558535Z", - "iopub.status.busy": "2026-06-25T11:49:35.558456Z", - "iopub.status.idle": "2026-06-25T11:49:35.560958Z", - "shell.execute_reply": "2026-06-25T11:49:35.560751Z" + "iopub.execute_input": "2026-06-25T13:58:11.346981Z", + "iopub.status.busy": "2026-06-25T13:58:11.346913Z", + "iopub.status.idle": "2026-06-25T13:58:11.349278Z", + "shell.execute_reply": "2026-06-25T13:58:11.349067Z" } }, "outputs": [], diff --git a/tests/unit/shap_model_tests/test_sm_fit.py b/tests/unit/shap_model_tests/test_sm_fit.py index b5ffbd6d..dd72c1e0 100644 --- a/tests/unit/shap_model_tests/test_sm_fit.py +++ b/tests/unit/shap_model_tests/test_sm_fit.py @@ -537,23 +537,17 @@ def test_default_is_interpolate(self): fuzzy_aggregation="interpolate", n_rounds=3).shap_values assert np.array_equal(default, explicit) - # n_rounds=None resolves to the per-estimator natural default: 1 for interpolate, 5 for threshold - def test_n_rounds_none_natural_default(self): - # interpolate default-rounds == explicit n_rounds=1 - auto_i = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( + # n_rounds defaults to a plain 5 for every estimator (no per-estimator magic) + def test_default_n_rounds_is_five(self): + auto = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True).shap_values - explicit_i1 = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( - SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, - fuzzy_aggregation="interpolate", n_rounds=1).shap_values - assert np.array_equal(auto_i, explicit_i1) - # threshold default-rounds == explicit n_rounds=5 - auto_t = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( - SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, - fuzzy_aggregation="threshold").shap_values - explicit_t5 = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( - SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, - fuzzy_aggregation="threshold", n_rounds=5).shap_values - assert np.array_equal(auto_t, explicit_t5) + explicit5 = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( + SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, n_rounds=5).shap_values + assert np.array_equal(auto, explicit5) + # and the default genuinely averages (differs from the single-round fast path) + one_round = aa.ShapModel(**ONE_MODEL, verbose=False, random_state=9).fit( + SMALL_X, labels=SMALL_LABELS_1FUZZY, fuzzy_labeling=True, n_rounds=1).shap_values + assert not np.array_equal(auto, one_round) # fuzzy_aggregation is inert when fuzzy labeling is off (binary path untouched) def test_fuzzy_aggregation_inert_without_fuzzy(self): From e0547477f08309a3f7316592afe72feb41d14cac Mon Sep 17 00:00:00 2001 From: Stephan Breimann Date: Thu, 25 Jun 2026 16:17:59 +0200 Subject: [PATCH 7/8] docs(shap): document per-round seeding (random_state + round) Spells out the seed scheme in the fit Notes, the constructor random_state docstring, and the example notebook: random_state is the initial seed and interpolate re-seeds each round with random_state + round (reproducible for a fixed seed, fresh entropy for None), while the threshold and non-fuzzy paths do not re-seed per round. Adds a notebook cell showing the fixed-seed result is reproducible across runs even with n_rounds>1. Co-Authored-By: Claude Opus 4.8 (1M context) --- aaanalysis/explainable_ai_pro/_shap_model.py | 17 +- examples/explainable_ai/sm_fit.ipynb | 297 +++++++++++-------- 2 files changed, 182 insertions(+), 132 deletions(-) diff --git a/aaanalysis/explainable_ai_pro/_shap_model.py b/aaanalysis/explainable_ai_pro/_shap_model.py index b119a84c..591ea679 100644 --- a/aaanalysis/explainable_ai_pro/_shap_model.py +++ b/aaanalysis/explainable_ai_pro/_shap_model.py @@ -380,7 +380,9 @@ def __init__(self, If ``True``, verbose outputs are enabled. random_state : int, optional The seed used by the random number generator. If a positive integer, results of stochastic processes are - consistent, enabling reproducibility. If ``None``, stochastic processes will be truly random. + consistent, enabling reproducibility. If ``None``, stochastic processes will be truly random. For + ``fuzzy_aggregation='interpolate'`` it is the initial seed and each round re-seeds with + ``random_state + round`` (see :meth:`ShapModel.fit` Notes). Notes ----- @@ -555,10 +557,19 @@ def fit(self, per-cell threshold ``<= p`` and the per-cell SHAP matrices are averaged — the sweep of [Breimann25]_. Because the grid is non-uniform on (0, 1], the effective positive-fraction is a *biased* approximation of ``p``. + **Per-round seeding (interpolate only)** + + The constructor ``random_state`` is the initial seed, and ``'interpolate'`` re-seeds **each round** with + ``random_state + round`` (round 0 -> ``random_state``, round 1 -> ``random_state + 1``, ...). So every round + fits a *different* model and ``n_rounds`` averages a Monte-Carlo mean over model variance, yet a fixed + ``random_state`` gives the identical seed sequence and therefore an exactly reproducible result; + ``random_state=None`` draws fresh entropy each round (truly-random, non-reproducible). The ``'threshold'`` + estimator and the non-fuzzy Monte-Carlo path do **not** re-seed per round — they bake ``random_state`` in once, + so their per-round variation comes from the threshold grid, not from the model seed. + **Choosing n_rounds for 'interpolate'** - Each round re-seeds the model fit (``random_state + round``), so ``n_rounds`` averages a Monte-Carlo mean over - model variance (reproducible for a fixed ``random_state``): + Because each round re-seeds, ``n_rounds`` is a speed/stability dial: * ``n_rounds=1`` -- the exact two-fit point estimate; fastest, but a single model draw (run-to-run spread ~20% across seeds). diff --git a/examples/explainable_ai/sm_fit.ipynb b/examples/explainable_ai/sm_fit.ipynb index 502c2ba3..a89b494a 100644 --- a/examples/explainable_ai/sm_fit.ipynb +++ b/examples/explainable_ai/sm_fit.ipynb @@ -21,10 +21,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T13:58:01.221390Z", - "iopub.status.busy": "2026-06-25T13:58:01.221138Z", - "iopub.status.idle": "2026-06-25T13:58:02.634160Z", - "shell.execute_reply": "2026-06-25T13:58:02.633962Z" + "iopub.execute_input": "2026-06-25T14:17:08.666558Z", + "iopub.status.busy": "2026-06-25T14:17:08.666056Z", + "iopub.status.idle": "2026-06-25T14:17:11.461965Z", + "shell.execute_reply": "2026-06-25T14:17:11.461767Z" } }, "outputs": [ @@ -40,105 +40,105 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 entrysequencelabeltmd_starttmd_stopjmd_ntmdjmd_centrysequencelabeltmd_starttmd_stopjmd_ntmdjmd_c
1Q14802MQKVTLGLLVFLAGF...PGETPPLITPGSAQS03759NSPFYYDWHSLQVGGLICAGVLCAMGIIIVMSAKCKCKFGQKS1Q14802MQKVTLGLLVFLAGF...PGETPPLITPGSAQS03759NSPFYYDWHSLQVGGLICAGVLCAMGIIIVMSAKCKCKFGQKS
2Q86UE4MAARSWQDELAQQAE...SPKQIKKKKKARRET05072LGLEPKRYPGWVILVGTGALGLLLLFLLGYGWAAACAGARKKR2Q86UE4MAARSWQDELAQQAE...SPKQIKKKKKARRET05072LGLEPKRYPGWVILVGTGALGLLLLFLLGYGWAAACAGARKKR
3Q969W9MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL04163FQSMEITELEFVQIIIIVVVMMVMVVVITCLLSHYKLSARSFI3Q969W9MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL04163FQSMEITELEFVQIIIIVVVMMVMVVVITCLLSHYKLSARSFI
4P05067MLPGLALLLLAAWTA...GYENPTYKFFEQMQN1701723FAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHH4P05067MLPGLALLLLAAWTA...GYENPTYKFFEQMQN1701723FAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHH
5P14925MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS1868890KLSTEPGSGVSVVLITTLLVIPVLVLLAIVMFIRWKKSRAFGD5P14925MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS1868890KLSTEPGSGVSVVLITTLLVIPVLVLLAIVMFIRWKKSRAFGD
6P70180MRSLLLFTFSACVLL...RELREDSIRSHFSVA1477499PCKSSGGLEESAVTGIVVGALLGAGLLMAFYFFRKKYRITIER6P70180MRSLLLFTFSACVLL...RELREDSIRSHFSVA1477499PCKSSGGLEESAVTGIVVGALLGAGLLMAFYFFRKKYRITIER
\n" @@ -188,10 +188,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T13:58:02.635215Z", - "iopub.status.busy": "2026-06-25T13:58:02.635138Z", - "iopub.status.idle": "2026-06-25T13:58:02.991000Z", - "shell.execute_reply": "2026-06-25T13:58:02.990774Z" + "iopub.execute_input": "2026-06-25T14:17:11.463121Z", + "iopub.status.busy": "2026-06-25T14:17:11.463035Z", + "iopub.status.idle": "2026-06-25T14:17:11.821627Z", + "shell.execute_reply": "2026-06-25T14:17:11.821385Z" } }, "outputs": [ @@ -200,16 +200,16 @@ "output_type": "stream", "text": [ "SHAP values explain the feature impact for 3 negative and 3 positive samples\n", - "[[-0.11 -0.09 -0.08 -0.09 -0.07]\n", - " [-0.12 -0.12 -0.09 -0.09 -0.08]\n", - " [-0.15 -0.13 -0.03 -0.09 -0.02]\n", - " [ 0.13 0.13 0.05 0.09 0.04]\n", - " [ 0.12 0.12 0.09 0.09 0.08]\n", - " [ 0.12 0.12 0.09 0.09 0.06]]\n", + "[[-0.11 -0.1 -0.07 -0.1 -0.06]\n", + " [-0.13 -0.12 -0.08 -0.1 -0.07]\n", + " [-0.15 -0.13 -0.04 -0.1 -0.02]\n", + " [ 0.13 0.13 0.06 0.09 0.04]\n", + " [ 0.12 0.12 0.08 0.1 0.07]\n", + " [ 0.12 0.12 0.08 0.1 0.06]]\n", "\n", "The expected value approximates the expected model output (average prediction score).\n", "For a binary classification with balanced datasets, it is around 0.5:\n", - "0.49833333333333363\n" + "0.4988333333333336\n" ] } ], @@ -250,10 +250,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T13:58:02.992140Z", - "iopub.status.busy": "2026-06-25T13:58:02.992079Z", - "iopub.status.idle": "2026-06-25T13:58:03.349375Z", - "shell.execute_reply": "2026-06-25T13:58:03.349147Z" + "iopub.execute_input": "2026-06-25T14:17:11.822729Z", + "iopub.status.busy": "2026-06-25T14:17:11.822658Z", + "iopub.status.idle": "2026-06-25T14:17:12.176338Z", + "shell.execute_reply": "2026-06-25T14:17:12.176113Z" } }, "outputs": [ @@ -262,15 +262,15 @@ "output_type": "stream", "text": [ "Reverse sign of SHAP values by changing reference class from 1 to 0\n", - "[[ 0.09 0.11 0.09 0.1 0.07]\n", - " [ 0.11 0.12 0.09 0.1 0.07]\n", - " [ 0.13 0.14 0.04 0.09 0.03]\n", - " [-0.12 -0.13 -0.06 -0.09 -0.04]\n", - " [-0.11 -0.13 -0.08 -0.09 -0.08]\n", - " [-0.11 -0.13 -0.08 -0.09 -0.06]]\n", + "[[ 0.1 0.11 0.09 0.09 0.07]\n", + " [ 0.12 0.12 0.09 0.09 0.08]\n", + " [ 0.13 0.14 0.03 0.09 0.02]\n", + " [-0.13 -0.12 -0.06 -0.09 -0.04]\n", + " [-0.12 -0.12 -0.08 -0.09 -0.07]\n", + " [-0.12 -0.12 -0.08 -0.09 -0.06]]\n", "\n", "Base value stays around 0.5:\n", - "0.5036666666666669\n" + "0.48800000000000027\n" ] } ], @@ -309,10 +309,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T13:58:03.350474Z", - "iopub.status.busy": "2026-06-25T13:58:03.350413Z", - "iopub.status.idle": "2026-06-25T13:58:03.992200Z", - "shell.execute_reply": "2026-06-25T13:58:03.991916Z" + "iopub.execute_input": "2026-06-25T14:17:12.177437Z", + "iopub.status.busy": "2026-06-25T14:17:12.177373Z", + "iopub.status.idle": "2026-06-25T14:17:12.856282Z", + "shell.execute_reply": "2026-06-25T14:17:12.856024Z" } }, "outputs": [], @@ -342,10 +342,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T13:58:03.993404Z", - "iopub.status.busy": "2026-06-25T13:58:03.993338Z", - "iopub.status.idle": "2026-06-25T13:58:04.639719Z", - "shell.execute_reply": "2026-06-25T13:58:04.639490Z" + "iopub.execute_input": "2026-06-25T14:17:12.857698Z", + "iopub.status.busy": "2026-06-25T14:17:12.857598Z", + "iopub.status.idle": "2026-06-25T14:17:13.562372Z", + "shell.execute_reply": "2026-06-25T14:17:13.562118Z" } }, "outputs": [ @@ -355,11 +355,11 @@ "text": [ "Impact of feature pre-selection\n", "[[-0.17 -0.18 -0.05 -0.06 0. ]\n", - " [-0.18 -0.2 -0.06 -0.05 0. ]\n", - " [-0.19 -0.2 -0.02 -0.05 0. ]\n", - " [ 0.19 0.2 0.03 0.05 0. ]\n", - " [ 0.19 0.2 0.05 0.05 0. ]\n", - " [ 0.19 0.2 0.05 0.05 0. ]]\n" + " [-0.19 -0.2 -0.04 -0.06 0. ]\n", + " [-0.19 -0.21 -0.01 -0.05 0. ]\n", + " [ 0.19 0.21 0.03 0.06 0. ]\n", + " [ 0.19 0.21 0.04 0.06 0. ]\n", + " [ 0.19 0.21 0.04 0.06 0. ]]\n" ] } ], @@ -395,10 +395,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T13:58:04.640920Z", - "iopub.status.busy": "2026-06-25T13:58:04.640846Z", - "iopub.status.idle": "2026-06-25T13:58:05.875887Z", - "shell.execute_reply": "2026-06-25T13:58:05.875653Z" + "iopub.execute_input": "2026-06-25T14:17:13.563510Z", + "iopub.status.busy": "2026-06-25T14:17:13.563442Z", + "iopub.status.idle": "2026-06-25T14:17:14.815617Z", + "shell.execute_reply": "2026-06-25T14:17:14.815377Z" } }, "outputs": [ @@ -408,7 +408,7 @@ "text": [ "First sample is labeled as 0.5 between negative (0) and positive (1)\n", "[[-0.04 -0.04 -0.02 -0.02 0. ]\n", - " [-0.23 -0.23 -0.05 -0.05 0. ]\n", + " [-0.23 -0.23 -0.04 -0.05 0. ]\n", " [-0.22 -0.22 -0.02 -0.04 0. ]\n", " [ 0.17 0.17 0.02 0.04 0. ]\n", " [ 0.17 0.17 0.03 0.04 0. ]\n", @@ -440,10 +440,10 @@ "id": "7312979a", "metadata": { "execution": { - "iopub.execute_input": "2026-06-25T13:58:05.877059Z", - "iopub.status.busy": "2026-06-25T13:58:05.876996Z", - "iopub.status.idle": "2026-06-25T13:58:10.104659Z", - "shell.execute_reply": "2026-06-25T13:58:10.104400Z" + "iopub.execute_input": "2026-06-25T14:17:14.816762Z", + "iopub.status.busy": "2026-06-25T14:17:14.816702Z", + "iopub.status.idle": "2026-06-25T14:17:19.118289Z", + "shell.execute_reply": "2026-06-25T14:17:19.118059Z" } }, "outputs": [ @@ -476,6 +476,45 @@ "print(sm_converged.shap_values[0].round(2))" ] }, + { + "cell_type": "markdown", + "id": "94a8f9c9", + "metadata": {}, + "source": [ + "The constructor ``random_state`` seeds the estimator. For ``'interpolate'`` it is the *initial* seed and each round re-seeds with ``random_state + round``, so ``n_rounds>1`` averages genuinely different model fits yet stays exactly reproducible for a fixed seed (``random_state=None`` instead draws fresh entropy each round). The ``'threshold'`` and non-fuzzy paths do not re-seed per round:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8556e656", + "metadata": { + "execution": { + "iopub.execute_input": "2026-06-25T14:17:19.119354Z", + "iopub.status.busy": "2026-06-25T14:17:19.119292Z", + "iopub.status.idle": "2026-06-25T14:17:21.600092Z", + "shell.execute_reply": "2026-06-25T14:17:21.599856Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Reproducible for a fixed random_state (per-round seed = random_state + round): True\n" + ] + } + ], + "source": [ + "# Same random_state -> identical result (reproducible), even with averaging over rounds\n", + "a = aa.ShapModel(random_state=42).fit(X, labels=labels, is_selected=is_selected,\n", + " fuzzy_labeling=True, n_rounds=5).shap_values\n", + "b = aa.ShapModel(random_state=42).fit(X, labels=labels, is_selected=is_selected,\n", + " fuzzy_labeling=True, n_rounds=5).shap_values\n", + "print(\"Reproducible for a fixed random_state (per-round seed = random_state + round):\", \n", + " bool((a == b).all()))" + ] + }, { "cell_type": "markdown", "id": "9d3e32ab", @@ -486,14 +525,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "0b3fd3f3", "metadata": { "execution": { - "iopub.execute_input": "2026-06-25T13:58:10.105963Z", - "iopub.status.busy": "2026-06-25T13:58:10.105886Z", - "iopub.status.idle": "2026-06-25T13:58:11.345893Z", - "shell.execute_reply": "2026-06-25T13:58:11.345656Z" + "iopub.execute_input": "2026-06-25T14:17:21.601218Z", + "iopub.status.busy": "2026-06-25T14:17:21.601154Z", + "iopub.status.idle": "2026-06-25T14:17:22.832303Z", + "shell.execute_reply": "2026-06-25T14:17:22.832045Z" } }, "outputs": [ @@ -502,10 +541,10 @@ "output_type": "stream", "text": [ "Sample 'Q14802' is labeled as 0.5 between negative (0) and positive (1)\n", - "[[-0.03 -0.04 -0.02 -0.02 0. ]\n", + "[[-0.04 -0.04 -0.02 -0.03 0. ]\n", " [-0.23 -0.23 -0.04 -0.05 0. ]\n", - " [-0.22 -0.22 -0.02 -0.03 0. ]\n", - " [ 0.18 0.17 0.02 0.04 0. ]\n", + " [-0.22 -0.22 -0.02 -0.04 0. ]\n", + " [ 0.17 0.17 0.02 0.04 0. ]\n", " [ 0.17 0.17 0.03 0.04 0. ]\n", " [ 0.17 0.17 0.03 0.04 0. ]]\n" ] @@ -534,7 +573,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "45c1efa0c82d7255", "metadata": { "ExecuteTime": { @@ -543,10 +582,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T13:58:11.346981Z", - "iopub.status.busy": "2026-06-25T13:58:11.346913Z", - "iopub.status.idle": "2026-06-25T13:58:11.349278Z", - "shell.execute_reply": "2026-06-25T13:58:11.349067Z" + "iopub.execute_input": "2026-06-25T14:17:22.833400Z", + "iopub.status.busy": "2026-06-25T14:17:22.833327Z", + "iopub.status.idle": "2026-06-25T14:17:22.837319Z", + "shell.execute_reply": "2026-06-25T14:17:22.837123Z" } }, "outputs": [], From c1d86843cebd77d55e9040117aafa12d1ef5365e Mon Sep 17 00:00:00 2001 From: Stephan Breimann Date: Thu, 25 Jun 2026 16:20:24 +0200 Subject: [PATCH 8/8] docs(shap): use bare `sm` instance in fuzzy example cell The class-abbreviation registry requires the canonical bare abbreviation; reassign `sm` per estimator instead of holding sm_threshold/sm_converged concurrently. Co-Authored-By: Claude Opus 4.8 (1M context) --- examples/explainable_ai/sm_fit.ipynb | 293 ++++++++++++++------------- 1 file changed, 149 insertions(+), 144 deletions(-) diff --git a/examples/explainable_ai/sm_fit.ipynb b/examples/explainable_ai/sm_fit.ipynb index a89b494a..9ead2c67 100644 --- a/examples/explainable_ai/sm_fit.ipynb +++ b/examples/explainable_ai/sm_fit.ipynb @@ -21,10 +21,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T14:17:08.666558Z", - "iopub.status.busy": "2026-06-25T14:17:08.666056Z", - "iopub.status.idle": "2026-06-25T14:17:11.461965Z", - "shell.execute_reply": "2026-06-25T14:17:11.461767Z" + "iopub.execute_input": "2026-06-25T14:19:46.705083Z", + "iopub.status.busy": "2026-06-25T14:19:46.704759Z", + "iopub.status.idle": "2026-06-25T14:19:48.092499Z", + "shell.execute_reply": "2026-06-25T14:19:48.092260Z" } }, "outputs": [ @@ -40,105 +40,105 @@ "data": { "text/html": [ "\n", - "\n", + "
\n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", " \n", " \n", "
 entrysequencelabeltmd_starttmd_stopjmd_ntmdjmd_centrysequencelabeltmd_starttmd_stopjmd_ntmdjmd_c
1Q14802MQKVTLGLLVFLAGF...PGETPPLITPGSAQS03759NSPFYYDWHSLQVGGLICAGVLCAMGIIIVMSAKCKCKFGQKS1Q14802MQKVTLGLLVFLAGF...PGETPPLITPGSAQS03759NSPFYYDWHSLQVGGLICAGVLCAMGIIIVMSAKCKCKFGQKS
2Q86UE4MAARSWQDELAQQAE...SPKQIKKKKKARRET05072LGLEPKRYPGWVILVGTGALGLLLLFLLGYGWAAACAGARKKR2Q86UE4MAARSWQDELAQQAE...SPKQIKKKKKARRET05072LGLEPKRYPGWVILVGTGALGLLLLFLLGYGWAAACAGARKKR
3Q969W9MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL04163FQSMEITELEFVQIIIIVVVMMVMVVVITCLLSHYKLSARSFI3Q969W9MHRLMGVNSTAAAAA...AIWSKEKDKQKGHPL04163FQSMEITELEFVQIIIIVVVMMVMVVVITCLLSHYKLSARSFI
4P05067MLPGLALLLLAAWTA...GYENPTYKFFEQMQN1701723FAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHH4P05067MLPGLALLLLAAWTA...GYENPTYKFFEQMQN1701723FAEDVGSNKGAIIGLMVGGVVIATVIVITLVMLKKKQYTSIHH
5P14925MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS1868890KLSTEPGSGVSVVLITTLLVIPVLVLLAIVMFIRWKKSRAFGD5P14925MAGRARSGLLLLLLG...EEEYSAPLPKPAPSS1868890KLSTEPGSGVSVVLITTLLVIPVLVLLAIVMFIRWKKSRAFGD
6P70180MRSLLLFTFSACVLL...RELREDSIRSHFSVA1477499PCKSSGGLEESAVTGIVVGALLGAGLLMAFYFFRKKYRITIER6P70180MRSLLLFTFSACVLL...RELREDSIRSHFSVA1477499PCKSSGGLEESAVTGIVVGALLGAGLLMAFYFFRKKYRITIER
\n" @@ -188,10 +188,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T14:17:11.463121Z", - "iopub.status.busy": "2026-06-25T14:17:11.463035Z", - "iopub.status.idle": "2026-06-25T14:17:11.821627Z", - "shell.execute_reply": "2026-06-25T14:17:11.821385Z" + "iopub.execute_input": "2026-06-25T14:19:48.093620Z", + "iopub.status.busy": "2026-06-25T14:19:48.093555Z", + "iopub.status.idle": "2026-06-25T14:19:48.447226Z", + "shell.execute_reply": "2026-06-25T14:19:48.446995Z" } }, "outputs": [ @@ -200,16 +200,16 @@ "output_type": "stream", "text": [ "SHAP values explain the feature impact for 3 negative and 3 positive samples\n", - "[[-0.11 -0.1 -0.07 -0.1 -0.06]\n", - " [-0.13 -0.12 -0.08 -0.1 -0.07]\n", - " [-0.15 -0.13 -0.04 -0.1 -0.02]\n", - " [ 0.13 0.13 0.06 0.09 0.04]\n", - " [ 0.12 0.12 0.08 0.1 0.07]\n", - " [ 0.12 0.12 0.08 0.1 0.06]]\n", + "[[-0.1 -0.11 -0.1 -0.09 -0.06]\n", + " [-0.12 -0.12 -0.09 -0.09 -0.07]\n", + " [-0.14 -0.14 -0.04 -0.09 -0.01]\n", + " [ 0.12 0.13 0.06 0.1 0.03]\n", + " [ 0.12 0.12 0.09 0.1 0.07]\n", + " [ 0.12 0.13 0.09 0.1 0.06]]\n", "\n", "The expected value approximates the expected model output (average prediction score).\n", "For a binary classification with balanced datasets, it is around 0.5:\n", - "0.4988333333333336\n" + "0.4981666666666669\n" ] } ], @@ -250,10 +250,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T14:17:11.822729Z", - "iopub.status.busy": "2026-06-25T14:17:11.822658Z", - "iopub.status.idle": "2026-06-25T14:17:12.176338Z", - "shell.execute_reply": "2026-06-25T14:17:12.176113Z" + "iopub.execute_input": "2026-06-25T14:19:48.448312Z", + "iopub.status.busy": "2026-06-25T14:19:48.448242Z", + "iopub.status.idle": "2026-06-25T14:19:48.801609Z", + "shell.execute_reply": "2026-06-25T14:19:48.801379Z" } }, "outputs": [ @@ -262,15 +262,15 @@ "output_type": "stream", "text": [ "Reverse sign of SHAP values by changing reference class from 1 to 0\n", - "[[ 0.1 0.11 0.09 0.09 0.07]\n", - " [ 0.12 0.12 0.09 0.09 0.08]\n", - " [ 0.13 0.14 0.03 0.09 0.02]\n", - " [-0.13 -0.12 -0.06 -0.09 -0.04]\n", - " [-0.12 -0.12 -0.08 -0.09 -0.07]\n", - " [-0.12 -0.12 -0.08 -0.09 -0.06]]\n", + "[[ 0.1 0.09 0.09 0.11 0.07]\n", + " [ 0.12 0.11 0.09 0.1 0.08]\n", + " [ 0.14 0.14 0.04 0.1 0.02]\n", + " [-0.13 -0.12 -0.05 -0.09 -0.04]\n", + " [-0.12 -0.11 -0.08 -0.09 -0.07]\n", + " [-0.12 -0.12 -0.08 -0.09 -0.07]]\n", "\n", "Base value stays around 0.5:\n", - "0.48800000000000027\n" + "0.49633333333333357\n" ] } ], @@ -309,10 +309,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T14:17:12.177437Z", - "iopub.status.busy": "2026-06-25T14:17:12.177373Z", - "iopub.status.idle": "2026-06-25T14:17:12.856282Z", - "shell.execute_reply": "2026-06-25T14:17:12.856024Z" + "iopub.execute_input": "2026-06-25T14:19:48.802738Z", + "iopub.status.busy": "2026-06-25T14:19:48.802667Z", + "iopub.status.idle": "2026-06-25T14:19:49.441570Z", + "shell.execute_reply": "2026-06-25T14:19:49.441305Z" } }, "outputs": [], @@ -342,10 +342,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T14:17:12.857698Z", - "iopub.status.busy": "2026-06-25T14:17:12.857598Z", - "iopub.status.idle": "2026-06-25T14:17:13.562372Z", - "shell.execute_reply": "2026-06-25T14:17:13.562118Z" + "iopub.execute_input": "2026-06-25T14:19:49.442802Z", + "iopub.status.busy": "2026-06-25T14:19:49.442739Z", + "iopub.status.idle": "2026-06-25T14:19:50.085482Z", + "shell.execute_reply": "2026-06-25T14:19:50.085260Z" } }, "outputs": [ @@ -354,12 +354,12 @@ "output_type": "stream", "text": [ "Impact of feature pre-selection\n", - "[[-0.17 -0.18 -0.05 -0.06 0. ]\n", - " [-0.19 -0.2 -0.04 -0.06 0. ]\n", - " [-0.19 -0.21 -0.01 -0.05 0. ]\n", - " [ 0.19 0.21 0.03 0.06 0. ]\n", - " [ 0.19 0.21 0.04 0.06 0. ]\n", - " [ 0.19 0.21 0.04 0.06 0. ]]\n" + "[[-0.18 -0.18 -0.05 -0.06 0. ]\n", + " [-0.2 -0.19 -0.05 -0.06 0. ]\n", + " [-0.2 -0.2 -0.02 -0.05 0. ]\n", + " [ 0.2 0.19 0.03 0.05 0. ]\n", + " [ 0.2 0.19 0.04 0.06 0. ]\n", + " [ 0.2 0.19 0.04 0.06 0. ]]\n" ] } ], @@ -395,10 +395,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T14:17:13.563510Z", - "iopub.status.busy": "2026-06-25T14:17:13.563442Z", - "iopub.status.idle": "2026-06-25T14:17:14.815617Z", - "shell.execute_reply": "2026-06-25T14:17:14.815377Z" + "iopub.execute_input": "2026-06-25T14:19:50.086683Z", + "iopub.status.busy": "2026-06-25T14:19:50.086627Z", + "iopub.status.idle": "2026-06-25T14:19:51.315443Z", + "shell.execute_reply": "2026-06-25T14:19:51.315230Z" } }, "outputs": [ @@ -407,10 +407,10 @@ "output_type": "stream", "text": [ "First sample is labeled as 0.5 between negative (0) and positive (1)\n", - "[[-0.04 -0.04 -0.02 -0.02 0. ]\n", + "[[-0.04 -0.04 -0.02 -0.03 0. ]\n", " [-0.23 -0.23 -0.04 -0.05 0. ]\n", " [-0.22 -0.22 -0.02 -0.04 0. ]\n", - " [ 0.17 0.17 0.02 0.04 0. ]\n", + " [ 0.17 0.18 0.02 0.04 0. ]\n", " [ 0.17 0.17 0.03 0.04 0. ]\n", " [ 0.17 0.17 0.03 0.04 0. ]]\n" ] @@ -440,10 +440,10 @@ "id": "7312979a", "metadata": { "execution": { - "iopub.execute_input": "2026-06-25T14:17:14.816762Z", - "iopub.status.busy": "2026-06-25T14:17:14.816702Z", - "iopub.status.idle": "2026-06-25T14:17:19.118289Z", - "shell.execute_reply": "2026-06-25T14:17:19.118059Z" + "iopub.execute_input": "2026-06-25T14:19:51.316608Z", + "iopub.status.busy": "2026-06-25T14:19:51.316545Z", + "iopub.status.idle": "2026-06-25T14:19:55.550029Z", + "shell.execute_reply": "2026-06-25T14:19:55.549803Z" } }, "outputs": [ @@ -452,7 +452,13 @@ "output_type": "stream", "text": [ "Threshold-sweep estimate (first, fuzzy sample):\n", - "[ 0.03 0.04 -0.03 -0.03 0. ]\n", + "[ 0.03 0.04 -0.03 -0.03 0. ]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\n", "Converged interpolate mean over 15 rounds (first, fuzzy sample):\n", "[-0.03 -0.04 -0.02 -0.02 0. ]\n" @@ -461,19 +467,18 @@ ], "source": [ "# The published threshold-sweep estimator, available via fuzzy_aggregation=\"threshold\"\n", - "sm_threshold = aa.ShapModel(random_state=42)\n", - "sm_threshold = sm_threshold.fit(X, labels=labels, is_selected=is_selected,\n", - " fuzzy_labeling=True, fuzzy_aggregation=\"threshold\")\n", + "sm = aa.ShapModel(random_state=42)\n", + "sm = sm.fit(X, labels=labels, is_selected=is_selected,\n", + " fuzzy_labeling=True, fuzzy_aggregation=\"threshold\")\n", + "print(\"Threshold-sweep estimate (first, fuzzy sample):\")\n", + "print(sm.shap_values[0].round(2))\n", "\n", "# Stable interpolate mean: n_rounds=1 is the fast exact blend, ~15-20 the converged mean\n", - "sm_converged = aa.ShapModel(random_state=42)\n", - "sm_converged = sm_converged.fit(X, labels=labels, is_selected=is_selected,\n", - " fuzzy_labeling=True, n_rounds=15)\n", - "\n", - "print(\"Threshold-sweep estimate (first, fuzzy sample):\")\n", - "print(sm_threshold.shap_values[0].round(2))\n", + "sm = aa.ShapModel(random_state=42)\n", + "sm = sm.fit(X, labels=labels, is_selected=is_selected,\n", + " fuzzy_labeling=True, n_rounds=15)\n", "print(\"\\nConverged interpolate mean over 15 rounds (first, fuzzy sample):\")\n", - "print(sm_converged.shap_values[0].round(2))" + "print(sm.shap_values[0].round(2))" ] }, { @@ -490,10 +495,10 @@ "id": "8556e656", "metadata": { "execution": { - "iopub.execute_input": "2026-06-25T14:17:19.119354Z", - "iopub.status.busy": "2026-06-25T14:17:19.119292Z", - "iopub.status.idle": "2026-06-25T14:17:21.600092Z", - "shell.execute_reply": "2026-06-25T14:17:21.599856Z" + "iopub.execute_input": "2026-06-25T14:19:55.551133Z", + "iopub.status.busy": "2026-06-25T14:19:55.551072Z", + "iopub.status.idle": "2026-06-25T14:19:58.018771Z", + "shell.execute_reply": "2026-06-25T14:19:58.018569Z" } }, "outputs": [ @@ -529,10 +534,10 @@ "id": "0b3fd3f3", "metadata": { "execution": { - "iopub.execute_input": "2026-06-25T14:17:21.601218Z", - "iopub.status.busy": "2026-06-25T14:17:21.601154Z", - "iopub.status.idle": "2026-06-25T14:17:22.832303Z", - "shell.execute_reply": "2026-06-25T14:17:22.832045Z" + "iopub.execute_input": "2026-06-25T14:19:58.019876Z", + "iopub.status.busy": "2026-06-25T14:19:58.019812Z", + "iopub.status.idle": "2026-06-25T14:19:59.257226Z", + "shell.execute_reply": "2026-06-25T14:19:59.257008Z" } }, "outputs": [ @@ -541,10 +546,10 @@ "output_type": "stream", "text": [ "Sample 'Q14802' is labeled as 0.5 between negative (0) and positive (1)\n", - "[[-0.04 -0.04 -0.02 -0.03 0. ]\n", + "[[-0.04 -0.03 -0.02 -0.02 0. ]\n", " [-0.23 -0.23 -0.04 -0.05 0. ]\n", - " [-0.22 -0.22 -0.02 -0.04 0. ]\n", - " [ 0.17 0.17 0.02 0.04 0. ]\n", + " [-0.23 -0.22 -0.02 -0.04 0. ]\n", + " [ 0.17 0.17 0.02 0.03 0. ]\n", " [ 0.17 0.17 0.03 0.04 0. ]\n", " [ 0.17 0.17 0.03 0.04 0. ]]\n" ] @@ -582,10 +587,10 @@ }, "collapsed": false, "execution": { - "iopub.execute_input": "2026-06-25T14:17:22.833400Z", - "iopub.status.busy": "2026-06-25T14:17:22.833327Z", - "iopub.status.idle": "2026-06-25T14:17:22.837319Z", - "shell.execute_reply": "2026-06-25T14:17:22.837123Z" + "iopub.execute_input": "2026-06-25T14:19:59.258321Z", + "iopub.status.busy": "2026-06-25T14:19:59.258258Z", + "iopub.status.idle": "2026-06-25T14:19:59.260533Z", + "shell.execute_reply": "2026-06-25T14:19:59.260367Z" } }, "outputs": [],