diff --git a/.gitignore b/.gitignore index 20bf20e..f20c4d1 100644 --- a/.gitignore +++ b/.gitignore @@ -44,6 +44,8 @@ Thumbs.db *.pkl data/ examples/data/ +*.fasta +*diamond* # Generated outputs *.html diff --git a/examples/cazyme_analysis.ipynb b/examples/cazyme_analysis.ipynb new file mode 100644 index 0000000..f41f764 --- /dev/null +++ b/examples/cazyme_analysis.ipynb @@ -0,0 +1,739 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": [ + "# CAZyme family GH43 Analysis Workflow\n", + "\n", + "Analysis of the Glycosly Hydrolase 43 family of CAZymes.\n", + "\n", + "**Requirements:** `pip install tmap`" + ] + }, + { + "cell_type": "markdown", + "id": "a-header", + "metadata": {}, + "source": [ + "---\n", + "## Part A: From FASTA to TMAP\n", + "\n", + "Load GH43 enzymes. Subset from dbCAN's full CAZyme dataset. " + ] + }, + { + "cell_type": "markdown", + "id": "fasta-header", + "metadata": {}, + "source": [ + "### 1. Load sequences from FASTA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fasta-parse", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from Bio import SeqIO\n", + "\n", + "MAX_SEQS = 50_000\n", + "\n", + "records = []\n", + "with open(\"./cazyme_data/GH43.fa\") as f:\n", + " for rec in SeqIO.parse(f, \"fasta\"):\n", + " records.append(rec)\n", + " if len(records) >= MAX_SEQS:\n", + " break\n", + "\n", + "ids_fasta = [rec.id for rec in records]\n", + "sequences = [str(rec.seq) for rec in records]\n", + "\n", + "print(f\"{len(sequences)} sequences loaded\")\n", + "print(f\"Length range: {min(len(s) for s in sequences)} -- {max(len(s) for s in sequences)} aa\")\n", + "print(f\"First ID: {ids_fasta[0]}\")\n", + "print(f\"First seq: {sequences[0][:60]}...\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "seqprops-header", + "metadata": {}, + "source": [ + "### 2. Compute sequence properties\n", + "\n", + "`sequence_properties` computes physicochemical descriptors from amino acid\n", + "sequences -- using BioPython protparams." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "seqprops", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from Bio.SeqUtils.ProtParam import ProteinAnalysis\n", + "\n", + "# AT PH 7.4, the following AAs are considered charged:\n", + "STANDARD_AAs = set(\"ACDEFGHIKLMNPQRSTVWY\")\n", + "_HYDROPHOBIC = set(\"AVILMFYW\")\n", + "_POLAR = set(\"NQST\")\n", + "_NEG_CHARGED = set(\"DE\")\n", + "_POS_CHARGED = set(\"KRH\")\n", + "_SPECIAL = set(\"C\") # Cysteine can form disulfide bonds, affecting properties\n", + "\n", + "PROP_KEYS = [\n", + " \"length\", \"molecular_weight\", \"isoelectric_point\", \"gravy\",\n", + " \"charge_at_ph7\", \"aromaticity\", \"aliphatic_index\",\n", + " \"frac_neg_charged\", \"frac_pos_charged\", \"frac_hydrophobic\", \"frac_polar\", \"n_cysteines\",\n", + "]\n", + "\n", + "def sequence_properties(records):\n", + " n_invalid = 0\n", + " rows = []\n", + " for rec in records:\n", + " seq = str(rec.seq).upper()\n", + " row = {\"id\": rec.id.split(\"|\")[0], **{k: np.nan for k in PROP_KEYS}}\n", + " if seq and set(seq).issubset(STANDARD_AAs):\n", + " try:\n", + " pa = ProteinAnalysis(seq)\n", + " aa = pa.get_amino_acids_percent()\n", + " row[\"length\"] = len(seq)\n", + " row[\"molecular_weight\"] = pa.molecular_weight()\n", + " row[\"isoelectric_point\"] = pa.isoelectric_point()\n", + " row[\"gravy\"] = pa.gravy()\n", + " row[\"charge_at_ph7\"] = pa.charge_at_pH(7.4)\n", + " row[\"aromaticity\"] = pa.aromaticity()\n", + " row[\"aliphatic_index\"] = (\n", + " aa.get(\"A\", 0) + 2.9 * aa.get(\"V\", 0) +\n", + " 3.9 * (aa.get(\"I\", 0) + aa.get(\"L\", 0))\n", + " ) * 100\n", + " row[\"frac_neg_charged\"] = sum(aa.get(a, 0) for a in _NEG_CHARGED)\n", + " row[\"frac_pos_charged\"] = sum(aa.get(a, 0) for a in _POS_CHARGED)\n", + " row[\"frac_hydrophobic\"] = sum(aa.get(a, 0) for a in _HYDROPHOBIC)\n", + " row[\"frac_polar\"] = sum(aa.get(a, 0) for a in _POLAR)\n", + " row[\"n_cysteines\"] = seq.count(\"C\")\n", + " except Exception:\n", + " n_invalid += 1\n", + " else:\n", + " n_invalid += 1\n", + " rows.append(row)\n", + " print(f\"sequence_properties: {n_invalid}/{len(records)} sequences invalid (non-standard AAs or empty)\")\n", + " return pd.DataFrame(rows)\n", + "\n", + "props_df = sequence_properties(records)\n", + "\n", + "for col in PROP_KEYS:\n", + " values = props_df[col]\n", + " print(f\"{col:25s} min={np.nanmin(values):8.1f} max={np.nanmax(values):10.1f}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "b-header", + "metadata": {}, + "source": [ + "---\n", + "## Part B: Pre-computed embeddings + annotations\n", + "\n", + "For production protein TMAPs, we have:\n", + "\n", + "1. **Embeddings** ESMC mean-pooled embeddings (.pt torch files)\n", + "2. **Annotation file** Custom cazyme annotation file as UniProt's auto-generated annotations are not accurate.\n", + "\n", + "ESM-c embeddings can be generated with:\n", + "- `esm` Python package (Meta)\n", + "- ESM Metagenomic Atlas API\n", + "- Local inference with `fair-esm`" + ] + }, + { + "cell_type": "markdown", + "id": "load-emb-header", + "metadata": {}, + "source": [ + "### 3. Load embeddings\n", + "\n", + "This assumes pre-computed embeddings saved as .pt torch files.\n", + "The paths `./cazyme_data/GH43_esmc_300m.pt\"` and `.cazyme_data/merged_cazyme_annotations.tsv` are dataset-specific --\n", + "replace them with your own." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e48f91bd-5783-4938-aeae-4150526458ef", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "tensor = torch.load(\n", + " \"./cazyme_data/GH43_esmc_300m.pt\",\n", + " map_location=\"cpu\",\n", + ")\n", + "keys = list(tensor.keys())\n", + "ids = np.array([k.split(\"|\")[0] for k in keys])\n", + "\n", + "# ESMC stores embeddings as plain tensors or as dicts with a \"per_sequence\" key\n", + "embeddings = np.stack([\n", + " tensor[k].numpy() if isinstance(tensor[k], torch.Tensor)\n", + " else tensor[k][\"per_sequence\"].numpy()\n", + " for k in keys\n", + "]).astype(np.float32)\n", + "\n", + "print(f\"Embeddings: {embeddings.shape} ({embeddings.dtype})\")\n", + "print(f\"IDs: {len(ids)}\")\n", + "print(f\"Sample IDs: {ids[:5]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83367889", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import pickle\n", + "\n", + "MERGED_ANN = \"./cazyme_data/merged_cazyme_annotations.tsv\"\n", + "ann_df = pd.read_csv(MERGED_ANN, sep=\"\\t\")\n", + "\n", + "meta = (\n", + " pd.DataFrame({\"id\": ids})\n", + " .merge(ann_df, left_on=\"id\", right_on=\"genbank_id\", how=\"left\")\n", + " .merge(props_df, on=\"id\", how=\"left\")\n", + ")\n", + "\n", + "meta[\"is_characterized\"] = meta[\"is_characterized\"].fillna(False)\n", + "meta.drop(columns=[c for c in [\"genbank_id\"] if c in meta.columns], inplace=True)\n", + "\n", + "print(f\"Annotation match rate: {meta['domain'].notna().mean():.1%}\")\n", + "print(f\"Mechanism match rate: {meta['Mechanism'].notna().mean():.1%}\")\n", + "print(f\"Characterized rate: {meta['is_characterized'].mean():.1%}\")\n", + "print(f\"Properties rate: {meta['length'].notna().mean():.1%}\")" + ] + }, + { + "cell_type": "markdown", + "id": "fit-emb-header", + "metadata": {}, + "source": [ + "### 4. Fit TMAP on embeddings\n", + "\n", + "Euclidean/Cosine are the recommended metrics for ESM-c embeddings.\n", + "USearch handles the dense nearest-neighbor search automatically. This is the longest step so we save the file as a .pkl to prevent re-running." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fit-emb", + "metadata": {}, + "outputs": [], + "source": [ + "from tmap import TMAP\n", + "import pickle\n", + "\n", + "model = TMAP(metric=\"cosine\", n_neighbors=100, seed=42).fit(embeddings)\n", + "print(f\"Embedding: {model.embedding_.shape}\")\n", + "print(f\"Tree edges: {model.tree_.edges.shape[0]}\")\n", + "with open(\"tmap.pkl\", \"wb\") as f:\n", + " pickle.dump(model, f)" + ] + }, + { + "cell_type": "markdown", + "id": "viz-emb-header", + "metadata": {}, + "source": [ + "### 5. Visualize with `add_metadata`\n", + "\n", + "Now we combine everything into a single visualization. `add_metadata`\n", + "auto-detects continuous vs categorical." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "viz-emb", + "metadata": {}, + "outputs": [], + "source": [ + "viz = model.to_tmapviz()\n", + "viz.title = \"CAZyme Embedding Space — ESMC-300M\"\n", + "\n", + "LABEL_COLS = {\"organism\", \"cazy_family_base\", \"id\", \"taxid\",\n", + " \"char_protein_name\", \"char_uniprot\", \"char_pdb\", \"raw_annotation\", \"ec_numbers\"}\n", + "\n", + "for col in meta.columns:\n", + " is_numeric = (pd.api.types.is_numeric_dtype(meta[col])\n", + " and not pd.api.types.is_bool_dtype(meta[col]))\n", + " if col in LABEL_COLS:\n", + " viz.add_label(col, meta[col].fillna(\"\").astype(str).tolist())\n", + " elif is_numeric:\n", + " viz.add_color_layout(col, meta[col].tolist())\n", + " else:\n", + " viz.add_color_layout(col, meta[col].fillna(\"nan\").astype(str).tolist(), categorical=True)\n", + "\n", + "print(f\"Color layers ({len(viz.layouts)}): {[l.name for l in viz.layouts]}\")\n", + "viz.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "export-emb", + "metadata": {}, + "outputs": [], + "source": [ + "viz.write_html(\"protein_map_GH43_cosine.html\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "static-emb", + "metadata": {}, + "outputs": [], + "source": [ + "model.plot_static(color_by=meta[\"cazy_family\"], color_map=\"tab10\", point_size=1)\n" + ] + }, + { + "cell_type": "markdown", + "id": "37bfc80b", + "metadata": {}, + "source": [ + "## Part C: Tree Exploration" + ] + }, + { + "cell_type": "markdown", + "id": "4fa6d6d0", + "metadata": {}, + "source": [ + "### 6. Distance correlations Cosine-TMAP hops\n", + "We pick a sequence and calculate all the raw cosine distances to it. Next we calculate all the tmap 'hops' between the sequence and all the other sequences and plot a correlation between the two." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "173c6474", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "query_id = \"ACL75587.1\" # change here to update both this cell and the Diamond cell below\n", + "hits = np.flatnonzero(ids == query_id)\n", + "if len(hits) == 0:\n", + " raise ValueError(f\"{query_id} not found in ids\")\n", + "QUERY_IDX = int(hits[0])\n", + "print(f\"Query: {query_id} (idx={QUERY_IDX})\")\n", + "\n", + "# TMAP distances from query to all sequences\n", + "tmap_dists = model.distances_from(QUERY_IDX)\n", + "\n", + "# Cosine distances — normalise first so dot product == cosine similarity\n", + "norms = np.linalg.norm(embeddings, axis=1, keepdims=True)\n", + "embeddings_normed = embeddings / norms\n", + "query_normed = embeddings_normed[QUERY_IDX:QUERY_IDX+1]\n", + "cosine_sims = (query_normed @ embeddings_normed.T).flatten()\n", + "cosine_dists = 1.0 - cosine_sims\n", + "\n", + "print(f\"Self cosine similarity: {cosine_sims[QUERY_IDX]:.6f} (should be 1.0)\")\n", + "\n", + "corr_df = pd.DataFrame({\n", + " \"tmap_distance\": tmap_dists,\n", + " \"cosine_distance\": cosine_dists,\n", + "})\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "ax.scatter(corr_df[\"cosine_distance\"], corr_df[\"tmap_distance\"],\n", + " alpha=0.5, s=2, rasterized=True, color='black')\n", + "ax.set_xlabel(\"Cosine Distance (Embeddings)\")\n", + "ax.set_ylabel(\"TMAP Distance\")\n", + "ax.set_title(f\"Distance correlation from {query_id}\")\n", + "\n", + "corr_coef = corr_df[[\"cosine_distance\", \"tmap_distance\"]].corr().iloc[0, 1]\n", + "ax.text(0.05, 0.95, f\"r = {corr_coef:.3f}\",\n", + " transform=ax.transAxes, fontsize=11, verticalalignment='top',\n", + " bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n", + "print(f\"Pearson r = {corr_coef:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "ad807f57", + "metadata": {}, + "source": [ + "#### 7. Blastp-TMAP correlations\n", + "Next we plot the Diamond blastp sequence similarity vs the tmap hops." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a51a6d0", + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from scipy.stats import spearmanr\n", + "\n", + "# QUERY_IDX and tmap_dists are set in the cosine distance cell above\n", + "BLAST_THREADS = 4\n", + "\n", + "id_to_seq = {rid.split(\"|\")[0]: seq for rid, seq in zip(ids_fasta, sequences)}\n", + "\n", + "with open(\"diamond_db.fasta\", \"w\") as f:\n", + " for sid in ids:\n", + " if sid in id_to_seq:\n", + " f.write(f\">{sid}\\n{id_to_seq[sid]}\\n\")\n", + "\n", + "with open(\"diamond_query.fasta\", \"w\") as f:\n", + " f.write(f\">{query_id}\\n{id_to_seq[query_id]}\\n\")\n", + "\n", + "subprocess.run(\n", + " [\"diamond\", \"makedb\", \"--in\", \"diamond_db.fasta\", \"--db\", \"diamond_db\"],\n", + " check=True, capture_output=True,\n", + ")\n", + "\n", + "subprocess.run([\n", + " \"diamond\", \"blastp\",\n", + " \"--query\", \"diamond_query.fasta\", \"--db\", \"diamond_db\",\n", + " \"--outfmt\", \"6\", \"qseqid\", \"sseqid\", \"bitscore\", \"evalue\", \"pident\",\n", + " \"--out\", \"diamond_results.tsv\",\n", + " \"--evalue\", \"100000\",\n", + " \"--max-target-seqs\", \"0\",\n", + " \"--min-score\", \"0\",\n", + " \"--threads\", str(BLAST_THREADS),\n", + " \"--ultra-sensitive\",\n", + "], check=True, capture_output=True)\n", + "\n", + "blast_df = pd.read_csv(\"diamond_results.tsv\", sep=\"\\t\",\n", + " names=[\"qseqid\", \"sseqid\", \"bitscore\", \"evalue\", \"pident\"])\n", + "print(f\"Diamond hits: {len(blast_df):,} / {len(ids):,} sequences\")\n", + "\n", + "# Similarity: pident as fraction (0–1); NaN for no-hit\n", + "id_to_pident = dict(zip(blast_df[\"sseqid\"], blast_df[\"pident\"]))\n", + "blast_sim = np.array([id_to_pident.get(sid, np.nan) for sid in ids]) / 100.0\n", + "\n", + "# TMAP similarity: 1 / (1 + distance), maps [0, ∞) → (0, 1]\n", + "tmap_d = np.where(np.isfinite(tmap_dists), tmap_dists, np.nan)\n", + "tmap_sim = 1.0 / (1.0 + tmap_d)\n", + "\n", + "mask = np.isfinite(tmap_sim) & np.isfinite(blast_sim)\n", + "mask[QUERY_IDX] = False\n", + "\n", + "r, p = spearmanr(tmap_sim[mask], blast_sim[mask])\n", + "print(f\"Spearman r = {r:.3f} p = {p:.2e} N = {mask.sum():,}\")\n", + "\n", + "fig, ax = plt.subplots(figsize=(6, 5))\n", + "ax.scatter(tmap_sim[mask], blast_sim[mask], s=5, alpha=0.3, rasterized=True, color='black')\n", + "ax.set_xlabel(\"TMAP similarity (1 / (1 + tree distance))\")\n", + "ax.set_ylabel(\"Sequence identity of match (pident / 100)\")\n", + "ax.set_title(f\"TMAP vs Diamond BLASTp | Spearman r = {r:.3f}\")\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "d02a724c", + "metadata": {}, + "source": [ + "#### 8. Visualize those two metrics in the tmap." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e111d389", + "metadata": {}, + "outputs": [], + "source": [ + "tmap_d = np.where(np.isfinite(tmap_dists), tmap_dists, np.nan)\n", + "tmap_sim = (1.0 / (1.0 + tmap_d)) # higher = more similar\n", + "cosine_sim = (1.0 - cosine_dists).astype(float) # higher = more similar\n", + "blast_sim_c = blast_sim.copy() # pident/100, higher = more similar\n", + "\n", + "viz.add_color_layout(\"tmap_similarity_from_query\", tmap_sim.tolist(), color=\"magma\")\n", + "viz.add_color_layout(\"cosine_similarity_from_query\", cosine_sim.tolist(), color=\"magma\")\n", + "viz.add_color_layout(\"diamond_pident_from_query\", blast_sim_c.tolist(), color=\"magma\")\n", + "\n", + "viz.write_html(\"protein_map_GH43_cosine_with_distances.html\")\n", + "\n", + "\n", + "meta[\"tmap_similarity_from_query\"] = tmap_sim\n", + "meta[\"cosine_similarity_from_query\"] = cosine_sim\n", + "meta[\"diamond_pident_from_query\"] = blast_sim_c\n" + ] + }, + { + "cell_type": "markdown", + "id": "c0d6f586", + "metadata": {}, + "source": [ + "#### 9. Plot the distances on the tmap statically" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4bcc4705", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "\n", + "def plot_static_grouped(model, color_by, color_map=None, top_n=19,\n", + " point_size=1, savepath=None, **kwargs):\n", + " color_by = pd.Series(color_by).reset_index(drop=True)\n", + " is_categorical = not pd.api.types.is_numeric_dtype(color_by)\n", + "\n", + " if is_categorical:\n", + " counts = color_by.dropna().value_counts()\n", + " if len(counts) > top_n:\n", + " top_cats = counts.nlargest(top_n).index\n", + " keep = color_by.isin(top_cats) | color_by.isna()\n", + " color_by = color_by.where(keep, \"Other\") # rare -> \"Other\"\n", + " color_map = color_map or \"tab20\"\n", + " figsize = (9, 9)\n", + " else:\n", + " # continuous plots get a colorbar, which narrows the axes box and —\n", + " # combined with set_aspect(\"equal\", adjustable=\"datalim\") — makes the\n", + " # data limits expand vertically to compensate. Extra width keeps the\n", + " # box proportions (and the visible data range) the same as categorical.\n", + " figsize = (10, 8)\n", + "\n", + " fig, ax = plt.subplots(figsize=figsize)\n", + " ax = model.plot_static(color_by=color_by, color_map=color_map,\n", + " point_size=point_size, ax=ax, **kwargs)\n", + "\n", + " legend = ax.get_legend()\n", + " if legend is not None:\n", + " legend.remove()\n", + " ax.legend(markerscale=5, frameon=False, fontsize=7,\n", + " bbox_to_anchor=(1.0, 1.0), loc=\"upper left\", borderaxespad=0.2)\n", + "\n", + " for spine in ax.spines.values():\n", + " spine.set_visible(True)\n", + "\n", + " if savepath:\n", + " ax.figure.savefig(savepath, dpi=600, bbox_inches=\"tight\")\n", + " return ax\n", + "\n", + "\n", + "# categorical — grouped into top 19 + \"Other\", tab20 colors, NaNs skipped\n", + "plot_static_grouped(model, meta[\"cazy_family\"],\n", + " savepath=\"tmap_cazy_family_grouped.png\")\n", + "\n", + "plot_static_grouped(model, meta[\"domain\"], color_map=\"tab10\",\n", + " savepath=\"tmap_domain_grouped.png\")\n", + "\n", + "# continuous — passes straight through to plot_static's own colorbar path\n", + "plot_static_grouped(model, meta[\"diamond_pident_from_query\"],\n", + " savepath=\"diamond_pident_from_query.png\")\n", + "\n", + "plot_static_grouped(model, meta[\"cosine_similarity_from_query\"], color_map=\"magma\",\n", + " savepath=\"cosine_similarity_from_query.png\")\n", + "\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "id": "2c870bb7", + "metadata": {}, + "source": [ + "---\n", + "## Classical Methods\n", + "\n", + "PCA, UMAP, and t-SNE applied to ESMC-300M embeddings, colored by taxonomic domain.\n", + "t-SNE uses 50 PCA components as input to keep runtime manageable." + ] + }, + { + "cell_type": "markdown", + "id": "8c0e6916", + "metadata": {}, + "source": [ + "#### C.1 PCA" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8028c09", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.decomposition import PCA\n", + "\n", + "# Verify alignment — meta must have exactly one row per embedding\n", + "assert len(meta) == len(embeddings), (\n", + " f\"meta ({len(meta)} rows) != embeddings ({len(embeddings)} rows); \"\n", + " \"a merge created duplicate rows — deduplicate meta.drop_duplicates('id') first\"\n", + ")\n", + "\n", + "GREY_LABELS = {\"nan\", \"NaN\", \"Unknown\"}\n", + "\n", + "# Fill actual NaN → \"nan\" so every point gets an explicit label\n", + "domain_series = meta[\"domain\"].fillna(\"nan\").reset_index(drop=True)\n", + "\n", + "unique_domains = sorted(v for v in domain_series.unique() if v not in GREY_LABELS)\n", + "_cmap = plt.cm.Set1.colors if len(unique_domains) <= 9 else plt.cm.tab20.colors\n", + "domain_color_map = {d: _cmap[i % len(_cmap)] for i, d in enumerate(unique_domains)}\n", + "for lbl in GREY_LABELS:\n", + " domain_color_map[lbl] = \"#d3d3d3\"\n", + "\n", + "# Grey labels drawn first so colored points render on top\n", + "_grey_present = [lbl for lbl in sorted(GREY_LABELS) if (domain_series == lbl).any()]\n", + "plot_order = _grey_present + unique_domains\n", + "\n", + "print(f\"Proteins: {len(domain_series)}\")\n", + "print(domain_series.value_counts(dropna=False).to_string())\n", + "\n", + "# PCA-50 used as t-SNE input; PCA-2 plotted directly\n", + "pca50 = PCA(n_components=50, random_state=42)\n", + "X_pca50 = pca50.fit_transform(embeddings)\n", + "\n", + "pca2 = PCA(n_components=2, random_state=42)\n", + "X_pca2 = pca2.fit_transform(embeddings)\n", + "print(f\"\\nPCA-2 explained variance: {pca2.explained_variance_ratio_.sum():.1%}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "477e6705", + "metadata": {}, + "outputs": [], + "source": [ + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "for d in plot_order:\n", + " mask = (domain_series == d).values\n", + " ax.scatter(X_pca2[mask, 0], X_pca2[mask, 1],\n", + " c=[domain_color_map[d]], s=2, alpha=0.4 if d in GREY_LABELS else 0.6,\n", + " label=f\"{d} ({mask.sum()})\", rasterized=True)\n", + "ax.set_title(\"PCA — colored by domain\")\n", + "ax.set_xlabel(f\"PC1 ({pca2.explained_variance_ratio_[0]:.1%})\")\n", + "ax.set_ylabel(f\"PC2 ({pca2.explained_variance_ratio_[1]:.1%})\")\n", + "ax.legend(markerscale=5, framealpha=0.8, fontsize=9)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "427d2e8f", + "metadata": {}, + "source": [ + "#### C.2 UMAP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72386536", + "metadata": {}, + "outputs": [], + "source": [ + "import umap\n", + "\n", + "reducer = umap.UMAP(n_components=2, n_neighbors=20, metric=\"cosine\", random_state=42)\n", + "X_umap = reducer.fit_transform(embeddings)\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "for d in plot_order:\n", + " mask = (domain_series == d).values\n", + " ax.scatter(X_umap[mask, 0], X_umap[mask, 1],\n", + " c=[domain_color_map[d]], s=2, alpha=0.4 if d in GREY_LABELS else 0.6,\n", + " label=f\"{d} ({mask.sum()})\", rasterized=True)\n", + "ax.set_title(\"UMAP — colored by domain\")\n", + "ax.set_xlabel(\"UMAP 1\")\n", + "ax.set_ylabel(\"UMAP 2\")\n", + "ax.legend(markerscale=5, framealpha=0.8, fontsize=9)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "1aad8768", + "metadata": {}, + "source": [ + "#### C.3 tSNE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "266ec2ed", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.manifold import TSNE\n", + "\n", + "tsne = TSNE(n_components=2, perplexity=30, random_state=42, verbose=1)\n", + "X_tsne = tsne.fit_transform(X_pca50)\n", + "\n", + "fig, ax = plt.subplots(figsize=(8, 6))\n", + "for d in plot_order:\n", + " mask = (domain_series == d).values\n", + " ax.scatter(X_tsne[mask, 0], X_tsne[mask, 1],\n", + " c=[domain_color_map[d]], s=2, alpha=0.4 if d in GREY_LABELS else 0.6,\n", + " label=f\"{d} ({mask.sum()})\", rasterized=True)\n", + "ax.set_title(\"t-SNE — colored by domain\")\n", + "ax.set_xlabel(\"t-SNE 1\")\n", + "ax.set_ylabel(\"t-SNE 2\")\n", + "ax.legend(markerscale=5, framealpha=0.8, fontsize=9)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}