Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CONTEXT.md
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,14 @@ _Avoid_: shap_mode, use_shap, sample_plot.
The strategy `ShapModel.fit` selects to turn a soft label `p` ∈ (0, 1) into a SHAP estimate when **fuzzy labeling** is active. `"interpolate"` (default, new in v1.1) fits the model twice (fuzzy sample at 0 → `S0`, at 1 → `S1`) and blends `p·S1 + (1−p)·S0` — the **unbiased** exact-`p` estimate. `"threshold"` (the `Breimann25` sweep) hard-labels the fuzzy sample `1` across a non-uniform `n_rounds × n_selection` grid and averages — a **biased** approximation whose effective positive-fraction is the grid's `frac1`, not `p`; kept as a first-class option. Each fuzzy protein is explained independently against the fixed balanced 0/1 **core**, with the other fuzzy proteins excluded from that run's training data. `n_rounds` (default `5`) is interpolate's speed/stability dial: `1` = fast exact two-fit estimate, `5` = light averaging, `≈15–20` = converged Monte-Carlo mean (run-to-run spread <5% on `DOM_GSEC`).
_Avoid_: fuzzy mode, blend mode, soft-label aggregation.

**CPPStructurePlot**:
Public **pro** plotting class in `aaanalysis/feature_engineering_pro/` (abbr `csp`) that paints per-residue CPP / CPP-SHAP **feature impact** onto a 3D protein structure. Its single method `map_structure(df_feat, pdb=…|uniprot=…)` maps each feature to the residues it spans (`get_positions_`, shifted to absolute residue numbers by `start`) and aggregates `col_imp` per residue with the **same normalized-sum** `CPPPlot.profile` uses — never a re-implemented per-position loop. It **reuses** the shared CPP position backend and the `StructurePreprocessor` structure parser (no duplication; a thin chain-by-id Cα/pLDDT extractor is the only new structure code). Modes: `"impact"` (white→`COLOR_SHAP_POS`/`COLOR_SHAP_NEG` ramp with a `sign·sqrt` perceptual transform) and `"plddt"` (AlphaFold confidence palette); focus `"whole"`/`"fade"`/`"zoom"`. Returns a [[StructureView]]. The structure-side companion to `CPPPlot` for the **CPP-SHAP analysis** level.
_Avoid_: structure_plot, plot_structure (the verb-noun method is `map_structure`), CPPStructure (it is a plot class, suffix `Plot`).

**StructureView**:
The thin return wrapper of [[CPPStructurePlot]]`.map_structure`, exposing a **uniform** `show()` / `write_html(path)` / `savefig(path)` / `_repr_html_` surface over its two render backends (interactive `py3Dmol` and static matplotlib `mplot3d`) whose native objects (`py3Dmol.view` vs `Figure`) are otherwise incompatible. A **pure delegator** — no rendering logic, no state beyond the backend object and the mapped `dict_impact` / `max_abs`. The package's first non-`Axes` plotting return type, a deliberate, documented exception to the "return fig/ax" rule (`savefig` is matplotlib-only; `write_html` is the py3Dmol shareable-interactive output).
_Avoid_: view wrapper, plot handle (it is specifically the structure-render delegator), figure (it is not a matplotlib Figure).

### Scale-set vocabulary

**explainable scale set** (`top_explain_n`):
Expand Down
11 changes: 10 additions & 1 deletion aaanalysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
# the ImportError's ``.name`` (reliable for ModuleNotFoundError on the Python 3.11+ floor), never on a
# substring of the message. See .claude/rules/pro-core-boundary.md.
_EXTRA_MODULES = {
"pro": {"shap", "Bio", "biopython", "upsetplot", "UpSetPlot", "requests", "afragmenter"},
"pro": {"shap", "Bio", "biopython", "upsetplot", "UpSetPlot", "requests", "afragmenter", "py3Dmol"},
"embed": {"torch", "transformers", "sentencepiece", "huggingface_hub"},
"dev": {"IPython"},
}
Expand Down Expand Up @@ -171,6 +171,15 @@ def missing_feature_stub(feature_name, error, mode="pro"):
"AnnotationPreprocessor", e, mode="pro")


try:
from .feature_engineering_pro import CPPStructurePlot
__all__.append("CPPStructurePlot")
except ImportError as e:
CPPStructurePlot = None
globals()["CPPStructurePlot"] = missing_feature_stub(
"CPPStructurePlot", e, mode="pro")


try:
from .show_html import display_df
__all__.append("display_df")
Expand Down
18 changes: 18 additions & 0 deletions aaanalysis/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,24 @@ def _folder_path(super_folder, folder_name):
"SAMPLES_REL_NEG": COLOR_REL_NEG
}

# pLDDT confidence palette (AlphaFold-DB), ordered low -> high confidence. Read
# high -> low it is the familiar blue -> cyan -> yellow -> orange ramp used to
# paint per-residue AlphaFold model confidence onto a structure.
COLOR_PLDDT_VERY_LOW = '#FF7D45' # orange, pLDDT < 50
COLOR_PLDDT_LOW = '#FFDB13' # yellow, 50 <= pLDDT < 70
COLOR_PLDDT_CONFIDENT = '#65CBF3' # cyan, 70 <= pLDDT < 90
COLOR_PLDDT_VERY_HIGH = '#0053D6' # blue, pLDDT >= 90
COLOR_STRUCT_MISSING = '#BFBFBF' # gray for residues without a mapped value

# Continuous low -> high ramp consumed by the pLDDT structure colouring.
LIST_COLOR_PLDDT = [COLOR_PLDDT_VERY_LOW, COLOR_PLDDT_LOW,
COLOR_PLDDT_CONFIDENT, COLOR_PLDDT_VERY_HIGH]

DICT_COLOR_PLDDT = {"very_low": COLOR_PLDDT_VERY_LOW,
"low": COLOR_PLDDT_LOW,
"confident": COLOR_PLDDT_CONFIDENT,
"very_high": COLOR_PLDDT_VERY_HIGH}

DICT_COLOR_CAT = {"ASA/Volume": "tab:blue",
"Composition": "tab:orange",
"Conformation": "tab:green",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def _align_atom_values_to_target(target_seq: str,


# II Main Functions
# NOTE: load_structure / _collect_chain_residues / _resolve_best_chain are reused by
# feature_engineering_pro.CPPStructurePlot (no duplication) — they are load-bearing
# beyond StructurePreprocessor; keep their signatures and behaviour stable.
def load_structure(pdb_path):
"""Parse a PDB or mmCIF file and return a Bio.PDB Structure (quiet mode).

Expand Down
19 changes: 19 additions & 0 deletions aaanalysis/feature_engineering_pro/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
Pro feature-engineering plots: CPP feature impact on 3D structure (``pro`` extra).

Public objects: CPPStructurePlot.
Gated behind the ``pro`` extra (needs ``biopython``; ``py3Dmol`` for the
interactive backend, with a matplotlib fallback otherwise). Paints the per-residue
CPP / CPP-SHAP feature impact from a ``df_feat`` onto a protein structure, reusing
the shared CPP position backend (``feature_engineering``) and the structure parser
(``data_handling_pro``). Imported lazily from the top-level package and replaced by
an install-hint stub when ``biopython`` is absent.

See ``.claude/rules/pro-core-boundary.md`` for the pro/core boundary, ``CONTEXT.md``
for domain terms (CPPStructurePlot, StructureView).
"""
from ._cpp_structure_plot import CPPStructurePlot

__all__ = [
"CPPStructurePlot",
]
Empty file.
Empty file.
85 changes: 85 additions & 0 deletions aaanalysis/feature_engineering_pro/_backend/cpp_struct/colors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
This is a script for the backend colour ramps of the CPPStructurePlot class:
the signed white->SHAP-red / white->SHAP-blue impact ramp (with the
``sign * sqrt(|t|)`` perceptual transform that keeps faint impacts visible) and
the AlphaFold-style pLDDT confidence ramp. All colours come from the ``ut``
constants barrel; no hex is hardcoded here.
"""
import numpy as np
import matplotlib.colors as mcolors

import aaanalysis.utils as ut

# White anchor of the signed impact ramp (a neutral, not a domain colour).
_WHITE = "#FFFFFF"

# Reuse the package's SHAP ramps (``sns.light_palette`` via ``plot_get_cmap_``) so the
# 3D structure colours match CPPPlot.profile / feature_map exactly instead of a
# divergent linear interpolation. Both go white (index 0) -> saturated (index 100).
_N_RAMP = 101
_RAMP_POS = [mcolors.to_hex(c) for c in
ut.plot_get_cmap_(cmap=ut.STR_CMAP_SHAP, n_colors=_N_RAMP, only_pos=True)]
_RAMP_NEG = [mcolors.to_hex(c) for c in
ut.plot_get_cmap_(cmap=ut.STR_CMAP_SHAP, n_colors=_N_RAMP, only_neg=True)][::-1]


# I Helper Functions
def _lerp_hex(color_lo, color_hi, frac):
"""Linearly interpolate between two colours in RGB; return a hex string."""
lo = np.asarray(mcolors.to_rgb(color_lo), dtype=np.float64)
hi = np.asarray(mcolors.to_rgb(color_hi), dtype=np.float64)
frac = float(np.clip(frac, 0.0, 1.0))
return mcolors.to_hex(lo + (hi - lo) * frac)


# II Main Functions
def perceptual_transform(t):
"""Signed square-root transform ``sign(t) * sqrt(|t|)`` on ``t`` in ``[-1, 1]``.

Compresses large magnitudes and stretches small ones so faint but real
impacts stay visible; the sign is preserved and the output stays in
``[-1, 1]``.
"""
t = np.asarray(t, dtype=np.float64)
return np.sign(t) * np.sqrt(np.abs(t))


def impact_to_hex(impact, max_abs, color_pos=None, color_neg=None):
"""Map a signed impact to a hex colour on the white->SHAP-pos / white->SHAP-neg ramp.

``impact`` is normalised by ``max_abs`` to ``[-1, 1]`` and passed through the
``sign * sqrt`` transform to get a blend fraction in ``[0, 1]``. By default the
colour is read off the package SHAP ramp (so it matches the 2D CPP plots); a
custom ``color_pos`` / ``color_neg`` falls back to a white->colour interpolation.
Zero / non-finite impact and a non-positive ``max_abs`` map to white.
"""
if max_abs is None or max_abs <= 0 or not np.isfinite(impact) or impact == 0:
return _WHITE
t = float(np.clip(impact / max_abs, -1.0, 1.0))
frac = float(np.sqrt(abs(t))) # |sign * sqrt(t)| -> blend fraction in [0, 1]
idx = int(round(frac * (_N_RAMP - 1)))
if t > 0:
return _RAMP_POS[idx] if color_pos is None else _lerp_hex(_WHITE, color_pos, frac)
return _RAMP_NEG[idx] if color_neg is None else _lerp_hex(_WHITE, color_neg, frac)


def plddt_cmap():
"""Continuous low->high pLDDT colormap built from the ``ut.LIST_COLOR_PLDDT`` ramp."""
return mcolors.LinearSegmentedColormap.from_list("plddt", ut.LIST_COLOR_PLDDT)


def plddt_to_hex(plddt):
"""Map a pLDDT value (0-100) to a hex colour; non-finite -> gray."""
if plddt is None or not np.isfinite(plddt):
return ut.COLOR_STRUCT_MISSING
frac = float(np.clip(plddt / 100.0, 0.0, 1.0))
return mcolors.to_hex(plddt_cmap()(frac))


def color_for_residue(resi, dict_impact, max_abs, plddt, mode,
color_pos=None, color_neg=None):
"""Resolve the colour of a single residue for ``mode`` ('impact' or 'plddt')."""
if mode == "plddt":
return plddt_to_hex(plddt)
return impact_to_hex(dict_impact.get(resi, 0.0), max_abs,
color_pos=color_pos, color_neg=color_neg)
69 changes: 69 additions & 0 deletions aaanalysis/feature_engineering_pro/_backend/cpp_struct/mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
This is a script for the backend feature->residue mapping of the CPPStructurePlot
class. It reuses the shared CPP feature backend (``get_positions_`` +
``get_df_pos_`` with the normalized-sum semantics) so the per-residue impact is
identical to what the CPP profile / feature map show, never a re-implemented
per-position loop.
"""
import numpy as np

import aaanalysis.utils as ut
# Shared CPP feature backend (the deliberately shared ``_backend/cpp/`` package,
# registered as SHARED_BACKEND_SUBPKGS for feature_engineering).
from aaanalysis.feature_engineering._backend.cpp.utils_feature import (
get_positions_, get_df_pos_)


# I Helper Functions
def _positions_union(feat_positions):
"""Flatten the comma-separated position strings into a sorted list of ints."""
positions = set()
for pos_str in feat_positions:
if not pos_str:
continue
for p in str(pos_str).split(","):
if p != "":
positions.add(int(p))
return sorted(positions)


# II Main Functions
def compute_residue_impact(df_feat=None, col_imp=None, start=1, tmd_len=20,
jmd_n_len=10, jmd_c_len=10, col_cat=None):
"""Map per-feature impact onto absolute residue numbers.

The feature positions are derived with the shared ``get_positions_`` helper
(so ``start`` shifts them to absolute residue numbers) and aggregated with
``get_df_pos_(value_type="sum")``, which divides each feature's value by the
number of positions it spans before summing per position — the same
normalized-sum the CPP profile uses. Summing across scale categories yields
one signed impact per residue.

Returns
-------
dict_impact : dict
``{resi: impact}`` for every residue in ``[start, stop]``.
max_abs : float
Maximum absolute per-residue impact (0.0 if none finite); used to
normalise the colour ramp.
positions_union : list of int
Sorted residue numbers actually spanned by ``df_feat`` (the auto window).
"""
col_cat = ut.COL_CAT if col_cat is None else col_cat
df_feat = df_feat.copy()
features = df_feat[ut.COL_FEATURE].to_list()
feat_positions = get_positions_(features=features, start=start, tmd_len=tmd_len,
jmd_n_len=jmd_n_len, jmd_c_len=jmd_c_len)
df_feat[ut.COL_POSITION] = feat_positions
if col_cat not in df_feat.columns:
df_feat[col_cat] = "feature"
stop = start + jmd_n_len + tmd_len + jmd_c_len - 1
df_pos = get_df_pos_(df_feat=df_feat, col_cat=col_cat, col_val=col_imp,
value_type="sum", start=start, stop=stop)
# Rows = scale categories, columns = positions; sum to one value per residue.
series = df_pos.sum(axis=0)
dict_impact = {int(p): float(v) for p, v in series.items()}
finite = [abs(v) for v in dict_impact.values() if np.isfinite(v)]
max_abs = max(finite) if finite else 0.0
positions_union = _positions_union(feat_positions)
return dict_impact, max_abs, positions_union
132 changes: 132 additions & 0 deletions aaanalysis/feature_engineering_pro/_backend/cpp_struct/render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
This is a script for the backend renderers of the CPPStructurePlot class: the
interactive py3Dmol cartoon (per-residue ``setStyle`` colouring, optional impact-
scaled sticks, fade / zoom focus) and the static matplotlib ``mplot3d`` Cα-scatter
fallback. Both paint the same colour ramp and are wrapped by ``StructureView``.
``py3Dmol`` is imported lazily so the matplotlib fallback works when it is absent.
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import aaanalysis.utils as ut
from .colors import color_for_residue
from .view import StructureView


# I Helper Functions
def py3dmol_available():
"""Return ``True`` if the optional ``py3Dmol`` renderer is importable."""
try:
import py3Dmol # noqa: F401
return True
except ImportError:
return False


def _stick_radius(impact, max_abs, max_radius=0.4):
"""Stick radius proportional to ``|impact| / max_abs`` (0 when undefined)."""
if max_abs is None or max_abs <= 0 or not np.isfinite(impact):
return 0.0
return float(max_radius * min(1.0, abs(impact) / max_abs))


def _read_structure_text(pdb_path):
"""Read raw PDB / CIF text to feed py3Dmol's ``addModel``."""
fmt = "cif" if str(pdb_path).lower().endswith(".cif") else "pdb"
with open(str(pdb_path), "r", encoding="utf-8") as f:
return f.read(), fmt


# II Main Functions
def render_py3dmol(pdb_path, records, dict_impact, max_abs, mode,
focus, window_resis, size_by_impact, chain_id=None,
color_pos=None, color_neg=None, width=600, height=450):
"""Build a py3Dmol cartoon view coloured per residue and wrap it in a StructureView.

``addModel`` loads the whole (possibly multi-chain) structure, so every
per-residue ``setStyle`` / ``zoomTo`` selection is qualified by ``chain_id`` —
otherwise residue number 50 would be coloured on every chain that has one.
"""
import py3Dmol
pdb_text, fmt = _read_structure_text(pdb_path)
view = py3Dmol.view(width=width, height=height)
view.addModel(pdb_text, fmt)
view.setStyle({}, {"cartoon": {"color": ut.COLOR_STRUCT_MISSING}})
present_resis = {res["resi"] for res in records}
for res in records:
resi = res["resi"]
color = color_for_residue(resi, dict_impact, max_abs, res["plddt"], mode,
color_pos=color_pos, color_neg=color_neg)
in_window = window_resis is None or resi in window_resis
cartoon = {"color": color}
if focus == "fade" and not in_window:
cartoon["opacity"] = 0.2
style = {"cartoon": cartoon}
if size_by_impact and mode == "impact":
radius = _stick_radius(dict_impact.get(resi, 0.0), max_abs)
if radius > 0:
style["stick"] = {"radius": radius, "color": color}
sel = {"resi": str(resi)}
if chain_id is not None:
sel["chain"] = chain_id
view.setStyle(sel, style)
# Only zoom to window residues that actually exist in the structure, else the
# camera silently fails to focus on an empty selection.
zoom_resis = sorted((window_resis or set()) & present_resis)
if focus == "zoom" and zoom_resis:
sel = {"resi": [str(r) for r in zoom_resis]}
if chain_id is not None:
sel["chain"] = chain_id
view.zoomTo(sel)
else:
view.zoomTo()
view.setBackgroundColor("white")
return StructureView(backend="py3dmol", view=view, dict_impact=dict_impact,
max_abs=max_abs, mode=mode)


def render_mpl(records, dict_impact, max_abs, mode, focus, window_resis,
size_by_impact, color_pos=None, color_neg=None, figsize=(6, 6)):
"""Build a matplotlib ``mplot3d`` Cα scatter coloured per residue; wrap as StructureView."""
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection="3d")
xs, ys, zs, rgba, sizes = [], [], [], [], []
win_coords = []
for res in records:
coord = res["coord"]
if not np.all(np.isfinite(coord)):
continue
resi = res["resi"]
color = color_for_residue(resi, dict_impact, max_abs, res["plddt"], mode,
color_pos=color_pos, color_neg=color_neg)
in_window = window_resis is None or resi in window_resis
alpha = 1.0 if (focus != "fade" or in_window) else 0.15
xs.append(coord[0])
ys.append(coord[1])
zs.append(coord[2])
rgba.append(mcolors.to_rgba(color, alpha=alpha))
size = 30.0
if size_by_impact and mode == "impact":
size = 30.0 + 220.0 * _stick_radius(dict_impact.get(resi, 0.0), max_abs)
sizes.append(size)
if in_window:
win_coords.append(coord)
if xs:
# Backbone trace plus the per-residue Cα scatter.
ax.plot(xs, ys, zs, color=ut.COLOR_STRUCT_MISSING, linewidth=0.8, alpha=0.6)
ax.scatter(xs, ys, zs, c=rgba, s=sizes, depthshade=False, edgecolors="none")
if focus == "zoom" and win_coords:
win = np.vstack(win_coords)
lo, hi = win.min(axis=0), win.max(axis=0)
pad = 0.1 * np.maximum(hi - lo, 1.0)
ax.set_xlim(lo[0] - pad[0], hi[0] + pad[0])
ax.set_ylim(lo[1] - pad[1], hi[1] + pad[1])
ax.set_zlim(lo[2] - pad[2], hi[2] + pad[2])
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
label = "pLDDT" if mode == "plddt" else "CPP feature impact"
ax.set_title(f"{label} on structure")
return StructureView(backend="mpl", fig=fig, ax=ax, dict_impact=dict_impact,
max_abs=max_abs, mode=mode)
Loading
Loading