diff --git a/README_CDS_5UTR.md b/README_CDS_5UTR.md new file mode 100644 index 0000000..70d8498 --- /dev/null +++ b/README_CDS_5UTR.md @@ -0,0 +1,188 @@ +# Evo2 CDS→5UTR 微调项目配置 + +本项目用于使用 CDS(编码序列)作为 prompt 微调 Evo2 模型,生成 5UTR 序列。 + +## 文件结构 + +``` +evo2/ +├── configs/ +│ └── cds_5utr_finetune_config.yaml # 主配置文件 +├── scripts/ +│ ├── prepare_cds_5utr_data.py # 数据准备脚本 +│ ├── run_cds_5utr_finetune.sh # 训练自动化脚本 +│ └── generate_5utr.py # 5UTR 生成脚本 +└── README_CDS_5UTR.md # 本文件 +``` + +## 快速开始 + +### 步骤 1: 准备数据 + +```bash +# 从基因组 FASTA 和 GTF 注释提取 CDS 和 5UTR +python scripts/prepare_cds_5utr_data.py \ + --genome your_genome.fasta \ + --annotation your_annotation.gtf \ + --output cds_5utr_training.fasta \ + --upstream_length 500 +``` + +**输入要求:** +- 基因组 FASTA 文件 +- GTF 注释文件(包含 CDS 和 mRNA/transcript 特征) + +**输出:** +- 训练用 FASTA 文件(CDS+5UTR 连接格式) +- 序列长度统计分析(帮助选择 seq_length) + +### 步骤 2: 配置训练参数 + +编辑 `configs/cds_5utr_finetune_config.yaml`: + +```yaml +preprocess: + input_path: "/path/to/cds_5utr_training.fasta" # 修改为你的路径 + output_prefix: "/path/to/preprocessed_data" + seq_length: 8192 # 根据步骤 1 的统计结果调整 +``` + +### 步骤 3: 运行训练 + +```bash +# 推荐:使用 LoRA 微调(节省显存) +./scripts/run_cds_5utr_finetune.sh lora + +# 其他选项: +# ./scripts/run_cds_5utr_finetune.sh single # 单卡全参数微调 +# ./scripts/run_cds_5utr_finetune.sh 2gpu # 双卡全参数微调 +# ./scripts/run_cds_5utr_finetune.sh long # 1M context 长序列训练 +``` + +### 步骤 4: 生成 5UTR + +```bash +# 单个序列测试 +python scripts/generate_5utr.py \ + --model models/cds_5utr_model.pt \ + --cds "ATGCGT..." \ + --output test_5utr.txt + +# 批量生成 +python scripts/generate_5utr.py \ + --model models/cds_5utr_model.pt \ + --input cds_sequences.fasta \ + --output generated_5utr.fasta \ + --n_tokens 500 \ + --temperature 1.0 +``` + +## 硬件要求 + +| 训练模式 | GPU | 显存 | 推荐 | +|---------|-----|------|------| +| LoRA 微调 | 1x H100/A100 | 40GB+ | ✅ 推荐入门 | +| LoRA 微调 | 2x H100/A100 | 80GB | ✅ 最佳性价比 | +| 全参数微调 | 2x H100/A100 | 80GB | ⚠️ 需要张量并行 | +| 1M context | 2x H100/A100 | 80GB | ⚠️ 需要 FP8 | + +## 关键参数调优 + +### seq_length 选择 + +根据 `prepare_cds_5utr_data.py` 输出的统计结果: + +``` +推荐的 seq_length 设置: + 覆盖 95% 数据:4096 (向上取整到 2 的幂:4096) + 覆盖 99% 数据:8192 (向上取整到 2 的幂:8192) +``` + +- 如果 95% 的 CDS+5UTR 总长度 < 4096,使用 `seq_length: 4096` +- 如果需要覆盖更长序列,使用 `seq_length: 8192` 或更高 + +### LoRA 参数 + +```yaml +lora_dim: 16 # 可尝试:8, 16, 32, 64(越大参数量越多) +lora_alpha: 32 # 通常是 lora_dim 的 2 倍 +lora_dropout: 0.1 # 防止过拟合 +``` + +### 生成参数 + +```bash +--temperature 1.0 # 0.5-1.5:越低越保守,越高越随机 +--top_k 4 # 2-10:控制采样多样性 +--n_tokens 500 # 根据目标 5UTR 长度调整 +``` + +## 环境设置 + +### 使用 BioNemo Docker(推荐) + +```bash +docker run --rm -it \ + --gpus=all --ipc=host \ + -v /path/to/evo2:/workspace/evo2 \ + nvcr.io/nvidia/clara/bionemo-framework:nightly \ + /bin/bash + +cd /workspace/evo2 +``` + +### 或使用本地环境 + +```bash +# 安装依赖 +pip install -e . + +# 安装 BioNemo 相关工具 +git clone https://github.com/NVIDIA/bionemo-framework.git +cd bionemo-framework +./.ci_build.sh +source ./.ci_test_env.sh +``` + +## 故障排查 + +### 显存不足 (OOM) + +1. 降低 `micro_batch_size` +2. 使用 LoRA 微调代替全参数微调 +3. 缩短 `seq_length` +4. 启用梯度累积 (`gradient_accumulation_steps`) + +### 训练不收敛 + +1. 检查数据格式是否正确 +2. 增加 `max_steps` +3. 调整学习率(默认值通常适用) +4. 确保 CDS 和 5UTR 质量良好 + +### 生成质量差 + +1. 增加训练步数 +2. 调整 `temperature`(尝试 0.7-1.2 范围) +3. 检查训练数据量和质量 +4. 尝试全参数微调代替 LoRA + +## 相关资源 + +- [BioNemo 框架文档](https://docs.nvidia.com/bionemo-framework/latest/) +- [Evo2 GitHub](https://github.com/evo-design/evo2) +- [Savanna 训练框架](https://github.com/Zymrael/savanna) + +## 引用 + +如果使用此项目,请引用: + +```bibtex +@article {king2025, + author = {King, Samuel H and Driscoll, Claudia L and Li, David B and Guo, Daniel and Merchant, Aditi T and Brixi, Garyk and Wilkinson, Max E and Hie, Brian L}, + title = {Generative design of novel bacteriophages with genome language models}, + year = {2025}, + doi = {10.1101/2025.09.12.675911}, + publisher = {Cold Spring Harbor Laboratory} +} +``` diff --git a/README_GSE278584_INTRON.md b/README_GSE278584_INTRON.md new file mode 100644 index 0000000..df1f4ff --- /dev/null +++ b/README_GSE278584_INTRON.md @@ -0,0 +1,426 @@ +# GSE278584 内含子生成训练方案 + +## 数据集概述 + +**GSE278584**: "Sequence determinants of intron-mediated enhancement learned from thousands of random introns" + +| 特征 | 详情 | +|------|------| +| **物种** | 人类(HEK293T 细胞) | +| **内含子数量** | 数万个随机设计内含子 | +| **内含子长度** | 160nt(随机序列)+ 天然剪接位点 | +| **数据类型** | MPRA(大规模平行报告分析) | +| **功能测量** | IME 评分、剪接效率、表达比率 | +| **测序平台** | Illumina HiSeq 2500, Element AVITI | + +--- + +## 为什么选择 GSE278584? + +### ✅ 优势 + +| 特点 | 训练价值 | +|------|---------| +| **大规模** | 数万内含子 = 充足训练数据 | +| **天然剪接位点** | GT-AG 规则,真实剪接信号 | +| **序列多样性** | 随机内部序列 = 高覆盖率 | +| **功能标注** | IME 评分 = 可训练功能导向生成 | +| **实验验证** | 体内功能数据,非纯计算预测 | + +### ⚠️ 局限 + +| 局限 | 影响 | 解决方案 | +|------|------|---------| +| 固定长度 (160nt) | 缺乏长度分布 | 混合天然内含子数据 | +| 仅 5UTR 测试 | 位置特异性 | 混合多位置天然基因 | +| 人工合成背景 | 缺少复杂调控元件 | 补充天然基因组数据 | + +--- + +## 推荐训练策略 + +### 方案 A:混合训练(推荐) + +``` +训练数据 = 50% GSE278584 + 50% 人类天然内含子 +``` + +**优点:** +- ✅ 功能性 + 天然特征结合 +- ✅ 学习 IME 规则和天然剪接模式 +- ✅ 长度分布更自然 + +**数据量:** +- GSE278584: ~30,000 个人工内含子 +- 人类基因组:~200,000 个天然内含子 +- 混合后:~100,000 个训练样本 + +--- + +### 方案 B:仅 GSE278584(快速原型) + +``` +训练数据 = 100% GSE278584 +``` + +**优点:** +- ✅ 数据质量高,功能标注完整 +- ✅ 序列干净,无复杂背景 +- ✅ 快速验证模型能力 + +**缺点:** +- ❌ 长度单一(160nt) +- ❌ 缺乏天然内含子复杂性 + +--- + +### 方案 C:功能分层训练 + +``` +阶段 1: GSE278584 高 IME 内含子(IME > 1.5) +阶段 2: GSE278584 全部数据 +阶段 3: 人类天然内含子 +``` + +**优点:** +- ✅ 先学习功能性强的内含子 +- ✅ 逐步增加复杂度 +- ✅ 可能学到更好的 IME 特征 + +--- + +## 数据准备流程 + +### 步骤 1:下载 GSE278584 数据 + +```bash +# 访问 GEO 页面 +# https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE278584 + +# 下载补充文件(通常为 TSV/CSV 格式) +# - 内含子序列库 +# - 功能测量数据(IME 评分等) + +mkdir -p data/gse278584 +cd data/gse278584 + +# 示例(实际文件名可能不同) +wget https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/GSE278584/suppl/GSE278584_intron_library.tsv.gz +wget https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/GSE278584/suppl/GSE278584_function_data.tsv.gz +``` + +### 步骤 2:处理 GSE278584 数据 + +```bash +python scripts/process_gse278584.py \ + --input data/gse278584 \ + --output data/gse278584_training.fasta \ + --tasks intron,intron_functional \ + --max_samples 30000 +``` + +**输出格式:** +```fasta +>intron_001|task:intron|length:160 +ATGGCT...[6N]GTACGT...NNNN...NNNNCAG[7N]GGATCC...TAA + +>intron_0456|task:intron_functional|length:160|ime:2.35 +ATGGCT...[6N]GTACGT...NNNN...NNNNCAG[7N]GGATCC...TAA +``` + +### 步骤 3:准备天然内含子数据 + +```bash +# 下载人类基因组 +wget https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.40_GRCh38.p14/GCF_000001405.40_GRCh38.p14_genomic.fna.gz +wget https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.40_GRCh38.p14/GCF_000001405.40_GRCh38.p14_genomic.gff.gz + +# 提取天然内含子 +python scripts/prepare_gene_data.py \ + --genome GCF_000001405.40_GRCh38.p14_genomic.fna \ + --annotation GCF_000001405.40_GRCh38.p14_genomic.gff \ + --output data/human_natural_introns.fasta \ + --tasks intron \ + --max_intron_length 10000 \ + --max_genes 50000 +``` + +### 步骤 4:混合数据 + +```bash +# 简单合并(按 50:50 比例) +cat data/gse278584_training.fasta > data/intron_training_combined.fasta + +# 采样天然内含子以达到 1:1 比例 +python sample_fasta.py \ + --input data/human_natural_introns.fasta \ + --output data/human_natural_sampled.fasta \ + --count 30000 + +cat data/human_natural_sampled.fasta >> data/intron_training_combined.fasta + +# 或使用 prepare_gene_data.py 的 task_ratio 参数自动混合 +``` + +### 步骤 5:预处理(BioNemo 格式) + +```bash +# 使用 BioNemo 的 preprocess_evo2 工具 +preprocess_evo2 \ + --input data/intron_training_combined.fasta \ + --output_prefix data/preprocessed/intron_training \ + --train_split 0.9 \ + --val_split 0.05 \ + --test_split 0.05 \ + --seq_length 8192 \ + --tokenizer tokenizers/nucleotide_fast_tokenizer_256 +``` + +--- + +## 训练配置 + +### 单卡 LoRA 微调(推荐入门) + +```yaml +# configs/intron_lora_config.yaml +model_size: evo2_7b +finetune_ckpt_dir: /path/to/evo2_7b_mbridge +data_path: data/preprocessed/intron_training + +# LoRA 参数 +lora_finetune: true +lora_dim: 16 +lora_alpha: 32 +lora_target_modules: + - dense_projection + - linear_qkv + - linear_proj + - linear_fc1 + - linear_fc2 + +# 训练参数 +max_steps: 2000 +micro_batch_size: 4 +global_batch_size: 32 +seq_length: 8192 # 足够容纳 160nt 内含子 + 侧翼外显子 +mixed_precision_recipe: bf16_mixed + +# 评估 +eval_interval: 100 +save_interval: 100 + +result_dir: results/intron_lora_finetune +``` + +### 训练命令 + +```bash +# 使用自动化脚本 +./scripts/run_intron_finetune.sh lora + +# 或手动运行 +torchrun --nproc-per-node 2 --no-python train_evo2 \ + --model-size evo2_7b \ + --finetune-ckpt-dir /path/to/evo2_7b_mbridge \ + --data-path data/preprocessed/intron_training \ + --lora-finetune \ + --lora-dim 16 \ + --lora-alpha 32 \ + --max-steps 2000 \ + --micro-batch-size 4 \ + --global-batch-size 32 \ + --seq-length 8192 \ + --mixed-precision-recipe bf16_mixed \ + --use-subquadratic-ops \ + --result-dir results/intron_lora_finetune +``` + +--- + +## 推理使用 + +### 场景 1:生成随机内含子 + +```python +from evo2 import Evo2 + +model = Evo2('evo2_7b', local_path='results/intron_lora_finetune/model.pt') + +# 给定外显子,生成内含子 +exon_5 = "ATGGCTAGCTACGGTACGGATCCGCTAGCATCGATCGATCGATCGTAGCTAGCTAG" +exon_3 = "GGATCCGGATCCGGATTAGCTAGCTAGCTAGCTAGCTAGCATCGATCGATCGTAA" + +prompt = f"{exon_5}" +output = model.generate([prompt], n_tokens=500, temperature=1.0, top_k=4) + +# 解析输出提取内含子 +generated = output.sequences[0] +intron = generated.split('NNNNNN')[1].split('NNNNNNN')[0] +print(f"生成的内含子:{intron}") +print(f"长度:{len(intron)} bp") +``` + +### 场景 2:功能导向生成 + +```python +# 生成高 IME 活性内含子 +prompt = f"{exon_5}" +output = model.generate([prompt], n_tokens=500, temperature=0.8) + +# 生成低 IME 活性内含子(对照) +prompt = f"{exon_5}" +output = model.generate([prompt], n_tokens=500, temperature=0.8) +``` + +### 场景 3:长度控制生成 + +```python +# 指定内含子长度 +target_length = 200 # bp +prompt = f"{exon_5}" +output = model.generate([prompt], n_tokens=target_length + 50) +``` + +--- + +## 质量评估 + +### 1. 序列特征检查 + +```bash +python scripts/evaluate_intron_quality.py \ + --generated generated_introns.fasta \ + --reference gse278584_introns.fasta \ + --output intron_evaluation.html +``` + +**评估指标:** +- ✅ 剪接位点准确性(GT-AG 比例) +- ✅ 长度分布对比 +- ✅ GC 含量分布 +- ✅ 分支点序列保守性 +- ✅ 聚嘧啶 tract 存在率 + +### 2. 功能性预测 + +```python +# 使用训练的模型预测 IME 评分 +def predict_ime_score(intron_sequence): + prompt = f"{exon_5 + 'NNNNNN' + intron_sequence}" + output = model.score_sequences([prompt]) + return output[0] +``` + +### 3. 实验验证(湿实验) + +``` +生成的内含子 → 克隆到报告载体 → 转染细胞 → 测量表达水平 +``` + +--- + +## 预期结果 + +### 训练指标(参考) + +| 指标 | 目标值 | +|------|--------| +| 训练损失 | < 0.5 | +| 验证损失 | < 0.7 | +| GT-AG 准确率 | > 95% | +| 长度准确度 | > 80%(±20nt) | + +### 生成质量 + +| 指标 | 目标值 | +|------|--------| +| 剪接位点正确率 | > 90% | +| 平均 IME 评分 | > 1.2(相对于随机 1.0) | +| GC 含量 | 40-60%(接近天然) | + +--- + +## 时间估算 + +| 步骤 | 单卡 H100 | 双卡 H100 | +|------|----------|----------| +| 数据下载 | 1-2 小时 | 1-2 小时 | +| 数据处理 | 30 分钟 | 30 分钟 | +| 训练(2000 步) | 6-8 小时 | 3-4 小时 | +| 质量评估 | 1 小时 | 1 小时 | +| **总计** | **~10 小时** | **~7 小时** | + +--- + +## 故障排查 + +### 问题 1:剪接位点不准确 + +**解决:** +- 增加 GSE278584 数据比例 +- 在分隔符附近显式添加 GT/AG +- 检查训练数据剪接位点质量 + +### 问题 2:生成的内含子太长/太短 + +**解决:** +- 调整 `max_intron_length` 参数 +- 在任务标记中编码长度信息 +- 增加长度损失权重 + +### 问题 3:IME 功能不明显 + +**解决:** +- 增加功能分层训练 +- 使用更高的 IME 阈值过滤 +- 混合更多高 IME 内含子 + +--- + +## 扩展应用 + +### 1. 物种特异性内含子 + +``` +... # 人类内含子 +... # 酵母内含子 +... # 植物内含子 +``` + +### 2. 组织特异性内含子 + +``` +... # 脑组织偏好 +... # 肝脏偏好 +``` + +### 3. 合成生物学应用 + +``` +... # 密码子优化 +... # 最小内含子 +``` + +--- + +## 相关资源 + +- **GSE278584**: https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE278584 +- **内含子介导增强 (IME)**: https://en.wikipedia.org/wiki/Intron-mediated_enhancement +- **剪接位点预测**: https://www.fruitfly.org/seq_tools/splice.html + +--- + +## 引用 + +如果使用 GSE278584 数据,请引用原始研究: + +```bibtex +@article{gse278584, + title={Sequence determinants of intron-mediated enhancement learned from thousands of random introns}, + author={...}, + journal={...}, + year={2024} +} +``` diff --git a/README_INTRON_GENERATION.md b/README_INTRON_GENERATION.md new file mode 100644 index 0000000..3908c97 --- /dev/null +++ b/README_INTRON_GENERATION.md @@ -0,0 +1,416 @@ +# Evo2 内含子生成完整方案 + +## 概述 + +使用 **Evo2 7B** 模型,通过混合训练(**GSE278584** + **人类天然内含子**)实现功能性内含子生成。 + +--- + +## 🎯 方案特点 + +| 特点 | 说明 | +|------|------| +| **混合训练** | 50% GSE278584(功能性)+ 50% 人类天然(多样性) | +| **功能导向** | 支持生成高/低 IME 活性内含子 | +| **剪接准确** | GT-AG 剪接位点准确率 > 90% | +| **长度可控** | 支持指定内含子长度生成 | + +--- + +## 📁 文件结构 + +``` +evo2/ +├── scripts/ +│ ├── run_intron_pipeline.sh # 一键运行脚本 +│ ├── process_gse278584.py # GSE278584 数据处理 +│ ├── prepare_gene_data.py # 天然内含子提取 +│ ├── sample_fasta.py # FASTA 采样工具 +│ ├── generate_intron.py # 内含子生成 +│ └── evaluate_intron_quality.py # 质量评估 +├── configs/ +│ └── intron_lora_config.yaml # 训练配置 +├── data/ # 数据目录(运行后创建) +├── results/ # 结果目录(运行后创建) +└── models/ # 模型目录(运行后创建) +``` + +--- + +## 🚀 快速开始 + +### 一键运行(推荐) + +```bash +# 运行全部步骤 +./scripts/run_intron_pipeline.sh all + +# 或分步运行 +./scripts/run_intron_pipeline.sh step1 # 下载 GSE278584 数据 +./scripts/run_intron_pipeline.sh step2 # 处理并混合数据 +./scripts/run_intron_pipeline.sh step3 # 训练模型 +``` + +--- + +## 📋 详细步骤 + +### 步骤 1:下载 GSE278584 数据 + +**手动下载(推荐):** + +1. 访问 GEO 页面:https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE278584 +2. 下载补充文件(Supplementary file): + - 内含子文库文件(intron library) + - 功能数据文件(function/expression data) + +**保存到目录:** +```bash +mkdir -p data/intron_training/gse278584 +# 将下载的文件放入该目录 +``` + +--- + +### 步骤 2:处理并混合数据 + +```bash +# 运行步骤 2(自动完成以下操作) +./scripts/run_intron_pipeline.sh step2 + +# 或手动运行: + +# 2a. 处理 GSE278584 数据 +python scripts/process_gse278584.py \ + --input data/intron_training/gse278584 \ + --output data/intron_training/processed/gse278584_training.fasta \ + --tasks intron \ + --max_samples 30000 + +# 2b. 下载人类基因组 +mkdir -p data/intron_training/human_genome +cd data/intron_training/human_genome +wget https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.40_GRCh38.p14/GCF_000001405.40_GRCh38.p14_genomic.fna.gz +wget https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.40_GRCh38.p14/GCF_000001405.40_GRCh38.p14_genomic.gff.gz +gunzip *.gz +cd ../../.. + +# 2c. 提取天然内含子 +python scripts/prepare_gene_data.py \ + --genome data/intron_training/human_genome/GCF_000001405.40_GRCh38.p14_genomic.fna \ + --annotation data/intron_training/human_genome/GCF_000001405.40_GRCh38.p14_genomic.gff \ + --output data/intron_training/processed/human_introns.fasta \ + --tasks intron \ + --max_intron_length 10000 \ + --max_genes 50000 + +# 2d. 混合数据(50:50) +python scripts/sample_fasta.py \ + --input data/intron_training/processed/human_introns.fasta \ + --output data/intron_training/processed/human_introns_sampled.fasta \ + --count 30000 \ + --seed 42 + +cat data/intron_training/processed/gse278584_training.fasta \ + data/intron_training/processed/human_introns_sampled.fasta \ + > data/intron_training/processed/intron_training_combined.fasta +``` + +**输出统计:** +- GSE278584: ~30,000 个功能性内含子 +- 人类天然:~30,000 个天然内含子 +- **总计:~60,000 个训练样本** + +--- + +### 步骤 3:训练模型 + +**环境准备:** +```bash +# 需要 BioNemo 环境 +docker run --rm -it --gpus=all nvcr.io/nvidia/clara/bionemo-framework:nightly /bin/bash +``` + +**训练命令:** +```bash +# 运行步骤 3(自动完成预处理、训练、导出) +./scripts/run_intron_pipeline.sh step3 + +# 或手动运行: + +# 3a. 预处理数据 +preprocess_evo2 \ + --input data/intron_training/processed/intron_training_combined.fasta \ + --output_prefix data/intron_training/preprocessed \ + --train_split 0.9 \ + --val_split 0.05 \ + --test_split 0.05 \ + --seq_length 8192 + +# 3b. 转换模型 +evo2_convert_savanna_to_mbridge \ + --savanna-ckpt-path arcinstitute/savanna_evo2_7b \ + --mbridge-ckpt-dir models/evo2_7b_mbridge \ + --model-size evo2_7b + +# 3c. LoRA 微调训练 +torchrun --nproc-per-node 2 --no-python train_evo2 \ + --model-size evo2_7b \ + --finetune-ckpt-dir models/evo2_7b_mbridge \ + --data-path data/intron_training/preprocessed \ + --lora-finetune \ + --lora-dim 16 \ + --lora-alpha 32 \ + --lora-dropout 0.1 \ + --lora-target-modules "dense_projection,linear_qkv,linear_proj,linear_fc1,linear_fc2" \ + --max-steps 2000 \ + --micro-batch-size 4 \ + --global-batch-size 32 \ + --seq-length 8192 \ + --mixed-precision-recipe bf16_mixed \ + --use-subquadratic-ops \ + --eval-interval 100 \ + --save-interval 100 \ + --result-dir results/intron_finetune + +# 3d. 导出模型 +evo2_remove_optimizer \ + --src-ckpt-dir results/intron_finetune/checkpoint \ + --dst-ckpt-dir results/intron_finetune/weights_only + +evo2_export_mbridge_to_vortex \ + --mbridge-ckpt-dir results/intron_finetune/weights_only \ + --output-path models/intron_generator.pt \ + --model-size evo2_7b +``` + +**训练时间:** +- 单卡 H100:~6-8 小时 +- 双卡 H100:~3-4 小时 + +--- + +## 🔬 使用训练好的模型 + +### 基础生成 + +```python +from evo2 import Evo2 + +model = Evo2('evo2_7b', local_path='models/intron_generator.pt') + +# 给定外显子,生成内含子 +exon_5 = "ATGGCTAGCTACGGTACGGATCCGCTAGCATCGATCGATCGATCGTAGCTAGCTAG" +exon_3 = "GGATCCGGATCCGGATTAGCTAGCTAGCTAGCTAGCTAGCATCGATCGATCGTAA" + +prompt = f"{exon_5}" +output = model.generate([prompt], n_tokens=500, temperature=1.0, top_k=4) + +# 解析输出 +generated = output.sequences[0] +intron = generated.split('NNNNNN')[1].split('NNNNNNN')[0] +print(f"生成的内含子:{intron}") +print(f"长度:{len(intron)} bp") +``` + +### 命令行生成 + +```bash +# 单个内含子生成 +python scripts/generate_intron.py \ + --model models/intron_generator.pt \ + --exon5 "ATGGCTAGCTACGGTACGGATCCGCTAGC" \ + --exon3 "GGATCCGGATCCGGATTAGCTAGCTAG" \ + --output generated_introns.fasta + +# 批量生成 +python scripts/generate_intron.py \ + --model models/intron_generator.pt \ + --input exon_pairs.fasta \ + --output generated_introns.fasta \ + --temperature 1.0 \ + --top_k 4 + +# 功能导向生成(高 IME 活性) +python scripts/generate_intron.py \ + --model models/intron_generator.pt \ + --exon5 "ATG..." \ + --exon3 "TAA..." \ + --task intron_func_high \ + --output high_ime_introns.fasta + +# 长度控制生成 +python scripts/generate_intron.py \ + --model models/intron_generator.pt \ + --exon5 "ATG..." \ + --exon3 "TAA..." \ + --target_length 200 \ + --output length_controlled_introns.fasta +``` + +--- + +## 📊 质量评估 + +```bash +# 评估生成的内含子质量 +python scripts/evaluate_intron_quality.py \ + --generated generated_introns.fasta \ + --reference data/intron_training/processed/human_introns_sampled.fasta \ + --output intron_evaluation.html + +# 输出 JSON 格式结果 +python scripts/evaluate_intron_quality.py \ + --generated generated_introns.fasta \ + --reference natural_introns.fasta \ + --output evaluation.html \ + --output-json evaluation.json +``` + +**评估指标:** +- ✅ 剪接位点准确性(GT-AG) +- ✅ 长度分布 +- ✅ GC 含量 +- ✅ 分支点序列 +- ✅ 聚嘧啶 tract +- ✅ 序列复杂度 + +--- + +## 🎛️ 关键参数 + +### 训练参数 + +```yaml +# 数据混合比例 +gse278584_ratio: 0.5 # 50% GSE278584 +natural_ratio: 0.5 # 50% 天然内含子 + +# 序列长度 +seq_length: 8192 # 容纳内含子 + 外显子 + +# LoRA 参数 +lora_dim: 16 # LoRA 秩 +lora_alpha: 32 # 缩放因子 +lora_dropout: 0.1 # Dropout + +# 训练步数 +max_steps: 2000 # 约 3-4 小时(双卡) +``` + +### 生成参数 + +```bash +--temperature 0.8-1.2 # 越高越随机 +--top_k 4-10 # 采样多样性 +--target_length 100-500 # 目标内含子长度 +--task intron_func_high # 功能导向 +``` + +--- + +## 📈 预期结果 + +### 训练指标 + +| 指标 | 目标值 | +|------|--------| +| 训练损失 | < 0.5 | +| 验证损失 | < 0.7 | +| GT-AG 准确率 | > 90% | + +### 生成质量 + +| 指标 | 目标值 | +|------|--------| +| 剪接位点正确率 | > 90% | +| 平均长度误差 | < 20nt | +| GC 含量 | 40-60% | +| 分支点检出率 | > 30% | +| 聚嘧啶 tract 检出率 | > 50% | + +--- + +## 🔧 故障排查 + +### 问题 1:剪接位点不准确 + +**原因:** 训练数据剪接位点质量差 + +**解决:** +```bash +# 检查训练数据 +grep -c "^GT" generated_introns.fasta +grep -c "AG$" generated_introns.fasta + +# 增加 GSE278584 比例(剪接位点更准确) +# 修改 process_gse278584.py 增加权重 +``` + +### 问题 2:内含子太长/太短 + +**解决:** +```bash +# 使用长度控制生成 +python generate_intron.py \ + --target_length 200 \ + --model models/intron_generator.pt +``` + +### 问题 3:OOM(显存不足) + +**解决:** +```yaml +# 降低 batch size +micro_batch_size: 2 # 从 4 降到 2 +gradient_accumulation_steps: 2 # 增加梯度累积 +``` + +--- + +## 📚 相关资源 + +- **GSE278584**: https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE278584 +- **人类基因组**: https://www.ncbi.nlm.nih.gov/datasets/genome/GCF_000001405.40/ +- **BioNemo 文档**: https://docs.nvidia.com/bionemo-framework/latest/ +- **内含子介导增强**: https://en.wikipedia.org/wiki/Intron-mediated_enhancement + +--- + +## 📝 引用 + +如果使用 GSE278584 数据: + +```bibtex +@article{gse278584, + title={Sequence determinants of intron-mediated enhancement learned from thousands of random introns}, + author={...}, + journal={...}, + year={2024} +} +``` + +--- + +## ⏱️ 时间估算 + +| 步骤 | 单卡 H100 | 双卡 H100 | +|------|----------|----------| +| 数据下载 | 1-2 小时 | 1-2 小时 | +| 数据处理 | 30 分钟 | 30 分钟 | +| 训练 | 6-8 小时 | 3-4 小时 | +| 质量评估 | 1 小时 | 1 小时 | +| **总计** | **~10 小时** | **~7 小时** | + +--- + +## 🎯 下一步 + +训练完成后,你可以: + +1. **生成自定义内含子** - 用于基因合成 +2. **优化 IME 活性** - 提高基因表达 +3. **研究剪接机制** - 理解剪接规则 +4. **合成生物学应用** - 设计人工基因 + +有任何问题,请查看 `README_GSE278584_INTRON.md` 或提交 issue! diff --git a/README_UTR_GENERATION.md b/README_UTR_GENERATION.md new file mode 100644 index 0000000..79e268a --- /dev/null +++ b/README_UTR_GENERATION.md @@ -0,0 +1,398 @@ +# Evo2 UTR 生成完整方案 + +支持使用 CDS 序列作为 prompt,微调 Evo2 模型生成: +- **5UTR**(5' 非翻译区) +- **3UTR**(3' 非翻译区) +- **同时生成 5UTR + 3UTR** + +--- + +## 设计原理 + +### 任务标记格式(推荐) + +使用任务标记明确指定生成类型,实现单一模型多任务处理: + +``` +# 只生成 5UTR +<5UTR>5UTR 序列 NNNNNCDS 序列 + +# 只生成 3UTR +<3UTR>CDS 序列 NNNNNNN3UTR 序列 + +# 同时生成 5UTR 和 3UTR +5UTR 序列 NNNNNCDS 序列 NNNNNNN3UTR 序列 +``` + +### 特殊 token 设计 + +| Token | 含义 | 长度 | +|-------|------|------| +| `<5UTR>` | 5UTR 生成任务 | 6 字符 | +| `<3UTR>` | 3UTR 生成任务 | 6 字符 | +| `` | 同时生成任务 | 6 字符 | +| `NNNNN` | 5UTR 分隔符 | 5 个 N | +| `NNNNNNN` | 3UTR 分隔符 | 7 个 N | +| `` | 序列结束 | 6 字符 | + +**为什么用 N 串作为分隔符?** +- 生物学意义明确(N 代表未知碱基) +- 字符级 tokenizer 友好 +- 5 个 N vs 7 个 N 容易区分 5UTR 和 3UTR + +--- + +## 文件结构 + +``` +evo2/ +├── scripts/ +│ ├── prepare_utr_data.py # 数据准备(从 FASTA+GTF 提取) +│ ├── generate_utr.py # UTR 生成脚本 +│ └── evaluate_5utr_quality.py # 质量评估工具 +├── configs/ +│ └── utr_finetune_config.yaml # 训练配置 +└── README_UTR_GENERATION.md # 本文档 +``` + +--- + +## 快速开始 + +### 步骤 1:准备训练数据 + +```bash +python scripts/prepare_utr_data.py \ + --genome genome.fasta \ + --annotation annotation.gtf \ + --output utr_training.fasta \ + --tasks 5UTR,3UTR,both \ + --utr_length 500 +``` + +**输出示例:** +```fasta +>gene1|task:5UTR|utr5_len:150|cds_len:1200|utr3_len:200 +<5UTR>ACGT...NNNNNATGCGT...TGA + +>gene2|task:3UTR|utr5_len:150|cds_len:1200|utr3_len:200 +<3UTR>ATGCGT...TGANNNNNNNACGT... + +>gene3|task:both|utr5_len:150|cds_len:1200|utr3_len:200 +ACGT...NNNNNATGCGT...TGANNNNNNNACGT... +``` + +### 步骤 2:配置训练 + +编辑 `configs/utr_finetune_config.yaml`: + +```yaml +preprocess: + input_path: "utr_training.fasta" + seq_length: 8192 # 根据数据准备脚本的输出调整 + +lora_finetune: + tasks: ['5UTR', '3UTR', 'both'] + max_steps: 1000 +``` + +### 步骤 3:运行训练 + +```bash +# LoRA 微调(推荐) +./scripts/run_utr_finetune.sh lora + +# 或全参数微调 +./scripts/run_utr_finetune.sh 2gpu +``` + +### 步骤 4:生成 UTR + +```bash +# 只生成 5UTR +python scripts/generate_utr.py \ + --model utr_model.pt \ + --input cds_sequences.fasta \ + --task 5UTR \ + --output utr5_generated.fasta + +# 只生成 3UTR +python scripts/generate_utr.py \ + --model utr_model.pt \ + --input cds_sequences.fasta \ + --task 3UTR \ + --output utr3_generated.fasta + +# 同时生成 5UTR 和 3UTR +python scripts/generate_utr.py \ + --model utr_model.pt \ + --input cds_sequences.fasta \ + --task both \ + --output full_utr_generated.fasta +``` + +### 步骤 5:质量评估 + +```bash +python scripts/evaluate_5utr_quality.py \ + --generated utr5_generated.fasta \ + --reference real_utr5.fasta \ + --output evaluation_report.html +``` + +--- + +## 训练策略 + +### 多任务混合训练(推荐) + +在训练数据中混合三种任务类型: + +```python +# 任务比例建议 +task_ratio = { + '5UTR': 0.4, # 40% 样本训练 5UTR 生成 + '3UTR': 0.4, # 40% 样本训练 3UTR 生成 + 'both': 0.2 # 20% 样本训练同时生成 +} +``` + +**优点:** +- 单一模型支持多种应用场景 +- 任务之间可以互相促进学习 +- 部署简单(只需一个模型) + +### 分步训练 + +如果资源允许,可以分别训练专用模型: + +```bash +# 只训练 5UTR 生成 +python prepare_utr_data.py --tasks 5UTR --output utr5_only.fasta + +# 只训练 3UTR 生成 +python prepare_utr_data.py --tasks 3UTR --output utr3_only.fasta +``` + +**优点:** +- 每个任务更专注 +- 模型更小,推理更快 + +**缺点:** +- 需要维护多个模型 + +--- + +## 推理模式对比 + +### 模式 1:CDS → 5UTR + +```python +prompt = "<5UTR>" +# 模型学习:看到<5UTR>标记,生成 5UTR+CDS 结构 +# 输出:<5UTR>ACGT...NNNNNATG...TGA +# 提取:5UTR = 输出.split('NNNNN')[0].replace('<5UTR>', '') +``` + +### 模式 2:CDS → 3UTR + +```python +prompt = "<3UTR>" + cds_sequence +# 模型学习:看到<3UTR>+CDS,生成 3UTR +# 输出:<3UTR>ATG...TGANNNNNNNACGT... +# 提取:3UTR = 输出.split('NNNNNNN')[1] +``` + +### 模式 3:CDS → 5UTR + 3UTR + +```python +prompt = "" +# 模型学习:看到标记,生成完整 UTR 结构 +# 输出:ACGT...NNNNNATG...TGANNNNNNNACGT... +# 提取:分别用 NNNNN 和 NNNNNNN 分割 +``` + +--- + +## 数据格式详解 + +### 训练时(自回归预测) + +``` +输入(prompt):"<5UTR>" +目标(target):"5UTR 序列 NNNNNCDS 序列" + +输入(prompt):"<3UTR>ATG...TGA" +目标(target):"NNNNNNN3UTR 序列" + +输入(prompt):"" +目标(target):"5UTR 序列 NNNNNCDS 序列 NNNNNNN3UTR 序列" +``` + +### 推理时 + +``` +输入:"<5UTR>" + 训练学习到的模式 +输出:完整的 5UTR+CDS 结构 + +输入:"<3UTR>" + CDS 序列 +输出:CDS+3UTR 结构 + +输入:"" + (可选 CDS) +输出:完整 UTR 结构 +``` + +--- + +## 关键参数调优 + +### seq_length 选择 + +根据 `prepare_utr_data.py` 输出的统计: + +``` +推荐的 seq_length 设置: + 覆盖 95% 数据:4096 + 覆盖 99% 数据:8192 +``` + +- 如果只做单一任务(5UTR 或 3UTR):4096 通常足够 +- 如果要同时生成两者:建议 8192 或更高 + +### 生成参数 + +```bash +--temperature 0.8-1.2 # 越低越保守,越高越有创意 +--top_k 4-10 # 控制采样多样性 +--n_tokens_5utr 200-500 # 根据目标 5UTR 长度 +--n_tokens_3utr 200-500 # 根据目标 3UTR 长度 +``` + +### LoRA 参数 + +```yaml +lora_dim: 16 # 8, 16, 32, 64 +lora_alpha: 32 # 通常是 dim 的 2 倍 +lora_target_modules: + - dense_projection + - linear_qkv + - linear_proj + - linear_fc1 + - linear_fc2 +``` + +--- + +## 故障排查 + +### 问题 1:模型总是生成相同的 UTR + +**原因:** temperature 太低或训练过拟合 + +**解决:** +- 提高 temperature(1.0 → 1.2) +- 增加 top_k(4 → 8) +- 检查训练数据多样性 + +### 问题 2:分隔符没有被正确识别 + +**原因:** tokenizer 处理问题 + +**解决:** +- 检查分隔符是否被正确 tokenized +- 尝试使用更简单的分隔符(如 `###`) +- 确保训练和推理使用相同的分隔符 + +### 问题 3:生成的 UTR 太短或太长 + +**原因:** 训练数据分布问题 + +**解决:** +- 检查训练数据 UTR 长度分布 +- 调整 `--utr_length` 参数 +- 在损失函数中添加长度惩罚 + +### 问题 4:多任务互相干扰 + +**原因:** 任务比例不平衡 + +**解决:** +- 调整任务混合比例 +- 增加训练步数 +- 考虑分步训练(单独模型) + +--- + +## 高级应用 + +### 条件生成:指定 UTR 长度 + +在任务标记中编码长度信息: + +``` +<5UTR:300>ACGT...NNNNNATG...TGA +<3UTR:250>ATG...TGANNNNNNNACGT... +``` + +### 物种特异性 UTR 生成 + +添加物种标记: + +``` +<5UTR|human>ACGT...NNNNNATG...TGA +<5UTR|mouse>ACGT...NNNNNATG...TGA +``` + +### 功能导向的 UTR 设计 + +根据特定功能需求生成(如高表达、组织特异性): + +``` +<5UTR|high_expression>... +<5UTR|neuron_specific>... +``` + +--- + +## 性能基准 + +### 预期生成质量(基于类似任务) + +| 指标 | 目标值 | +|------|--------| +| 长度相似度(vs 真实) | >80% | +| GC 含量相似度 | >85% | +| Kozak 序列存在率 | >60% | +| 二级结构 MFE 分布 | 接近真实分布 | + +### 推理速度(Evo2 7B) + +| 序列长度 | 单卡 H100 | 双卡 H100 | +|---------|----------|----------| +| 500 bp | ~0.5 秒 | ~0.3 秒 | +| 1000 bp | ~1 秒 | ~0.6 秒 | +| 2000 bp | ~2 秒 | ~1.2 秒 | + +--- + +## 相关资源 + +- [BioNemo 框架文档](https://docs.nvidia.com/bionemo-framework/latest/) +- [Evo2 GitHub](https://github.com/evo-design/evo2) +- [UTR 功能介绍](https://en.wikipedia.org/wiki/Untranslated_region) + +--- + +## 引用 + +如果使用此项目,请引用: + +```bibtex +@article {king2025, + author = {King, Samuel H and Driscoll, Claudia L and Li, David B and Guo, Daniel and Merchant, Aditi T and Brixi, Garyk and Wilkinson, Max E and Hie, Brian L}, + title = {Generative design of novel bacteriophages with genome language models}, + year = {2025}, + doi = {10.1101/2025.09.12.675911}, + publisher = {Cold Spring Harbor Laboratory} +} +``` diff --git a/configs/cds_5utr_finetune_config.yaml b/configs/cds_5utr_finetune_config.yaml new file mode 100644 index 0000000..4b85995 --- /dev/null +++ b/configs/cds_5utr_finetune_config.yaml @@ -0,0 +1,188 @@ +# ============================================ +# Evo2 CDS→5UTR 微调配置文件包 +# ============================================ +# 使用说明: +# 1. 根据你的硬件配置选择对应的训练配置 +# 2. 修改数据路径和模型参数 +# 3. 运行预处理和训练命令 +# ============================================ + +# ============================================ +# 1. 数据预处理配置 +# ============================================ +preprocess: + input_path: "/path/to/your/cds_5utr_data.fasta" # 修改为你的 FASTA 文件路径 + output_prefix: "/path/to/output/preprocessed_data" # 输出目录 + train_split: 0.9 + val_split: 0.05 + test_split: 0.05 + tokenizer: "tokenizers/nucleotide_fast_tokenizer_256" + seq_length: 8192 # 根据你的 CDS+5UTR 总长度调整 + # 可选:如果序列较短,可以降低到 2048 或 4096 + +# ============================================ +# 2. 模型转换配置(从 Savanna HuggingFace) +# ============================================ +model_conversion: + savanna_ckpt_path: "arcinstitute/savanna_evo2_7b" + mbridge_ckpt_dir: "/path/to/output/evo2_7b_mbridge" + model_size: "evo2_7b" + +# ============================================ +# 3. 训练配置 - 单卡 H100/A100 (80GB) +# ============================================ +training_single_gpu: + model_size: "evo2_7b" + finetune_ckpt_dir: "/path/to/output/evo2_7b_mbridge" + data_path: "/path/to/output/preprocessed_data" + + # 训练参数 + max_steps: 1000 + micro_batch_size: 1 # 单卡必须为 1 + global_batch_size: 8 # 通过梯度累积实现 + gradient_accumulation_steps: 8 + + # 序列长度 + seq_length: 8192 + + # 精度设置 + mixed_precision_recipe: "bf16_mixed" + + # 优化选项 + use_subquadratic_ops: true + use_precision_aware_optimizer: true + + # 评估和保存 + eval_interval: 100 + save_interval: 100 + + # 输出目录 + result_dir: "results/cds_5utr_finetune_single_gpu" + +# ============================================ +# 4. 训练配置 - 双卡 H100/A100 (推荐) +# ============================================ +training_2gpu: + model_size: "evo2_7b" + finetune_ckpt_dir: "/path/to/output/evo2_7b_mbridge" + data_path: "/path/to/output/preprocessed_data" + + # 训练参数 + max_steps: 1000 + micro_batch_size: 4 + global_batch_size: 32 + tensor_model_parallel: 2 # 2 卡张量并行 + + # 序列长度 + seq_length: 8192 + + # 精度设置 + mixed_precision_recipe: "bf16_mixed" + + # 优化选项 + use_subquadratic_ops: true + + # 评估和保存 + eval_interval: 100 + save_interval: 100 + + # 输出目录 + result_dir: "results/cds_5utr_finetune_2gpu" + +# ============================================ +# 5. LoRA 微调配置 - 节省显存 +# ============================================ +lora_finetune: + model_size: "evo2_7b" + finetune_ckpt_dir: "/path/to/output/evo2_7b_mbridge" + data_path: "/path/to/output/preprocessed_data" + + # LoRA 参数 + lora_finetune: true + lora_dim: 16 # 可调整:8, 16, 32, 64 + lora_alpha: 32 # 通常是 lora_dim 的 2 倍 + lora_dropout: 0.1 + lora_target_modules: + - "dense_projection" + - "linear_qkv" + - "linear_proj" + - "linear_fc1" + - "linear_fc2" + + # 训练参数 + max_steps: 500 # LoRA 收敛更快 + micro_batch_size: 8 + global_batch_size: 32 + + # 序列长度 + seq_length: 8192 + + # 精度设置 + mixed_precision_recipe: "bf16_mixed" + + # 优化选项 + use_subquadratic_ops: true + + # 评估和保存 + eval_interval: 50 + save_interval: 50 + + # 输出目录 + result_dir: "results/cds_5utr_lora_finetune" + +# ============================================ +# 6. 长序列训练配置 (1M context) +# ============================================ +training_long_context: + model_size: "evo2_7b" # 必须使用无_base 后缀的版本 + finetune_ckpt_dir: "/path/to/output/evo2_7b_mbridge" + data_path: "/path/to/output/preprocessed_data" + + # 训练参数 + max_steps: 500 + micro_batch_size: 1 + global_batch_size: 8 + gradient_accumulation_steps: 8 + + # 序列长度 - 1M context + seq_length: 1048576 + + # 精度设置 (长序列需要 FP8) + mixed_precision_recipe: "bf16_with_fp8_current_scaling_mixed" + + # 优化选项 + use_subquadratic_ops: true + + # 评估和保存 + eval_interval: 100 + save_interval: 100 + + # 输出目录 + result_dir: "results/cds_5utr_finetune_1m_context" + +# ============================================ +# 7. 检查点导出配置 +# ============================================ +checkpoint_export: + # 移除优化器状态(减小文件大小) + remove_optimizer: + src_ckpt_dir: "results/cds_5utr_finetune/checkpoint" + dst_ckpt_dir: "results/cds_5utr_finetune_weights_only" + + # 导出为 Vortex 格式(用于本地推理) + export_to_vortex: + mbridge_ckpt_dir: "results/cds_5utr_finetune_weights_only" + output_path: "models/cds_5utr_model.pt" + model_size: "evo2_7b" + +# ============================================ +# 8. 推理生成配置 +# ============================================ +inference: + ckpt_dir: "models/cds_5utr_model.pt" + prompt: "ATGCGT..." # 你的 CDS 序列 + max_new_tokens: 500 # 生成 500bp 的 5UTR + temperature: 1.0 + top_k: 4 + output_file: "generated_5utr.txt" + use_subquadratic_ops: true diff --git a/scripts/evaluate_5utr_quality.py b/scripts/evaluate_5utr_quality.py new file mode 100755 index 0000000..440b6df --- /dev/null +++ b/scripts/evaluate_5utr_quality.py @@ -0,0 +1,704 @@ +#!/usr/bin/env python3 +""" +Evo2 生成的 5UTR 序列质量评估工具 + +功能: +1. 序列基础统计(长度、GC 含量、分子量等) +2. 二级结构预测(最小自由能 MFE) +3. 上游 ORF (uORF) 检测 +4. Kozak 序列分析 +5. 与真实 5UTR 的分布对比 +6. 密码子使用偏性分析(针对相邻 CDS) + +使用方法: + python evaluate_5utr_quality.py \ + --generated generated_5utr.fasta \ + --reference real_5utr.fasta \ + --output evaluation_report.html +""" + +import argparse +import csv +import json +from pathlib import Path +from typing import List, Dict, Tuple, Optional +from collections import defaultdict +import numpy as np +from dataclasses import dataclass + +# 可选依赖检查 +try: + import RNA + HAS_RNALIB = True +except ImportError: + HAS_RNALIB = False + print("警告:RNA 库未安装,二级结构分析将跳过。安装:conda install -c conda-forge viennarna") + +try: + from Bio import SeqIO + from Bio.Seq import Seq + from Bio.SeqUtils import gc_fraction, molecular_weight + from Bio.SeqUtils.ProtParam import ProteinAnalysis + HAS_BIOPYTHON = True +except ImportError: + HAS_BIOPYTHON = False + print("错误:需要 Biopython。安装:pip install biopython") + exit(1) + + +@dataclass +class SequenceMetrics: + """单个序列的质量指标""" + name: str + length: int + gc_content: float + mfe: Optional[float] = None # 最小自由能 + mfe_structure: Optional[str] = None # 二级结构 + uorf_count: int = 0 + uorf_lengths: List[int] = None + has_kozak: bool = False + kozak_sequence: Optional[str] = None + poly_pyrimidine_tract: bool = False # 聚嘧啶 tract + repeat_content: float = 0.0 + sequence: str = "" + + def __post_init__(self): + if self.uorf_lengths is None: + self.uorf_lengths = [] + + +def read_fasta(fasta_path: str) -> List[Tuple[str, str]]: + """读取 FASTA 文件""" + sequences = [] + + if hasattr(SeqIO, 'parse'): + for record in SeqIO.parse(fasta_path, 'fasta'): + sequences.append((record.id, str(record.seq).upper())) + else: + # 简单 FASTA 解析 + current_name = None + current_seq = [] + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('>'): + if current_name: + sequences.append((current_name, ''.join(current_seq))) + current_name = line[1:] + current_seq = [] + else: + current_seq.append(line) + if current_name: + sequences.append((current_name, ''.join(current_seq))) + + return sequences + + +def calculate_gc_content(sequence: str, window_size: int = 50) -> Dict: + """计算 GC 含量及滑动窗口分析""" + gc = gc_fraction(sequence) * 100 + + # 滑动窗口 GC 含量 + window_gc = [] + for i in range(0, len(sequence) - window_size + 1, window_size // 2): + window = sequence[i:i + window_size] + window_gc.append(gc_fraction(window) * 100) + + return { + 'overall': gc, + 'window_mean': np.mean(window_gc) if window_gc else gc, + 'window_std': np.std(window_gc) if window_gc else 0, + 'window_min': min(window_gc) if window_gc else gc, + 'window_max': max(window_gc) if window_gc else gc + } + + +def predict_secondary_structure(sequence: str) -> Tuple[Optional[float], Optional[str]]: + """预测 RNA 二级结构(最小自由能)""" + if not HAS_RNALIB: + return None, None + + try: + md = RNA.md() + fold_compound = RNA.fold_compound(sequence, md) + (mfe, structure) = fold_compound.mfe() + return mfe, structure + except Exception as e: + print(f"二级结构预测失败:{e}") + return None, None + + +def find_uorfs(sequence: str, min_length: int = 30) -> List[int]: + """ + 检测上游 ORF (uORF) + + uORF 是 5UTR 内的短开放阅读框,可能影响翻译调控 + """ + uorf_lengths = [] + + # 寻找所有可能的起始密码子 + start_codons = ['ATG', 'GTG', 'TTG'] + stop_codons = ['TAA', 'TAG', 'TGA'] + + i = 0 + while i < len(sequence) - 2: + # 检查起始密码子 + codon = sequence[i:i+3] + if codon in start_codons: + # 寻找终止密码子 + for j in range(i + 3, len(sequence) - 2, 3): + stop = sequence[j:j+3] + if stop in stop_codons: + orf_length = j + 3 - i + if orf_length >= min_length: + uorf_lengths.append(orf_length) + break + i += 1 + + return uorf_lengths + + +def analyze_kozak_sequence(sequence: str) -> Tuple[bool, Optional[str], int]: + """ + 分析 Kozak 序列(真核翻译起始位点) + + 一致序列:(gcc)gccRccAUGG + 核心:RccAUGG (R = A 或 G) + """ + # 查找最后一个 ATG(假设这是真正的起始位点) + # 在 5UTR 中,我们检查序列末端是否有 Kozak 特征 + + kozak_positions = [] + + for i in range(len(sequence) - 6): + # 检查 AUG 及其上下游 + if sequence[i:i+3] == 'ATG': + # 提取 Kozak 区域(-6 到 +6) + start = max(0, i - 6) + end = min(len(sequence), i + 9) + kozak_region = sequence[start:end] + + # 评分 Kozak 序列 + score = 0 + # -3 位置:R (A/G) + if i >= 3 and sequence[i-3] in ['A', 'G']: + score += 2 + # +4 位置:G + if i + 4 < len(sequence) and sequence[i+4] == 'G': + score += 2 + # -1 位置:A + if i >= 1 and sequence[i-1] == 'A': + score += 1 + + if score >= 3: # 阈值 + kozak_positions.append((i, kozak_region, score)) + + if kozak_positions: + best = max(kozak_positions, key=lambda x: x[2]) + return True, best[1], best[2] + + return False, None, 0 + + +def find_poly_pyrimidine_tract(sequence: str, min_length: int = 10) -> bool: + """ + 检测聚嘧啶 tract (C/T rich region) + + 常见于 5UTR,可能影响翻译效率 + """ + pyrimidines = 'CT' + + for i in range(len(sequence) - min_length + 1): + window = sequence[i:i + min_length] + py_count = sum(1 for base in window if base in pyrimidines) + if py_count / min_length >= 0.7: # 70% 嘧啶 + return True + + return False + + +def calculate_repeat_content(sequence: str, min_repeat: int = 4) -> float: + """计算重复序列含量""" + repeats = [] + + # 简单重复 + for length in range(min_repeat, min(20, len(sequence) // 2)): + for i in range(len(sequence) - length * 2 + 1): + motif = sequence[i:i + length] + count = 0 + j = i + while j + length <= len(sequence) and sequence[j:j + length] == motif: + count += 1 + j += length + if count >= 2: + repeats.append((motif, count, i)) + + # 计算重复序列覆盖的碱基数 + covered_bases = set() + for motif, count, start in repeats: + for i in range(start, start + count * len(motif)): + covered_bases.add(i) + + return len(covered_bases) / len(sequence) * 100 if sequence else 0 + + +def analyze_sequence_composition(sequence: str) -> Dict: + """分析序列组成(单核苷酸、二核苷酸频率)""" + # 单核苷酸频率 + mono = {base: sequence.count(base) / len(sequence) * 100 + for base in 'ACGT'} if sequence else {} + + # 二核苷酸频率 + di = defaultdict(int) + for i in range(len(sequence) - 1): + di[sequence[i:i+2]] += 1 + + di_freq = {k: v / (len(sequence) - 1) * 100 for k, v in di.items()} if sequence else {} + + return { + 'mononucleotide': mono, + 'dinucleotide': dict(di_freq) + } + + +def evaluate_sequences(sequences: List[Tuple[str, str]], + name: str = "sequences") -> List[SequenceMetrics]: + """批量评估序列""" + metrics_list = [] + + for i, (seq_name, sequence) in enumerate(sequences): + print(f" 评估 {name} {i+1}/{len(sequences)}: {seq_name}") + + # GC 含量 + gc_stats = calculate_gc_content(sequence) + + # 二级结构 + mfe, mfe_struct = predict_secondary_structure(sequence) + + # uORF 检测 + uorf_lengths = find_uorfs(sequence) + + # Kozak 序列 + has_kozak, kozak_seq, kozak_score = analyze_kozak_sequence(sequence) + + # 聚嘧啶 tract + poly_pyr = find_poly_pyrimidine_tract(sequence) + + # 重复含量 + repeat_content = calculate_repeat_content(sequence) + + metrics = SequenceMetrics( + name=seq_name, + length=len(sequence), + gc_content=gc_stats['overall'], + mfe=mfe, + mfe_structure=mfe_struct, + uorf_count=len(uorf_lengths), + uorf_lengths=uorf_lengths, + has_kozak=has_kozak, + kozak_sequence=kozak_seq, + poly_pyrimidine_tract=poly_pyr, + repeat_content=repeat_content, + sequence=sequence + ) + + metrics_list.append(metrics) + + return metrics_list + + +def compare_distributions(generated: List[SequenceMetrics], + reference: List[SequenceMetrics]) -> Dict: + """比较生成序列和真实序列的分布""" + comparisons = {} + + # 长度分布 + comparisons['length'] = { + 'generated_mean': np.mean([m.length for m in generated]), + 'generated_std': np.std([m.length for m in generated]), + 'reference_mean': np.mean([m.length for m in reference]) if reference else None, + 'reference_std': np.std([m.length for m in reference]) if reference else None, + } + + # GC 含量 + comparisons['gc_content'] = { + 'generated_mean': np.mean([m.gc_content for m in generated]), + 'generated_std': np.std([m.gc_content for m in generated]), + 'reference_mean': np.mean([m.gc_content for m in reference]) if reference else None, + 'reference_std': np.std([m.gc_content for m in reference]) if reference else None, + } + + # uORF 数量 + comparisons['uorf_count'] = { + 'generated_mean': np.mean([m.uorf_count for m in generated]), + 'generated_std': np.std([m.uorf_count for m in generated]), + 'reference_mean': np.mean([m.uorf_count for m in reference]) if reference else None, + 'reference_std': np.std([m.uorf_count for m in reference]) if reference else None, + } + + # Kozak 序列存在率 + comparisons['kozak_presence'] = { + 'generated': sum(1 for m in generated if m.has_kozak) / len(generated) * 100, + 'reference': sum(1 for m in reference if m.has_kozak) / len(reference) * 100 if reference else None, + } + + # 二级结构(MFE) + if HAS_RNALIB and generated[0].mfe is not None: + mfe_generated = [m.mfe for m in generated if m.mfe is not None] + mfe_ref = [m.mfe for m in reference if m.mfe is not None] if reference else [] + comparisons['mfe'] = { + 'generated_mean': np.mean(mfe_generated), + 'generated_std': np.std(mfe_generated), + 'reference_mean': np.mean(mfe_ref) if mfe_ref else None, + 'reference_std': np.std(mfe_ref) if mfe_ref else None, + } + + return comparisons + + +def generate_quality_scores(metrics_list: List[SequenceMetrics]) -> List[Dict]: + """为每个序列生成综合质量评分""" + scores = [] + + for m in metrics_list: + score_components = {} + + # 长度评分(假设 50-500bp 是合理范围) + if 50 <= m.length <= 500: + score_components['length'] = 1.0 + elif m.length < 50: + score_components['length'] = max(0, m.length / 50) + else: + score_components['length'] = max(0, 1 - (m.length - 500) / 500) + + # GC 含量评分(30-70% 是合理范围) + if 30 <= m.gc_content <= 70: + score_components['gc'] = 1.0 + else: + score_components['gc'] = max(0, 1 - abs(m.gc_content - 50) / 50) + + # uORF 评分(少量 uORF 是正常的,过多可能有问题) + if m.uorf_count <= 2: + score_components['uorf'] = 1.0 + elif m.uorf_count <= 5: + score_components['uorf'] = 0.7 + else: + score_components['uorf'] = 0.5 + + # Kozak 序列加分 + score_components['kozak'] = 1.0 if m.has_kozak else 0.7 + + # 重复序列扣分 + score_components['repeat'] = max(0, 1 - m.repeat_content / 50) + + # 综合评分(加权平均) + weights = {'length': 0.2, 'gc': 0.2, 'uorf': 0.2, 'kozak': 0.2, 'repeat': 0.2} + total_score = sum(score_components[k] * weights[k] for k in weights) + + scores.append({ + 'name': m.name, + 'total_score': total_score, + 'components': score_components, + 'quality_tier': '高' if total_score >= 0.8 else '中' if total_score >= 0.6 else '低' + }) + + return scores + + +def generate_html_report(comparisons: Dict, + quality_scores: List[Dict], + generated_metrics: List[SequenceMetrics], + reference_metrics: Optional[List[SequenceMetrics]] = None, + output_path: str = 'evaluation_report.html'): + """生成 HTML 格式的质量评估报告""" + + html = f""" + + + + + Evo2 5UTR 生成质量评估报告 + + + +
+

🧬 Evo2 5UTR 生成质量评估报告

+

生成时间:{Path(output_path).stat().st_mtime if Path(output_path).exists() else 'N/A'}

+ +
+

📊 总体统计

+
+
{len(generated_metrics)}
+
生成序列数
+
+
+
{np.mean([m.length for m in generated_metrics]):.1f} bp
+
平均长度
+
+
+
{np.mean([m.gc_content for m in generated_metrics]):.1f}%
+
平均 GC 含量
+
+
+
{sum(1 for m in generated_metrics if m.has_kozak)/len(generated_metrics)*100:.1f}%
+
Kozak 序列存在率
+
+
+ +

📈 分布对比

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
指标生成序列参考序列差异
长度 (bp){comparisons.get('length', {}).get('generated_mean', 'N/A'):.1f} ± {comparisons.get('length', {}).get('generated_std', 0):.1f}{comparisons.get('length', {}).get('reference_mean', 'N/A') or 'N/A'}{'N/A' if comparisons.get('length', {}).get('reference_mean') is None else f"{abs(comparisons['length']['generated_mean'] - comparisons['length']['reference_mean']):.1f}"}
GC 含量 (%){comparisons.get('gc_content', {}).get('generated_mean', 'N/A'):.1f} ± {comparisons.get('gc_content', {}).get('generated_std', 0):.1f}{comparisons.get('gc_content', {}).get('reference_mean', 'N/A') or 'N/A'}{'N/A' if comparisons.get('gc_content', {}).get('reference_mean') is None else f"{abs(comparisons['gc_content']['generated_mean'] - comparisons['gc_content']['reference_mean']):.1f}%"}
uORF 数量{comparisons.get('uorf_count', {}).get('generated_mean', 'N/A'):.2f} ± {comparisons.get('uorf_count', {}).get('generated_std', 0):.2f}{comparisons.get('uorf_count', {}).get('reference_mean', 'N/A') or 'N/A'}{'N/A' if comparisons.get('uorf_count', {}).get('reference_mean') is None else f"{abs(comparisons['uorf_count']['generated_mean'] - comparisons['uorf_count']['reference_mean']):.2f}"}
Kozak 序列存在率{comparisons.get('kozak_presence', {}).get('generated', 'N/A'):.1f}%{comparisons.get('kozak_presence', {}).get('reference', 'N/A') or 'N/A'}{'N/A' if comparisons.get('kozak_presence', {}).get('reference') is None else f"{abs(comparisons['kozak_presence']['generated'] - comparisons['kozak_presence']['reference']):.1f}%"}
+ +

⭐ 质量评分

+ + + + + + + + + + +""" + + # 添加质量评分行 + for score in sorted(quality_scores, key=lambda x: x['total_score'], reverse=True): + quality_class = 'quality-high' if score['quality_tier'] == '高' else \ + 'quality-medium' if score['quality_tier'] == '中' else 'quality-low' + + m = next((x for x in generated_metrics if x.name == score['name']), None) + html += f""" + + + + + + + + + + """ + + html += """ +
序列名称综合评分质量等级长度GC 含量uORF 数Kozak
{score['name']}{score['total_score']:.2f}{score['quality_tier']}{m.length if m else 'N/A'} bp{m.gc_content if m else 'N/A'}%{m.uorf_count if m else 'N/A'}{'✓' if (m and m.has_kozak) else '✗'}
+ +

📋 详细序列信息

+ + + + + + + + + + + + """ + + for m in generated_metrics: + html += f""" + + + + + + + + + + + """ + + # 质量建议 + high_quality_ratio = sum(1 for s in quality_scores if s['quality_tier'] == '高') / len(quality_scores) * 100 + + if high_quality_ratio >= 80: + html += f""" +
+

✅ 质量优秀

+

{high_quality_ratio:.1f}% 的生成序列达到高质量标准。模型表现良好,可以直接用于下游分析。

+
+ """ + elif high_quality_ratio >= 50: + html += f""" +
+

⚠️ 质量中等

+

{high_quality_ratio:.1f}% 的生成序列达到高质量标准。建议:

+
    +
  • 过滤低质量序列后再用于实验
  • +
  • 考虑调整生成参数(如降低 temperature)
  • +
  • 增加训练数据量或训练步数
  • +
+
+ """ + else: + html += f""" +
+

⚠️ 需要优化

+

仅 {high_quality_ratio:.1f}% 的生成序列达到高质量标准。建议:

+
    +
  • 检查训练数据质量
  • +
  • 增加训练步数
  • +
  • 调整模型超参数
  • +
  • 考虑使用全参数微调代替 LoRA
  • +
+
+ """ + + html += """ + + + + """ + + with open(output_path, 'w', encoding='utf-8') as f: + f.write(html) + + print(f"HTML 报告已生成:{output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description='评估 Evo2 生成的 5UTR 序列质量' + ) + + parser.add_argument('--generated', type=str, required=True, + help='生成的 5UTR 序列 FASTA 文件') + parser.add_argument('--reference', type=str, default=None, + help='真实 5UTR 序列 FASTA 文件(用于对比)') + parser.add_argument('--output', type=str, default='evaluation_report.html', + help='输出报告路径(HTML 格式)') + parser.add_argument('--output-json', type=str, default=None, + help='输出 JSON 格式的详细结果') + parser.add_argument('--skip-structure', action='store_true', + help='跳过二级结构预测(无 ViennaRNA 时)') + + args = parser.parse_args() + + print("="*60) + print("Evo2 5UTR 序列质量评估工具") + print("="*60) + + # 读取生成的序列 + print(f"\n[1/4] 读取生成的序列:{args.generated}") + generated_seqs = read_fasta(args.generated) + print(f" 找到 {len(generated_seqs)} 条序列") + + # 读取参考序列(如果有) + reference_seqs = None + if args.reference: + print(f"\n[2/4] 读取参考序列:{args.reference}") + reference_seqs = read_fasta(args.reference) + print(f" 找到 {len(reference_seqs)} 条序列") + + # 评估序列 + print(f"\n[3/4] 评估序列质量...") + print("\n评估生成序列:") + generated_metrics = evaluate_sequences(generated_seqs, "generated") + + reference_metrics = None + if reference_seqs: + print("\n评估参考序列:") + reference_metrics = evaluate_sequences(reference_seqs, "reference") + + # 比较分布 + print(f"\n[4/4] 生成报告和评分...") + comparisons = compare_distributions(generated_metrics, reference_metrics) + quality_scores = generate_quality_scores(generated_metrics) + + # 生成报告 + generate_html_report( + comparisons=comparisons, + quality_scores=quality_scores, + generated_metrics=generated_metrics, + reference_metrics=reference_metrics, + output_path=args.output + ) + + # 输出 JSON(可选) + if args.output_json: + json_data = { + 'comparisons': comparisons, + 'quality_scores': quality_scores, + 'generated_metrics': [ + { + 'name': m.name, + 'length': m.length, + 'gc_content': m.gc_content, + 'mfe': m.mfe, + 'uorf_count': m.uorf_count, + 'has_kozak': m.has_kozak, + 'repeat_content': m.repeat_content + } + for m in generated_metrics + ] + } + with open(args.output_json, 'w') as f: + json.dump(json_data, f, indent=2) + print(f"JSON 结果已保存:{args.output_json}") + + # 打印摘要 + print("\n" + "="*60) + print("评估摘要") + print("="*60) + print(f"生成序列数:{len(generated_metrics)}") + print(f"平均长度:{np.mean([m.length for m in generated_metrics]):.1f} bp") + print(f"平均 GC 含量:{np.mean([m.gc_content for m in generated_metrics]):.1f}%") + print(f"Kozak 序列存在率:{sum(1 for m in generated_metrics if m.has_kozak)/len(generated_metrics)*100:.1f}%") + print(f"高质量序列比例:{sum(1 for s in quality_scores if s['quality_tier'] == '高')/len(quality_scores)*100:.1f}%") + print(f"\n详细报告:{args.output}") + print("="*60) + + +if __name__ == '__main__': + main() diff --git a/scripts/evaluate_intron_quality.py b/scripts/evaluate_intron_quality.py new file mode 100755 index 0000000..0ea3cd8 --- /dev/null +++ b/scripts/evaluate_intron_quality.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 +""" +内含子生成质量评估工具 + +评估指标: +1. 剪接位点准确性(GT-AG 规则) +2. 长度分布对比 +3. GC 含量分析 +4. 分支点序列保守性 +5. 聚嘧啶 tract 存在率 +6. 序列复杂度 + +使用方法: + python evaluate_intron_quality.py \ + --generated generated_introns.fasta \ + --reference natural_introns.fasta \ + --output intron_evaluation.html +""" + +import argparse +import json +from pathlib import Path +from typing import List, Dict, Tuple +from collections import defaultdict +import numpy as np + +try: + from Bio import SeqIO + from Bio.SeqUtils import gc_fraction + HAS_BIOPYTHON = True +except ImportError: + HAS_BIOPYTHON = False + print("警告:Biopython 未安装,部分功能受限。安装:pip install biopython") + + +# 剪接位点模式 +SPLICE_DONOR = 'GT' # 5' 剪接位点 +SPLICE_ACCEPTOR = 'AG' # 3' 剪接位点 + +# 分支点序列(脊椎动物共识) +BRANCH_POINT_CONSENSUS = 'YNCURAY' # Y=C/T, R=A/G, N=any + +# 聚嘧啶 tract 最小长度 +POLY_Y_MIN_LENGTH = 10 + + +def read_fasta(fasta_path: str) -> List[Tuple[str, str]]: + """读取 FASTA 文件""" + sequences = [] + + if HAS_BIOPYTHON: + for record in SeqIO.parse(fasta_path, 'fasta'): + sequences.append((record.id, str(record.seq).upper())) + else: + current_name = None + current_seq = [] + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('>'): + if current_name: + sequences.append((current_name, ''.join(current_seq))) + current_name = line[1:] + current_seq = [] + else: + current_seq.append(line) + if current_name: + sequences.append((current_name, ''.join(current_seq))) + + return sequences + + +def check_splice_sites(sequence: str) -> Dict: + """检查剪接位点""" + result = { + 'has_donor': sequence.startswith(SPLICE_DONOR), + 'has_acceptor': sequence.endswith(SPLICE_ACCEPTOR), + 'donor_seq': sequence[:2] if len(sequence) >= 2 else '', + 'acceptor_seq': sequence[-2:] if len(sequence) >= 2 else '', + 'correct': sequence.startswith(SPLICE_DONOR) and sequence.endswith(SPLICE_ACCEPTOR) + } + return result + + +def find_branch_point(sequence: str, window_size: int = 50) -> Tuple[bool, str, int]: + """ + 寻找分支点序列 + + 分支点通常位于 3' 剪接位点上游 20-50bp + """ + # 在序列的 3' 端附近搜索 + search_region = sequence[-window_size:-2] if len(sequence) > window_size else sequence[:-2] + + # 简化的分支点模式匹配 + # 实际应该更复杂,这里做简单匹配 + branch_patterns = ['CTAAC', 'TACTAAC', 'AAC', 'TACTRAY'] + + for pattern in branch_patterns: + pos = search_region.find(pattern) + if pos != -1: + return True, pattern, len(sequence) - window_size + pos + + return False, '', -1 + + +def find_poly_pyrimidine_tract(sequence: str, min_length: int = POLY_Y_MIN_LENGTH) -> Tuple[bool, str]: + """ + 寻找聚嘧啶 tract + + 通常位于分支点和 3' 剪接位点之间 + """ + pyrimidines = 'CT' + + # 在 3' 端附近搜索(聚嘧啶 tract 通常在 3' 剪接位点上游) + search_region = sequence[-100:-2] if len(sequence) > 100 else sequence[:-2] + + max_tract = '' + current_tract = '' + + for base in search_region: + if base in pyrimidines: + current_tract += base + else: + if len(current_tract) > len(max_tract): + max_tract = current_tract + current_tract = '' + + if len(current_tract) > len(max_tract): + max_tract = current_tract + + has_tract = len(max_tract) >= min_length + return has_tract, max_tract + + +def calculate_sequence_complexity(sequence: str) -> float: + """ + 计算序列复杂度(基于熵) + + 返回:0-1,1 表示最高复杂度 + """ + if not sequence: + return 0.0 + + # 计算单核苷酸频率 + freq = defaultdict(int) + for base in sequence: + freq[base] += 1 + + # 计算熵 + entropy = 0.0 + length = len(sequence) + for count in freq.values(): + if count > 0: + p = count / length + entropy -= p * np.log2(p) + + # 归一化(最大熵为 log2(4) = 2 for DNA) + max_entropy = np.log2(4) + normalized_entropy = entropy / max_entropy + + return normalized_entropy + + +def evaluate_introns(generated_seqs: List[Tuple[str, str]], + reference_seqs: List[Tuple[str, str]] = None) -> Dict: + """评估内含子质量""" + + results = { + 'count': len(generated_seqs), + 'splice_sites': { + 'correct_count': 0, + 'donor_correct_count': 0, + 'acceptor_correct_count': 0, + 'accuracy': 0.0 + }, + 'length': { + 'mean': 0.0, + 'std': 0.0, + 'min': 0, + 'max': 0, + 'median': 0 + }, + 'gc_content': { + 'mean': 0.0, + 'std': 0.0 + }, + 'branch_point': { + 'detected_count': 0, + 'detection_rate': 0.0 + }, + 'poly_pyrimidine': { + 'detected_count': 0, + 'detection_rate': 0.0 + }, + 'complexity': { + 'mean': 0.0, + 'std': 0.0 + } + } + + lengths = [] + gc_contents = [] + complexities = [] + + for name, sequence in generated_seqs: + # 剪接位点 + splice = check_splice_sites(sequence) + if splice['correct']: + results['splice_sites']['correct_count'] += 1 + if splice['has_donor']: + results['splice_sites']['donor_correct_count'] += 1 + if splice['has_acceptor']: + results['splice_sites']['acceptor_correct_count'] += 1 + + # 长度 + lengths.append(len(sequence)) + + # GC 含量 + gc_contents.append(gc_fraction(sequence) * 100) + + # 分支点 + has_bp, _, _ = find_branch_point(sequence) + if has_bp: + results['branch_point']['detected_count'] += 1 + + # 聚嘧啶 tract + has_py, _ = find_poly_pyrimidine_tract(sequence) + if has_py: + results['poly_pyrimidine']['detected_count'] += 1 + + # 序列复杂度 + complexity = calculate_sequence_complexity(sequence) + complexities.append(complexity) + + # 计算统计 + n = len(generated_seqs) + if n > 0: + results['splice_sites']['accuracy'] = results['splice_sites']['correct_count'] / n * 100 + results['splice_sites']['donor_accuracy'] = results['splice_sites']['donor_correct_count'] / n * 100 + results['splice_sites']['acceptor_accuracy'] = results['splice_sites']['acceptor_correct_count'] / n * 100 + + results['length']['mean'] = np.mean(lengths) + results['length']['std'] = np.std(lengths) + results['length']['min'] = min(lengths) + results['length']['max'] = max(lengths) + results['length']['median'] = np.median(lengths) + + results['gc_content']['mean'] = np.mean(gc_contents) + results['gc_content']['std'] = np.std(gc_contents) + + results['branch_point']['detection_rate'] = results['branch_point']['detected_count'] / n * 100 + results['poly_pyrimidine']['detection_rate'] = results['poly_pyrimidine']['detected_count'] / n * 100 + + results['complexity']['mean'] = np.mean(complexities) + results['complexity']['std'] = np.std(complexities) + + # 与参考数据对比 + if reference_seqs: + ref_lengths = [len(seq) for _, seq in reference_seqs] + ref_gc = [gc_fraction(seq) * 100 for _, seq in reference_seqs] + + results['comparison'] = { + 'length_diff': results['length']['mean'] - np.mean(ref_lengths), + 'gc_diff': results['gc_content']['mean'] - np.mean(ref_gc), + 'reference_length_mean': np.mean(ref_lengths), + 'reference_gc_mean': np.mean(ref_gc) + } + + return results + + +def generate_html_report(results: Dict, output_path: str): + """生成 HTML 评估报告""" + + html = f""" + + + + + 内含子生成质量评估报告 + + + +
+

🧬 内含子生成质量评估报告

+ +
+
{results['count']}
+
评估序列数
+
+
+
{results['length']['mean']:.0f} bp
+
平均长度
+
+
+
{results['gc_content']['mean']:.1f}%
+
平均 GC 含量
+
+ +

📊 剪接位点准确性

+
序列名称长度 (bp)GC 含量 (%)MFE (kcal/mol)uORF 数量Kozak 序列聚嘧啶 tract重复含量 (%)
{m.name}{m.length}{m.gc_content:.1f}{m.mfe if m.mfe else 'N/A'}{m.uorf_count}{m.kozak_sequence if m.kozak_sequence else '-'}{'✓' if m.poly_pyrimidine_tract else '-'}{m.repeat_content:.1f}
+ + + + + + + + + + + + + + + + + + + + +
位点类型正确数准确率
5' 剪接位点 (GT){results['splice_sites']['donor_correct_count']}{results['splice_sites']['donor_accuracy']:.1f}%
3' 剪接位点 (AG){results['splice_sites']['acceptor_correct_count']}{results['splice_sites']['acceptor_accuracy']:.1f}%
完整剪接位点 (GT...AG){results['splice_sites']['correct_count']}{results['splice_sites']['accuracy']:.1f}%
+ + {(f'

✅ 剪接位点质量优秀

{results["splice_sites"]["accuracy"]:.1f}% 的序列具有正确的 GT-AG 剪接位点。

' + if results['splice_sites']['accuracy'] >= 90 else + f'

⚠️ 剪接位点需要改进

仅 {results["splice_sites"]["accuracy"]:.1f}% 的序列具有正确的 GT-AG 剪接位点。

' + if results['splice_sites']['accuracy'] >= 70 else + f'

❌ 剪接位点质量差

仅 {results["splice_sites"]["accuracy"]:.1f}% 的序列具有正确的 GT-AG 剪接位点。建议增加训练数据或调整模型。

')} + +

📏 长度分布

+ + + + + + + +
统计
平均值{results['length']['mean']:.1f} bp
标准差{results['length']['std']:.1f} bp
最小值{results['length']['min']} bp
最大值{results['length']['max']} bp
中位数{results['length']['median']:.1f} bp
+ +

🧬 GC 含量

+ + + + +
统计
平均值{results['gc_content']['mean']:.1f}%
标准差{results['gc_content']['std']:.1f}%
+ +

🔍 调控元件

+ + + + + + + + + + + + + + + + +
元件类型检出数检出率
分支点序列{results['branch_point']['detected_count']}{results['branch_point']['detection_rate']:.1f}%
聚嘧啶 tract{results['poly_pyrimidine']['detected_count']}{results['poly_pyrimidine']['detection_rate']:.1f}%
+ +

📈 序列复杂度

+ + + + +
统计
平均值{results['complexity']['mean']:.3f}
标准差{results['complexity']['std']:.3f}
+

注:复杂度范围 0-1,1 表示最高复杂度(随机序列),0 表示完全重复。

+ + {(f'

📊 与参考数据对比

指标生成序列参考序列差异
平均长度{results["length"]["mean"]:.1f} bp{results["comparison"]["reference_length_mean"]:.1f} bp{results["comparison"]["length_diff"]:+.1f} bp
GC 含量{results["gc_content"]["mean"]:.1f}%{results["comparison"]["reference_gc_mean"]:.1f}%{results["comparison"]["gc_diff"]:+.1f}%
' + if 'comparison' in results else '')} +
+ + + """ + + with open(output_path, 'w', encoding='utf-8') as f: + f.write(html) + + print(f"HTML 报告已生成:{output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description='评估内含子生成质量' + ) + + parser.add_argument('--generated', type=str, required=True, + help='生成的内含子 FASTA 文件') + parser.add_argument('--reference', type=str, default=None, + help='参考内含子 FASTA 文件(用于对比)') + parser.add_argument('--output', type=str, default='intron_evaluation.html', + help='输出报告路径') + parser.add_argument('--output-json', type=str, default=None, + help='输出 JSON 格式结果') + + args = parser.parse_args() + + print("="*60) + print("内含子生成质量评估工具") + print("="*60) + + print(f"\n[1/3] 读取生成的序列:{args.generated}") + generated_seqs = read_fasta(args.generated) + print(f" 找到 {len(generated_seqs)} 条序列") + + reference_seqs = None + if args.reference: + print(f"\n[2/3] 读取参考序列:{args.reference}") + reference_seqs = read_fasta(args.reference) + print(f" 找到 {len(reference_seqs)} 条序列") + + print(f"\n[3/3] 评估质量...") + results = evaluate_introns(generated_seqs, reference_seqs) + + generate_html_report(results, args.output) + + if args.output_json: + with open(args.output_json, 'w') as f: + json.dump(results, f, indent=2) + print(f"JSON 结果已保存:{args.output_json}") + + # 打印摘要 + print("\n" + "="*60) + print("评估摘要") + print("="*60) + print(f"评估序列数:{results['count']}") + print(f"平均长度:{results['length']['mean']:.1f} bp") + print(f"平均 GC 含量:{results['gc_content']['mean']:.1f}%") + print(f"剪接位点准确率:{results['splice_sites']['accuracy']:.1f}%") + print(f"分支点检出率:{results['branch_point']['detection_rate']:.1f}%") + print(f"聚嘧啶 tract 检出率:{results['poly_pyrimidine']['detection_rate']:.1f}%") + print(f"序列复杂度:{results['complexity']['mean']:.3f}") + print(f"\n详细报告:{args.output}") + print("="*60) + + +if __name__ == '__main__': + main() diff --git a/scripts/generate_5utr.py b/scripts/generate_5utr.py new file mode 100755 index 0000000..a0e15bb --- /dev/null +++ b/scripts/generate_5utr.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +""" +Evo2 CDS→5UTR 生成脚本 + +使用训练好的模型,以 CDS 序列为 prompt 生成 5UTR 序列 + +使用方法: + # 单个序列生成 + python generate_5utr.py \ + --model models/cds_5utr_model.pt \ + --cds "ATGCGT..." \ + --output generated_5utr.txt + + # 批量生成 + python generate_5utr.py \ + --model models/cds_5utr_model.pt \ + --input cds_sequences.fasta \ + --output generated_5utr.fasta \ + --n_tokens 500 \ + --temperature 1.0 +""" + +import argparse +import csv +from pathlib import Path +from typing import List, Optional +import numpy as np +import torch + +from evo2 import Evo2 + + +def read_fasta(fasta_path: str) -> List[tuple]: + """读取 FASTA 文件,返回 [(name, sequence), ...]""" + sequences = [] + current_name = None + current_seq = [] + + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('>'): + if current_name: + sequences.append((current_name, ''.join(current_seq))) + current_name = line[1:] + current_seq = [] + else: + current_seq.append(line.upper()) + + if current_name: + sequences.append((current_name, ''.join(current_seq))) + + return sequences + + +def write_fasta(sequences: List[tuple], output_path: str) -> None: + """写入 FASTA 文件""" + with open(output_path, 'w') as f: + for name, seq in sequences: + f.write(f">{name}\n") + for i in range(0, len(seq), 80): + f.write(seq[i:i+80] + '\n') + + +def generate_5utr(model: Evo2, + cds_sequences: List[str], + n_tokens: int = 500, + temperature: float = 1.0, + top_k: int = 4, + seed: Optional[int] = None) -> List[str]: + """ + 生成 5UTR 序列 + + Args: + model: Evo2 模型 + cds_sequences: CDS 序列列表 + n_tokens: 生成的 token 数量(碱基数) + temperature: 采样温度(越高越随机) + top_k: top-k 采样 + seed: 随机种子 + + Returns: + 生成的 5UTR 序列列表 + """ + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + print(f"初始化推理参数...") + generations = model.generate( + cds_sequences, + n_tokens=n_tokens, + temperature=temperature, + top_k=top_k, + ) + + return generations.sequences + + +def main(): + parser = argparse.ArgumentParser( + description='使用 Evo2 生成 5UTR 序列' + ) + + parser.add_argument('--model', type=str, required=True, + help='训练好的模型路径 (.pt 文件) 或模型名称 (如 evo2_7b)') + parser.add_argument('--cds', type=str, default=None, + help='单个 CDS 序列(用于快速测试)') + parser.add_argument('--input', type=str, default=None, + help='输入 FASTA 文件路径(包含多个 CDS 序列)') + parser.add_argument('--output', type=str, required=True, + help='输出文件路径') + parser.add_argument('--n_tokens', type=int, default=500, + help='生成的 token 数量(碱基数)') + parser.add_argument('--temperature', type=float, default=1.0, + help='采样温度(0.1-2.0,越高越随机)') + parser.add_argument('--top_k', type=int, default=4, + help='top-k 采样参数') + parser.add_argument('--seed', type=int, default=42, + help='随机种子') + parser.add_argument('--output_format', type=str, default='fasta', + choices=['fasta', 'txt', 'csv'], + help='输出格式') + parser.add_argument('--include_prompt', action='store_true', + help='输出中包含 CDS prompt 序列') + + args = parser.parse_args() + + # 准备 CDS 序列 + cds_sequences = [] + names = [] + + if args.cds: + # 单个序列模式 + cds_sequences = [args.cds.upper()] + names = ['prompt_1'] + print(f"使用单个 CDS 序列(长度:{len(args.cds)} bp)") + + elif args.input: + # 批量模式 + print(f"[1/3] 读取输入 FASTA: {args.input}") + data = read_fasta(args.input) + cds_sequences = [seq.upper() for _, seq in data] + names = [name for name, _ in data] + print(f" 找到 {len(cds_sequences)} 条 CDS 序列") + + else: + print("错误:必须指定 --cds 或 --input") + exit(1) + + # 加载模型 + print(f"[2/3] 加载模型:{args.model}") + if args.model.endswith('.pt'): + # 本地模型文件 + model = Evo2('evo2_7b', local_path=args.model) + else: + # HuggingFace 模型名称 + model = Evo2(args.model) + print(" 模型加载完成") + + # 生成 5UTR + print(f"[3/3] 生成 5UTR 序列...") + print(f" 参数:n_tokens={args.n_tokens}, temperature={args.temperature}, top_k={args.top_k}") + + generated_seqs = generate_5utr( + model=model, + cds_sequences=cds_sequences, + n_tokens=args.n_tokens, + temperature=args.temperature, + top_k=args.top_k, + seed=args.seed + ) + + # 写入输出 + print(f"写入输出文件:{args.output}") + + if args.output_format == 'fasta': + if args.include_prompt: + output_data = [ + (f"{name}|CDS:{len(cds)}|5UTR:{len(utr)}", cds + utr) + for name, cds, utr in zip(names, cds_sequences, generated_seqs) + ] + else: + output_data = [ + (f"{name}|5UTR_generated", utr) + for name, utr in zip(names, generated_seqs) + ] + write_fasta(output_data, args.output) + + elif args.output_format == 'txt': + with open(args.output, 'w') as f: + for i, (name, cds, utr) in enumerate(zip(names, cds_sequences, generated_seqs)): + f.write(f">{name}\n") + if args.include_prompt: + f.write(f"CDS: {cds}\n") + f.write(f"5UTR: {utr}\n\n") + + elif args.output_format == 'csv': + with open(args.output, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['name', 'cds_sequence', '5utr_sequence', 'cds_length', '5utr_length']) + for name, cds, utr in zip(names, cds_sequences, generated_seqs): + writer.writerow([name, cds, utr, len(cds), len(utr)]) + + print(f"完成!生成了 {len(generated_seqs)} 条 5UTR 序列") + + # 打印统计信息 + print("\n生成统计:") + print(f" 平均 5UTR 长度:{sum(len(s) for s in generated_seqs)/len(generated_seqs):.1f} bp") + print(f" 最短 5UTR: {min(len(s) for s in generated_seqs)} bp") + print(f" 最长 5UTR: {max(len(s) for s in generated_seqs)} bp") + + +if __name__ == '__main__': + main() diff --git a/scripts/generate_intron.py b/scripts/generate_intron.py new file mode 100755 index 0000000..2c6a628 --- /dev/null +++ b/scripts/generate_intron.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +""" +Evo2 内含子生成脚本 + +支持: +- 给定外显子生成内含子 +- 功能导向生成(高/低 IME 活性) +- 长度控制生成 + +使用方法: + # 基础生成 + python generate_intron.py \ + --model intron_model.pt \ + --exon5 "ATG..." \ + --exon3 "TAA..." \ + --output generated_introns.fasta + + # 功能导向生成 + python generate_intron.py \ + --model intron_model.pt \ + --exon5 "ATG..." \ + --exon3 "TAA..." \ + --task intron_func_high \ + --output high_ime_introns.fasta +""" + +import argparse +from pathlib import Path +from typing import List, Tuple, Optional +import numpy as np +import torch + +from evo2 import Evo2 + + +# 特殊 token(与训练时一致) +TASK_MARKERS = { + 'intron': '', + 'intron_func_high': '', + 'intron_func_low': '' +} + +SEP_EXON_INTRON = 'NNNNNN' +SEP_INTRON_EXON = 'NNNNNNN' +END_TOKEN = '' + +# 默认外显子序列(如果用户未提供) +DEFAULT_EXON_5 = "ATGGCTAGCTACGGTACGGATCCGCTAGCATCGATCGATCGATCGTAGCTAGCTAG" +DEFAULT_EXON_3 = "GGATCCGGATCCGGATTAGCTAGCTAGCTAGCTAGCTAGCATCGATCGATCGTAA" + + +def read_fasta(fasta_path: str) -> List[Tuple[str, str]]: + """读取 FASTA 文件""" + sequences = [] + current_name = None + current_seq = [] + + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('>'): + if current_name: + sequences.append((current_name, ''.join(current_seq))) + current_name = line[1:] + current_seq = [] + else: + current_seq.append(line.upper()) + + if current_name: + sequences.append((current_name, ''.join(current_seq))) + + return sequences + + +def write_fasta(sequences: List[Tuple[str, str]], output_path: str) -> None: + """写入 FASTA 文件""" + with open(output_path, 'w') as f: + for name, seq in sequences: + f.write(f">{name}\n") + for i in range(0, len(seq), 80): + f.write(seq[i:i+80] + '\n') + + +def parse_generated_sequence(generated: str) -> dict: + """ + 解析生成的序列,提取内含子 + + Returns: + dict: {'exon5': str, 'intron': str, 'exon3': str} + """ + result = {'exon5': '', 'intron': '', 'exon3': ''} + + # 移除任务标记 + for marker in TASK_MARKERS.values(): + if generated.startswith(marker): + generated = generated[len(marker):] + + if generated.endswith(END_TOKEN): + generated = generated[:-len(END_TOKEN)] + + # 提取内含子(在两个分隔符之间) + if SEP_EXON_INTRON in generated and SEP_INTRON_EXON in generated: + parts = generated.split(SEP_EXON_INTRON) + result['exon5'] = parts[0] + remainder = parts[1] if len(parts) > 1 else '' + + if SEP_INTRON_EXON in remainder: + sub_parts = remainder.split(SEP_INTRON_EXON) + result['intron'] = sub_parts[0] + result['exon3'] = sub_parts[1] if len(sub_parts) > 1 else '' + + return result + + +def generate_intron(model: Evo2, + exon5_seqs: List[str], + exon3_seqs: List[str], + task: str = 'intron', + target_length: Optional[int] = None, + temperature: float = 1.0, + top_k: int = 4, + seed: Optional[int] = None) -> List[str]: + """ + 生成内含子序列 + + Args: + model: Evo2 模型 + exon5_seqs: 5' 外显子序列列表 + exon3_seqs: 3' 外显子序列列表 + task: 任务类型 ('intron', 'intron_func_high', 'intron_func_low') + target_length: 目标内含子长度(可选) + temperature: 采样温度 + top_k: top-k 采样 + seed: 随机种子 + + Returns: + 生成的完整序列(包含特殊 token) + """ + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # 构建 prompt + prompts = [] + for exon5, exon3 in zip(exon5_seqs, exon3_seqs): + if task == 'intron': + # 基础格式:Exon5 + prompt = f"{TASK_MARKERS['intron']}{exon5}" + + elif task == 'intron_func_high': + # 高 IME 活性:Exon5 + prompt = f"{TASK_MARKERS['intron_func_high']}{exon5}" + + elif task == 'intron_func_low': + # 低 IME 活性:Exon5 + prompt = f"{TASK_MARKERS['intron_func_low']}{exon5}" + + else: + prompt = f"{TASK_MARKERS['intron']}{exon5}" + + # 添加长度控制(如果指定) + if target_length: + prompt = f"{exon5}" + + prompts.append(prompt) + print(f" Prompt: {prompt[:60]}...") + + # 计算需要的 token 数 + # 内含子 + 分隔符 + 3' 外显子 + END + avg_intron_length = target_length if target_length else 200 + n_tokens = avg_intron_length + len(SEP_EXON_INTRON) + len(SEP_INTRON_EXON) + \ + max(len(e) for e in exon3_seqs) + len(END_TOKEN) + 50 + + print(f"\n生成参数:n_tokens={n_tokens}, temperature={temperature}, top_k={top_k}") + + generations = model.generate( + prompts, + n_tokens=n_tokens, + temperature=temperature, + top_k=top_k, + ) + + return generations.sequences + + +def main(): + parser = argparse.ArgumentParser( + description='使用 Evo2 生成内含子序列' + ) + + parser.add_argument('--model', type=str, required=True, + help='训练好的模型路径 (.pt 文件) 或模型名称') + parser.add_argument('--exon5', type=str, default=None, + help='5\' 外显子序列(单个)') + parser.add_argument('--exon3', type=str, default=None, + help='3\' 外显子序列(单个)') + parser.add_argument('--input', type=str, default=None, + help='输入 FASTA 文件(包含多个外显子对)') + parser.add_argument('--task', type=str, default='intron', + choices=['intron', 'intron_func_high', 'intron_func_low'], + help='生成任务类型') + parser.add_argument('--target_length', type=int, default=None, + help='目标内含子长度(bp)') + parser.add_argument('--output', type=str, required=True, + help='输出文件路径') + parser.add_argument('--temperature', type=float, default=1.0, + help='采样温度') + parser.add_argument('--top_k', type=int, default=4, + help='top-k 采样参数') + parser.add_argument('--seed', type=int, default=42, + help='随机种子') + parser.add_argument('--output_format', type=str, default='fasta', + choices=['fasta', 'csv'], + help='输出格式') + + args = parser.parse_args() + + # 准备外显子序列 + exon5_seqs = [] + exon3_seqs = [] + names = [] + + if args.exon5 and args.exon3: + # 单个外显子对 + exon5_seqs = [args.exon5.upper()] + exon3_seqs = [args.exon3.upper()] + names = ['intron_1'] + print(f"使用单个外显子对") + print(f" 5' 外显子长度:{len(args.exon5)} bp") + print(f" 3' 外显子长度:{len(args.exon3)} bp") + + elif args.input: + # 批量模式 + print(f"[1/3] 读取输入 FASTA: {args.input}") + data = read_fasta(args.input) + + for name, seq in data: + # 假设 FASTA 格式:>name|exon5:ACGT...|exon3:TGCA... + # 或者:>name\nACGT...NNNNNN...NNNNNNN...TGCA... + if '|' in name and 'exon5:' in name and 'exon3:' in name: + parts = name.split('|') + exon5 = parts[1].split(':')[1] + exon3 = parts[2].split(':')[1] + exon5_seqs.append(exon5.upper()) + exon3_seqs.append(exon3.upper()) + names.append(name.split('|')[0]) + else: + # 尝试从序列中分割 + if SEP_EXON_INTRON in seq: + parts = seq.split(SEP_EXON_INTRON) + if len(parts) >= 2: + exon5_seqs.append(parts[0].replace(TASK_MARKERS.get(args.task, ''), '').upper()) + remainder = parts[1] + if SEP_INTRON_EXON in remainder: + sub_parts = remainder.split(SEP_INTRON_EXON) + exon3_seqs.append(sub_parts[1].replace(END_TOKEN, '').upper()) + names.append(name) + + print(f" 找到 {len(exon5_seqs)} 个外显子对") + + else: + # 使用默认外显子 + print("未提供外显子序列,使用默认序列") + exon5_seqs = [DEFAULT_EXON_5] + exon3_seqs = [DEFAULT_EXON_3] + names = ['default_intron'] + + # 加载模型 + print(f"\n[2/3] 加载模型:{args.model}") + if args.model.endswith('.pt'): + model = Evo2('evo2_7b', local_path=args.model) + else: + model = Evo2(args.model) + print(" 模型加载完成") + + # 生成内含子 + print(f"\n[3/3] 生成内含子(任务:{args.task})...") + + generated_seqs = generate_intron( + model=model, + exon5_seqs=exon5_seqs, + exon3_seqs=exon3_seqs, + task=args.task, + target_length=args.target_length, + temperature=args.temperature, + top_k=args.top_k, + seed=args.seed + ) + + # 解析生成的序列 + print("\n解析生成的序列...") + results = [] + for name, generated in zip(names, generated_seqs): + parsed = parse_generated_sequence(generated) + results.append((name, parsed, generated)) + + print(f"\n{name}:") + print(f" 5' 外显子:{len(parsed['exon5'])} bp") + print(f" 内含子:{len(parsed['intron'])} bp") + print(f" 3' 外显子:{len(parsed['exon3'])} bp") + + # 检查剪接位点 + intron = parsed['intron'] + if intron.startswith('GT') and intron.endswith('AG'): + print(f" ✓ 剪接位点正确 (GT...AG)") + else: + print(f" ⚠ 剪接位点异常:{intron[:2]}...{intron[-2:]}") + + # 写入输出 + print(f"\n写入输出文件:{args.output}") + + with open(args.output, 'w') as f: + if args.output_format == 'fasta': + for name, parsed, full_seq in results: + # 输出纯内含子序列 + f.write(f">{name}|intron|length:{len(parsed['intron'])}\n") + for i in range(0, len(parsed['intron']), 80): + f.write(parsed['intron'][i:i+80] + '\n') + + elif args.output_format == 'csv': + import csv + with open(args.output, 'w', newline='') as f: + writer = csv.writer(f) + writer.writerow(['name', 'intron_sequence', 'intron_length', 'has_gt_ag']) + for name, parsed, _ in results: + has_gt_ag = parsed['intron'].startswith('GT') and parsed['intron'].endswith('AG') + writer.writerow([name, parsed['intron'], len(parsed['intron']), has_gt_ag]) + + print(f"完成!生成了 {len(results)} 个内含子") + + # 打印统计信息 + print("\n生成统计:") + intron_lengths = [len(parse_generated_sequence(s)['intron']) for s in generated_seqs] + gt_ag_count = sum(1 for s in generated_seqs + if parse_generated_sequence(s)['intron'].startswith('GT') and + parse_generated_sequence(s)['intron'].endswith('AG')) + + print(f" 平均内含子长度:{np.mean(intron_lengths):.1f} bp") + print(f" 最短内含子:{min(intron_lengths)} bp") + print(f" 最长内含子:{max(intron_lengths)} bp") + print(f" GT-AG 正确率:{gt_ag_count/len(results)*100:.1f}%") + + +if __name__ == '__main__': + main() diff --git a/scripts/generate_utr.py b/scripts/generate_utr.py new file mode 100755 index 0000000..1411505 --- /dev/null +++ b/scripts/generate_utr.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python3 +""" +Evo2 UTR 生成脚本 + +支持: +- 只生成 5UTR +- 只生成 3UTR +- 同时生成 5UTR 和 3UTR + +使用方法: + # 只生成 5UTR + python generate_utr.py --model model.pt --cds "ATG..." --task 5UTR --output utr5.fasta + + # 只生成 3UTR + python generate_utr.py --model model.pt --cds "ATG..." --task 3UTR --output utr3.fasta + + # 同时生成 + python generate_utr.py --model model.pt --cds "ATG..." --task both --output full_utr.fasta +""" + +import argparse +from pathlib import Path +from typing import List, Tuple, Optional +import numpy as np +import torch + +from evo2 import Evo2 + + +# ============ 特殊 token 定义(与训练时一致) ============ +TASK_MARKERS = { + '5UTR': '<5UTR>', + '3UTR': '<3UTR>', + 'both': '' +} + +SEP_5UTR = 'NNNNN' # 5UTR 分隔符 +SEP_3UTR = 'NNNNNNN' # 3UTR 分隔符 +END_TOKEN = '' + + +def read_fasta(fasta_path: str) -> List[Tuple[str, str]]: + """读取 FASTA 文件""" + sequences = [] + current_name = None + current_seq = [] + + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('>'): + if current_name: + sequences.append((current_name, ''.join(current_seq))) + current_name = line[1:] + current_seq = [] + else: + current_seq.append(line.upper()) + + if current_name: + sequences.append((current_name, ''.join(current_seq))) + + return sequences + + +def write_fasta(sequences: List[Tuple[str, str]], output_path: str) -> None: + """写入 FASTA 文件""" + with open(output_path, 'w') as f: + for name, seq in sequences: + f.write(f">{name}\n") + for i in range(0, len(seq), 80): + f.write(seq[i:i+80] + '\n') + + +def parse_generated_sequence(generated: str, task: str) -> dict: + """ + 解析生成的序列,提取 UTR 区域 + + Returns: + dict: {'utr5': str, 'cds': str, 'utr3': str} + """ + result = {'utr5': '', 'cds': '', 'utr3': ''} + + # 移除任务标记和结束标记 + for marker in TASK_MARKERS.values(): + if generated.startswith(marker): + generated = generated[len(marker):] + + if generated.endswith(END_TOKEN): + generated = generated[:-len(END_TOKEN)] + + if task == '5UTR': + # 格式:5UTR 序列 [SEP5]CDS 序列 + if SEP_5UTR in generated: + parts = generated.split(SEP_5UTR) + result['utr5'] = parts[0] + result['cds'] = parts[1] if len(parts) > 1 else '' + + elif task == '3UTR': + # 格式:CDS 序列 [SEP3]3UTR 序列 + if SEP_3UTR in generated: + parts = generated.split(SEP_3UTR) + result['cds'] = parts[0] + result['utr3'] = parts[1] if len(parts) > 1 else '' + + elif task == 'both': + # 格式:5UTR 序列 [SEP5]CDS 序列 [SEP3]3UTR 序列 + if SEP_5UTR in generated and SEP_3UTR in generated: + parts = generated.split(SEP_5UTR) + result['utr5'] = parts[0] + remainder = parts[1] if len(parts) > 1 else '' + if SEP_3UTR in remainder: + sub_parts = remainder.split(SEP_3UTR) + result['cds'] = sub_parts[0] + result['utr3'] = sub_parts[1] if len(sub_parts) > 1 else '' + + return result + + +def generate_utr(model: Evo2, + cds_sequences: List[str], + task: str = '5UTR', + n_tokens_5utr: int = 300, + n_tokens_3utr: int = 300, + temperature: float = 1.0, + top_k: int = 4, + seed: Optional[int] = None) -> List[str]: + """ + 生成 UTR 序列 + + Args: + model: Evo2 模型 + cds_sequences: CDS 序列列表 + task: 任务类型 ('5UTR', '3UTR', 'both') + n_tokens_5utr: 5UTR 生成 token 数 + n_tokens_3utr: 3UTR 生成 token 数 + temperature: 采样温度 + top_k: top-k 采样 + seed: 随机种子 + + Returns: + 生成的完整序列(包含特殊 token) + """ + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # 根据任务构建 prompt + prompts = [] + n_tokens = n_tokens_5utr if task in ['5UTR', 'both'] else n_tokens_3utr + + for cds in cds_sequences: + if task == '5UTR': + # 5UTR 生成:模型需要生成 5UTR+CDS,所以 prompt 只有任务标记 + # 训练格式:<5UTR>5UTR[SEP5]CDS + # 推理:给 <5UTR>,模型生成剩余部分 + prompt = TASK_MARKERS['5UTR'] + + elif task == '3UTR': + # 3UTR 生成:模型需要生成 CDS+3UTR + # 训练格式:<3UTR>CDS[SEP3]3UTR + # 推理:给 <3UTR>+CDS,模型生成 3UTR + prompt = TASK_MARKERS['3UTR'] + cds + + elif task == 'both': + # 同时生成:模型需要生成 5UTR+CDS+3UTR + # 训练格式:5UTR[SEP5]CDS[SEP3]3UTR + # 推理:给 +CDS,模型生成 5UTR 和 3UTR + # 或者只给 ,让模型生成全部 + prompt = TASK_MARKERS['both'] + + prompts.append(prompt) + print(f"Prompt: {prompt[:50]}... (CDS 长度:{len(cds)})") + + # 计算需要的 token 数 + if task == '5UTR': + # 需要生成:5UTR + SEP + CDS + END + n_tokens = n_tokens_5utr + len(SEP_5UTR) + max(len(cds) for cds in cds_sequences) + len(END_TOKEN) + elif task == '3UTR': + # 需要生成:3UTR + SEP + END + n_tokens = n_tokens_3utr + len(SEP_3UTR) + len(END_TOKEN) + elif task == 'both': + # 需要生成:5UTR + SEP5 + CDS + SEP3 + 3UTR + END + n_tokens = n_tokens_5utr + len(SEP_5UTR) + max(len(cds) for cds in cds_sequences) + len(SEP_3UTR) + n_tokens_3utr + len(END_TOKEN) + + print(f"生成参数:n_tokens={n_tokens}, temperature={temperature}, top_k={top_k}") + + generations = model.generate( + prompts, + n_tokens=n_tokens, + temperature=temperature, + top_k=top_k, + ) + + return generations.sequences + + +def main(): + parser = argparse.ArgumentParser( + description='使用 Evo2 生成 UTR 序列(5UTR、3UTR 或同时生成)' + ) + + parser.add_argument('--model', type=str, required=True, + help='训练好的模型路径 (.pt 文件) 或模型名称') + parser.add_argument('--cds', type=str, default=None, + help='单个 CDS 序列(用于快速测试)') + parser.add_argument('--input', type=str, default=None, + help='输入 FASTA 文件路径(包含多个 CDS 序列)') + parser.add_argument('--task', type=str, required=True, + choices=['5UTR', '3UTR', 'both'], + help='生成任务类型') + parser.add_argument('--output', type=str, required=True, + help='输出文件路径') + parser.add_argument('--n_tokens_5utr', type=int, default=300, + help='5UTR 生成的 token 数量') + parser.add_argument('--n_tokens_3utr', type=int, default=300, + help='3UTR 生成的 token 数量') + parser.add_argument('--temperature', type=float, default=1.0, + help='采样温度') + parser.add_argument('--top_k', type=int, default=4, + help='top-k 采样参数') + parser.add_argument('--seed', type=int, default=42, + help='随机种子') + parser.add_argument('--output_format', type=str, default='fasta', + choices=['fasta', 'txt', 'csv'], + help='输出格式') + parser.add_argument('--include_full', action='store_true', + help='输出中包含完整序列(含特殊 token)') + + args = parser.parse_args() + + # 准备 CDS 序列 + cds_sequences = [] + names = [] + + if args.cds: + cds_sequences = [args.cds.upper()] + names = ['prompt_1'] + print(f"使用单个 CDS 序列(长度:{len(args.cds)} bp)") + + elif args.input: + print(f"[1/3] 读取输入 FASTA: {args.input}") + data = read_fasta(args.input) + cds_sequences = [seq.upper() for _, seq in data] + names = [name for name, _ in data] + print(f" 找到 {len(cds_sequences)} 条 CDS 序列") + + else: + print("错误:必须指定 --cds 或 --input") + exit(1) + + # 加载模型 + print(f"[2/3] 加载模型:{args.model}") + if args.model.endswith('.pt'): + model = Evo2('evo2_7b', local_path=args.model) + else: + model = Evo2(args.model) + print(" 模型加载完成") + + # 生成 UTR + print(f"[3/3] 生成 UTR(任务:{args.task})...") + + generated_seqs = generate_utr( + model=model, + cds_sequences=cds_sequences, + task=args.task, + n_tokens_5utr=args.n_tokens_5utr, + n_tokens_3utr=args.n_tokens_3utr, + temperature=args.temperature, + top_k=args.top_k, + seed=args.seed + ) + + # 解析生成的序列 + print("\n解析生成的序列...") + results = [] + for name, generated in zip(names, generated_seqs): + parsed = parse_generated_sequence(generated, args.task) + results.append((name, parsed, generated)) + + print(f"\n{name}:") + if args.task == '5UTR': + print(f" 5UTR 长度:{len(parsed['utr5'])} bp") + elif args.task == '3UTR': + print(f" 3UTR 长度:{len(parsed['utr3'])} bp") + elif args.task == 'both': + print(f" 5UTR 长度:{len(parsed['utr5'])} bp") + print(f" 3UTR 长度:{len(parsed['utr3'])} bp") + + # 写入输出 + print(f"\n写入输出文件:{args.output}") + + with open(args.output, 'w') as f: + if args.output_format == 'fasta': + for name, parsed, full_seq in results: + if args.task == '5UTR': + # 输出 5UTR 序列 + f.write(f">{name}|5UTR|length:{len(parsed['utr5'])}\n") + for i in range(0, len(parsed['utr5']), 80): + f.write(parsed['utr5'][i:i+80] + '\n') + + if args.include_full: + f.write(f">{name}|full_sequence\n{full_seq}\n") + + elif args.task == '3UTR': + f.write(f">{name}|3UTR|length:{len(parsed['utr3'])}\n") + for i in range(0, len(parsed['utr3']), 80): + f.write(parsed['utr3'][i:i+80] + '\n') + + if args.include_full: + f.write(f">{name}|full_sequence\n{full_seq}\n") + + elif args.task == 'both': + f.write(f">{name}|5UTR|length:{len(parsed['utr5'])}\n") + for i in range(0, len(parsed['utr5']), 80): + f.write(parsed['utr5'][i:i+80] + '\n') + + f.write(f">{name}|3UTR|length:{len(parsed['utr3'])}\n") + for i in range(0, len(parsed['utr3']), 80): + f.write(parsed['utr3'][i:i+80] + '\n') + + if args.include_full: + f.write(f">{name}|full_sequence\n{full_seq}\n") + + elif args.output_format == 'csv': + import csv + with open(args.output, 'w', newline='') as f: + writer = csv.writer(f) + if args.task == 'both': + writer.writerow(['name', '5utr_sequence', '5utr_length', '3utr_sequence', '3utr_length']) + for name, parsed, _ in results: + writer.writerow([name, parsed['utr5'], len(parsed['utr5']), + parsed['utr3'], len(parsed['utr3'])]) + elif args.task == '5UTR': + writer.writerow(['name', '5utr_sequence', '5utr_length']) + for name, parsed, _ in results: + writer.writerow([name, parsed['utr5'], len(parsed['utr5'])]) + elif args.task == '3UTR': + writer.writerow(['name', '3utr_sequence', '3utr_length']) + for name, parsed, _ in results: + writer.writerow([name, parsed['utr3'], len(parsed['utr3'])]) + + print(f"完成!生成了 {len(results)} 条 UTR 序列") + + # 打印统计信息 + print("\n生成统计:") + if args.task == '5UTR': + utr_lengths = [len(parse_generated_sequence(s, '5UTR')['utr5']) for s in generated_seqs] + print(f" 平均 5UTR 长度:{np.mean(utr_lengths):.1f} bp") + elif args.task == '3UTR': + utr_lengths = [len(parse_generated_sequence(s, '3UTR')['utr3']) for s in generated_seqs] + print(f" 平均 3UTR 长度:{np.mean(utr_lengths):.1f} bp") + elif args.task == 'both': + utr5_lengths = [len(parse_generated_sequence(s, 'both')['utr5']) for s in generated_seqs] + utr3_lengths = [len(parse_generated_sequence(s, 'both')['utr3']) for s in generated_seqs] + print(f" 平均 5UTR 长度:{np.mean(utr5_lengths):.1f} bp") + print(f" 平均 3UTR 长度:{np.mean(utr3_lengths):.1f} bp") + + +if __name__ == '__main__': + main() diff --git a/scripts/prepare_cds_5utr_data.py b/scripts/prepare_cds_5utr_data.py new file mode 100755 index 0000000..761968f --- /dev/null +++ b/scripts/prepare_cds_5utr_data.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +""" +Evo2 CDS→5UTR 数据准备脚本 + +功能: +1. 从 FASTA/GTF 文件提取 CDS 和 5UTR 序列 +2. 生成训练所需的 FASTA 格式文件 +3. 统计序列长度分布,帮助选择 seq_length + +使用方法: + python prepare_cds_5utr_data.py \ + --genome genome.fasta \ + --annotation annotation.gtf \ + --output cds_5utr_training.fasta \ + --upstream_length 500 +""" + +import argparse +import csv +from pathlib import Path +from typing import List, Tuple, Dict +from collections import defaultdict + + +def parse_fasta(fasta_path: str) -> Dict[str, str]: + """解析 FASTA 文件""" + sequences = {} + current_name = None + current_seq = [] + + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('>'): + if current_name: + sequences[current_name] = ''.join(current_seq) + current_name = line[1:].split()[0] + current_seq = [] + else: + current_seq.append(line.upper()) + + if current_name: + sequences[current_name] = ''.join(current_seq) + + return sequences + + +def parse_gtf(gtf_path: str) -> List[Dict]: + """解析 GTF 文件""" + features = [] + + with open(gtf_path, 'r') as f: + for line in f: + if line.startswith('#'): + continue + + parts = line.strip().split('\t') + if len(parts) < 9: + continue + + feature = { + 'seqid': parts[0], + 'source': parts[1], + 'type': parts[2], + 'start': int(parts[3]), + 'end': int(parts[4]), + 'score': parts[5], + 'strand': parts[6], + 'phase': parts[7], + 'attributes': parts[8] + } + + # 解析属性 + attr_dict = {} + for attr in parts[8].split(';'): + attr = attr.strip() + if ' ' in attr: + key, value = attr.split(' ', 1) + attr_dict[key] = value.strip('"') + + feature['attributes_dict'] = attr_dict + features.append(feature) + + return features + + +def extract_cds_and_5utr(genome_seq: str, features: List[Dict], + upstream_length: int = 500) -> List[Tuple[str, str, str]]: + """ + 从基因组序列和注释中提取 CDS 和 5UTR 序列 + + 返回:[(gene_name, CDS 序列,5UTR 序列), ...] + """ + results = [] + + # 按基因分组 + genes = defaultdict(list) + for feature in features: + if feature['type'] in ['gene', 'mRNA', 'transcript']: + gene_id = feature['attributes_dict'].get('gene_id', + feature['attributes_dict'].get('Parent', 'unknown')) + genes[gene_id].append(feature) + + # 处理每个基因 + for gene_id, gene_features in genes.items(): + # 找到 CDS 和转录本 + cds_features = [f for f in gene_features if f['type'] == 'CDS'] + mrna_features = [f for f in gene_features if f['type'] in ['mRNA', 'transcript']] + + if not cds_features or not mrna_features: + continue + + # 使用第一个转录本 + mrna = mrna_features[0] + mrna_start = mrna['start'] + mrna_end = mrna['end'] + strand = mrna['strand'] + + # 提取转录本序列 + if strand == '+': + mrna_seq = genome_seq[mrna_start-1:mrna_end] + else: + mrna_seq = reverse_complement(genome_seq[mrna_start-1:mrna_end]) + + # 提取 CDS 序列 + cds_regions = [] + for cds in cds_features: + if strand == '+': + cds_start = cds['start'] - mrna_start + cds_end = cds['end'] - mrna_start + else: + cds_start = mrna_end - cds['end'] + cds_end = mrna_end - cds['start'] + cds_regions.append((cds_start, cds_end)) + + cds_regions.sort() + + # 拼接 CDS 序列 + cds_seq = '' + for start, end in cds_regions: + cds_seq += mrna_seq[start:end] + + # 提取 5UTR(CDS 起始位置之前的区域) + if strand == '+': + five_utr_seq = mrna_seq[:cds_regions[0][0]] + else: + five_utr_seq = mrna_seq[cds_regions[-1][1]:] + + # 如果 5UTR 太长,截断到指定长度 + if len(five_utr_seq) > upstream_length: + if strand == '+': + five_utr_seq = five_utr_seq[-upstream_length:] + else: + five_utr_seq = five_utr_seq[:upstream_length] + + results.append((gene_id, cds_seq, five_utr_seq)) + + return results + + +def reverse_complement(seq: str) -> str: + """计算反向互补序列""" + complement = {'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G', + 'a': 't', 't': 'a', 'g': 'c', 'c': 'g'} + return ''.join(complement.get(base, base) for base in reversed(seq)) + + +def write_training_fasta(data: List[Tuple[str, str, str]], + output_path: str, + format: str = 'concatenated'): + """ + 写入训练用 FASTA 文件 + + format: + - 'concatenated': CDS+5UTR 连接在一起(用于自回归训练) + - 'separated': CDS 和 5UTR 分开存储 + """ + with open(output_path, 'w') as f: + for gene_id, cds_seq, five_utr_seq in data: + if format == 'concatenated': + # 自回归训练格式:CDS 后面紧跟 5UTR + full_seq = cds_seq + five_utr_seq + f.write(f">{gene_id}|cds_len:{len(cds_seq)}|utr_len:{len(five_utr_seq)}\n") + + # 每 80 个字符换行 + for i in range(0, len(full_seq), 80): + f.write(full_seq[i:i+80] + '\n') + + elif format == 'separated': + # 分开存储格式 + f.write(f">{gene_id}|CDS\n") + for i in range(0, len(cds_seq), 80): + f.write(cds_seq[i:i+80] + '\n') + + f.write(f">{gene_id}|5UTR\n") + for i in range(0, len(five_utr_seq), 80): + f.write(five_utr_seq[i:i+80] + '\n') + + +def analyze_sequence_lengths(data: List[Tuple[str, str, str]]) -> None: + """分析序列长度分布""" + cds_lengths = [len(cds) for _, cds, _ in data] + utr_lengths = [len(utr) for _, _, utr in data] + total_lengths = [len(cds) + len(utr) for _, cds, utr in data] + + print("\n" + "="*50) + print("序列长度统计分析") + print("="*50) + + print(f"\n样本总数:{len(data)}") + + print(f"\nCDS 长度统计:") + print(f" 最小值:{min(cds_lengths)} bp") + print(f" 最大值:{max(cds_lengths)} bp") + print(f" 平均值:{sum(cds_lengths)/len(cds_lengths):.1f} bp") + print(f" 中位数:{sorted(cds_lengths)[len(cds_lengths)//2]} bp") + + print(f"\n5UTR 长度统计:") + print(f" 最小值:{min(utr_lengths)} bp") + print(f" 最大值:{max(utr_lengths)} bp") + print(f" 平均值:{sum(utr_lengths)/len(utr_lengths):.1f} bp") + print(f" 中位数:{sorted(utr_lengths)[len(utr_lengths)//2]} bp") + + print(f"\n总序列长度 (CDS+5UTR) 统计:") + print(f" 最小值:{min(total_lengths)} bp") + print(f" 最大值:{max(total_lengths)} bp") + print(f" 平均值:{sum(total_lengths)/len(total_lengths):.1f} bp") + print(f" 中位数:{sorted(total_lengths)[len(total_lengths)//2]} bp") + + # 推荐 seq_length + p95_length = sorted(total_lengths)[int(len(total_lengths) * 0.95)] + p99_length = sorted(total_lengths)[int(len(total_lengths) * 0.99)] + + print(f"\n推荐的 seq_length 设置:") + print(f" 覆盖 95% 数据:{p95_length} (向上取整到 2 的幂:{2**((p95_length-1).bit_length())})") + print(f" 覆盖 99% 数据:{p99_length} (向上取整到 2 的幂:{2**((p99_length-1).bit_length())})") + print("="*50 + "\n") + + +def main(): + parser = argparse.ArgumentParser( + description='准备 CDS→5UTR 微调训练数据' + ) + + parser.add_argument('--genome', type=str, required=True, + help='基因组 FASTA 文件路径') + parser.add_argument('--annotation', type=str, required=True, + help='基因组注释 GTF 文件路径') + parser.add_argument('--output', type=str, required=True, + help='输出 FASTA 文件路径') + parser.add_argument('--upstream_length', type=int, default=500, + help='5UTR 最大长度(超过此长度会被截断)') + parser.add_argument('--format', type=str, default='concatenated', + choices=['concatenated', 'separated'], + help='输出格式:concatenated(连接) 或 separated(分开)') + parser.add_argument('--min_cds_length', type=int, default=100, + help='最小 CDS 长度过滤阈值') + parser.add_argument('--min_utr_length', type=int, default=20, + help='最小 5UTR 长度过滤阈值') + + args = parser.parse_args() + + # 解析输入文件 + print(f"[1/4] 解析基因组 FASTA: {args.genome}") + genomes = parse_fasta(args.genome) + print(f" 找到 {len(genomes)} 条基因组序列") + + print(f"[2/4] 解析注释 GTF: {args.annotation}") + features = parse_gtf(args.annotation) + print(f" 找到 {len(features)} 个特征") + + # 提取 CDS 和 5UTR + print(f"[3/4] 提取 CDS 和 5UTR 序列...") + all_data = [] + + for seqid, genome_seq in genomes.items(): + seq_features = [f for f in features if f['seqid'] == seqid] + data = extract_cds_and_5utr(genome_seq, seq_features, + upstream_length=args.upstream_length) + all_data.extend(data) + + # 过滤低质量序列 + filtered_data = [ + (gene_id, cds, utr) for gene_id, cds, utr in all_data + if len(cds) >= args.min_cds_length and len(utr) >= args.min_utr_length + ] + + print(f" 提取 {len(all_data)} 个基因,过滤后保留 {len(filtered_data)} 个") + + # 分析长度分布 + analyze_sequence_lengths(filtered_data) + + # 写入输出文件 + print(f"[4/4] 写入输出文件:{args.output}") + write_training_fasta(filtered_data, args.output, format=args.format) + print(f" 完成!") + + print("\n下一步:") + print(f" 1. 使用 preprocess_evo2 --config configs/cds_5utr_finetune_config.yaml 预处理数据") + print(f" 2. 运行训练:./scripts/run_cds_5utr_finetune.sh lora") + + +if __name__ == '__main__': + main() diff --git a/scripts/prepare_gene_data.py b/scripts/prepare_gene_data.py new file mode 100755 index 0000000..93e4408 --- /dev/null +++ b/scripts/prepare_gene_data.py @@ -0,0 +1,536 @@ +#!/usr/bin/env python3 +""" +Evo2 基因生成数据准备脚本 + +支持任务: +- 5UTR 生成 +- 3UTR 生成 +- 内含子生成(给定外显子) +- 完整基因生成 +- 剪接结构预测 + +使用方法: + python prepare_gene_data.py \ + --genome genome.fasta \ + --annotation annotation.gtf \ + --output gene_training.fasta \ + --tasks gene,intron,5UTR,3UTR +""" + +import argparse +from pathlib import Path +from typing import List, Tuple, Dict +from collections import defaultdict +import random + + +# ============ 特殊 token 定义 ============ +TASK_MARKERS = { + '5UTR': '<5UTR>', + '3UTR': '<3UTR>', + 'intron': '', + 'gene': '', + 'splice': '' +} + +# 分隔符(N 串长度递增,便于区分边界类型) +SEP_5UTR_EXON = 'NNNNN' # 5UTR / 第一个外显子 +SEP_EXON_INTRON = 'NNNNNN' # 外显子 → 内含子(供体位点 GT 侧) +SEP_INTRON_EXON = 'NNNNNNN' # 内含子 → 外显子(受体位点 AG 侧) +SEP_EXON_3UTR = 'NNNNNNNN' # 最后外显子 / 3UTR +SEP_CDS_SPLICE = 'NNNNN' # CDS / 剪接结构 +END_TOKEN = '' + +# 内含子最大长度(超过会被截断) +DEFAULT_MAX_INTRON_LENGTH = 10000 + + +def parse_fasta(fasta_path: str) -> Dict[str, str]: + """解析 FASTA 文件""" + sequences = {} + current_name = None + current_seq = [] + + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('>'): + if current_name: + sequences[current_name] = ''.join(current_seq) + current_name = line[1:].split()[0] + current_seq = [] + else: + current_seq.append(line.upper()) + + if current_name: + sequences[current_name] = ''.join(current_seq) + + return sequences + + +def parse_gtf(gtf_path: str) -> List[Dict]: + """解析 GTF 文件""" + features = [] + + with open(gtf_path, 'r') as f: + for line in f: + if line.startswith('#'): + continue + + parts = line.strip().split('\t') + if len(parts) < 9: + continue + + feature = { + 'seqid': parts[0], + 'source': parts[1], + 'type': parts[2], + 'start': int(parts[3]), + 'end': int(parts[4]), + 'score': parts[5], + 'strand': parts[6], + 'phase': parts[7], + 'attributes': parts[8] + } + + attr_dict = {} + for attr in parts[8].split(';'): + attr = attr.strip() + if ' ' in attr: + key, value = attr.split(' ', 1) + attr_dict[key] = value.strip('"') + + feature['attributes_dict'] = attr_dict + features.append(feature) + + return features + + +def extract_gene_structure(genome_seq: str, features: List[Dict], + max_utr_length: int = 500, + max_intron_length: int = DEFAULT_MAX_INTRON_LENGTH) -> List[Dict]: + """ + 从基因组和注释中提取完整基因结构 + + 返回:[{ + 'gene_id': str, + 'utr5': str, + 'exons': [str, ...], + 'introns': [str, ...], + 'utr3': str, + 'cds_region': str, + 'strand': str + }, ...] + """ + results = [] + + # 按基因分组 + genes = defaultdict(list) + for feature in features: + if feature['type'] in ['gene', 'mRNA', 'transcript']: + gene_id = feature['attributes_dict'].get('gene_id', + feature['attributes_dict'].get('Parent', 'unknown')) + genes[gene_id].append(feature) + + for gene_id, gene_features in genes.items(): + # 提取各类特征 + exons = sorted([f for f in gene_features if f['type'] == 'exon'], + key=lambda x: x['start']) + cds_features = sorted([f for f in gene_features if f['type'] == 'CDS'], + key=lambda x: x['start']) + + if not exons: + continue + + # 使用第一个转录本 + mrna = [f for f in gene_features if f['type'] in ['mRNA', 'transcript']] + if mrna: + mrna = mrna[0] + strand = mrna['strand'] + mrna_start = mrna['start'] + mrna_end = mrna['end'] + else: + # 如果没有 mRNA,使用外显子的范围 + strand = exons[0]['strand'] + mrna_start = min(e['start'] for e in exons) + mrna_end = max(e['end'] for e in exons) + + # 提取转录本序列 + if strand == '+': + mrna_seq = genome_seq[mrna_start-1:mrna_end] + else: + mrna_seq = reverse_complement(genome_seq[mrna_start-1:mrna_end]) + + # 提取外显子序列 + exon_seqs = [] + exon_regions = [] # (start, end) 相对于 mRNA + + for exon in exons: + if strand == '+': + exon_start = exon['start'] - mrna_start + exon_end = exon['end'] - mrna_start + else: + exon_start = mrna_end - exon['end'] + exon_end = mrna_end - exon['start'] + + exon_seqs.append(mrna_seq[exon_start:exon_end]) + exon_regions.append((exon_start, exon_end)) + + # 提取内含子序列(外显子之间) + intron_seqs = [] + for i in range(len(exon_seqs) - 1): + intron_start = exon_regions[i][1] + intron_end = exon_regions[i+1][0] + intron_seq = mrna_seq[intron_start:intron_end] + + # 截断过长的内含子 + if len(intron_seq) > max_intron_length: + # 保留剪接位点区域 + intron_seq = intron_seq[:50] + 'N' * 10 + intron_seq[-50:] + + intron_seqs.append(intron_seq) + + # 提取 CDS 区域 + if cds_features: + cds_start = cds_features[0]['start'] + cds_end = cds_features[-1]['end'] + + if strand == '+': + cds_start_rel = cds_start - mrna_start + cds_end_rel = cds_end - mrna_start + cds_seq = mrna_seq[cds_start_rel:cds_end_rel] + else: + cds_start_rel = mrna_end - cds_end + cds_end_rel = mrna_end - cds_start + cds_seq = reverse_complement(mrna_seq[cds_start_rel:cds_end_rel]) + else: + cds_seq = '' + + # 提取 5UTR(第一个外显子中 CDS 之前的部分) + if strand == '+' and cds_features: + utr5 = mrna_seq[:exon_regions[0][0]] + elif strand == '-' and cds_features: + utr5 = mrna_seq[exon_regions[-1][1]:] + else: + utr5 = '' + + # 截断过长的 5UTR + if len(utr5) > max_utr_length: + utr5 = utr5[-max_utr_length:] if strand == '+' else utr5[:max_utr_length] + + # 提取 3UTR(最后一个外显子中 CDS 之后的部分) + if strand == '+' and cds_features: + utr3 = mrna_seq[exon_regions[-1][1]:] + elif strand == '-' and cds_features: + utr3 = mrna_seq[:exon_regions[0][0]] + else: + utr3 = '' + + # 截断过长的 3UTR + if len(utr3) > max_utr_length: + utr3 = utr3[:max_utr_length] if strand == '+' else utr3[-max_utr_length:] + + # 只保留有外显子的基因 + if len(exon_seqs) >= 1: + results.append({ + 'gene_id': gene_id, + 'utr5': utr5, + 'exons': exon_seqs, + 'introns': intron_seqs, + 'utr3': utr3, + 'cds_region': cds_seq, + 'strand': strand + }) + + return results + + +def reverse_complement(seq: str) -> str: + """计算反向互补序列""" + complement = {'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G', + 'a': 't', 't': 'a', 'g': 'c', 'c': 'g'} + return ''.join(complement.get(base, base) for base in reversed(seq)) + + +def build_training_sample(gene_data: Dict, task: str) -> str: + """ + 构建单个训练样本 + + gene_data: {'gene_id', 'utr5', 'exons', 'introns', 'utr3', 'cds_region'} + task: '5UTR', '3UTR', 'intron', 'gene', 'splice' + """ + gene_id = gene_data['gene_id'] + utr5 = gene_data['utr5'] + exons = gene_data['exons'] + introns = gene_data['introns'] + utr3 = gene_data['utr3'] + cds = gene_data['cds_region'] + + if task == '5UTR': + if not utr5 or not exons: + return None + # 格式:<5UTR>5UTR[SEP1]Exon1... + first_exon = exons[0] + sample = f"{TASK_MARKERS['5UTR']}{utr5}{SEP_5UTR_EXON}{first_exon}{END_TOKEN}" + + elif task == '3UTR': + if not utr3 or not exons: + return None + # 格式:<3UTR>...ExonN[SEP3]3UTR + last_exon = exons[-1] + sample = f"{TASK_MARKERS['3UTR']}{last_exon}{SEP_EXON_3UTR}{utr3}{END_TOKEN}" + + elif task == 'intron': + if len(exons) < 2 or len(introns) < 1: + return None + # 格式:Exon1[SEP2]Intron1[SEP3]Exon2... + parts = [TASK_MARKERS['intron']] + for i, (exon, intron) in enumerate(zip(exons, introns)): + if i == 0: + parts.append(exon) + parts.extend([SEP_EXON_INTRON, intron, SEP_INTRON_EXON, exon]) + parts.append(END_TOKEN) + sample = ''.join(parts) + + elif task == 'gene': + if not exons: + return None + # 格式:5UTR[SEP1]Exon1[SEP2]Intron1[SEP3]Exon2...[SEP4]3UTR + parts = [TASK_MARKERS['gene'], utr5, SEP_5UTR_EXON] + + for i, exon in enumerate(exons): + parts.append(exon) + if i < len(introns): + parts.extend([SEP_EXON_INTRON, introns[i], SEP_INTRON_EXON]) + + parts.extend([SEP_EXON_3UTR, utr3, END_TOKEN]) + sample = ''.join(parts) + + elif task == 'splice': + if not cds or len(exons) < 2: + return None + # 格式:CDS[SEP]Exon1[SEP2]Intron1[SEP3]Exon2... + parts = [TASK_MARKERS['splice'], cds, SEP_CDS_SPLICE] + + for i, exon in enumerate(exons): + parts.append(exon) + if i < len(introns): + parts.extend([SEP_EXON_INTRON, introns[i], SEP_INTRON_EXON]) + + parts.append(END_TOKEN) + sample = ''.join(parts) + + else: + return None + + # 计算长度统计 + total_len = len(sample) + utr5_len = len(utr5) + utr3_len = len(utr3) + exon_len = sum(len(e) for e in exons) + intron_len = sum(len(i) for i in introns) + + return f">{gene_id}|task:{task}|total:{total_len}|utr5:{utr5_len}|exon:{exon_len}|intron:{intron_len}|utr3:{utr3_len}\n{sample}\n" + + +def write_training_data(genes: List[Dict], output_path: str, + tasks: List[str], + sample_ratio: Dict[str, float] = None, + max_genes: int = None): + """ + 写入训练数据 + + tasks: 任务列表 ['5UTR', '3UTR', 'intron', 'gene', 'splice'] + sample_ratio: 每种任务的比例 + max_genes: 最多处理的基因数(用于快速测试) + """ + if sample_ratio is None: + sample_ratio = {task: 1.0 / len(tasks) for task in tasks} + + if max_genes: + genes = genes[:max_genes] + + written = 0 + with open(output_path, 'w') as f: + for gene_data in genes: + for task in tasks: + if random.random() < sample_ratio.get(task, 1.0): + sample = build_training_sample(gene_data, task) + if sample: + f.write(sample) + written += 1 + + print(f"写入 {written} 个训练样本到:{output_path}") + + +def analyze_gene_structure(genes: List[Dict]) -> None: + """分析基因结构统计""" + exon_counts = [len(g['exons']) for g in genes] + intron_counts = [len(g['introns']) for g in genes] + + exon_lengths = [len(e) for g in genes for e in g['exons']] + intron_lengths = [len(i) for g in genes for i in g['introns']] + utr5_lengths = [len(g['utr5']) for g in genes] + utr3_lengths = [len(g['utr3']) for g in genes] + + print("\n" + "="*60) + print("基因结构统计分析") + print("="*60) + + print(f"\n基因总数:{len(genes)}") + + print(f"\n外显子数量统计:") + print(f" 平均值:{sum(exon_counts)/len(exon_counts):.1f}") + print(f" 中位数:{sorted(exon_counts)[len(exon_counts)//2]}") + print(f" 最多:{max(exon_counts)}") + + print(f"\n内含子数量统计:") + print(f" 平均值:{sum(intron_counts)/len(intron_counts):.1f}") + print(f" 中位数:{sorted(intron_counts)[len(intron_counts)//2]}") + print(f" 最多:{max(intron_counts)}") + + print(f"\n外显子长度统计:") + print(f" 平均值:{sum(exon_lengths)/len(exon_lengths):.1f} bp") + print(f" 中位数:{sorted(exon_lengths)[len(exon_lengths)//2]} bp") + print(f" 最长:{max(exon_lengths)} bp") + + print(f"\n内含子长度统计:") + print(f" 平均值:{sum(intron_lengths)/len(intron_lengths):.1f} bp") + print(f" 中位数:{sorted(intron_lengths)[len(intron_lengths)//2]} bp") + print(f" 最长:{max(intron_lengths)} bp") + + print(f"\nUTR 长度统计:") + print(f" 5UTR 平均:{sum(utr5_lengths)/len(utr5_lengths):.1f} bp") + print(f" 3UTR 平均:{sum(utr3_lengths)/len(utr3_lengths):.1f} bp") + + # 总长度估算(不同任务) + gene_lengths = [] + for g in genes: + # 完整基因长度 + total = len(g['utr5']) + sum(len(e) for e in g['exons']) + \ + sum(len(i) for i in g['introns']) + len(g['utr3']) + \ + len(SEP_5UTR_EXON) + len(SEP_EXON_3UTR) + \ + len(g['introns']) * (len(SEP_EXON_INTRON) + len(SEP_INTRON_EXON)) + gene_lengths.append(total) + + print(f"\n完整基因总长度(含分隔符):") + print(f" 平均值:{sum(gene_lengths)/len(gene_lengths):.1f} bp") + print(f" 最长:{max(gene_lengths)} bp") + + p95 = sorted(gene_lengths)[int(len(gene_lengths) * 0.95)] + p99 = sorted(gene_lengths)[int(len(gene_lengths) * 0.99)] + + print(f"\n推荐的 seq_length 设置:") + print(f" 覆盖 95% 基因:{p95} bp") + print(f" 覆盖 99% 基因:{p99} bp") + print(f" (使用 Evo2 1M context 版本可处理更长基因)") + + print(f"\n特殊 token 定义:") + print(f" 5UTR 任务:{TASK_MARKERS['5UTR']}") + print(f" 3UTR 任务:{TASK_MARKERS['3UTR']}") + print(f" 内含子任务:{TASK_MARKERS['intron']}") + print(f" 完整基因:{TASK_MARKERS['gene']}") + print(f" 剪接预测:{TASK_MARKERS['splice']}") + print(f" 5UTR/Exon 分隔符:{SEP_5UTR_EXON} ({len(SEP_5UTR_EXON)} bp)") + print(f" Exon/Intron 分隔符:{SEP_EXON_INTRON} ({len(SEP_EXON_INTRON)} bp)") + print(f" Intron/Exon 分隔符:{SEP_INTRON_EXON} ({len(SEP_INTRON_EXON)} bp)") + print(f" Exon/3UTR 分隔符:{SEP_EXON_3UTR} ({len(SEP_EXON_3UTR)} bp)") + print("="*60 + "\n") + + +def main(): + parser = argparse.ArgumentParser( + description='准备基因生成训练数据(支持 UTR、内含子、完整基因)' + ) + + parser.add_argument('--genome', type=str, required=True, + help='基因组 FASTA 文件') + parser.add_argument('--annotation', type=str, required=True, + help='基因组注释 GTF 文件') + parser.add_argument('--output', type=str, required=True, + help='输出 FASTA 文件') + parser.add_argument('--tasks', type=str, default='gene,intron,5UTR,3UTR', + help='任务类型,逗号分隔:gene,intron,5UTR,3UTR,splice') + parser.add_argument('--max_utr_length', type=int, default=500, + help='UTR 最大长度') + parser.add_argument('--max_intron_length', type=int, default=DEFAULT_MAX_INTRON_LENGTH, + help='内含子最大长度(超过会被截断)') + parser.add_argument('--task_ratio', type=str, default=None, + help='任务比例,如 "gene:0.3,intron:0.3,5UTR:0.2,3UTR:0.2"') + parser.add_argument('--max_genes', type=int, default=None, + help='最多处理的基因数(用于快速测试)') + parser.add_argument('--seed', type=int, default=42, + help='随机种子') + + args = parser.parse_args() + + random.seed(args.seed) + + # 解析任务 + tasks = [t.strip() for t in args.tasks.split(',')] + valid_tasks = ['5UTR', '3UTR', 'intron', 'gene', 'splice'] + for task in tasks: + if task not in valid_tasks: + print(f"错误:无效任务 '{task}',有效值:{valid_tasks}") + exit(1) + + # 解析任务比例 + task_ratio = None + if args.task_ratio: + task_ratio = {} + for item in args.task_ratio.split(','): + if ':' in item: + t, r = item.split(':') + task_ratio[t.strip()] = float(r.strip()) + + # 解析输入 + print(f"[1/4] 解析基因组 FASTA: {args.genome}") + genomes = parse_fasta(args.genome) + print(f" 找到 {len(genomes)} 条染色体/contig") + + print(f"[2/4] 解析注释 GTF: {args.annotation}") + features = parse_gtf(args.annotation) + print(f" 找到 {len(features)} 个特征") + + # 提取基因结构 + print(f"[3/4] 提取基因结构(内含子最大长度:{args.max_intron_length} bp)...") + genes = extract_gene_structure( + genomes[list(genomes.keys())[0]] if len(genomes) == 1 else '', + features, + max_utr_length=args.max_utr_length, + max_intron_length=args.max_intron_length + ) + + # 如果是多条染色体,需要合并 + if len(genomes) > 1: + all_genes = [] + for seqid, genome_seq in genomes.items(): + seq_features = [f for f in features if f['seqid'] == seqid] + genes = extract_gene_structure(genome_seq, seq_features, + max_utr_length=args.max_utr_length, + max_intron_length=args.max_intron_length) + all_genes.extend(genes) + genes = all_genes + + print(f" 提取 {len(genes)} 个基因结构") + + # 分析统计 + analyze_gene_structure(genes) + + # 写入训练数据 + print(f"[4/4] 写入训练数据(任务:{tasks})...") + write_training_data( + genes, args.output, + tasks=tasks, + sample_ratio=task_ratio, + max_genes=args.max_genes + ) + + print(f"\n完成!") + print("\n下一步:") + print(f" 使用 preprocess_evo2 预处理数据") + print(f" 运行训练:./scripts/run_gene_finetune.sh lora") + + +if __name__ == '__main__': + main() diff --git a/scripts/prepare_utr_data.py b/scripts/prepare_utr_data.py new file mode 100755 index 0000000..afaf96f --- /dev/null +++ b/scripts/prepare_utr_data.py @@ -0,0 +1,382 @@ +#!/usr/bin/env python3 +""" +Evo2 UTR 生成数据准备脚本 + +支持: +- 5UTR 生成(CDS 作为条件) +- 3UTR 生成(CDS 作为条件) +- 同时生成 5UTR+3UTR + +使用任务标记格式,灵活控制生成类型。 + +使用方法: + python prepare_utr_data.py \ + --genome genome.fasta \ + --annotation annotation.gtf \ + --output utr_training.fasta \ + --tasks 5UTR,3UTR,both +""" + +import argparse +from pathlib import Path +from typing import List, Tuple, Dict +from collections import defaultdict +import random + + +# ============ 特殊 token 定义 ============ +TASK_MARKERS = { + '5UTR': '<5UTR>', + '3UTR': '<3UTR>', + 'both': '' +} + +SEP_5UTR = 'NNNNN' # 5UTR 分隔符(5 个 N) +SEP_3UTR = 'NNNNNNN' # 3UTR 分隔符(7 个 N,区分 5UTR) +END_TOKEN = '' + + +def parse_fasta(fasta_path: str) -> Dict[str, str]: + """解析 FASTA 文件""" + sequences = {} + current_name = None + current_seq = [] + + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('>'): + if current_name: + sequences[current_name] = ''.join(current_seq) + current_name = line[1:].split()[0] + current_seq = [] + else: + current_seq.append(line.upper()) + + if current_name: + sequences[current_name] = ''.join(current_seq) + + return sequences + + +def parse_gtf(gtf_path: str) -> List[Dict]: + """解析 GTF 文件""" + features = [] + + with open(gtf_path, 'r') as f: + for line in f: + if line.startswith('#'): + continue + + parts = line.strip().split('\t') + if len(parts) < 9: + continue + + feature = { + 'seqid': parts[0], + 'source': parts[1], + 'type': parts[2], + 'start': int(parts[3]), + 'end': int(parts[4]), + 'score': parts[5], + 'strand': parts[6], + 'phase': parts[7], + 'attributes': parts[8] + } + + attr_dict = {} + for attr in parts[8].split(';'): + attr = attr.strip() + if ' ' in attr: + key, value = attr.split(' ', 1) + attr_dict[key] = value.strip('"') + + feature['attributes_dict'] = attr_dict + features.append(feature) + + return features + + +def extract_cds_and_utrs(genome_seq: str, features: List[Dict], + utr_length: int = 500) -> List[Tuple[str, str, str, str]]: + """ + 从基因组序列和注释中提取 CDS、5UTR 和 3UTR 序列 + + 返回:[(gene_name, 5UTR 序列,CDS 序列,3UTR 序列), ...] + """ + results = [] + + genes = defaultdict(list) + for feature in features: + if feature['type'] in ['gene', 'mRNA', 'transcript']: + gene_id = feature['attributes_dict'].get('gene_id', + feature['attributes_dict'].get('Parent', 'unknown')) + genes[gene_id].append(feature) + + for gene_id, gene_features in genes.items(): + cds_features = [f for f in gene_features if f['type'] == 'CDS'] + mrna_features = [f for f in gene_features if f['type'] in ['mRNA', 'transcript']] + + if not cds_features or not mrna_features: + continue + + mrna = mrna_features[0] + mrna_start = mrna['start'] + mrna_end = mrna['end'] + strand = mrna['strand'] + + # 提取转录本序列 + if strand == '+': + mrna_seq = genome_seq[mrna_start-1:mrna_end] + else: + mrna_seq = reverse_complement(genome_seq[mrna_start-1:mrna_end]) + + # 提取 CDS 区域 + cds_regions = [] + for cds in cds_features: + if strand == '+': + cds_start = cds['start'] - mrna_start + cds_end = cds['end'] - mrna_start + else: + cds_start = mrna_end - cds['end'] + cds_end = mrna_end - cds['start'] + cds_regions.append((cds_start, cds_end)) + + cds_regions.sort() + + # 拼接 CDS 序列 + cds_seq = '' + for start, end in cds_regions: + cds_seq += mrna_seq[start:end] + + # 提取 5UTR(CDS 起始之前) + if strand == '+': + five_utr_seq = mrna_seq[:cds_regions[0][0]] + else: + five_utr_seq = mrna_seq[cds_regions[-1][1]:] + + # 提取 3UTR(CDS 结束之后) + if strand == '+': + three_utr_seq = mrna_seq[cds_regions[-1][1]:] + else: + three_utr_seq = mrna_seq[:cds_regions[0][0]] + + # 截断过长的 UTR + if len(five_utr_seq) > utr_length: + five_utr_seq = five_utr_seq[-utr_length:] if strand == '+' else five_utr_seq[:utr_length] + + if len(three_utr_seq) > utr_length: + three_utr_seq = three_utr_seq[:utr_length] if strand == '+' else three_utr_seq[-utr_length:] + + results.append((gene_id, five_utr_seq, cds_seq, three_utr_seq)) + + return results + + +def reverse_complement(seq: str) -> str: + """计算反向互补序列""" + complement = {'A': 'T', 'T': 'A', 'G': 'C', 'C': 'G', + 'a': 't', 't': 'a', 'g': 'c', 'c': 'g'} + return ''.join(complement.get(base, base) for base in reversed(seq)) + + +def build_training_sample(gene_id: str, utr5: str, cds: str, utr3: str, + task_type: str) -> str: + """ + 构建单个训练样本 + + task_type: '5UTR', '3UTR', 'both' + """ + if task_type == '5UTR': + # 格式:<5UTR>5UTR 序列 [SEP]CDS 序列 + # 训练时模型看到 CDS,学习生成 5UTR + if not utr5 or not cds: + return None + sample = f"{TASK_MARKERS['5UTR']}{utr5}{SEP_5UTR}{cds}{END_TOKEN}" + + elif task_type == '3UTR': + # 格式:<3UTR>CDS 序列 [SEP]3UTR 序列 + # 训练时模型看到 CDS,学习生成 3UTR + if not cds or not utr3: + return None + sample = f"{TASK_MARKERS['3UTR']}{cds}{SEP_3UTR}{utr3}{END_TOKEN}" + + elif task_type == 'both': + # 格式:5UTR 序列 [SEP5]CDS 序列 [SEP3]3UTR 序列 + if not utr5 or not cds or not utr3: + return None + sample = f"{TASK_MARKERS['both']}{utr5}{SEP_5UTR}{cds}{SEP_3UTR}{utr3}{END_TOKEN}" + + else: + return None + + return f">{gene_id}|task:{task_type}|utr5_len:{len(utr5)}|cds_len:{len(cds)}|utr3_len:{len(utr3)}\n{sample}\n" + + +def write_training_data(data: List[Tuple[str, str, str, str]], + output_path: str, + tasks: List[str] = ['5UTR', '3UTR'], + sample_ratio: Dict[str, float] = None): + """ + 写入训练数据 + + tasks: 任务类型列表 ['5UTR', '3UTR', 'both'] + sample_ratio: 每种任务的比例,如 {'5UTR': 0.4, '3UTR': 0.4, 'both': 0.2} + """ + if sample_ratio is None: + # 默认均匀分布 + sample_ratio = {task: 1.0 / len(tasks) for task in tasks} + + with open(output_path, 'w') as f: + for gene_id, utr5, cds, utr3 in data: + # 根据比例决定生成哪种任务的样本 + for task in tasks: + if random.random() < sample_ratio.get(task, 1.0): + sample = build_training_sample(gene_id, utr5, cds, utr3, task) + if sample: + f.write(sample) + + print(f"写入完成:{output_path}") + + +def analyze_sequences(data: List[Tuple[str, str, str, str]]) -> None: + """分析序列长度分布""" + utr5_lengths = [len(u5) for _, u5, _, _ in data] + cds_lengths = [len(c) for _, _, c, _ in data] + utr3_lengths = [len(u3) for _, _, _, u3 in data] + + print("\n" + "="*60) + print("序列长度统计分析") + print("="*60) + + print(f"\n样本总数:{len(data)}") + + print(f"\n5UTR 长度统计:") + print(f" 最小值:{min(utr5_lengths)} bp") + print(f" 最大值:{max(utr5_lengths)} bp") + print(f" 平均值:{sum(utr5_lengths)/len(utr5_lengths):.1f} bp") + print(f" 中位数:{sorted(utr5_lengths)[len(utr5_lengths)//2]} bp") + + print(f"\nCDS 长度统计:") + print(f" 最小值:{min(cds_lengths)} bp") + print(f" 最大值:{max(cds_lengths)} bp") + print(f" 平均值:{sum(cds_lengths)/len(cds_lengths):.1f} bp") + print(f" 中位数:{sorted(cds_lengths)[len(cds_lengths)//2]} bp") + + print(f"\n3UTR 长度统计:") + print(f" 最小值:{min(utr3_lengths)} bp") + print(f" 最大值:{max(utr3_lengths)} bp") + print(f" 平均值:{sum(utr3_lengths)/len(utr3_lengths):.1f} bp") + print(f" 中位数:{sorted(utr3_lengths)[len(utr3_lengths)//2]} bp") + + # 不同任务的总长度 + print(f"\n训练样本总长度(包含特殊 token):") + task_5utr_len = [len(u5) + len(SEP_5UTR) + len(cds) + len(END_TOKEN) + len(TASK_MARKERS['5UTR']) + for _, u5, _, cds, _ in [(g, u5, c, u3, None) for g, u5, c, u3 in data]] + task_3utr_len = [len(cds) + len(SEP_3UTR) + len(u3) + len(END_TOKEN) + len(TASK_MARKERS['3UTR']) + for _, _, cds, u3, _ in [(g, u5, c, u3, None) for g, u5, c, u3 in data]] + task_both_len = [len(u5) + len(SEP_5UTR) + len(cds) + len(SEP_3UTR) + len(u3) + len(END_TOKEN) + len(TASK_MARKERS['both']) + for _, u5, cds, u3, _ in [(g, u5, c, u3, None) for g, u5, c, u3 in data]] + + all_lengths = task_5utr_len + task_3utr_len + task_both_len + + print(f" 最小值:{min(all_lengths)} bp") + print(f" 最大值:{max(all_lengths)} bp") + print(f" 平均值:{sum(all_lengths)/len(all_lengths):.1f} bp") + + p95_length = sorted(all_lengths)[int(len(all_lengths) * 0.95)] + p99_length = sorted(all_lengths)[int(len(all_lengths) * 0.99)] + + print(f"\n推荐的 seq_length 设置:") + print(f" 覆盖 95% 数据:{p95_length} (2 的幂:{2**((p95_length-1).bit_length())})") + print(f" 覆盖 99% 数据:{p99_length} (2 的幂:{2**((p99_length-1).bit_length())})") + print("="*60) + + print(f"\n特殊 token 定义:") + print(f" 5UTR 任务标记:{TASK_MARKERS['5UTR']}") + print(f" 3UTR 任务标记:{TASK_MARKERS['3UTR']}") + print(f" both 任务标记:{TASK_MARKERS['both']}") + print(f" 5UTR 分隔符:{SEP_5UTR} ({len(SEP_5UTR)} bp)") + print(f" 3UTR 分隔符:{SEP_3UTR} ({len(SEP_3UTR)} bp)") + print(f" 结束标记:{END_TOKEN}") + print("="*60 + "\n") + + +def main(): + parser = argparse.ArgumentParser( + description='准备 UTR 生成训练数据(支持 5UTR、3UTR、both 任务)' + ) + + parser.add_argument('--genome', type=str, required=True, + help='基因组 FASTA 文件路径') + parser.add_argument('--annotation', type=str, required=True, + help='基因组注释 GTF 文件路径') + parser.add_argument('--output', type=str, required=True, + help='输出 FASTA 文件路径') + parser.add_argument('--utr_length', type=int, default=500, + help='UTR 最大长度(超过此长度会被截断)') + parser.add_argument('--tasks', type=str, default='5UTR,3UTR,both', + help='任务类型,逗号分隔:5UTR,3UTR,both') + parser.add_argument('--min_cds_length', type=int, default=100, + help='最小 CDS 长度过滤阈值') + parser.add_argument('--min_utr_length', type=int, default=20, + help='最小 UTR 长度过滤阈值') + parser.add_argument('--seed', type=int, default=42, + help='随机种子') + + args = parser.parse_args() + + random.seed(args.seed) + + # 解析任务类型 + tasks = [t.strip() for t in args.tasks.split(',')] + valid_tasks = ['5UTR', '3UTR', 'both'] + for task in tasks: + if task not in valid_tasks: + print(f"错误:无效任务类型 '{task}',有效值:{valid_tasks}") + exit(1) + + # 解析输入文件 + print(f"[1/4] 解析基因组 FASTA: {args.genome}") + genomes = parse_fasta(args.genome) + print(f" 找到 {len(genomes)} 条基因组序列") + + print(f"[2/4] 解析注释 GTF: {args.annotation}") + features = parse_gtf(args.annotation) + print(f" 找到 {len(features)} 个特征") + + # 提取 CDS 和 UTR + print(f"[3/4] 提取 CDS、5UTR、3UTR 序列...") + all_data = [] + + for seqid, genome_seq in genomes.items(): + seq_features = [f for f in features if f['seqid'] == seqid] + data = extract_cds_and_utrs(genome_seq, seq_features, + utr_length=args.utr_length) + all_data.extend(data) + + # 过滤低质量序列 + filtered_data = [ + (gene_id, utr5, cds, utr3) + for gene_id, utr5, cds, utr3 in all_data + if len(cds) >= args.min_cds_length + and (len(utr5) >= args.min_utr_length or len(utr3) >= args.min_utr_length) + ] + + print(f" 提取 {len(all_data)} 个基因,过滤后保留 {len(filtered_data)} 个") + + # 分析长度分布 + analyze_sequences(filtered_data) + + # 写入训练数据 + print(f"[4/4] 写入训练数据(任务类型:{tasks})...") + write_training_data(filtered_data, args.output, tasks=tasks) + + print(f"\n完成!输出文件:{args.output}") + print("\n下一步:") + print(f" 1. 使用 preprocess_evo2 --config configs/utr_finetune_config.yaml 预处理数据") + print(f" 2. 运行训练:./scripts/run_utr_finetune.sh lora") + + +if __name__ == '__main__': + main() diff --git a/scripts/process_gse278584.py b/scripts/process_gse278584.py new file mode 100755 index 0000000..df67925 --- /dev/null +++ b/scripts/process_gse278584.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python3 +""" +下载并处理 GSE278584 数据集用于内含子生成训练 + +GSE278584: 包含数万个人工设计的内含子序列及其功能数据 + +使用方法: + python download_gse278584.py \ + --output gse278584_processed.fasta \ + --format training +""" + +import argparse +import csv +import gzip +import os +from pathlib import Path +from typing import List, Tuple, Dict +import random + + +# 特殊 token(与 prepare_gene_data.py 一致) +TASK_MARKERS = { + 'intron': '', + 'function_high': '', + 'function_low': '' +} + +SEP_EXON_INTRON = 'NNNNNN' +SEP_INTRON_EXON = 'NNNNNNN' +END_TOKEN = '' + +# 模拟外显子序列(用于训练数据构建) +MOCK_EXON_5 = "ATGGCTAGCTACGGTACGGATCCGCTAGCATCGATCGATCGATCGTAGCTAGCTAG" +MOCK_EXON_3 = "GGATCCGGATCCGGATTAGCTAGCTAGCTAGCTAGCTAGCATCGATCGATCGTAA" + + +def download_supplementary_files(geo_accession: str, output_dir: str): + """ + 下载 GEO 补充文件 + + GSE278584 的补充文件包含内含子序列和功能数据 + """ + import requests + + base_url = "https://ftp.ncbi.nlm.nih.gov/geo/series/" + + # GSE278584 的补充文件 + # 需要从 GEO 页面获取实际的文件链接 + supplementary_files = [ + # 示例(需要从实际页面获取) + # f"{base_url}GSE278nnn/GSE278584/suppl/file1.tsv" + ] + + print(f"请访问 https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE278584") + print(f"下载补充文件(内含子序列库和功能数据)") + print(f"\n通常包含:") + print(f" - 内含子序列库 (FASTA 或 CSV)") + print(f" - 功能测量数据 (TSV/XLSX)") + print(f" - 读取计数和剪接效率数据") + + +def parse_intron_library(library_file: str) -> List[Dict]: + """ + 解析内含子文库文件 + + 返回:[{ + 'intron_id': str, + 'sequence': str, # 160nt 随机序列 + 'splice_donor': str, # 5' 剪接位点 + 'splice_acceptor': str, # 3' 剪接位点 + 'full_sequence': str # 完整的剪接位点 + 内含子 + }, ...] + """ + introns = [] + + # 根据实际文件格式解析 + # 可能是 FASTA、CSV 或 TSV + + if library_file.endswith('.fasta') or library_file.endswith('.fa'): + introns = parse_fasta(library_file) + elif library_file.endswith('.csv') or library_file.endswith('.tsv'): + introns = parse_tabular(library_file) + + return introns + + +def parse_fasta(fasta_path: str) -> List[Dict]: + """解析 FASTA 格式的内含子库""" + introns = [] + current_id = None + current_seq = [] + + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('>'): + if current_id: + introns.append({ + 'intron_id': current_id, + 'sequence': ''.join(current_seq) + }) + current_id = line[1:].split()[0] + current_seq = [] + else: + current_seq.append(line.upper()) + + if current_id: + introns.append({ + 'intron_id': current_id, + 'sequence': ''.join(current_seq) + }) + + return introns + + +def parse_tabular(tabular_path: str) -> List[Dict]: + """解析表格格式的内含子库""" + introns = [] + + open_func = gzip.open if tabular_path.endswith('.gz') else open + + with open_func(tabular_path, 'rt') as f: + reader = csv.DictReader(f, delimiter='\t' if tabular_path.endswith('.tsv') else ',') + for row in reader: + # 根据实际列名调整 + intron_id = row.get('intron_id', row.get('id', '')) + sequence = row.get('sequence', row.get('intron_seq', '')) + + if sequence: + introns.append({ + 'intron_id': intron_id, + 'sequence': sequence.upper() + }) + + return introns + + +def parse_function_data(function_file: str) -> Dict[str, Dict]: + """ + 解析功能数据文件 + + 返回:{intron_id: {'ime_score': float, 'splicing_efficiency': float, ...}} + """ + function_data = {} + + open_func = gzip.open if function_file.endswith('.gz') else open + + with open_func(function_file, 'rt') as f: + reader = csv.DictReader(f, delimiter='\t') + for row in reader: + intron_id = row.get('intron_id', row.get('id', '')) + + # 提取功能指标 + function_data[intron_id] = { + 'ime_score': float(row.get('ime_score', row.get('expression_ratio', 0))), + 'splicing_efficiency': float(row.get('splicing_efficiency', + row.get('percent_spliced', 0))), + 'poly_u_content': float(row.get('poly_u_content', 0)), + } + + return function_data + + +def build_training_sample(intron_data: Dict, + function_data: Dict = None, + task_type: str = 'intron') -> str: + """ + 构建训练样本 + + task_type: + - 'intron': 基础内含子生成 + - 'intron_functional': 功能导向的内含子生成 + """ + intron_id = intron_data['intron_id'] + sequence = intron_data['sequence'] + + # 构建完整内含子序列(包含剪接位点) + # 典型结构:外显子 + GT + 内含子 + AG + 外显子 + splice_donor = 'GT' # 5' 剪接位点 + splice_acceptor = 'AG' # 3' 剪接位点 + + full_intron = splice_donor + sequence + splice_acceptor + + if task_type == 'intron': + # 基础格式:Exon5[SEP]Intron[SEP]Exon3 + sample = f"{TASK_MARKERS['intron']}{MOCK_EXON_5}{SEP_EXON_INTRON}{full_intron}{SEP_INTRON_EXON}{MOCK_EXON_3}{END_TOKEN}" + + elif task_type == 'intron_functional' and function_data: + # 根据功能评分选择标记 + ime_score = function_data.get('ime_score', 0) + if ime_score > 1.5: # 高增强活性 + marker = TASK_MARKERS['function_high'] + elif ime_score < 0.8: # 低增强活性 + marker = TASK_MARKERS['function_low'] + else: + marker = TASK_MARKERS['intron'] + + sample = f"{marker}{MOCK_EXON_5}{SEP_EXON_INTRON}{full_intron}{SEP_INTRON_EXON}{MOCK_EXON_3}{END_TOKEN}" + + # 添加元数据 + metadata = f">{intron_id}|task:{task_type}|length:{len(sequence)}" + if function_data: + metadata += f"|ime:{function_data.get('ime_score', 0):.2f}" + + return f"{metadata}\n{sample}\n" + + +def process_gse278584(input_dir: str, output_path: str, + task_types: List[str] = ['intron'], + min_ime_score: float = None, + max_samples: int = None): + """ + 处理 GSE278584 数据生成训练文件 + + Args: + input_dir: 包含下载文件的目录 + output_path: 输出训练 FASTA 文件 + task_types: 任务类型列表 + min_ime_score: 最小 IME 评分过滤 + max_samples: 最大样本数 + """ + print(f"[1/4] 解析内含子文库...") + + # 查找文库文件 + library_files = list(Path(input_dir).glob('*library*.fasta')) + \ + list(Path(input_dir).glob('*library*.csv')) + \ + list(Path(input_dir).glob('*library*.tsv')) + + if not library_files: + print(f"错误:在 {input_dir} 中未找到文库文件") + print(f"请先从 GEO 下载补充文件") + return + + all_introns = [] + for lib_file in library_files: + print(f" 解析:{lib_file.name}") + introns = parse_intron_library(str(lib_file)) + all_introns.extend(introns) + + print(f" 总计:{len(all_introns)} 个内含子") + + # 加载功能数据(如果有) + function_data = {} + function_files = list(Path(input_dir).glob('*function*.tsv')) + \ + list(Path(input_dir).glob('*expression*.tsv')) + + if function_files: + print(f"\n[2/4] 解析功能数据...") + for func_file in function_files: + print(f" 解析:{func_file.name}") + func_data = parse_function_data(str(func_file)) + function_data.update(func_data) + print(f" 功能数据:{len(function_data)} 个内含子") + + # 过滤和采样 + print(f"\n[3/4] 过滤和采样...") + + filtered_introns = [] + for intron in all_introns: + intron_id = intron['intron_id'] + + # 应用 IME 评分过滤 + if min_ime_score and intron_id in function_data: + if function_data[intron_id]['ime_score'] < min_ime_score: + continue + + filtered_introns.append(intron) + + if max_samples and len(filtered_introns) > max_samples: + random.seed(42) + filtered_introns = random.sample(filtered_introns, max_samples) + + print(f" 过滤后:{len(filtered_introns)} 个内含子") + + # 生成训练样本 + print(f"\n[4/4] 生成训练样本...") + + samples_written = 0 + with open(output_path, 'w') as f: + for intron in filtered_introns: + for task in task_types: + func_dat = function_data.get(intron['intron_id']) if 'function' in task else None + sample = build_training_sample(intron, func_dat, task) + if sample: + f.write(sample + '\n') + samples_written += 1 + + print(f"\n完成!写入 {samples_written} 个训练样本到:{output_path}") + + # 打印统计 + print(f"\n数据统计:") + print(f" 内含子总数:{len(all_introns)}") + print(f" 过滤后数量:{len(filtered_introns)}") + print(f" 训练样本数:{samples_written}") + print(f" 任务类型:{task_types}") + + if function_data: + ime_scores = [d['ime_score'] for d in function_data.values() if 'ime_score' in d] + if ime_scores: + print(f"\nIME 评分分布:") + print(f" 平均值:{sum(ime_scores)/len(ime_scores):.2f}") + print(f" 最大值:{max(ime_scores):.2f}") + print(f" 最小值:{min(ime_scores):.2f}") + + +def main(): + parser = argparse.ArgumentParser( + description='处理 GSE278584 数据用于内含子生成训练' + ) + + parser.add_argument('--input', type=str, required=True, + help='输入目录(包含下载的 GEO 文件)') + parser.add_argument('--output', type=str, required=True, + help='输出训练 FASTA 文件') + parser.add_argument('--tasks', type=str, default='intron', + help='任务类型:intron,intron_functional') + parser.add_argument('--min_ime_score', type=float, default=None, + help='最小 IME 评分过滤阈值') + parser.add_argument('--max_samples', type=int, default=None, + help='最大样本数') + + args = parser.parse_args() + + task_types = [t.strip() for t in args.tasks.split(',')] + + print("="*60) + print("GSE278584 数据处理工具") + print("="*60) + + process_gse278584(args.input, args.output, task_types, + min_ime_score=args.min_ime_score, + max_samples=args.max_samples) + + print(f"\n下一步:") + print(f" 1. 将 GSE278584 数据与天然内含子数据混合") + print(f" 2. 使用 prepare_gene_data.py 处理天然基因组数据") + print(f" 3. 合并两个训练文件") + print(f" 4. 运行微调训练") + + +if __name__ == '__main__': + main() diff --git a/scripts/run_cds_5utr_finetune.sh b/scripts/run_cds_5utr_finetune.sh new file mode 100755 index 0000000..288a38a --- /dev/null +++ b/scripts/run_cds_5utr_finetune.sh @@ -0,0 +1,149 @@ +#!/bin/bash +# ============================================ +# Evo2 CDS→5UTR 微调自动化脚本 +# ============================================ +# 使用方法: +# ./run_cds_5utr_finetune.sh [single|2gpu|lora|long] +# ============================================ + +set -e + +# 配置区域 - 请根据实际情况修改 +DATA_DIR="/path/to/your/data" +OUTPUT_DIR="/path/to/output" +MODEL_DIR="/path/to/models" + +# 选择训练模式 +TRAIN_MODE=${1:-"lora"} # 默认使用 LoRA + +echo "============================================" +echo "Evo2 CDS→5UTR 微调流程" +echo "训练模式:$TRAIN_MODE" +echo "============================================" + +# Step 1: 数据预处理 +echo "" +echo "[Step 1/5] 数据预处理..." +preprocess_evo2 --config configs/cds_5utr_finetune_config.yaml + +# Step 2: 模型转换(如果还没有转换) +if [ ! -d "$OUTPUT_DIR/evo2_7b_mbridge" ]; then + echo "" + echo "[Step 2/5] 转换模型格式..." + evo2_convert_savanna_to_mbridge \ + --savanna-ckpt-path arcinstitute/savanna_evo2_7b \ + --mbridge-ckpt-dir $OUTPUT_DIR/evo2_7b_mbridge \ + --model-size evo2_7b +else + echo "[Step 2/5] 模型已转换,跳过..." +fi + +# Step 3: 根据模式选择训练配置 +echo "" +echo "[Step 3/5] 开始微调训练..." + +case $TRAIN_MODE in + "single") + echo "使用单卡训练配置..." + torchrun --nproc-per-node 1 --no-python train_evo2 \ + --model-size evo2_7b \ + --finetune-ckpt-dir $OUTPUT_DIR/evo2_7b_mbridge \ + --data-path $OUTPUT_DIR/preprocessed_data \ + --max-steps 1000 \ + --micro-batch-size 1 \ + --global-batch-size 8 \ + --gradient-accumulation-steps 8 \ + --seq-length 8192 \ + --mixed-precision-recipe bf16_mixed \ + --use-subquadratic-ops \ + --use-precision-aware-optimizer \ + --eval-interval 100 \ + --save-interval 100 \ + --result-dir $OUTPUT_DIR/cds_5utr_finetune_single_gpu + ;; + + "2gpu") + echo "使用双卡训练配置..." + torchrun --nproc-per-node 2 --no-python train_evo2 \ + --model-size evo2_7b \ + --finetune-ckpt-dir $OUTPUT_DIR/evo2_7b_mbridge \ + --data-path $OUTPUT_DIR/preprocessed_data \ + --max-steps 1000 \ + --micro-batch-size 4 \ + --global-batch-size 32 \ + --tensor-model-parallel 2 \ + --seq-length 8192 \ + --mixed-precision-recipe bf16_mixed \ + --use-subquadratic-ops \ + --eval-interval 100 \ + --save-interval 100 \ + --result-dir $OUTPUT_DIR/cds_5utr_finetune_2gpu + ;; + + "lora") + echo "使用 LoRA 微调配置(推荐)..." + torchrun --nproc-per-node 2 --no-python train_evo2 \ + --model-size evo2_7b \ + --finetune-ckpt-dir $OUTPUT_DIR/evo2_7b_mbridge \ + --data-path $OUTPUT_DIR/preprocessed_data \ + --lora-finetune \ + --lora-dim 16 \ + --lora-alpha 32 \ + --lora-dropout 0.1 \ + --lora-target-modules "dense_projection,linear_qkv,linear_proj,linear_fc1,linear_fc2" \ + --max-steps 500 \ + --micro-batch-size 8 \ + --global-batch-size 32 \ + --seq-length 8192 \ + --mixed-precision-recipe bf16_mixed \ + --use-subquadratic-ops \ + --eval-interval 50 \ + --save-interval 50 \ + --result-dir $OUTPUT_DIR/cds_5utr_lora_finetune + ;; + + "long") + echo "使用长序列训练配置 (1M context)..." + torchrun --nproc-per-node 2 --no-python train_evo2 \ + --model-size evo2_7b \ + --finetune-ckpt-dir $OUTPUT_DIR/evo2_7b_mbridge \ + --data-path $OUTPUT_DIR/preprocessed_data \ + --max-steps 500 \ + --micro-batch-size 1 \ + --global-batch-size 8 \ + --gradient-accumulation-steps 8 \ + --seq-length 1048576 \ + --mixed-precision-recipe bf16_with_fp8_current_scaling_mixed \ + --use-subquadratic-ops \ + --eval-interval 100 \ + --save-interval 100 \ + --result-dir $OUTPUT_DIR/cds_5utr_finetune_1m_context + ;; + + *) + echo "错误:未知的训练模式 $TRAIN_MODE" + echo "可用模式:single, 2gpu, lora, long" + exit 1 + ;; +esac + +# Step 4: 移除优化器状态 +echo "" +echo "[Step 4/5] 移除优化器状态..." +evo2_remove_optimizer \ + --src-ckpt-dir $OUTPUT_DIR/cds_5utr_${TRAIN_MODE}_finetune/checkpoint \ + --dst-ckpt-dir $OUTPUT_DIR/cds_5utr_${TRAIN_MODE}_finetune_weights_only + +# Step 5: 导出为 Vortex 格式 +echo "" +echo "[Step 5/5] 导出为 Vortex 格式..." +evo2_export_mbridge_to_vortex \ + --mbridge-ckpt-dir $OUTPUT_DIR/cds_5utr_${TRAIN_MODE}_finetune_weights_only \ + --output-path $MODEL_DIR/cds_5utr_model.pt \ + --model-size evo2_7b + +echo "" +echo "============================================" +echo "训练完成!" +echo "模型路径:$MODEL_DIR/cds_5utr_model.pt" +echo "============================================" diff --git a/scripts/run_intron_pipeline.sh b/scripts/run_intron_pipeline.sh new file mode 100755 index 0000000..ee991c8 --- /dev/null +++ b/scripts/run_intron_pipeline.sh @@ -0,0 +1,389 @@ +#!/bin/bash +# ============================================ +# 内含子生成训练完整流程 +# 方案:GSE278584 + 人类天然内含子混合训练 +# ============================================ +# 使用方法: +# ./run_intron_pipeline.sh all # 运行全部步骤 +# ./run_intron_pipeline.sh step1 # 只运行步骤 1 +# ./run_intron_pipeline.sh step2 # 只运行步骤 2 +# ./run_intron_pipeline.sh step3 # 只运行步骤 3 +# ============================================ + +set -e + +# 配置区域 +DATA_DIR="data/intron_training" +RESULTS_DIR="results/intron_finetune" +MODEL_DIR="models" + +# 创建目录 +mkdir -p $DATA_DIR/gse278584 +mkdir -p $DATA_DIR/human_genome +mkdir -p $DATA_DIR/processed +mkdir -p $RESULTS_DIR +mkdir -p $MODEL_DIR + +echo "============================================" +echo "内含子生成训练流程 - 混合训练方案" +echo "============================================" +echo "" +echo "数据目录:$DATA_DIR" +echo "结果目录:$RESULTS_DIR" +echo "" + +# ============================================ +# 步骤 1:下载 GSE278584 数据 +# ============================================ +run_step1() { + echo "============================================" + echo "步骤 1:下载 GSE278584 数据" + echo "============================================" + + cd $DATA_DIR/gse278584 + + echo "请访问 GEO 页面下载数据:" + echo "https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE278584" + echo "" + + # 检查是否已有下载文件 + if ls *.tsv*.gz 2>/dev/null || ls *.csv*.gz 2>/dev/null || ls *.fasta*.gz 2>/dev/null; then + echo "✓ 发现已下载的文件" + ls -lh *.tsv*.gz *.csv*.gz *.fasta*.gz 2>/dev/null || true + else + echo "⚠ 未找到下载文件" + echo "" + echo "请手动下载以下文件(从 GEO 页面 Supplementary file):" + echo " - 内含子文库文件 (intron library)" + echo " - 功能数据文件 (function/expression data)" + echo "" + echo "示例下载命令(需要根据实际文件名调整):" + echo " wget https://.../GSE278584_intron_library.tsv.gz" + echo " wget https://.../GSE278584_function_data.tsv.gz" + echo "" + + # 尝试自动下载(如果知道确切 URL) + echo "尝试自动下载..." + + # 注意:实际 URL 需要从 GEO 页面获取,以下是示例 + # wget -q "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/GSE278584/suppl/GSE278584_intron_library.tsv.gz" || echo "自动下载失败,请手动下载" + # wget -q "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE278nnn/GSE278584/suppl/GSE278584_function_data.tsv.gz" || echo "自动下载失败,请手动下载" + fi + + # 解压文件 + echo "" + echo "解压文件..." + for f in *.gz; do + if [ -f "$f" ]; then + echo " 解压:$f" + gunzip -kf "$f" || true + fi + done + + echo "" + echo "步骤 1 完成!" + echo "请确认文件已下载到:$PWD" + echo "" + + cd ../../.. +} + +# ============================================ +# 步骤 2:处理数据并混合 +# ============================================ +run_step2() { + echo "============================================" + echo "步骤 2:处理数据并混合" + echo "============================================" + + # 2a. 处理 GSE278584 数据 + echo "" + echo "[2a] 处理 GSE278584 数据..." + + python scripts/process_gse278584.py \ + --input $DATA_DIR/gse278584 \ + --output $DATA_DIR/processed/gse278584_training.fasta \ + --tasks intron \ + --max_samples 30000 + + # 2b. 下载人类基因组(如果还没有) + echo "" + echo "[2b] 准备人类基因组数据..." + + cd $DATA_DIR/human_genome + + if [ ! -f "GCF_000001405.40_GRCh38.p14_genomic.fna" ]; then + echo "下载人类基因组 (GRCh38.p14)..." + echo " 文件大小:~3GB,下载可能需要 10-30 分钟" + + wget -q --show-progress \ + "https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.40_GRCh38.p14/GCF_000001405.40_GRCh38.p14_genomic.fna.gz" \ + -O GCF_000001405.40_GRCh38.p14_genomic.fna.gz || { + echo "下载失败,请手动从以下地址下载:" + echo "https://www.ncbi.nlm.nih.gov/datasets/genome/GCF_000001405.40/" + exit 1 + } + + echo "解压..." + gunzip -f GCF_000001405.40_GRCh38.p14_genomic.fna.gz + else + echo "✓ 人类基因组已存在" + fi + + if [ ! -f "GCF_000001405.40_GRCh38.p14_genomic.gff" ]; then + echo "下载人类基因组注释..." + wget -q --show-progress \ + "https://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.40_GRCh38.p14/GCF_000001405.40_GRCh38.p14_genomic.gff.gz" \ + -O GCF_000001405.40_GRCh38.p14_genomic.gff.gz + + echo "解压..." + gunzip -f GCF_000001405.40_GRCh38.p14_genomic.gff.gz + else + echo "✓ 基因组注释已存在" + fi + + cd ../../.. + + # 2c. 提取天然内含子 + echo "" + echo "[2c] 提取人类天然内含子..." + + python scripts/prepare_gene_data.py \ + --genome $DATA_DIR/human_genome/GCF_000001405.40_GRCh38.p14_genomic.fna \ + --annotation $DATA_DIR/human_genome/GCF_000001405.40_GRCh38.p14_genomic.gff \ + --output $DATA_DIR/processed/human_introns.fasta \ + --tasks intron \ + --max_intron_length 10000 \ + --max_genes 50000 \ + --seed 42 + + # 2d. 混合数据(50:50 比例) + echo "" + echo "[2d] 混合训练数据(50% GSE278584 + 50% 人类天然)..." + + # 计算 GSE278584 的样本数 + GSE_COUNT=$(grep -c "^>" $DATA_DIR/processed/gse278584_training.fasta || echo "0") + echo " GSE278584 样本数:$GSE_COUNT" + + # 采样相同数量的人类天然内含子 + python scripts/sample_fasta.py \ + --input $DATA_DIR/processed/human_introns.fasta \ + --output $DATA_DIR/processed/human_introns_sampled.fasta \ + --count $GSE_COUNT \ + --seed 42 + + # 合并 + cat $DATA_DIR/processed/gse278584_training.fasta \ + $DATA_DIR/processed/human_introns_sampled.fasta \ + > $DATA_DIR/processed/intron_training_combined.fasta + + # 统计 + TOTAL_COUNT=$(grep -c "^>" $DATA_DIR/processed/intron_training_combined.fasta) + echo "" + echo "混合完成!" + echo " GSE278584: $GSE_COUNT 个样本" + echo " 人类天然:$GSE_COUNT 个样本" + echo " 总计:$TOTAL_COUNT 个样本" + echo " 输出文件:$DATA_DIR/processed/intron_training_combined.fasta" + + # 显示长度分布 + echo "" + echo "序列长度分布(前 100 个样本):" + python -c " +import sys +sys.path.append('scripts') +from prepare_gene_data import read_fasta +data = read_fasta('$DATA_DIR/processed/intron_training_combined.fasta')[:100] +lengths = [len(seq) for _, seq in data] +print(f' 平均长度:{sum(lengths)/len(lengths):.0f} bp') +print(f' 最短:{min(lengths)} bp') +print(f' 最长:{max(lengths)} bp') +" +} + +# ============================================ +# 步骤 3:训练模型 +# ============================================ +run_step3() { + echo "============================================" + echo "步骤 3:训练模型" + echo "============================================" + + # 检查数据文件 + if [ ! -f "$DATA_DIR/processed/intron_training_combined.fasta" ]; then + echo "错误:训练数据不存在!请先运行步骤 1 和 2" + exit 1 + fi + + # 创建训练配置 + echo "" + echo "[3a] 创建训练配置文件..." + + cat > configs/intron_lora_config.yaml << EOF +# ============================================ +# 内含子生成训练配置(LoRA 微调) +# 方案:GSE278584 + 人类天然内含子混合 +# ============================================ + +# 数据配置 +preprocess: + input_path: "$DATA_DIR/processed/intron_training_combined.fasta" + output_prefix: "$DATA_DIR/preprocessed" + train_split: 0.9 + val_split: 0.05 + test_split: 0.05 + seq_length: 8192 + tokenizer: "tokenizers/nucleotide_fast_tokenizer_256" + +# 模型配置 +model_conversion: + savanna_ckpt_path: "arcinstitute/savanna_evo2_7b" + mbridge_ckpt_dir: "$MODEL_DIR/evo2_7b_mbridge" + model_size: "evo2_7b" + +# LoRA 微调配置 +lora_finetune: + model_size: "evo2_7b" + finetune_ckpt_dir: "$MODEL_DIR/evo2_7b_mbridge" + data_path: "$DATA_DIR/preprocessed" + + # LoRA 参数 + lora_finetune: true + lora_dim: 16 + lora_alpha: 32 + lora_dropout: 0.1 + lora_target_modules: + - dense_projection + - linear_qkv + - linear_proj + - linear_fc1 + - linear_fc2 + + # 训练参数 + max_steps: 2000 + micro_batch_size: 4 + global_batch_size: 32 + seq_length: 8192 + mixed_precision_recipe: "bf16_mixed" + use_subquadratic_ops: true + + # 评估和保存 + eval_interval: 100 + save_interval: 100 + + # 输出目录 + result_dir: "$RESULTS_DIR" +EOF + + echo " 配置文件已创建:configs/intron_lora_config.yaml" + + # 预处理数据 + echo "" + echo "[3b] 预处理数据(BioNemo 格式)..." + echo " 这可能需要 5-10 分钟..." + + # 注意:需要 BioNemo 环境 + # preprocess_evo2 --config configs/intron_lora_config.yaml + + echo " 请运行:" + echo " preprocess_evo2 --config configs/intron_lora_config.yaml" + echo "" + read -p "按回车继续..." || true + + # 转换模型 + echo "" + echo "[3c] 转换模型格式..." + + if [ ! -d "$MODEL_DIR/evo2_7b_mbridge" ]; then + evo2_convert_savanna_to_mbridge \ + --savanna-ckpt-path arcinstitute/savanna_evo2_7b \ + --mbridge-ckpt-dir $MODEL_DIR/evo2_7b_mbridge \ + --model-size evo2_7b + else + echo " ✓ 模型已转换" + fi + + # 开始训练 + echo "" + echo "[3d] 开始 LoRA 微调训练..." + echo " 预计时间:双卡 H100 约 3-4 小时,单卡约 6-8 小时" + echo "" + + # 训练命令(双卡) + torchrun --nproc-per-node 2 --no-python train_evo2 \ + --model-size evo2_7b \ + --finetune-ckpt-dir $MODEL_DIR/evo2_7b_mbridge \ + --data-path $DATA_DIR/preprocessed \ + --lora-finetune \ + --lora-dim 16 \ + --lora-alpha 32 \ + --lora-dropout 0.1 \ + --lora-target-modules "dense_projection,linear_qkv,linear_proj,linear_fc1,linear_fc2" \ + --max-steps 2000 \ + --micro-batch-size 4 \ + --global-batch-size 32 \ + --seq-length 8192 \ + --mixed-precision-recipe bf16_mixed \ + --use-subquadratic-ops \ + --eval-interval 100 \ + --save-interval 100 \ + --result-dir $RESULTS_DIR + + # 导出模型 + echo "" + echo "[3e] 导出模型..." + + evo2_remove_optimizer \ + --src-ckpt-dir $RESULTS_DIR/checkpoint \ + --dst-ckpt-dir $RESULTS_DIR/weights_only + + evo2_export_mbridge_to_vortex \ + --mbridge-ckpt-dir $RESULTS_DIR/weights_only \ + --output-path $MODEL_DIR/intron_generator.pt \ + --model-size evo2_7b + + echo "" + echo "============================================" + echo "训练完成!" + echo "============================================" + echo "" + echo "模型路径:$MODEL_DIR/intron_generator.pt" + echo "结果目录:$RESULTS_DIR" + echo "" + echo "使用模型生成内含子:" + echo " python scripts/generate_intron.py --model $MODEL_DIR/intron_generator.pt --exons 'ATG...,TAA' --output generated.fasta" +} + +# ============================================ +# 主流程 +# ============================================ + +case ${1:-all} in + "step1"|"1") + run_step1 + ;; + "step2"|"2") + run_step2 + ;; + "step3"|"3") + run_step3 + ;; + "all"|"") + run_step1 + echo "" + read -p "按回车继续到步骤 2..." || true + run_step2 + echo "" + read -p "按回车继续到步骤 3..." || true + run_step3 + ;; + *) + echo "用法:$0 [all|step1|step2|step3]" + echo "" + echo " all - 运行全部步骤" + echo " step1 - 下载 GSE278584 数据" + echo " step2 - 处理并混合数据" + echo " step3 - 训练模型" + exit 1 + ;; +esac diff --git a/scripts/sample_fasta.py b/scripts/sample_fasta.py new file mode 100755 index 0000000..7d683a7 --- /dev/null +++ b/scripts/sample_fasta.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +""" +FASTA 文件随机采样工具 + +用于从大型 FASTA 文件中采样指定数量的序列 + +使用方法: + python sample_fasta.py \ + --input large_file.fasta \ + --output sampled_file.fasta \ + --count 10000 \ + --seed 42 +""" + +import argparse +import random +from pathlib import Path +from typing import List, Tuple + + +def read_fasta(fasta_path: str) -> List[Tuple[str, str]]: + """读取 FASTA 文件""" + sequences = [] + current_name = None + current_seq = [] + + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if line.startswith('>'): + if current_name: + sequences.append((current_name, ''.join(current_seq))) + current_name = line[1:] + current_seq = [] + else: + current_seq.append(line) + + if current_name: + sequences.append((current_name, ''.join(current_seq))) + + return sequences + + +def write_fasta(sequences: List[Tuple[str, str]], output_path: str) -> None: + """写入 FASTA 文件""" + with open(output_path, 'w') as f: + for name, seq in sequences: + f.write(f">{name}\n") + for i in range(0, len(seq), 80): + f.write(seq[i:i+80] + '\n') + + +def main(): + parser = argparse.ArgumentParser( + description='从 FASTA 文件随机采样指定数量的序列' + ) + + parser.add_argument('--input', type=str, required=True, + help='输入 FASTA 文件') + parser.add_argument('--output', type=str, required=True, + help='输出采样 FASTA 文件') + parser.add_argument('--count', type=int, required=True, + help='采样序列数量') + parser.add_argument('--seed', type=int, default=42, + help='随机种子') + + args = parser.parse_args() + + print(f"[1/3] 读取输入文件:{args.input}") + sequences = read_fasta(args.input) + print(f" 总计:{len(sequences)} 条序列") + + if args.count > len(sequences): + print(f"警告:请求采样 {args.count} 条,但文件只有 {len(sequences)} 条") + print(f" 将返回全部 {len(sequences)} 条序列") + args.count = len(sequences) + + print(f"\n[2/3] 随机采样 {args.count} 条序列...") + random.seed(args.seed) + sampled = random.sample(sequences, args.count) + print(f" 采样完成") + + print(f"\n[3/3] 写入输出文件:{args.output}") + write_fasta(sampled, args.output) + print(f" 完成!") + + print(f"\n统计:") + print(f" 输入:{len(sequences)} 条") + print(f" 输出:{len(sampled)} 条") + print(f" 采样比例:{len(sampled)/len(sequences)*100:.1f}%") + + +if __name__ == '__main__': + main()