From a1add9aa67a0f66605a2c4428a880e235cf577fb Mon Sep 17 00:00:00 2001 From: Samuel Martin <39562021+SR-Martin@users.noreply.github.com> Date: Wed, 11 Feb 2026 17:33:49 +0000 Subject: [PATCH 1/8] Fix bug Memory not initialised properly. --- model/model_dna_rate_variation.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/model_dna_rate_variation.cpp b/model/model_dna_rate_variation.cpp index eb42df44..a631ede4 100644 --- a/model/model_dna_rate_variation.cpp +++ b/model/model_dna_rate_variation.cpp @@ -256,7 +256,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { for(int j = 0; j < num_states_; j++) { W[i * num_states_ + j] = 0; for(int k = 0; k < num_states_; k++) { - C[i * num_states_ + row_index[j] + k] = 0; + C[i * mat_size + row_index[j] + k] = 0; } } } From 70a0ecabde69ff465d5a4c5062ff7f4ed8e95f8b Mon Sep 17 00:00:00 2001 From: Samuel Martin <39562021+SR-Martin@users.noreply.github.com> Date: Tue, 24 Feb 2026 14:48:15 +0000 Subject: [PATCH 2/8] MATs Option added to write mutation annotated tree in nexus format with --estimate-MAT. --- maple/cmaple.cpp | 16 +++- tree/tree.cpp | 228 +++++++++++++++++++++++++++++++++++++++++++++-- tree/tree.h | 15 +++- utils/tools.cpp | 8 ++ utils/tools.h | 5 ++ 5 files changed, 258 insertions(+), 14 deletions(-) diff --git a/maple/cmaple.cpp b/maple/cmaple.cpp index 76b029e8..af661c40 100644 --- a/maple/cmaple.cpp +++ b/maple/cmaple.cpp @@ -197,6 +197,16 @@ void cmaple::runCMAPLE(cmaple::Params ¶ms) out << tree.exportTSV(); out.close(); } + + // export MAT if selected + if(params.output_MAT) + { + std::string filename = params.output_prefix + "_MAT.nex"; + std::cout << "Writing MAT to file " << filename << std::endl; + ofstream out = ofstream(filename); + out << tree.exportNexus(tree_format, false, true); + out.close(); + } // output log-likelihood of the tree if (cmaple::verbose_mode > cmaple::VB_QUIET) { @@ -218,10 +228,12 @@ void cmaple::runCMAPLE(cmaple::Params ¶ms) // Show information about output files std::cout << "Analysis results written to:" << std::endl; std::cout << "Maximum-likelihood tree: " << output_treefile << std::endl; + if (params.output_MAT) + std::cout << "Estimated mutation-annotated tree (MAT): " << output_treefile + "_MAT.nwk" << std::endl; if (params.output_NEXUS || params.compute_SPRTA) - std::cout << "Tree in NEXUS format: " << output_treefile + ".nex" << std::endl; + std::cout << "Tree in NEXUS format: " << output_treefile + ".nex" << std::endl; if (params.compute_SPRTA && params.output_alternative_spr) - std::cout << "Meta data in TSV format: " << output_treefile + ".tsv" << std::endl; + std::cout << "Meta data in TSV format: " << output_treefile + ".tsv" << std::endl; /*if (params.compute_aLRT_SH) { std::cout << "Tree with aLRT-SH values: " << prefix + ".aLRT_SH.treefile" << std::endl; diff --git a/tree/tree.cpp b/tree/tree.cpp index cb126d21..47a6322c 100644 --- a/tree/tree.cpp +++ b/tree/tree.cpp @@ -196,7 +196,8 @@ std::string cmaple::Tree::exportNewick(const TreeType tree_type, } std::string cmaple::Tree::exportNexus(const TreeType tree_type, - const bool show_branch_supports) { + const bool show_branch_supports, + const bool show_mutations) { assert(aln); assert(model); @@ -208,10 +209,10 @@ std::string cmaple::Tree::exportNexus(const TreeType tree_type, // output the tree according to its type switch (tree_type) { case BIN_TREE: - return exportNexus(true, show_branch_supports_checked); + return exportNexus(true, show_branch_supports_checked, show_mutations); // break; case MUL_TREE: - return exportNexus(false, show_branch_supports_checked); + return exportNexus(false, show_branch_supports_checked, show_mutations); // break; case UNKNOWN_TREE: default: @@ -1418,7 +1419,8 @@ std::string cmaple::Tree::exportNodeString(const bool is_newick_format, const bool binary, const NumSeqsType node_vec_index, const bool print_internal_id, - const bool show_branch_supports) { + const bool show_branch_supports, + const bool show_mutations) { string internal_name = ""; PhyloNode& node = nodes[node_vec_index]; string output = "("; @@ -1499,6 +1501,15 @@ std::string cmaple::Tree::exportNodeString(const bool is_newick_format, } } } + + if(show_mutations) + { + std::string mutation_string = getMutationStringForNode(node); + if(mutation_string.size() > 0) + { + annotation_vec.push_back("mutationsInf={" + mutation_string + "}"); + } + } if (sh_alrt_str.length() > 0) { @@ -1564,11 +1575,11 @@ std::string cmaple::Tree::exportNodeString(const bool is_newick_format, output += exportNodeString(is_newick_format, binary, node.getNeighborIndex(RIGHT).getVectorIndex(), - print_internal_id, show_branch_supports); + print_internal_id, show_branch_supports, show_mutations); output += ","; output += exportNodeString(is_newick_format, binary, node.getNeighborIndex(LEFT).getVectorIndex(), - print_internal_id, show_branch_supports); + print_internal_id, show_branch_supports, show_mutations); } // output SH-aLRT in newick string @@ -1590,6 +1601,205 @@ std::string cmaple::Tree::exportNodeString(const bool is_newick_format, return output; } +std::string cmaple::Tree::getMutationStringForNode(cmaple::PhyloNode& node) +{ + + std::string mutation_string = ""; + RealNumType blength = node.getUpperLength(); + if(blength <= 0.) { + return mutation_string; + } + + PositionType genome_size = aln->ref_seq.size(); + SeqRegion::SeqType seq_type = aln->getSeqType(); + StateType num_states = aln->num_states; + + Index parent_index = node.getNeighborIndex(TOP); + PhyloNode& parent_node = nodes[parent_index.getVectorIndex()]; + const std::unique_ptr& parent_regions = parent_node.getPartialLh(parent_index.getMiniIndex()); + // Note: getPartialLh is not const so cannot const cmaple::PhyloNode& node + const std::unique_ptr& child_regions = node.getPartialLh(TOP); + + PositionType pos = 0; + const SeqRegions& seqP_regions = *parent_regions; + const SeqRegions& seqC_regions = *child_regions; + size_t iseq1 = 0; + size_t iseq2 = 0; + + while(pos < genome_size) + { + PositionType end_pos = 0; + SeqRegions::getNextSharedSegment(pos, seqP_regions, seqC_regions, iseq1, iseq2, end_pos); + const auto* seqP_region = &seqP_regions[iseq1]; + const auto* seqC_region = &seqC_regions[iseq2]; + + // if the child of this branch does not observe its state directly then + // skip this branch. + if(seqC_region->plength_observation2node > 0) { + pos = end_pos + 1; + continue; + } + + // distance to last observation or root if last observation was across the root. + RealNumType branch_length_to_observation = blength; + if(seqP_region->plength_observation2node > 0 && seqP_region->plength_observation2root <= 0) { + branch_length_to_observation = blength + seqP_region->plength_observation2node; + } + else if(seqP_region->plength_observation2root > 0) { + branch_length_to_observation = blength + seqP_region->plength_observation2root; + } + + if(seqP_region->type != seqC_region->type && + seqP_region->type <= TYPE_R && + seqC_region->type <= TYPE_R) + { + StateType stateA = seqP_region->type; + StateType stateB = seqC_region->type; + if(seqP_region->type == TYPE_R) + { + stateA = aln->ref_seq[static_cast::size_type>(end_pos)]; + } + if(seqC_region->type == TYPE_R) + { + stateB = aln->ref_seq[static_cast::size_type>(end_pos)]; + } + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos+1) + + aln->convertState2Char(stateB, seq_type) + ":1.0,"; + } + else if(seqP_region->type <= TYPE_R && seqC_region->type == TYPE_O) + { + StateType stateA = seqP_region->type; + if(seqP_region->type == TYPE_R) + { + stateA = aln->ref_seq[static_cast::size_type>(end_pos)]; + } + // Calculate a weight vector giving the relative probabilities of observing + // each state at the O node. + std::vector weight_vector(num_states); + RealNumType sum = 0.0; + for(StateType stateB = 0; stateB < num_states; stateB++) + { + RealNumType likelihoodB = seqC_region->getLH(stateB); + RealNumType prob = 0.; + if(stateB != stateA) + { + prob = likelihoodB * branch_length_to_observation + * model->getMutationMatrixEntry(stateA, stateB, end_pos); + } else { + prob = likelihoodB * (1 - branch_length_to_observation + * model->getMutationMatrixEntry(stateB, stateB, end_pos)); + } + weight_vector[stateB] = prob; + sum += prob; + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states, sum); + + // write out mutations + for(StateType stateB = 0; stateB < num_states; stateB++) + { + if (stateB != stateA && weight_vector[stateB] > 0.01) + { + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos+1) + + aln->convertState2Char(stateB, seq_type) + ":" + + std::to_string(weight_vector[stateB]) + "," ; + } + } + } + else if(seqP_region->type == TYPE_O && seqC_region->type <= TYPE_R) + { + StateType stateB = seqC_region->type; + if(seqC_region->type == TYPE_R) + { + stateB = aln->ref_seq[static_cast::size_type>(end_pos)]; + } + // Calculate a weight vector giving the relative probabilities of observing + // each state at the O node. + std::vector weight_vector(num_states); + RealNumType sum = 0.0; + for(StateType stateA = 0; stateA < num_states; stateA++) + { + RealNumType likelihoodA = seqP_region->getLH(stateA); + RealNumType prob = 0; + if(stateA != stateB) + { + prob = likelihoodA * branch_length_to_observation + * model->getMutationMatrixEntry(stateA, stateB, end_pos); + } else { + prob = likelihoodA * (1 - branch_length_to_observation + * model->getMutationMatrixEntry(stateA, stateA, end_pos)); + } + weight_vector[stateA] = prob; + sum += prob; + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states, sum); + + // write out mutations + for(StateType stateA = 0; stateA < num_states; stateA++) + { + if (stateA != stateB && weight_vector[stateA] > 0.01) + { + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos + 1) + + aln->convertState2Char(stateB, seq_type) + ":" + + std::to_string(weight_vector[stateA]) + "," ; + } + } + } + else if(seqP_region->type == TYPE_O && seqC_region->type == TYPE_O) + { + // Calculate a weight vector giving the relative probabilities of observing + // each state at each of the O nodes. + std::vector weight_vector(num_states * num_states); + RealNumType sum = 0.0; + for(StateType stateA = 0; stateA < num_states; stateA++) { + RealNumType likelihoodA = seqP_region->getLH(stateA); + for(StateType stateB = 0; stateB < num_states; stateB++) { + RealNumType likelihoodB = seqC_region->getLH(stateB); + RealNumType prob = 0.; + if(stateA != stateB) { + prob = likelihoodA * likelihoodB * branch_length_to_observation + * model->getMutationMatrixEntry(stateA, stateB, end_pos); + } else { + prob = likelihoodA * likelihoodB * (1 - branch_length_to_observation + * model->getMutationMatrixEntry(stateA, stateA, end_pos)); + } + weight_vector[model->row_index[stateA] + stateB] = prob; + sum += prob; + } + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states * num_states, sum); + + // write out mutations + for(StateType stateA = 0; stateA < num_states; stateA++) + { + for(StateType stateB =0; stateB < num_states; stateB++) + { + if (stateA != stateB && weight_vector[model->row_index[stateA] + stateB] > 0.01) + { + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos + 1) + + aln->convertState2Char(stateB, seq_type) + ":" + + std::to_string(weight_vector[model->row_index[stateA] + stateB]) + "," ; + } + } + } + } + pos = end_pos + 1; + } + + if(mutation_string.size() > 0) + { + // remove trailing comma + mutation_string = mutation_string.substr(0, mutation_string.size()-1); + } + return mutation_string; +} + std::string cmaple::Tree::exportNewick(const bool binary, const bool print_internal_id, const bool show_branch_supports) { @@ -1611,7 +1821,8 @@ std::string cmaple::Tree::exportNewick(const bool binary, } std::string cmaple::Tree::exportNexus(const bool binary, - const bool show_branch_supports) { + const bool show_branch_supports, + const bool show_mutations) { assert(annotations.size() == nodes.size()); // make sure tree is not empty @@ -1690,7 +1901,7 @@ std::string cmaple::Tree::exportNexus(const bool binary, return pre_output + convertIntToString(seq_names.size()) + mid_output_1 + list_leaf_names + mid_output_2 - + exportNodeString(false, binary, root_vector_index, true, show_branch_supports) + + exportNodeString(false, binary, root_vector_index, true, show_branch_supports, show_mutations) + ";" + post_output; } @@ -6753,7 +6964,6 @@ void calculateSampleCost_R_O(const SeqRegion& seq1_region, RealNumType& total_factor, const ModelBase* model) { PositionType pos = seq2_region.position; - assert(seq1_region.position == seq2_region.position); if (seq1_region.plength_observation2root >= 0) { RealNumType total_blength = seq1_region.plength_observation2root + blength; diff --git a/tree/tree.h b/tree/tree.h index 09401434..d692c2cf 100644 --- a/tree/tree.h +++ b/tree/tree.h @@ -446,12 +446,14 @@ class Tree { * (bifurcating tree), MUL_TREE (multifurcating tree) * @param[in] show_branch_supports TRUE to output the branch supports (aLRT-SH * values) + * @param[in] show_mutations TRUE to output estimated mutations along tree * @return A tree string in NEXUS format * @throw std::invalid\_argument if any of the following situations occur. * - tree\_type is unknown */ std::string exportNexus(const TreeType tree_type = BIN_TREE, - const bool show_branch_supports = true); + const bool show_branch_supports = true, + const bool show_mutations = false); /** Export a TSV file that contains useful information from SPRTA @@ -461,6 +463,11 @@ class Tree { /*! \endcond */ private: + /** + * Get mutation string for MATs + */ + std::string getMutationStringForNode(cmaple::PhyloNode& node); + /** Pointer to LoadTree method */ @@ -1698,7 +1705,8 @@ bool isDiffFromOrigPlacement( const bool binary, const cmaple::NumSeqsType node_vec_index, const bool print_internal_id, - const bool show_branch_supports); + const bool show_branch_supports, + const bool show_mutations = false); /** Export string of an alternative branch (for SPRTA) @@ -1799,7 +1807,8 @@ bool isDiffFromOrigPlacement( Export tree std::string in NEXUS format */ std::string exportNexus(const bool binary, - const bool show_branch_supports); + const bool show_branch_supports, + const bool show_mutations); /** Traverse the tree to export TSV content diff --git a/utils/tools.cpp b/utils/tools.cpp index 65c334be..68ca5afe 100644 --- a/utils/tools.cpp +++ b/utils/tools.cpp @@ -614,6 +614,7 @@ cmaple::Params::Params() { make_consistent = false; print_internal_ids = false; output_NEXUS = false; + output_MAT = false; ignore_input_annotations = false; allow_rerooting = true; compute_SPRTA = false; @@ -822,6 +823,13 @@ void cmaple::parseArg(int argc, char* argv[], Params& params) { continue; } + if (strcmp(argv[cnt], "--estimate-MAT") == 0 || + strcmp(argv[cnt], "-estimate-MAT") == 0) { + + params.output_MAT = true; + + continue; + } if (strcmp(argv[cnt], "--out-internal") == 0 || strcmp(argv[cnt], "-out-int") == 0) { diff --git a/utils/tools.h b/utils/tools.h index 2cb98ff8..28f04208 100644 --- a/utils/tools.h +++ b/utils/tools.h @@ -689,6 +689,11 @@ class Params { */ bool output_NEXUS; + /** + * TRUE to also output MAT in nexus format + */ + bool output_MAT; + /** * TRUE to compute the SPRTA branch supports */ From ba03b9bff0b3f560a45ccd3a22f16d5c5ebe5eb3 Mon Sep 17 00:00:00 2001 From: Samuel Martin <39562021+SR-Martin@users.noreply.github.com> Date: Wed, 25 Feb 2026 15:44:31 +0000 Subject: [PATCH 3/8] Updates to rate-variation models - EM is now used for rate-variation (scalar) model. - User can specify maximum number of iterations for EM rate estimation for both rate-variation and site-specific-matrix models. Use "--rv-max-EM-steps ". - Added option to perform rate estimation after each tree traversal during SPR search. Use "--estimate-rates-during-SPR". --- maple/cmaple.cpp | 4 +-- model/model.cpp | 3 +- model/model.h | 1 + model/model_dna_rate_variation.cpp | 55 +++++++++++++++++++----------- model/model_dna_rate_variation.h | 4 ++- tree/tree.cpp | 18 ++++++++-- tree/tree.h | 12 +++---- utils/tools.cpp | 30 +++++++++++++--- utils/tools.h | 15 +++++++- 9 files changed, 105 insertions(+), 37 deletions(-) diff --git a/maple/cmaple.cpp b/maple/cmaple.cpp index af661c40..8c8ea8fe 100644 --- a/maple/cmaple.cpp +++ b/maple/cmaple.cpp @@ -122,8 +122,8 @@ void cmaple::runCMAPLE(cmaple::Params ¶ms) throw std::invalid_argument("Unknown Model " + params.sub_model_str); } assert(sub_model != cmaple::ModelBase::UNKNOWN); - bool useRateVariationModel = params.rate_variation || params.site_specific_rates; - Model model(aln.ref_seq.size(), useRateVariationModel, params.rate_variation, params.wt_pseudocount, params.rates_filename, sub_model, aln.getSeqType()); + bool useRateVariationModel = params.rate_variation || params.site_specific_rate_matrix; + Model model(aln.ref_seq.size(), useRateVariationModel, params.rate_variation, params.wt_pseudocount, params.rates_filename, params.rate_variation_max_num_EM_steps, sub_model, aln.getSeqType()); // If users only want to convert the alignment to another format -> convert it and terminate if (params.output_aln.length()) diff --git a/model/model.cpp b/model/model.cpp index ccbaac86..34a9af3b 100644 --- a/model/model.cpp +++ b/model/model.cpp @@ -11,6 +11,7 @@ cmaple::Model::Model( cmaple::PositionType ref_genome_size, bool _siteRates, cmaple::RealNumType wt_pseudocount, const std::string _rates_filename, + int _max_num_EM_steps, const cmaple::ModelBase::SubModel sub_model, const cmaple::SeqRegion::SeqType seqtype) : model_base(nullptr) { @@ -75,7 +76,7 @@ cmaple::Model::Model( cmaple::PositionType ref_genome_size, } case cmaple::SeqRegion::SEQ_DNA: { if(rate_variation){ - model_base = new ModelDNARateVariation(n_sub_model, ref_genome_size, _siteRates, wt_pseudocount, _rates_filename); + model_base = new ModelDNARateVariation(n_sub_model, ref_genome_size, _siteRates, wt_pseudocount, _rates_filename, _max_num_EM_steps); } else { model_base = new ModelDNA(n_sub_model); } diff --git a/model/model.h b/model/model.h index 0ca3ebe3..be9fe950 100644 --- a/model/model.h +++ b/model/model.h @@ -51,6 +51,7 @@ class Model { bool _siteRates, cmaple::RealNumType wt_pseudocount, const std::string _rates_filename, + int _max_num_EM_steps, const cmaple::ModelBase::SubModel sub_model = cmaple::ModelBase::DEFAULT, const cmaple::SeqRegion::SeqType seqtype = cmaple::SeqRegion::SEQ_AUTO); diff --git a/model/model_dna_rate_variation.cpp b/model/model_dna_rate_variation.cpp index a631ede4..2da4a163 100644 --- a/model/model_dna_rate_variation.cpp +++ b/model/model_dna_rate_variation.cpp @@ -4,8 +4,7 @@ using namespace cmaple; -ModelDNARateVariation::ModelDNARateVariation( const cmaple::ModelBase::SubModel sub_model, PositionType _genome_size, - bool _use_site_rates, cmaple::RealNumType _wt_pseudocount, std::string _rates_filename) +ModelDNARateVariation::ModelDNARateVariation(const cmaple::ModelBase::SubModel sub_model, PositionType _genome_size, bool _use_site_rates, cmaple::RealNumType _wt_pseudocount, std::string _rates_filename, int _max_num_EM_steps) : ModelDNA(sub_model) { genome_size = _genome_size; @@ -13,6 +12,7 @@ ModelDNARateVariation::ModelDNARateVariation( const cmaple::ModelBase::SubModel mat_size = row_index[num_states_]; waiting_time_pseudocount = _wt_pseudocount; rates_filename = _rates_filename; + max_num_EM_steps = _max_num_EM_steps; mutation_matrices = new RealNumType[mat_size * genome_size](); transposed_mutation_matrices = new RealNumType[mat_size * genome_size](); @@ -88,22 +88,39 @@ bool ModelDNARateVariation::updateMutationMatEmpirical() { void ModelDNARateVariation::estimateRates(cmaple::Tree* tree) { rates_estimated = true; - if(use_site_rates) { - estimateRatePerSite(tree); - - } else { - if(rates_filename.size() == 0) { - RealNumType old_LK = -std::numeric_limits::infinity(); - RealNumType new_LK = tree->computeLh(); - int num_steps = 0; - while(new_LK - old_LK > 1 && num_steps < 20) { + if(rates_filename.size() == 0) { + RealNumType old_LK = -std::numeric_limits::infinity(); + RealNumType new_LK = tree->computeLh(); + if(cmaple::verbose_mode > VB_MIN) + { + std::cout << "Estimation mutation rates using EM..." << std::endl; + std::cout << "Starting log-LK: " << + std::setprecision(10) << new_LK << std::endl; + } + + int num_steps = 0; + while(abs(new_LK - old_LK) > 1 && num_steps < max_num_EM_steps) + { + if(use_site_rates) + { + estimateRatePerSite(tree); + } + else + { estimateRatesPerSitePerEntry(tree); - old_LK = new_LK; - new_LK = tree->computeLh(); - } - } + } + old_LK = new_LK; + tree->computeCumulativeRate(); + new_LK = tree->computeLh(); + if(cmaple::verbose_mode > VB_MIN) + { + std::cout << "EM round " << num_steps + 1 << ": " << + std::setprecision(10) << new_LK << std::endl; + } + num_steps++; + } } - + // Write out rate matrices to file if(cmaple::verbose_mode > VB_MIN) { @@ -128,7 +145,7 @@ void ModelDNARateVariation::estimateRates(cmaple::Tree* tree) { } void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){ - std::cout << "Estimating mutation rate per site..." << std::endl; + //std::cout << "Estimating mutation rate per site..." << std::endl; RealNumType* waiting_times = new RealNumType[num_states_ * genome_size]; RealNumType* num_substitutions = new RealNumType[genome_size]; for(int i = 0; i < genome_size; i++) { @@ -204,7 +221,7 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){ } else { RealNumType expected_rate_no_substitution = 0; for(int j = 0; j < num_states_; j++) { - RealNumType summand = waiting_times[i * num_states_ + j] * abs(diagonal_mut_mat[j]); + RealNumType summand = waiting_times[i * num_states_ + j] * abs(getDiagonalMutationMatrixEntry(j,i)); expected_rate_no_substitution += summand; } if(expected_rate_no_substitution <= 0.01) { @@ -224,7 +241,7 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){ RealNumType row_sum = 0; for(int stateB = 0; stateB < num_states_; stateB++) { if(stateA != stateB) { - RealNumType val = mutation_matrices[i * mat_size + (stateB + row_index[stateA])] * rates[i]; + RealNumType val = mutation_mat[stateB + row_index[stateA]] * rates[i]; mutation_matrices[i * mat_size + (stateB + row_index[stateA])] = val; transposed_mutation_matrices[i * mat_size + (stateA + row_index[stateB])] = val; freqi_freqj_Qijs[i * mat_size + (stateB + row_index[stateA])] = root_freqs[stateA] * inverse_root_freqs[stateB] * val; diff --git a/model/model_dna_rate_variation.h b/model/model_dna_rate_variation.h index 1be8bbed..3ad1a776 100644 --- a/model/model_dna_rate_variation.h +++ b/model/model_dna_rate_variation.h @@ -13,7 +13,8 @@ class Tree; class ModelDNARateVariation : public ModelDNA { public: ModelDNARateVariation( const cmaple::ModelBase::SubModel sub_model, PositionType _genome_size, - bool _use_site_rates, cmaple::RealNumType _wt_pseudocount, std::string _rates_filename); + bool _use_site_rates, cmaple::RealNumType _wt_pseudocount, std::string _rates_filename, + int _max_num_EM_steps); virtual ~ModelDNARateVariation(); void estimateRates(cmaple::Tree* tree); @@ -102,6 +103,7 @@ class ModelDNARateVariation : public ModelDNA { cmaple::RealNumType waiting_time_pseudocount; std::string rates_filename; + int max_num_EM_steps; }; } \ No newline at end of file diff --git a/tree/tree.cpp b/tree/tree.cpp index 47a6322c..1dbec851 100644 --- a/tree/tree.cpp +++ b/tree/tree.cpp @@ -1031,11 +1031,9 @@ void cmaple::Tree::applySPRTemplate( template void cmaple::Tree::doRateEstimationTemplate(std::ostream& out_stream) { - if(params->rate_variation || params->site_specific_rates) { + if(params->rate_variation || params->site_specific_rate_matrix) { ModelDNARateVariation* rvModel = (ModelDNARateVariation*) model; rvModel->estimateRates(this); - //rvModel->setAllMatricesToDefault(); - computeCumulativeRate(); } } @@ -1127,6 +1125,13 @@ void cmaple::Tree::optimizeTreeTopology(const TreeSearchType tree_search_type, // traverse the tree from root to try improvements on the entire tree RealNumType improvement = improveEntireTree(tree_search_type, short_range_search); + + if(params->estimate_rates_during_SPR && + (params->rate_variation || params->site_specific_rate_matrix)) + { + ModelDNARateVariation* rvModel = (ModelDNARateVariation*) model; + rvModel->estimateRates(this); + } // if only compute SPRTA (~ tree search type = FAST), stop searching further, one round is enough if (tree_search_type == FAST_TREE_SEARCH) @@ -1155,6 +1160,13 @@ void cmaple::Tree::optimizeTreeTopology(const TreeSearchType tree_search_type, << "Tree log likelihood: " << computeLh() << std::endl; } + if(params->estimate_rates_during_SPR && + (params->rate_variation || params->site_specific_rate_matrix)) + { + ModelDNARateVariation* rvModel = (ModelDNARateVariation*) model; + rvModel->estimateRates(this); + } + // stop trying if the improvement is so small if (improvement < params->thresh_entire_tree_improvement) { break; diff --git a/tree/tree.h b/tree/tree.h index d692c2cf..dd2417ff 100644 --- a/tree/tree.h +++ b/tree/tree.h @@ -293,6 +293,12 @@ class Tree { */ std::unique_ptr& getPartialLhAtNode(const cmaple::Index index); + /** + Compute cumulative rate of the ref genome + @throw std::logic\_error if the reference genome is empty + */ + void computeCumulativeRate(); + // ----------------- END OF PUBLIC APIs ------------------------------------ // // @@ -628,12 +634,6 @@ class Tree { Model* model, std::unique_ptr&& params); - /** - Compute cumulative rate of the ref genome - @throw std::logic\_error if the reference genome is empty - */ - void computeCumulativeRate(); - /*! Optimize the tree topology @throw std::logic\_error if unexpected values/behaviors found during the operations diff --git a/utils/tools.cpp b/utils/tools.cpp index 68ca5afe..1a95287d 100644 --- a/utils/tools.cpp +++ b/utils/tools.cpp @@ -624,7 +624,9 @@ cmaple::Params::Params() { min_support_alt_branches = 0.01; thresh_loglh_optimal_diff_fac = 1.0; rate_variation = false; - site_specific_rates = false; + site_specific_rate_matrix = false; + rate_variation_max_num_EM_steps = 20; + estimate_rates_during_SPR = false; wt_pseudocount = 0.1; rates_filename = ""; @@ -1284,8 +1286,10 @@ void cmaple::parseArg(int argc, char* argv[], Params& params) { continue; } if (strcmp(argv[cnt], "--site-specific-rates") == 0 || + strcmp(argv[cnt], "--site-specific-rate-matrix") == 0 || + strcmp(argv[cnt], "--site-specific-matrix") == 0 || strcmp(argv[cnt], "-ssr") == 0) { - params.site_specific_rates = true; + params.site_specific_rate_matrix = true; continue; } @@ -1302,6 +1306,24 @@ void cmaple::parseArg(int argc, char* argv[], Params& params) { continue; } + if (strcmp(argv[cnt], "--rv-max-EM-steps") == 0) { + cnt++; + if (cnt >= argc) { + outError("Use --rv-max-EM-steps "); + } + try { + params.rate_variation_max_num_EM_steps = convert_int(argv[cnt]); + } catch (std::invalid_argument e) { + outError(e.what()); + } + continue; + } + + if (strcmp(argv[cnt], "--estimate-rates-during-SPR") == 0) { + params.estimate_rates_during_SPR = true; + continue; + } + if (strcmp(argv[cnt], "--rates-filename") == 0) { cnt++; if (cnt >= argc) { @@ -1378,8 +1400,8 @@ void cmaple::parseArg(int argc, char* argv[], Params& params) { "if SPRTA is not computed. Please use " "`--sprta` if you want to compute SPRTA."); } - if(params.rate_variation && params.site_specific_rates) { - outError("Unable to use rate variation and site-specific rate matrices.\n" + if(params.rate_variation && params.site_specific_rate_matrix) { + outError("Unable to use rate-variation and site-specific rate matrices.\n" "Please choose either:\n\t \"--rate-variation\" for a rate multiplier at each genomic site, or \n" "\t\"--site-specific-rates\" for an independent rate matrix at each genomic site."); } diff --git a/utils/tools.h b/utils/tools.h index 28f04208..42f8e2de 100644 --- a/utils/tools.h +++ b/utils/tools.h @@ -665,7 +665,20 @@ class Params { /** * TRUE to allow an independent rate matrix for each genomic site. */ - bool site_specific_rates; + bool site_specific_rate_matrix; + + /** + * Maximum number of EM iterations for estimating rates. + */ + int rate_variation_max_num_EM_steps; + + /** + * When to perform rate estimation for rate variation and + * site-specific matrix models. + * If true then estimate rates after each SPR tree traversal. + * Otherwise only estimate rates after initial sample placement. + */ + bool estimate_rates_during_SPR; /** * Name of file containing rates for each genomic site. From dc6ddc308d1cf0eb6a511a7dcd180e5757ca02f0 Mon Sep 17 00:00:00 2001 From: Samuel Martin <39562021+SR-Martin@users.noreply.github.com> Date: Mon, 16 Mar 2026 16:01:29 +0000 Subject: [PATCH 4/8] Updates to scalar rate variation model Updates to scalar rate variation model: - Estimate rates using nodes of type 0. - Take into account distance from last observation when estimating waiting times - Add pseudocounts - Use std min and max - Tidy code and reduce duplicate code --- model/model_dna_rate_variation.cpp | 467 ++++++++++++++++++++--------- model/model_dna_rate_variation.h | 28 +- 2 files changed, 342 insertions(+), 153 deletions(-) diff --git a/model/model_dna_rate_variation.cpp b/model/model_dna_rate_variation.cpp index 2da4a163..eca23be3 100644 --- a/model/model_dna_rate_variation.cpp +++ b/model/model_dna_rate_variation.cpp @@ -4,11 +4,17 @@ using namespace cmaple; -ModelDNARateVariation::ModelDNARateVariation(const cmaple::ModelBase::SubModel sub_model, PositionType _genome_size, bool _use_site_rates, cmaple::RealNumType _wt_pseudocount, std::string _rates_filename, int _max_num_EM_steps) +ModelDNARateVariation::ModelDNARateVariation( + const cmaple::ModelBase::SubModel sub_model, + PositionType _genome_size, + bool _scalar_rate_model, + cmaple::RealNumType _wt_pseudocount, + std::string _rates_filename, + int _max_num_EM_steps) : ModelDNA(sub_model) { genome_size = _genome_size; - use_site_rates = _use_site_rates; + scalar_rate_model = _scalar_rate_model; mat_size = row_index[num_states_]; waiting_time_pseudocount = _wt_pseudocount; rates_filename = _rates_filename; @@ -20,7 +26,7 @@ ModelDNARateVariation::ModelDNARateVariation(const cmaple::ModelBase::SubModel s freqi_freqj_Qijs = new RealNumType[mat_size * genome_size](); freqj_transposedijs = new RealNumType[mat_size * genome_size](); - if(use_site_rates) { + if(scalar_rate_model) { rates = new cmaple::RealNumType[genome_size](); } if(rates_filename.length() > 0) { @@ -34,7 +40,7 @@ ModelDNARateVariation::~ModelDNARateVariation() { delete[] diagonal_mutation_matrices; delete[] freqi_freqj_Qijs; delete[] freqj_transposedijs; - if(use_site_rates) { + if(scalar_rate_model) { delete[] rates; } } @@ -91,34 +97,42 @@ void ModelDNARateVariation::estimateRates(cmaple::Tree* tree) { if(rates_filename.size() == 0) { RealNumType old_LK = -std::numeric_limits::infinity(); RealNumType new_LK = tree->computeLh(); + if(cmaple::verbose_mode > VB_MIN) { - std::cout << "Estimation mutation rates using EM..." << std::endl; + std::string model = scalar_rate_model ? "scalar rate variation" : "site-specific rate matrix"; + std::cout << "Estimation mutation rates under " << model << " model..." << std::endl; std::cout << "Starting log-LK: " << std::setprecision(10) << new_LK << std::endl; } - int num_steps = 0; - while(abs(new_LK - old_LK) > 1 && num_steps < max_num_EM_steps) + if(scalar_rate_model) { - if(use_site_rates) - { - estimateRatePerSite(tree); - } - else - { - estimateRatesPerSitePerEntry(tree); - } - old_LK = new_LK; + estimateRatePerSite(tree); tree->computeCumulativeRate(); new_LK = tree->computeLh(); if(cmaple::verbose_mode > VB_MIN) { - std::cout << "EM round " << num_steps + 1 << ": " << - std::setprecision(10) << new_LK << std::endl; + std::cout << "After rate estimation: " << std::setprecision(10) << new_LK << std::endl; } - num_steps++; - } + } + else + { + int num_steps = 0; + while(abs(new_LK - old_LK) > 1 && num_steps < max_num_EM_steps) + { + estimateRatesPerSitePerEntry(tree); + old_LK = new_LK; + tree->computeCumulativeRate(); + new_LK = tree->computeLh(); + if(cmaple::verbose_mode > VB_MIN) + { + std::cout << "EM round " << num_steps + 1 << ": " << + std::setprecision(10) << new_LK << std::endl; + } + num_steps++; + } + } } // Write out rate matrices to file @@ -132,7 +146,7 @@ void ModelDNARateVariation::estimateRates(cmaple::Tree* tree) { printMatrix(getOriginalRateMatrix(), &out_file); for(int i = 0; i < genome_size; i++) { out_file << "Position: " << i << std::endl; - if(use_site_rates) { + if(scalar_rate_model) { out_file << "Rate: " << rates[i] << std::endl; } out_file << "Rate Matrix: " << std::endl; @@ -182,61 +196,180 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){ const std::unique_ptr& child_regions = node.getPartialLh(TOP); PositionType pos = 0; - const SeqRegions& seq1_regions = *parent_regions; - const SeqRegions& seq2_regions = *child_regions; + const SeqRegions& seqP_regions = *parent_regions; + const SeqRegions& seqC_regions = *child_regions; size_t iseq1 = 0; size_t iseq2 = 0; while(pos < genome_size) { PositionType end_pos; - SeqRegions::getNextSharedSegment(pos, seq1_regions, seq2_regions, iseq1, iseq2, end_pos); - const auto* seq1_region = &seq1_regions[iseq1]; - const auto* seq2_region = &seq2_regions[iseq2]; + SeqRegions::getNextSharedSegment(pos, seqP_regions, seqC_regions, iseq1, iseq2, end_pos); + const auto* seqP_region = &seqP_regions[iseq1]; + const auto* seqC_region = &seqC_regions[iseq2]; - if(seq1_region->type == TYPE_R && seq2_region->type == TYPE_R) { + // if the child of this branch does not observe its state directly then + // skip this branch. + if(seqC_region->plength_observation2node > 0) { + pos = end_pos + 1; + continue; + } + + // distance to last observation or root if last observation was across the root. + RealNumType branch_length_to_observation = blength; + if(seqP_region->plength_observation2node > 0 && seqP_region->plength_observation2root < 0) { + branch_length_to_observation += seqP_region->plength_observation2node; + } + else if(seqP_region->plength_observation2root >= 0) { + branch_length_to_observation += seqP_region->plength_observation2root; + } + + if(seqP_region->type == TYPE_R && seqC_region->type == TYPE_R) { // both states are type REF for(int i = pos; i <= end_pos; i++) { - int state = tree->aln->ref_seq[static_cast::size_type>(i)]; - waiting_times[i * num_states_ + state] += blength; + StateType state = tree->aln->ref_seq[static_cast::size_type>(i)]; + waiting_times[i * num_states_ + state] += branch_length_to_observation; } - } else if(seq1_region->type == seq2_region->type && seq1_region->type < TYPE_R) { + } else if(seqP_region->type == seqC_region->type && seqP_region->type < TYPE_R) { // both states are equal but not of type REF - for(int i = pos; i <= end_pos; i++) { - waiting_times[i * num_states_ + seq1_region->type] += blength; - } - } else if(seq1_region->type <= TYPE_R && seq2_region->type <= TYPE_R) { + waiting_times[end_pos * num_states_ + seqP_region->type] += branch_length_to_observation; + + } else if(seqP_region->type <= TYPE_R && seqC_region->type <= TYPE_R) { // both states are not equal - for(int i = pos; i <= end_pos; i++) { - num_substitutions[i] += 1; - } - } + StateType parent_state = seqP_region->type; + StateType child_state = seqC_region->type; + if(seqP_region->type == TYPE_R) { + parent_state = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; + } + if (seqC_region->type == TYPE_R) { + child_state = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; + } + // Case 1: Last observation was this side of the root node + if(seqP_region->plength_observation2root < 0) { + waiting_times[end_pos * num_states_ + parent_state] += branch_length_to_observation / 2; + waiting_times[end_pos * num_states_ + child_state] += branch_length_to_observation / 2; + num_substitutions[end_pos] += 1; + } else { + // Case 2: Last observation was the other side of the root. + // In this case there are two further cases - the mutation happened either side of the root. + // We calculate the relative likelihood of each case and use this to weight waiting times etc. + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + updateCountsAndWaitingTimesAcrossRoot(end_pos, parent_state, child_state, dist_to_root, dist_to_observed, waiting_times, num_substitutions); + } + } else if(seqP_region->type <= TYPE_R && seqC_region->type == TYPE_O) { + StateType parent_state = seqP_region->type; + if(seqP_region->type == TYPE_R) { + parent_state = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; + } + + // Get weight vector giving the relative probabilities of observing + // each state at the O node. + std::vector weight_vector = getRelativeProbabilityOfChildOStatesForRegion(seqC_region, parent_state, branch_length_to_observation, end_pos); + + // Case 1: Last observation was this side of the root node + if(seqP_region->plength_observation2root < 0) { + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType prob = weight_vector[child_state]; + if(child_state != parent_state) { + num_substitutions[end_pos] += prob; + waiting_times[end_pos * num_states_ + parent_state] += prob * branch_length_to_observation/2; + waiting_times[end_pos * num_states_ + child_state] += prob * branch_length_to_observation/2; + } else { + waiting_times[end_pos * num_states_ + child_state] += prob * branch_length_to_observation; + } + } + } else { + // Case 2: Last observation was the other side of the root. + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType prob = weight_vector[child_state]; + updateCountsAndWaitingTimesAcrossRoot(end_pos, parent_state, child_state, dist_to_root, dist_to_observed, waiting_times, num_substitutions, prob); + } + } + } else if(seqP_region->type == TYPE_O && seqC_region->type <= TYPE_R) { + StateType child_state = seqC_region->type; + if(seqC_region->type == TYPE_R) { + child_state = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; + } + + // Calculate a weight vector giving the relative probabilities of observing + // each state at the O node. + std::vector weight_vector = getRelativeProbabilityOfParentOStatesForRegion(seqP_region, child_state, branch_length_to_observation, end_pos); + + // Case 1: Last observation was this side of the root node + if(seqP_region->plength_observation2root < 0) { + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + RealNumType prob = weight_vector[parent_state]; + if(child_state != parent_state) { + num_substitutions[end_pos] += prob; + waiting_times[end_pos * num_states_ + parent_state] += prob * branch_length_to_observation/2; + waiting_times[end_pos * num_states_ + child_state] += prob * branch_length_to_observation/2; + } else { + waiting_times[end_pos * num_states_ + parent_state] += prob * branch_length_to_observation; + } + } + } else { + // Case 2: Last observation was the other side of the root. + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + RealNumType prob = weight_vector[parent_state]; + updateCountsAndWaitingTimesAcrossRoot(end_pos, parent_state, child_state, dist_to_root, dist_to_observed, waiting_times, num_substitutions, prob); + } + } + } else if(seqP_region->type == TYPE_O && seqC_region->type == TYPE_O) { + // Get weight vector giving the relative probabilities of observing + // each state at each of the O nodes. + std::vector weight_vector = getRelativeProbabilityOfParentOChildOStatesForRegion(seqP_region, seqC_region, branch_length_to_observation, end_pos); + + // Case 1: Last observation was this side of the root node + if(seqP_region->plength_observation2root < 0) { + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType prob = weight_vector[row_index[parent_state] + child_state]; + if(child_state != parent_state) { + num_substitutions[end_pos] += prob; + waiting_times[end_pos * num_states_ + parent_state] += prob * branch_length_to_observation/2; + waiting_times[end_pos * num_states_ + child_state] += prob * branch_length_to_observation/2; + } else { + waiting_times[end_pos * num_states_ + parent_state] += prob * branch_length_to_observation; + } + } + } + } else { + // Case 2: Last observation was the other side of the root. + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType prob = weight_vector[row_index[parent_state] + child_state]; + updateCountsAndWaitingTimesAcrossRoot(end_pos, parent_state, child_state, dist_to_root, dist_to_observed, waiting_times, num_substitutions, prob); + } + } + } + } pos = end_pos + 1; } } + // calculate site-rate as number of substitutions at site / expected rate of no substitution (according to genome-wide rates). RealNumType rate_count = 0; for(int i = 0; i < genome_size; i++) { - if(num_substitutions[i] == 0) { - rates[i] = 0.001; - } else { - RealNumType expected_rate_no_substitution = 0; - for(int j = 0; j < num_states_; j++) { - RealNumType summand = waiting_times[i * num_states_ + j] * abs(getDiagonalMutationMatrixEntry(j,i)); - expected_rate_no_substitution += summand; - } - if(expected_rate_no_substitution <= 0.01) { - rates[i] = 1.; - } else { - rates[i] = num_substitutions[i] / expected_rate_no_substitution; - } + RealNumType expected_rate_no_substitution = 0; + for(int j = 0; j < num_states_; j++) { + RealNumType summand = waiting_times[i * num_states_ + j] * abs(diagonal_mut_mat[row_index[j] + j]); + expected_rate_no_substitution += summand; } + rates[i] = (num_substitutions[i]+1) / (expected_rate_no_substitution+1); rate_count += rates[i]; } + // normalise so average rate is 1. RealNumType average_rate = rate_count / genome_size; for(int i = 0; i < genome_size; i++) { - rates[i] /= average_rate; - rates[i] = MIN(100.0, MAX(0.0001, rates[i])); + rates[i] /= average_rate; + rates[i] = std::min(250.0, std::max(0.0001, rates[i])); for(int stateA = 0; stateA < num_states_; stateA++) { RealNumType row_sum = 0; for(int stateB = 0; stateB < num_states_; stateB++) { @@ -323,11 +456,11 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { // distance to last observation or root if last observation was across the root. RealNumType branch_length_to_observation = blength; - if(seqP_region->plength_observation2node > 0 && seqP_region->plength_observation2root <= 0) { - branch_length_to_observation = blength + seqP_region->plength_observation2node; + if(seqP_region->plength_observation2node > 0 && seqP_region->plength_observation2root < 0) { + branch_length_to_observation += seqP_region->plength_observation2node; } - else if(seqP_region->plength_observation2root > 0) { - branch_length_to_observation = blength + seqP_region->plength_observation2root; + else if(seqP_region->plength_observation2root >= 0) { + branch_length_to_observation += seqP_region->plength_observation2root; } if(seqP_region->type == TYPE_R && seqC_region->type == TYPE_R) { @@ -337,12 +470,10 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { W[i * num_states_ + state] += branch_length_to_observation; } } else if(seqP_region->type == seqC_region->type && seqP_region->type < TYPE_R) { - // both states are equal but not of type REF - for(int i = pos; i <= end_pos; i++) { - W[i * num_states_ + seqP_region->type] += branch_length_to_observation; - } + // both states are equal but not of type REF or O + W[end_pos * num_states_ + seqP_region->type] += branch_length_to_observation; } else if(seqP_region->type <= TYPE_R && seqC_region->type <= TYPE_R) { - //states are not equal + //states are not equal but neither is O StateType stateA = seqP_region->type; StateType stateB = seqC_region->type; if(seqP_region->type == TYPE_R) { @@ -352,19 +483,17 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { stateB = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; } // Case 1: Last observation was this side of the root node - if(seqP_region->plength_observation2root <= 0) { - for(int i = pos; i <= end_pos; i++) { - W[i * num_states_ + stateA] += branch_length_to_observation/2; - W[i * num_states_ + stateB] += branch_length_to_observation/2; - C[i * mat_size + stateB + row_index[stateA]] += 1; - } + if(seqP_region->plength_observation2root < 0) { + W[end_pos * num_states_ + stateA] += branch_length_to_observation/2; + W[end_pos * num_states_ + stateB] += branch_length_to_observation/2; + C[end_pos * mat_size + stateB + row_index[stateA]] += 1; } else { // Case 2: Last observation was the other side of the root. // In this case there are two further cases - the mutation happened either side of the root. // We calculate the relative likelihood of each case and use this to weight waiting times etc. RealNumType dist_to_root = seqP_region->plength_observation2root + blength; RealNumType dist_to_observed = seqP_region->plength_observation2node; - updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C); + updateCountsAndWaitingTimesAcrossRoot(end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C); } } else if(seqP_region->type <= TYPE_R && seqC_region->type == TYPE_O) { StateType stateA = seqP_region->type; @@ -372,24 +501,9 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { stateA = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; } - // Calculate a weight vector giving the relative probabilities of observing + // Get weight vector giving the relative probabilities of observing // each state at the O node. - std::vector weight_vector(num_states_); - RealNumType sum = 0.0; - for(StateType stateB = 0; stateB < num_states_; stateB++) { - RealNumType likelihoodB = std::round(seqC_region->getLH(stateB) * 1000) / 1000.0; - if(stateB != stateA) { - RealNumType prob = likelihoodB * branch_length_to_observation * getMutationMatrixEntry(stateA, stateB, end_pos); - weight_vector[stateB] += prob; - sum += prob; - } else { - RealNumType prob = likelihoodB * (1 - branch_length_to_observation * getMutationMatrixEntry(stateB, stateB, end_pos)); - weight_vector[stateB] += prob; - sum += prob; - } - } - // Normalise weight vector - normalize_arr(weight_vector.data(), num_states_, sum); + std::vector weight_vector = getRelativeProbabilityOfChildOStatesForRegion(seqC_region, stateA, branch_length_to_observation, end_pos); // Case 1: Last observation was this side of the root node if(seqP_region->plength_observation2root < 0) { @@ -410,7 +524,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { RealNumType dist_to_observed = seqP_region->plength_observation2node; for(StateType stateB = 0; stateB < num_states_; stateB++) { RealNumType prob = weight_vector[stateB]; - updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); + updateCountsAndWaitingTimesAcrossRoot(end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); } } } else if(seqP_region->type == TYPE_O && seqC_region->type <= TYPE_R) { @@ -418,27 +532,11 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { if(seqC_region->type == TYPE_R) { stateB = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; } - // Calculate a weight vector giving the relative probabilities of observing // each state at the O node. - std::vector weight_vector(num_states_); - RealNumType sum = 0.0; - for(StateType stateA = 0; stateA < num_states_; stateA++) { - RealNumType likelihoodA = std::round(seqP_region->getLH(stateA) * 1000)/1000.0; - if(stateA != stateB) { - RealNumType prob = likelihoodA * branch_length_to_observation * getMutationMatrixEntry(stateA, stateB, end_pos); - weight_vector[stateA] += prob; - sum += prob; - } else { - RealNumType prob = likelihoodA * (1 - branch_length_to_observation * getMutationMatrixEntry(stateA, stateA, end_pos)); - weight_vector[stateA] += prob; - sum += prob; - } - } - // Normalise weight vector - normalize_arr(weight_vector.data(), num_states_, sum); + std::vector weight_vector = getRelativeProbabilityOfParentOStatesForRegion(seqP_region, stateB, branch_length_to_observation, end_pos); - // Case 1: Last observation was this side of the root node + // Case 1: Last observation was this side of the root node if(seqP_region->plength_observation2root < 0) { for(StateType stateA = 0; stateA < num_states_; stateA++) { RealNumType prob = weight_vector[stateA]; @@ -457,31 +555,13 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { RealNumType dist_to_observed = seqP_region->plength_observation2node; for(StateType stateA = 0; stateA < num_states_; stateA++) { RealNumType prob = weight_vector[stateA]; - updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); - } + updateCountsAndWaitingTimesAcrossRoot(end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); + } } } else if(seqP_region->type == TYPE_O && seqC_region->type == TYPE_O) { - // Calculate a weight vector giving the relative probabilities of observing + // Get weight vector giving the relative probabilities of observing // each state at each of the O nodes. - std::vector weight_vector(num_states_ * num_states_); - RealNumType sum = 0.0; - for(StateType stateA = 0; stateA < num_states_; stateA++) { - RealNumType likelihoodA = std::round(seqP_region->getLH(stateA) * 1000)/1000.0; - for(StateType stateB = 0; stateB < num_states_; stateB++) { - RealNumType likelihoodB = std::round(seqC_region->getLH(stateB)*1000)/1000.0; - if(stateA != stateB) { - RealNumType prob = likelihoodA * likelihoodB * branch_length_to_observation * getMutationMatrixEntry(stateA, stateB, end_pos); - weight_vector[row_index[stateA] + stateB] += prob; - sum += prob; - } else { - RealNumType prob = likelihoodA * likelihoodB * (1 - branch_length_to_observation * getMutationMatrixEntry(stateA, stateA, end_pos)); - weight_vector[row_index[stateA] + stateB] += prob; - sum += prob; - } - } - } - // Normalise weight vector - normalize_arr(weight_vector.data(), num_states_*num_states_, sum); + std::vector weight_vector = getRelativeProbabilityOfParentOChildOStatesForRegion(seqP_region, seqC_region, branch_length_to_observation, end_pos); // Case 1: Last observation was this side of the root node if(seqP_region->plength_observation2root < 0) { @@ -499,17 +579,17 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { } } } else { - // Case 2: Last observation was the other side of the root. + // Case 2: Last observation was the other side of the root. RealNumType dist_to_root = seqP_region->plength_observation2root + blength; RealNumType dist_to_observed = seqP_region->plength_observation2node; for(StateType stateA = 0; stateA < num_states_; stateA++) { for(StateType stateB = 0; stateB < num_states_; stateB++) { RealNumType prob = weight_vector[row_index[stateA] + stateB]; - updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); + updateCountsAndWaitingTimesAcrossRoot(end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); } } } - } + } pos = end_pos + 1; } } @@ -628,7 +708,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { RealNumType val = mutation_matrices[i * mat_size + (stateB + row_index[stateA])]; //val /= average_rate; val /= total_rate; - val = MIN(250.0, MAX(0.001, val)); + val = std::min(250.0, std::max(0.001, val)); mutation_matrices[i * mat_size + (stateB + row_index[stateA])] = val; transposed_mutation_matrices[i * mat_size + (stateA + row_index[stateB])] = val; @@ -655,28 +735,35 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { delete[] W; } -void ModelDNARateVariation::updateCountsAndWaitingTimesAcrossRoot( PositionType start, PositionType end, - StateType parent_state, StateType child_state, - RealNumType dist_to_root, RealNumType dist_to_observed, - RealNumType* waiting_times, RealNumType* counts, - RealNumType weight) +void ModelDNARateVariation::updateCountsAndWaitingTimesAcrossRoot( + PositionType genome_pos, + StateType parent_state, StateType child_state, + RealNumType dist_to_root, RealNumType dist_to_observed, + RealNumType* waiting_times, RealNumType* counts, + RealNumType weight) { if(parent_state != child_state) { - for(int i = start; i <= end; i++) { - RealNumType p_root_is_state_parent = root_freqs[parent_state] * getMutationMatrixEntry(parent_state, child_state, i) * dist_to_root; - RealNumType p_root_is_state_child = root_freqs[child_state] * getMutationMatrixEntry(child_state, parent_state, i) * dist_to_observed; - RealNumType relative_root_is_state_parent = p_root_is_state_parent / (p_root_is_state_parent + p_root_is_state_child); - waiting_times[i * num_states_ + parent_state] += weight * relative_root_is_state_parent * dist_to_root/2; - waiting_times[i * num_states_ + child_state] += weight * relative_root_is_state_parent * dist_to_root/2; - counts[i * mat_size + child_state + row_index[parent_state]] += weight * relative_root_is_state_parent; + RealNumType p_root_is_state_parent = root_freqs[parent_state] * getMutationMatrixEntry(parent_state, child_state, genome_pos) * dist_to_root; + RealNumType p_root_is_state_child = root_freqs[child_state] * getMutationMatrixEntry(child_state, parent_state, genome_pos) * dist_to_observed; + RealNumType relative_root_is_state_parent = p_root_is_state_parent / (p_root_is_state_parent + p_root_is_state_child); + + // only update waiting times for this side of the root + waiting_times[genome_pos * num_states_ + parent_state] += weight * relative_root_is_state_parent * dist_to_root/2; + waiting_times[genome_pos * num_states_ + child_state] += weight * relative_root_is_state_parent * dist_to_root/2; - RealNumType relative_root_is_state_child = 1 - relative_root_is_state_parent; - waiting_times[i * num_states_ + child_state] += weight * relative_root_is_state_child * dist_to_root; + // Counts array depends on model type: + // scalar rate variation has int per position + // matrix rate variaition has 4x4 matrix per position + int index = genome_pos; + if(!scalar_rate_model) { + index = genome_pos * mat_size + child_state + row_index[parent_state]; } + counts[index] += weight * relative_root_is_state_parent; + + RealNumType relative_root_is_state_child = 1 - relative_root_is_state_parent; + waiting_times[genome_pos * num_states_ + child_state] += weight * relative_root_is_state_child * dist_to_root; } else { - for(int i = start; i <= end; i++) { - waiting_times[i * num_states_ + child_state] += weight * dist_to_root; - } + waiting_times[genome_pos * num_states_ + child_state] += weight * dist_to_root; } } @@ -789,4 +876,88 @@ void ModelDNARateVariation::readRatesFile() { else { std::cerr << "Unable to open rate matrix file " << rates_filename << std::endl; } +} + +std::vector ModelDNARateVariation::getRelativeProbabilityOfParentOStatesForRegion( + const SeqRegion* seqP_region, + StateType child_state, + RealNumType branch_length_to_obs, + PositionType genome_pos) +{ + assert(seqP_region->type == TYPE_O); + std::vector weight_vector(num_states_); + RealNumType sum = 0.0; + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + RealNumType likelihood = seqP_region->getLH(parent_state); + RealNumType site_specific_mutation_rate = getMutationMatrixEntry(parent_state, child_state, genome_pos); + if(parent_state != child_state) { + RealNumType prob = likelihood * branch_length_to_obs * site_specific_mutation_rate; + weight_vector[parent_state] += prob; + sum += prob; + } else { + RealNumType prob = likelihood * (1 - branch_length_to_obs * site_specific_mutation_rate); + weight_vector[parent_state] += prob; + sum += prob; + } + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states_, sum); + return weight_vector; +} + +std::vector ModelDNARateVariation::getRelativeProbabilityOfChildOStatesForRegion( + const SeqRegion* seqC_region, + StateType parent_state, + RealNumType branch_length_to_obs, + PositionType genome_pos) +{ + assert(seqC_region->type == TYPE_O); + std::vector weight_vector(num_states_); + RealNumType sum = 0.0; + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType likelihood = seqC_region->getLH(child_state); + RealNumType site_specific_mutation_rate = getMutationMatrixEntry(parent_state, child_state, genome_pos); + if(parent_state != child_state) { + RealNumType prob = likelihood * branch_length_to_obs * site_specific_mutation_rate; + weight_vector[child_state] += prob; + sum += prob; + } else { + RealNumType prob = likelihood * (1 - branch_length_to_obs * site_specific_mutation_rate); + weight_vector[child_state] += prob; + sum += prob; + } + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states_, sum); + return weight_vector; +} + +std::vector ModelDNARateVariation::getRelativeProbabilityOfParentOChildOStatesForRegion( + const cmaple::SeqRegion* seqP_region, + const cmaple::SeqRegion* seqC_region, + RealNumType branch_length_to_obs, + PositionType genome_pos) +{ + assert(seqP_region->type == TYPE_O && seqC_region->type == TYPE_O); + std::vector weight_vector(num_states_ * num_states_); + RealNumType sum = 0.0; + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + RealNumType parent_likelihood = seqP_region->getLH(parent_state); + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType child_likelihood = seqC_region->getLH(child_state); + RealNumType site_specific_mutation_rate = getMutationMatrixEntry(parent_state, child_state, genome_pos); + if(parent_state != child_state) { + RealNumType prob = parent_likelihood * child_likelihood * branch_length_to_obs * site_specific_mutation_rate; + weight_vector[row_index[parent_state] + child_state] += prob; + sum += prob; + } else { + RealNumType prob = parent_likelihood * child_likelihood * (1 - branch_length_to_obs * site_specific_mutation_rate); + weight_vector[row_index[parent_state] + child_state] += prob; + sum += prob; + } + } + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states_*num_states_, sum); + return weight_vector; } \ No newline at end of file diff --git a/model/model_dna_rate_variation.h b/model/model_dna_rate_variation.h index 3ad1a776..80e753a5 100644 --- a/model/model_dna_rate_variation.h +++ b/model/model_dna_rate_variation.h @@ -12,9 +12,14 @@ class Tree; /** Class of DNA evolutionary models with rate variation */ class ModelDNARateVariation : public ModelDNA { public: - ModelDNARateVariation( const cmaple::ModelBase::SubModel sub_model, PositionType _genome_size, - bool _use_site_rates, cmaple::RealNumType _wt_pseudocount, std::string _rates_filename, - int _max_num_EM_steps); + ModelDNARateVariation( + const cmaple::ModelBase::SubModel sub_model, + PositionType _genome_size, + bool _use_site_rates, + cmaple::RealNumType _wt_pseudocount, + std::string _rates_filename, + int _max_num_EM_steps); + virtual ~ModelDNARateVariation(); void estimateRates(cmaple::Tree* tree); @@ -80,7 +85,7 @@ class ModelDNARateVariation : public ModelDNA { private: - void updateCountsAndWaitingTimesAcrossRoot( PositionType start, PositionType end, + void updateCountsAndWaitingTimesAcrossRoot( PositionType genome_pos, StateType parent_state, StateType child_state, RealNumType dist_to_root, RealNumType dist_to_observed, RealNumType* waiting_times, RealNumType* counts, @@ -88,6 +93,19 @@ class ModelDNARateVariation : public ModelDNA { void readRatesFile(); + std::vector getRelativeProbabilityOfParentOStatesForRegion( const cmaple::SeqRegion* seqP_region, + StateType child_state, + RealNumType branch_length_to_obs, + PositionType genome_pos); + std::vector getRelativeProbabilityOfChildOStatesForRegion( const cmaple::SeqRegion* seqC_region, + StateType parent_state, + RealNumType branch_length_to_obs, + PositionType genome_pos); + std::vector getRelativeProbabilityOfParentOChildOStatesForRegion(const cmaple::SeqRegion* seqP_region, + const cmaple::SeqRegion* seqC_region, + RealNumType branch_length_to_obs, + PositionType genome_pos); + cmaple::PositionType genome_size; cmaple::RealNumType* mutation_matrices = nullptr; @@ -97,7 +115,7 @@ class ModelDNARateVariation : public ModelDNA { cmaple::RealNumType* freqj_transposedijs = nullptr; cmaple::RealNumType* rates = nullptr; uint16_t mat_size; - bool use_site_rates = false; + bool scalar_rate_model = false; bool rates_estimated = false; cmaple::RealNumType waiting_time_pseudocount; From 21319c7c7a8058bb7cfc6e62c320508b39b4db86 Mon Sep 17 00:00:00 2001 From: Samuel Martin <39562021+SR-Martin@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:48:22 +0000 Subject: [PATCH 5/8] Update model_dna_rate_variation.cpp Fix bug in scalar rate variation estimation. --- model/model_dna_rate_variation.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/model_dna_rate_variation.cpp b/model/model_dna_rate_variation.cpp index eca23be3..e4bb06cf 100644 --- a/model/model_dna_rate_variation.cpp +++ b/model/model_dna_rate_variation.cpp @@ -358,7 +358,7 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){ for(int i = 0; i < genome_size; i++) { RealNumType expected_rate_no_substitution = 0; for(int j = 0; j < num_states_; j++) { - RealNumType summand = waiting_times[i * num_states_ + j] * abs(diagonal_mut_mat[row_index[j] + j]); + RealNumType summand = waiting_times[i * num_states_ + j] * abs(diagonal_mut_mat[j]); expected_rate_no_substitution += summand; } rates[i] = (num_substitutions[i]+1) / (expected_rate_no_substitution+1); From fd95a17c645f95d7916e97450f96689eb98a681b Mon Sep 17 00:00:00 2001 From: Samuel Martin <39562021+SR-Martin@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:54:03 +0000 Subject: [PATCH 6/8] Updates to MAT estimates When writing out MATs, mutation support is now spread across all branches on the path between the nodes with the observations. When this path goes across the root, we consider separately the subpaths from the root and use the root distribution to estimate support. --- tree/tree.cpp | 201 ++++++++++++++++++++++++++++++++++++-------------- tree/tree.h | 25 +++++-- 2 files changed, 166 insertions(+), 60 deletions(-) diff --git a/tree/tree.cpp b/tree/tree.cpp index 1dbec851..35c0806d 100644 --- a/tree/tree.cpp +++ b/tree/tree.cpp @@ -1613,6 +1613,39 @@ std::string cmaple::Tree::exportNodeString(const bool is_newick_format, return output; } +std::string cmaple::Tree::getMutationStringAcrossRoot( + RealNumType dist_to_root, + RealNumType dist_to_observed, + StateType parent_state, + StateType child_state, + PositionType pos, + RealNumType weight, + SeqRegion::SeqType seq_type +) +{ + assert(parent_state != child_state); + std::string mutation_string = ""; + RealNumType p_root_is_state_parent = model->root_freqs[parent_state] * model->getMutationMatrixEntry(parent_state, child_state, pos) * dist_to_root; + RealNumType p_root_is_state_child = model->root_freqs[child_state] * model->getMutationMatrixEntry(child_state, parent_state, pos) * dist_to_observed; + RealNumType relative_root_is_state_parent = p_root_is_state_parent / (p_root_is_state_parent + p_root_is_state_child); + RealNumType parent_support = weight * relative_root_is_state_parent; + RealNumType child_support = weight * (1 - relative_root_is_state_parent); + + if(parent_support >= min_mutation_support) { + mutation_string += aln->convertState2Char(parent_state, seq_type) + + std::to_string(pos+1) + + aln->convertState2Char(child_state, seq_type) + + ":" + std::to_string(parent_support) + ","; + } + if(child_support >= min_mutation_support) { + mutation_string += aln->convertState2Char(child_state, seq_type) + + std::to_string(pos+1) + + aln->convertState2Char(parent_state, seq_type) + + ":" + std::to_string(child_support) + ","; + } + return mutation_string; +} + std::string cmaple::Tree::getMutationStringForNode(cmaple::PhyloNode& node) { @@ -1645,46 +1678,50 @@ std::string cmaple::Tree::getMutationStringForNode(cmaple::PhyloNode& node) const auto* seqP_region = &seqP_regions[iseq1]; const auto* seqC_region = &seqC_regions[iseq2]; - // if the child of this branch does not observe its state directly then - // skip this branch. - if(seqC_region->plength_observation2node > 0) { - pos = end_pos + 1; - continue; - } - // distance to last observation or root if last observation was across the root. RealNumType branch_length_to_observation = blength; - if(seqP_region->plength_observation2node > 0 && seqP_region->plength_observation2root <= 0) { - branch_length_to_observation = blength + seqP_region->plength_observation2node; + if(seqP_region->plength_observation2node > 0 && seqP_region->plength_observation2root < 0) { + branch_length_to_observation += seqP_region->plength_observation2node; } - else if(seqP_region->plength_observation2root > 0) { - branch_length_to_observation = blength + seqP_region->plength_observation2root; + if(seqC_region->plength_observation2node > 0 && seqC_region->plength_observation2root < 0) { + branch_length_to_observation += seqC_region->plength_observation2node; } + RealNumType blength_weight = blength / branch_length_to_observation; if(seqP_region->type != seqC_region->type && seqP_region->type <= TYPE_R && seqC_region->type <= TYPE_R) { - StateType stateA = seqP_region->type; - StateType stateB = seqC_region->type; - if(seqP_region->type == TYPE_R) - { - stateA = aln->ref_seq[static_cast::size_type>(end_pos)]; - } - if(seqC_region->type == TYPE_R) - { - stateB = aln->ref_seq[static_cast::size_type>(end_pos)]; + StateType stateA = seqP_region->type; + StateType stateB = seqC_region->type; + if(seqP_region->type == TYPE_R) + { + stateA = aln->ref_seq[static_cast::size_type>(pos)]; + } + if(seqC_region->type == TYPE_R) + { + stateB = aln->ref_seq[static_cast::size_type>(pos)]; + } + if(seqP_region->plength_observation2root < 0) + { + if(blength_weight >= min_mutation_support) { + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos+1) + + aln->convertState2Char(stateB, seq_type) + + ":" + std::to_string(blength_weight) + ","; } - mutation_string += aln->convertState2Char(stateA, seq_type) + - std::to_string(pos+1) + - aln->convertState2Char(stateB, seq_type) + ":1.0,"; + } else { + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + mutation_string += getMutationStringAcrossRoot(dist_to_root, dist_to_observed, stateA, stateB, pos, blength_weight, seq_type); + } } else if(seqP_region->type <= TYPE_R && seqC_region->type == TYPE_O) { StateType stateA = seqP_region->type; if(seqP_region->type == TYPE_R) { - stateA = aln->ref_seq[static_cast::size_type>(end_pos)]; + stateA = aln->ref_seq[static_cast::size_type>(pos)]; } // Calculate a weight vector giving the relative probabilities of observing // each state at the O node. @@ -1697,10 +1734,10 @@ std::string cmaple::Tree::getMutationStringForNode(cmaple::PhyloNode& node) if(stateB != stateA) { prob = likelihoodB * branch_length_to_observation - * model->getMutationMatrixEntry(stateA, stateB, end_pos); + * model->getMutationMatrixEntry(stateA, stateB, pos); } else { prob = likelihoodB * (1 - branch_length_to_observation - * model->getMutationMatrixEntry(stateB, stateB, end_pos)); + * model->getMutationMatrixEntry(stateB, stateB, pos)); } weight_vector[stateB] = prob; sum += prob; @@ -1709,14 +1746,30 @@ std::string cmaple::Tree::getMutationStringForNode(cmaple::PhyloNode& node) normalize_arr(weight_vector.data(), num_states, sum); // write out mutations - for(StateType stateB = 0; stateB < num_states; stateB++) - { - if (stateB != stateA && weight_vector[stateB] > 0.01) + if(seqP_region->plength_observation2root < 0) { + for(StateType stateB = 0; stateB < num_states; stateB++) { - mutation_string += aln->convertState2Char(stateA, seq_type) + - std::to_string(pos+1) + - aln->convertState2Char(stateB, seq_type) + ":" + - std::to_string(weight_vector[stateB]) + "," ; + if (stateB != stateA) + { + RealNumType mutation_support = weight_vector[stateB] * blength_weight; + if(mutation_support >= min_mutation_support) + { + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos+1) + + aln->convertState2Char(stateB, seq_type) + ":" + + std::to_string(mutation_support) + "," ; + } + } + } + } else { + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType stateB = 0; stateB < num_states; stateB++) + { + if(stateA != stateB) { + RealNumType relative_stateB = weight_vector[stateB]; + mutation_string += getMutationStringAcrossRoot(dist_to_root, dist_to_observed, stateA, stateB, pos, relative_stateB * blength_weight, seq_type); + } } } } @@ -1725,7 +1778,7 @@ std::string cmaple::Tree::getMutationStringForNode(cmaple::PhyloNode& node) StateType stateB = seqC_region->type; if(seqC_region->type == TYPE_R) { - stateB = aln->ref_seq[static_cast::size_type>(end_pos)]; + stateB = aln->ref_seq[static_cast::size_type>(pos)]; } // Calculate a weight vector giving the relative probabilities of observing // each state at the O node. @@ -1738,10 +1791,10 @@ std::string cmaple::Tree::getMutationStringForNode(cmaple::PhyloNode& node) if(stateA != stateB) { prob = likelihoodA * branch_length_to_observation - * model->getMutationMatrixEntry(stateA, stateB, end_pos); + * model->getMutationMatrixEntry(stateA, stateB, pos); } else { prob = likelihoodA * (1 - branch_length_to_observation - * model->getMutationMatrixEntry(stateA, stateA, end_pos)); + * model->getMutationMatrixEntry(stateA, stateA, pos)); } weight_vector[stateA] = prob; sum += prob; @@ -1750,15 +1803,32 @@ std::string cmaple::Tree::getMutationStringForNode(cmaple::PhyloNode& node) normalize_arr(weight_vector.data(), num_states, sum); // write out mutations - for(StateType stateA = 0; stateA < num_states; stateA++) - { - if (stateA != stateB && weight_vector[stateA] > 0.01) + if(seqP_region->plength_observation2root < 0) + { + for(StateType stateA = 0; stateA < num_states; stateA++) { - mutation_string += aln->convertState2Char(stateA, seq_type) + - std::to_string(pos + 1) + - aln->convertState2Char(stateB, seq_type) + ":" + - std::to_string(weight_vector[stateA]) + "," ; + if (stateA != stateB) + { + RealNumType mutation_support = weight_vector[stateA] * blength_weight; + if(mutation_support >= min_mutation_support) + { + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos + 1) + + aln->convertState2Char(stateB, seq_type) + ":" + + std::to_string(mutation_support) + "," ; + } + } } + } else { + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType stateA = 0; stateA < num_states; stateA++) + { + if(stateA != stateB) { + RealNumType relative_stateA = weight_vector[stateA]; + mutation_string += getMutationStringAcrossRoot(dist_to_root, dist_to_observed, stateA, stateB, pos, relative_stateA * blength_weight, seq_type); + } + } } } else if(seqP_region->type == TYPE_O && seqC_region->type == TYPE_O) @@ -1767,17 +1837,19 @@ std::string cmaple::Tree::getMutationStringForNode(cmaple::PhyloNode& node) // each state at each of the O nodes. std::vector weight_vector(num_states * num_states); RealNumType sum = 0.0; - for(StateType stateA = 0; stateA < num_states; stateA++) { + for(StateType stateA = 0; stateA < num_states; stateA++) + { RealNumType likelihoodA = seqP_region->getLH(stateA); - for(StateType stateB = 0; stateB < num_states; stateB++) { + for(StateType stateB = 0; stateB < num_states; stateB++) + { RealNumType likelihoodB = seqC_region->getLH(stateB); RealNumType prob = 0.; if(stateA != stateB) { prob = likelihoodA * likelihoodB * branch_length_to_observation - * model->getMutationMatrixEntry(stateA, stateB, end_pos); + * model->getMutationMatrixEntry(stateA, stateB, pos); } else { prob = likelihoodA * likelihoodB * (1 - branch_length_to_observation - * model->getMutationMatrixEntry(stateA, stateA, end_pos)); + * model->getMutationMatrixEntry(stateA, stateA, pos)); } weight_vector[model->row_index[stateA] + stateB] = prob; sum += prob; @@ -1787,18 +1859,37 @@ std::string cmaple::Tree::getMutationStringForNode(cmaple::PhyloNode& node) normalize_arr(weight_vector.data(), num_states * num_states, sum); // write out mutations - for(StateType stateA = 0; stateA < num_states; stateA++) - { - for(StateType stateB =0; stateB < num_states; stateB++) + if(seqP_region->plength_observation2root < 0) { + for(StateType stateA = 0; stateA < num_states; stateA++) { - if (stateA != stateB && weight_vector[model->row_index[stateA] + stateB] > 0.01) + for(StateType stateB = 0; stateB < num_states; stateB++) { - mutation_string += aln->convertState2Char(stateA, seq_type) + - std::to_string(pos + 1) + - aln->convertState2Char(stateB, seq_type) + ":" + - std::to_string(weight_vector[model->row_index[stateA] + stateB]) + "," ; + if (stateA != stateB) + { + RealNumType mutation_support = weight_vector[model->row_index[stateA] + stateB] * blength_weight; + if(mutation_support >= min_mutation_support) + { + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos + 1) + + aln->convertState2Char(stateB, seq_type) + ":" + + std::to_string(mutation_support) + "," ; + } + } } } + } else { + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType stateA = 0; stateA < num_states; stateA++) + { + for(StateType stateB = 0; stateB < num_states; stateB++) + { + if(stateA != stateB) { + RealNumType relative_stateAB = weight_vector[model->row_index[stateA] + stateB]; + mutation_string += getMutationStringAcrossRoot(dist_to_root, dist_to_observed, stateA, stateB, pos, relative_stateAB * blength_weight, seq_type); + } + } + } } } pos = end_pos + 1; diff --git a/tree/tree.h b/tree/tree.h index dd2417ff..be5d8e5c 100644 --- a/tree/tree.h +++ b/tree/tree.h @@ -469,11 +469,26 @@ class Tree { /*! \endcond */ private: - /** - * Get mutation string for MATs - */ - std::string getMutationStringForNode(cmaple::PhyloNode& node); - + /** + * Get mutation string for MATs + */ + std::string getMutationStringForNode(cmaple::PhyloNode& node); + + /** + * Minimum support for writing a mutation to MAT + */ + RealNumType min_mutation_support = 0.01; + /** + * Helper function when last observation of state goes across the root. + */ + std::string getMutationStringAcrossRoot( + RealNumType dist_to_root, + RealNumType dist_to_observed, + StateType parent_state, + StateType child_state, + PositionType pos, + RealNumType weight, + SeqRegion::SeqType seq_type ); /** Pointer to LoadTree method */ From f8fc7c255be16bac2b03104633c265d9ba09abb6 Mon Sep 17 00:00:00 2001 From: Samuel Martin <39562021+SR-Martin@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:32:42 +0000 Subject: [PATCH 7/8] Update usage_cmaple Add details on rate-variation related options to usage_cmaple(). --- utils/tools.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/utils/tools.cpp b/utils/tools.cpp index 1a95287d..3712a3c3 100644 --- a/utils/tools.cpp +++ b/utils/tools.cpp @@ -1530,6 +1530,17 @@ void cmaple::usage_cmaple() { << endl << " alternative SPRs." << endl + << "RATE VARIATION MODELS:" << endl + << " --rate-variation Use a model of rate variation where each site " << endl + << " has an independent scalar rate multiplier." << endl + << " --site-specific-rate-matrix Use a model of rate variation where each site " << endl + << " has an independent rate matrix." << endl + << " --estimate-rates-during-SPR Re-estimate rates after every SPR tree traversal " << endl + << " (default: only after initial tree construction)." << endl + << " --waiting-time-pseudocount Set the waiting-time pseudocount (default: 1)." << endl + << " --rv-max-EM-steps . Maximum number of steps to attempt for EM " << endl + << " convergence when estimating rates with " << endl + << " --site-specific-rate-matrix (default: 20)." << endl << endl; exit(0); From 54a0a290acda073d6386b88e87ca677c1f9b4e55 Mon Sep 17 00:00:00 2001 From: Samuel Martin <39562021+SR-Martin@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:35:44 +0000 Subject: [PATCH 8/8] Update usage_cmaple 2 Add --estimate-MAT to usage_cmaple(). --- utils/tools.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/utils/tools.cpp b/utils/tools.cpp index 3712a3c3..b044de03 100644 --- a/utils/tools.cpp +++ b/utils/tools.cpp @@ -1497,6 +1497,9 @@ void cmaple::usage_cmaple() { << " --mean-subs Specify the mean #substitutions per site" << endl << " that CMAPLE is effective. Default: 0.02." << endl + << " --estimate-MAT Write a mutation-annotated tree (MAT) to " << endl + << " nexus file." + << endl << " --seed Set a seed number for random generators." << endl << " -v Set the verbose mode "