diff --git a/maple/cmaple.cpp b/maple/cmaple.cpp index 76b029e8..8c8ea8fe 100644 --- a/maple/cmaple.cpp +++ b/maple/cmaple.cpp @@ -122,8 +122,8 @@ void cmaple::runCMAPLE(cmaple::Params ¶ms) throw std::invalid_argument("Unknown Model " + params.sub_model_str); } assert(sub_model != cmaple::ModelBase::UNKNOWN); - bool useRateVariationModel = params.rate_variation || params.site_specific_rates; - Model model(aln.ref_seq.size(), useRateVariationModel, params.rate_variation, params.wt_pseudocount, params.rates_filename, sub_model, aln.getSeqType()); + bool useRateVariationModel = params.rate_variation || params.site_specific_rate_matrix; + Model model(aln.ref_seq.size(), useRateVariationModel, params.rate_variation, params.wt_pseudocount, params.rates_filename, params.rate_variation_max_num_EM_steps, sub_model, aln.getSeqType()); // If users only want to convert the alignment to another format -> convert it and terminate if (params.output_aln.length()) @@ -197,6 +197,16 @@ void cmaple::runCMAPLE(cmaple::Params ¶ms) out << tree.exportTSV(); out.close(); } + + // export MAT if selected + if(params.output_MAT) + { + std::string filename = params.output_prefix + "_MAT.nex"; + std::cout << "Writing MAT to file " << filename << std::endl; + ofstream out = ofstream(filename); + out << tree.exportNexus(tree_format, false, true); + out.close(); + } // output log-likelihood of the tree if (cmaple::verbose_mode > cmaple::VB_QUIET) { @@ -218,10 +228,12 @@ void cmaple::runCMAPLE(cmaple::Params ¶ms) // Show information about output files std::cout << "Analysis results written to:" << std::endl; std::cout << "Maximum-likelihood tree: " << output_treefile << std::endl; + if (params.output_MAT) + std::cout << "Estimated mutation-annotated tree (MAT): " << output_treefile + "_MAT.nwk" << std::endl; if (params.output_NEXUS || params.compute_SPRTA) - std::cout << "Tree in NEXUS format: " << output_treefile + ".nex" << std::endl; + std::cout << "Tree in NEXUS format: " << output_treefile + ".nex" << std::endl; if (params.compute_SPRTA && params.output_alternative_spr) - std::cout << "Meta data in TSV format: " << output_treefile + ".tsv" << std::endl; + std::cout << "Meta data in TSV format: " << output_treefile + ".tsv" << std::endl; /*if (params.compute_aLRT_SH) { std::cout << "Tree with aLRT-SH values: " << prefix + ".aLRT_SH.treefile" << std::endl; diff --git a/model/model.cpp b/model/model.cpp index ccbaac86..34a9af3b 100644 --- a/model/model.cpp +++ b/model/model.cpp @@ -11,6 +11,7 @@ cmaple::Model::Model( cmaple::PositionType ref_genome_size, bool _siteRates, cmaple::RealNumType wt_pseudocount, const std::string _rates_filename, + int _max_num_EM_steps, const cmaple::ModelBase::SubModel sub_model, const cmaple::SeqRegion::SeqType seqtype) : model_base(nullptr) { @@ -75,7 +76,7 @@ cmaple::Model::Model( cmaple::PositionType ref_genome_size, } case cmaple::SeqRegion::SEQ_DNA: { if(rate_variation){ - model_base = new ModelDNARateVariation(n_sub_model, ref_genome_size, _siteRates, wt_pseudocount, _rates_filename); + model_base = new ModelDNARateVariation(n_sub_model, ref_genome_size, _siteRates, wt_pseudocount, _rates_filename, _max_num_EM_steps); } else { model_base = new ModelDNA(n_sub_model); } diff --git a/model/model.h b/model/model.h index 0ca3ebe3..be9fe950 100644 --- a/model/model.h +++ b/model/model.h @@ -51,6 +51,7 @@ class Model { bool _siteRates, cmaple::RealNumType wt_pseudocount, const std::string _rates_filename, + int _max_num_EM_steps, const cmaple::ModelBase::SubModel sub_model = cmaple::ModelBase::DEFAULT, const cmaple::SeqRegion::SeqType seqtype = cmaple::SeqRegion::SEQ_AUTO); diff --git a/model/model_dna_rate_variation.cpp b/model/model_dna_rate_variation.cpp index eb42df44..e4bb06cf 100644 --- a/model/model_dna_rate_variation.cpp +++ b/model/model_dna_rate_variation.cpp @@ -4,15 +4,21 @@ using namespace cmaple; -ModelDNARateVariation::ModelDNARateVariation( const cmaple::ModelBase::SubModel sub_model, PositionType _genome_size, - bool _use_site_rates, cmaple::RealNumType _wt_pseudocount, std::string _rates_filename) +ModelDNARateVariation::ModelDNARateVariation( + const cmaple::ModelBase::SubModel sub_model, + PositionType _genome_size, + bool _scalar_rate_model, + cmaple::RealNumType _wt_pseudocount, + std::string _rates_filename, + int _max_num_EM_steps) : ModelDNA(sub_model) { genome_size = _genome_size; - use_site_rates = _use_site_rates; + scalar_rate_model = _scalar_rate_model; mat_size = row_index[num_states_]; waiting_time_pseudocount = _wt_pseudocount; rates_filename = _rates_filename; + max_num_EM_steps = _max_num_EM_steps; mutation_matrices = new RealNumType[mat_size * genome_size](); transposed_mutation_matrices = new RealNumType[mat_size * genome_size](); @@ -20,7 +26,7 @@ ModelDNARateVariation::ModelDNARateVariation( const cmaple::ModelBase::SubModel freqi_freqj_Qijs = new RealNumType[mat_size * genome_size](); freqj_transposedijs = new RealNumType[mat_size * genome_size](); - if(use_site_rates) { + if(scalar_rate_model) { rates = new cmaple::RealNumType[genome_size](); } if(rates_filename.length() > 0) { @@ -34,7 +40,7 @@ ModelDNARateVariation::~ModelDNARateVariation() { delete[] diagonal_mutation_matrices; delete[] freqi_freqj_Qijs; delete[] freqj_transposedijs; - if(use_site_rates) { + if(scalar_rate_model) { delete[] rates; } } @@ -88,22 +94,47 @@ bool ModelDNARateVariation::updateMutationMatEmpirical() { void ModelDNARateVariation::estimateRates(cmaple::Tree* tree) { rates_estimated = true; - if(use_site_rates) { - estimateRatePerSite(tree); - - } else { - if(rates_filename.size() == 0) { - RealNumType old_LK = -std::numeric_limits::infinity(); - RealNumType new_LK = tree->computeLh(); + if(rates_filename.size() == 0) { + RealNumType old_LK = -std::numeric_limits::infinity(); + RealNumType new_LK = tree->computeLh(); + + if(cmaple::verbose_mode > VB_MIN) + { + std::string model = scalar_rate_model ? "scalar rate variation" : "site-specific rate matrix"; + std::cout << "Estimation mutation rates under " << model << " model..." << std::endl; + std::cout << "Starting log-LK: " << + std::setprecision(10) << new_LK << std::endl; + } + + if(scalar_rate_model) + { + estimateRatePerSite(tree); + tree->computeCumulativeRate(); + new_LK = tree->computeLh(); + if(cmaple::verbose_mode > VB_MIN) + { + std::cout << "After rate estimation: " << std::setprecision(10) << new_LK << std::endl; + } + } + else + { int num_steps = 0; - while(new_LK - old_LK > 1 && num_steps < 20) { + while(abs(new_LK - old_LK) > 1 && num_steps < max_num_EM_steps) + { estimateRatesPerSitePerEntry(tree); old_LK = new_LK; + tree->computeCumulativeRate(); new_LK = tree->computeLh(); - } - } + if(cmaple::verbose_mode > VB_MIN) + { + std::cout << "EM round " << num_steps + 1 << ": " << + std::setprecision(10) << new_LK << std::endl; + } + num_steps++; + } + } } - + // Write out rate matrices to file if(cmaple::verbose_mode > VB_MIN) { @@ -115,7 +146,7 @@ void ModelDNARateVariation::estimateRates(cmaple::Tree* tree) { printMatrix(getOriginalRateMatrix(), &out_file); for(int i = 0; i < genome_size; i++) { out_file << "Position: " << i << std::endl; - if(use_site_rates) { + if(scalar_rate_model) { out_file << "Rate: " << rates[i] << std::endl; } out_file << "Rate Matrix: " << std::endl; @@ -128,7 +159,7 @@ void ModelDNARateVariation::estimateRates(cmaple::Tree* tree) { } void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){ - std::cout << "Estimating mutation rate per site..." << std::endl; + //std::cout << "Estimating mutation rate per site..." << std::endl; RealNumType* waiting_times = new RealNumType[num_states_ * genome_size]; RealNumType* num_substitutions = new RealNumType[genome_size]; for(int i = 0; i < genome_size; i++) { @@ -165,66 +196,185 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){ const std::unique_ptr& child_regions = node.getPartialLh(TOP); PositionType pos = 0; - const SeqRegions& seq1_regions = *parent_regions; - const SeqRegions& seq2_regions = *child_regions; + const SeqRegions& seqP_regions = *parent_regions; + const SeqRegions& seqC_regions = *child_regions; size_t iseq1 = 0; size_t iseq2 = 0; while(pos < genome_size) { PositionType end_pos; - SeqRegions::getNextSharedSegment(pos, seq1_regions, seq2_regions, iseq1, iseq2, end_pos); - const auto* seq1_region = &seq1_regions[iseq1]; - const auto* seq2_region = &seq2_regions[iseq2]; + SeqRegions::getNextSharedSegment(pos, seqP_regions, seqC_regions, iseq1, iseq2, end_pos); + const auto* seqP_region = &seqP_regions[iseq1]; + const auto* seqC_region = &seqC_regions[iseq2]; + + // if the child of this branch does not observe its state directly then + // skip this branch. + if(seqC_region->plength_observation2node > 0) { + pos = end_pos + 1; + continue; + } + + // distance to last observation or root if last observation was across the root. + RealNumType branch_length_to_observation = blength; + if(seqP_region->plength_observation2node > 0 && seqP_region->plength_observation2root < 0) { + branch_length_to_observation += seqP_region->plength_observation2node; + } + else if(seqP_region->plength_observation2root >= 0) { + branch_length_to_observation += seqP_region->plength_observation2root; + } - if(seq1_region->type == TYPE_R && seq2_region->type == TYPE_R) { + if(seqP_region->type == TYPE_R && seqC_region->type == TYPE_R) { // both states are type REF for(int i = pos; i <= end_pos; i++) { - int state = tree->aln->ref_seq[static_cast::size_type>(i)]; - waiting_times[i * num_states_ + state] += blength; + StateType state = tree->aln->ref_seq[static_cast::size_type>(i)]; + waiting_times[i * num_states_ + state] += branch_length_to_observation; } - } else if(seq1_region->type == seq2_region->type && seq1_region->type < TYPE_R) { + } else if(seqP_region->type == seqC_region->type && seqP_region->type < TYPE_R) { // both states are equal but not of type REF - for(int i = pos; i <= end_pos; i++) { - waiting_times[i * num_states_ + seq1_region->type] += blength; - } - } else if(seq1_region->type <= TYPE_R && seq2_region->type <= TYPE_R) { + waiting_times[end_pos * num_states_ + seqP_region->type] += branch_length_to_observation; + + } else if(seqP_region->type <= TYPE_R && seqC_region->type <= TYPE_R) { // both states are not equal - for(int i = pos; i <= end_pos; i++) { - num_substitutions[i] += 1; - } - } + StateType parent_state = seqP_region->type; + StateType child_state = seqC_region->type; + if(seqP_region->type == TYPE_R) { + parent_state = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; + } + if (seqC_region->type == TYPE_R) { + child_state = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; + } + // Case 1: Last observation was this side of the root node + if(seqP_region->plength_observation2root < 0) { + waiting_times[end_pos * num_states_ + parent_state] += branch_length_to_observation / 2; + waiting_times[end_pos * num_states_ + child_state] += branch_length_to_observation / 2; + num_substitutions[end_pos] += 1; + } else { + // Case 2: Last observation was the other side of the root. + // In this case there are two further cases - the mutation happened either side of the root. + // We calculate the relative likelihood of each case and use this to weight waiting times etc. + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + updateCountsAndWaitingTimesAcrossRoot(end_pos, parent_state, child_state, dist_to_root, dist_to_observed, waiting_times, num_substitutions); + } + } else if(seqP_region->type <= TYPE_R && seqC_region->type == TYPE_O) { + StateType parent_state = seqP_region->type; + if(seqP_region->type == TYPE_R) { + parent_state = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; + } + + // Get weight vector giving the relative probabilities of observing + // each state at the O node. + std::vector weight_vector = getRelativeProbabilityOfChildOStatesForRegion(seqC_region, parent_state, branch_length_to_observation, end_pos); + + // Case 1: Last observation was this side of the root node + if(seqP_region->plength_observation2root < 0) { + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType prob = weight_vector[child_state]; + if(child_state != parent_state) { + num_substitutions[end_pos] += prob; + waiting_times[end_pos * num_states_ + parent_state] += prob * branch_length_to_observation/2; + waiting_times[end_pos * num_states_ + child_state] += prob * branch_length_to_observation/2; + } else { + waiting_times[end_pos * num_states_ + child_state] += prob * branch_length_to_observation; + } + } + } else { + // Case 2: Last observation was the other side of the root. + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType prob = weight_vector[child_state]; + updateCountsAndWaitingTimesAcrossRoot(end_pos, parent_state, child_state, dist_to_root, dist_to_observed, waiting_times, num_substitutions, prob); + } + } + } else if(seqP_region->type == TYPE_O && seqC_region->type <= TYPE_R) { + StateType child_state = seqC_region->type; + if(seqC_region->type == TYPE_R) { + child_state = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; + } + + // Calculate a weight vector giving the relative probabilities of observing + // each state at the O node. + std::vector weight_vector = getRelativeProbabilityOfParentOStatesForRegion(seqP_region, child_state, branch_length_to_observation, end_pos); + + // Case 1: Last observation was this side of the root node + if(seqP_region->plength_observation2root < 0) { + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + RealNumType prob = weight_vector[parent_state]; + if(child_state != parent_state) { + num_substitutions[end_pos] += prob; + waiting_times[end_pos * num_states_ + parent_state] += prob * branch_length_to_observation/2; + waiting_times[end_pos * num_states_ + child_state] += prob * branch_length_to_observation/2; + } else { + waiting_times[end_pos * num_states_ + parent_state] += prob * branch_length_to_observation; + } + } + } else { + // Case 2: Last observation was the other side of the root. + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + RealNumType prob = weight_vector[parent_state]; + updateCountsAndWaitingTimesAcrossRoot(end_pos, parent_state, child_state, dist_to_root, dist_to_observed, waiting_times, num_substitutions, prob); + } + } + } else if(seqP_region->type == TYPE_O && seqC_region->type == TYPE_O) { + // Get weight vector giving the relative probabilities of observing + // each state at each of the O nodes. + std::vector weight_vector = getRelativeProbabilityOfParentOChildOStatesForRegion(seqP_region, seqC_region, branch_length_to_observation, end_pos); + + // Case 1: Last observation was this side of the root node + if(seqP_region->plength_observation2root < 0) { + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType prob = weight_vector[row_index[parent_state] + child_state]; + if(child_state != parent_state) { + num_substitutions[end_pos] += prob; + waiting_times[end_pos * num_states_ + parent_state] += prob * branch_length_to_observation/2; + waiting_times[end_pos * num_states_ + child_state] += prob * branch_length_to_observation/2; + } else { + waiting_times[end_pos * num_states_ + parent_state] += prob * branch_length_to_observation; + } + } + } + } else { + // Case 2: Last observation was the other side of the root. + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType prob = weight_vector[row_index[parent_state] + child_state]; + updateCountsAndWaitingTimesAcrossRoot(end_pos, parent_state, child_state, dist_to_root, dist_to_observed, waiting_times, num_substitutions, prob); + } + } + } + } pos = end_pos + 1; } } + // calculate site-rate as number of substitutions at site / expected rate of no substitution (according to genome-wide rates). RealNumType rate_count = 0; for(int i = 0; i < genome_size; i++) { - if(num_substitutions[i] == 0) { - rates[i] = 0.001; - } else { - RealNumType expected_rate_no_substitution = 0; - for(int j = 0; j < num_states_; j++) { - RealNumType summand = waiting_times[i * num_states_ + j] * abs(diagonal_mut_mat[j]); - expected_rate_no_substitution += summand; - } - if(expected_rate_no_substitution <= 0.01) { - rates[i] = 1.; - } else { - rates[i] = num_substitutions[i] / expected_rate_no_substitution; - } + RealNumType expected_rate_no_substitution = 0; + for(int j = 0; j < num_states_; j++) { + RealNumType summand = waiting_times[i * num_states_ + j] * abs(diagonal_mut_mat[j]); + expected_rate_no_substitution += summand; } + rates[i] = (num_substitutions[i]+1) / (expected_rate_no_substitution+1); rate_count += rates[i]; } + // normalise so average rate is 1. RealNumType average_rate = rate_count / genome_size; for(int i = 0; i < genome_size; i++) { - rates[i] /= average_rate; - rates[i] = MIN(100.0, MAX(0.0001, rates[i])); + rates[i] /= average_rate; + rates[i] = std::min(250.0, std::max(0.0001, rates[i])); for(int stateA = 0; stateA < num_states_; stateA++) { RealNumType row_sum = 0; for(int stateB = 0; stateB < num_states_; stateB++) { if(stateA != stateB) { - RealNumType val = mutation_matrices[i * mat_size + (stateB + row_index[stateA])] * rates[i]; + RealNumType val = mutation_mat[stateB + row_index[stateA]] * rates[i]; mutation_matrices[i * mat_size + (stateB + row_index[stateA])] = val; transposed_mutation_matrices[i * mat_size + (stateA + row_index[stateB])] = val; freqi_freqj_Qijs[i * mat_size + (stateB + row_index[stateA])] = root_freqs[stateA] * inverse_root_freqs[stateB] * val; @@ -256,7 +406,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { for(int j = 0; j < num_states_; j++) { W[i * num_states_ + j] = 0; for(int k = 0; k < num_states_; k++) { - C[i * num_states_ + row_index[j] + k] = 0; + C[i * mat_size + row_index[j] + k] = 0; } } } @@ -306,11 +456,11 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { // distance to last observation or root if last observation was across the root. RealNumType branch_length_to_observation = blength; - if(seqP_region->plength_observation2node > 0 && seqP_region->plength_observation2root <= 0) { - branch_length_to_observation = blength + seqP_region->plength_observation2node; + if(seqP_region->plength_observation2node > 0 && seqP_region->plength_observation2root < 0) { + branch_length_to_observation += seqP_region->plength_observation2node; } - else if(seqP_region->plength_observation2root > 0) { - branch_length_to_observation = blength + seqP_region->plength_observation2root; + else if(seqP_region->plength_observation2root >= 0) { + branch_length_to_observation += seqP_region->plength_observation2root; } if(seqP_region->type == TYPE_R && seqC_region->type == TYPE_R) { @@ -320,12 +470,10 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { W[i * num_states_ + state] += branch_length_to_observation; } } else if(seqP_region->type == seqC_region->type && seqP_region->type < TYPE_R) { - // both states are equal but not of type REF - for(int i = pos; i <= end_pos; i++) { - W[i * num_states_ + seqP_region->type] += branch_length_to_observation; - } + // both states are equal but not of type REF or O + W[end_pos * num_states_ + seqP_region->type] += branch_length_to_observation; } else if(seqP_region->type <= TYPE_R && seqC_region->type <= TYPE_R) { - //states are not equal + //states are not equal but neither is O StateType stateA = seqP_region->type; StateType stateB = seqC_region->type; if(seqP_region->type == TYPE_R) { @@ -335,19 +483,17 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { stateB = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; } // Case 1: Last observation was this side of the root node - if(seqP_region->plength_observation2root <= 0) { - for(int i = pos; i <= end_pos; i++) { - W[i * num_states_ + stateA] += branch_length_to_observation/2; - W[i * num_states_ + stateB] += branch_length_to_observation/2; - C[i * mat_size + stateB + row_index[stateA]] += 1; - } + if(seqP_region->plength_observation2root < 0) { + W[end_pos * num_states_ + stateA] += branch_length_to_observation/2; + W[end_pos * num_states_ + stateB] += branch_length_to_observation/2; + C[end_pos * mat_size + stateB + row_index[stateA]] += 1; } else { // Case 2: Last observation was the other side of the root. // In this case there are two further cases - the mutation happened either side of the root. // We calculate the relative likelihood of each case and use this to weight waiting times etc. RealNumType dist_to_root = seqP_region->plength_observation2root + blength; RealNumType dist_to_observed = seqP_region->plength_observation2node; - updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C); + updateCountsAndWaitingTimesAcrossRoot(end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C); } } else if(seqP_region->type <= TYPE_R && seqC_region->type == TYPE_O) { StateType stateA = seqP_region->type; @@ -355,24 +501,9 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { stateA = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; } - // Calculate a weight vector giving the relative probabilities of observing + // Get weight vector giving the relative probabilities of observing // each state at the O node. - std::vector weight_vector(num_states_); - RealNumType sum = 0.0; - for(StateType stateB = 0; stateB < num_states_; stateB++) { - RealNumType likelihoodB = std::round(seqC_region->getLH(stateB) * 1000) / 1000.0; - if(stateB != stateA) { - RealNumType prob = likelihoodB * branch_length_to_observation * getMutationMatrixEntry(stateA, stateB, end_pos); - weight_vector[stateB] += prob; - sum += prob; - } else { - RealNumType prob = likelihoodB * (1 - branch_length_to_observation * getMutationMatrixEntry(stateB, stateB, end_pos)); - weight_vector[stateB] += prob; - sum += prob; - } - } - // Normalise weight vector - normalize_arr(weight_vector.data(), num_states_, sum); + std::vector weight_vector = getRelativeProbabilityOfChildOStatesForRegion(seqC_region, stateA, branch_length_to_observation, end_pos); // Case 1: Last observation was this side of the root node if(seqP_region->plength_observation2root < 0) { @@ -393,7 +524,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { RealNumType dist_to_observed = seqP_region->plength_observation2node; for(StateType stateB = 0; stateB < num_states_; stateB++) { RealNumType prob = weight_vector[stateB]; - updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); + updateCountsAndWaitingTimesAcrossRoot(end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); } } } else if(seqP_region->type == TYPE_O && seqC_region->type <= TYPE_R) { @@ -401,27 +532,11 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { if(seqC_region->type == TYPE_R) { stateB = tree->aln->ref_seq[static_cast::size_type>(end_pos)]; } - // Calculate a weight vector giving the relative probabilities of observing // each state at the O node. - std::vector weight_vector(num_states_); - RealNumType sum = 0.0; - for(StateType stateA = 0; stateA < num_states_; stateA++) { - RealNumType likelihoodA = std::round(seqP_region->getLH(stateA) * 1000)/1000.0; - if(stateA != stateB) { - RealNumType prob = likelihoodA * branch_length_to_observation * getMutationMatrixEntry(stateA, stateB, end_pos); - weight_vector[stateA] += prob; - sum += prob; - } else { - RealNumType prob = likelihoodA * (1 - branch_length_to_observation * getMutationMatrixEntry(stateA, stateA, end_pos)); - weight_vector[stateA] += prob; - sum += prob; - } - } - // Normalise weight vector - normalize_arr(weight_vector.data(), num_states_, sum); + std::vector weight_vector = getRelativeProbabilityOfParentOStatesForRegion(seqP_region, stateB, branch_length_to_observation, end_pos); - // Case 1: Last observation was this side of the root node + // Case 1: Last observation was this side of the root node if(seqP_region->plength_observation2root < 0) { for(StateType stateA = 0; stateA < num_states_; stateA++) { RealNumType prob = weight_vector[stateA]; @@ -440,31 +555,13 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { RealNumType dist_to_observed = seqP_region->plength_observation2node; for(StateType stateA = 0; stateA < num_states_; stateA++) { RealNumType prob = weight_vector[stateA]; - updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); - } + updateCountsAndWaitingTimesAcrossRoot(end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); + } } } else if(seqP_region->type == TYPE_O && seqC_region->type == TYPE_O) { - // Calculate a weight vector giving the relative probabilities of observing + // Get weight vector giving the relative probabilities of observing // each state at each of the O nodes. - std::vector weight_vector(num_states_ * num_states_); - RealNumType sum = 0.0; - for(StateType stateA = 0; stateA < num_states_; stateA++) { - RealNumType likelihoodA = std::round(seqP_region->getLH(stateA) * 1000)/1000.0; - for(StateType stateB = 0; stateB < num_states_; stateB++) { - RealNumType likelihoodB = std::round(seqC_region->getLH(stateB)*1000)/1000.0; - if(stateA != stateB) { - RealNumType prob = likelihoodA * likelihoodB * branch_length_to_observation * getMutationMatrixEntry(stateA, stateB, end_pos); - weight_vector[row_index[stateA] + stateB] += prob; - sum += prob; - } else { - RealNumType prob = likelihoodA * likelihoodB * (1 - branch_length_to_observation * getMutationMatrixEntry(stateA, stateA, end_pos)); - weight_vector[row_index[stateA] + stateB] += prob; - sum += prob; - } - } - } - // Normalise weight vector - normalize_arr(weight_vector.data(), num_states_*num_states_, sum); + std::vector weight_vector = getRelativeProbabilityOfParentOChildOStatesForRegion(seqP_region, seqC_region, branch_length_to_observation, end_pos); // Case 1: Last observation was this side of the root node if(seqP_region->plength_observation2root < 0) { @@ -482,17 +579,17 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { } } } else { - // Case 2: Last observation was the other side of the root. + // Case 2: Last observation was the other side of the root. RealNumType dist_to_root = seqP_region->plength_observation2root + blength; RealNumType dist_to_observed = seqP_region->plength_observation2node; for(StateType stateA = 0; stateA < num_states_; stateA++) { for(StateType stateB = 0; stateB < num_states_; stateB++) { RealNumType prob = weight_vector[row_index[stateA] + stateB]; - updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); + updateCountsAndWaitingTimesAcrossRoot(end_pos, stateA, stateB, dist_to_root, dist_to_observed, W, C, prob); } } } - } + } pos = end_pos + 1; } } @@ -611,7 +708,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { RealNumType val = mutation_matrices[i * mat_size + (stateB + row_index[stateA])]; //val /= average_rate; val /= total_rate; - val = MIN(250.0, MAX(0.001, val)); + val = std::min(250.0, std::max(0.001, val)); mutation_matrices[i * mat_size + (stateB + row_index[stateA])] = val; transposed_mutation_matrices[i * mat_size + (stateA + row_index[stateB])] = val; @@ -638,28 +735,35 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) { delete[] W; } -void ModelDNARateVariation::updateCountsAndWaitingTimesAcrossRoot( PositionType start, PositionType end, - StateType parent_state, StateType child_state, - RealNumType dist_to_root, RealNumType dist_to_observed, - RealNumType* waiting_times, RealNumType* counts, - RealNumType weight) +void ModelDNARateVariation::updateCountsAndWaitingTimesAcrossRoot( + PositionType genome_pos, + StateType parent_state, StateType child_state, + RealNumType dist_to_root, RealNumType dist_to_observed, + RealNumType* waiting_times, RealNumType* counts, + RealNumType weight) { if(parent_state != child_state) { - for(int i = start; i <= end; i++) { - RealNumType p_root_is_state_parent = root_freqs[parent_state] * getMutationMatrixEntry(parent_state, child_state, i) * dist_to_root; - RealNumType p_root_is_state_child = root_freqs[child_state] * getMutationMatrixEntry(child_state, parent_state, i) * dist_to_observed; - RealNumType relative_root_is_state_parent = p_root_is_state_parent / (p_root_is_state_parent + p_root_is_state_child); - waiting_times[i * num_states_ + parent_state] += weight * relative_root_is_state_parent * dist_to_root/2; - waiting_times[i * num_states_ + child_state] += weight * relative_root_is_state_parent * dist_to_root/2; - counts[i * mat_size + child_state + row_index[parent_state]] += weight * relative_root_is_state_parent; + RealNumType p_root_is_state_parent = root_freqs[parent_state] * getMutationMatrixEntry(parent_state, child_state, genome_pos) * dist_to_root; + RealNumType p_root_is_state_child = root_freqs[child_state] * getMutationMatrixEntry(child_state, parent_state, genome_pos) * dist_to_observed; + RealNumType relative_root_is_state_parent = p_root_is_state_parent / (p_root_is_state_parent + p_root_is_state_child); + + // only update waiting times for this side of the root + waiting_times[genome_pos * num_states_ + parent_state] += weight * relative_root_is_state_parent * dist_to_root/2; + waiting_times[genome_pos * num_states_ + child_state] += weight * relative_root_is_state_parent * dist_to_root/2; - RealNumType relative_root_is_state_child = 1 - relative_root_is_state_parent; - waiting_times[i * num_states_ + child_state] += weight * relative_root_is_state_child * dist_to_root; + // Counts array depends on model type: + // scalar rate variation has int per position + // matrix rate variaition has 4x4 matrix per position + int index = genome_pos; + if(!scalar_rate_model) { + index = genome_pos * mat_size + child_state + row_index[parent_state]; } + counts[index] += weight * relative_root_is_state_parent; + + RealNumType relative_root_is_state_child = 1 - relative_root_is_state_parent; + waiting_times[genome_pos * num_states_ + child_state] += weight * relative_root_is_state_child * dist_to_root; } else { - for(int i = start; i <= end; i++) { - waiting_times[i * num_states_ + child_state] += weight * dist_to_root; - } + waiting_times[genome_pos * num_states_ + child_state] += weight * dist_to_root; } } @@ -772,4 +876,88 @@ void ModelDNARateVariation::readRatesFile() { else { std::cerr << "Unable to open rate matrix file " << rates_filename << std::endl; } +} + +std::vector ModelDNARateVariation::getRelativeProbabilityOfParentOStatesForRegion( + const SeqRegion* seqP_region, + StateType child_state, + RealNumType branch_length_to_obs, + PositionType genome_pos) +{ + assert(seqP_region->type == TYPE_O); + std::vector weight_vector(num_states_); + RealNumType sum = 0.0; + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + RealNumType likelihood = seqP_region->getLH(parent_state); + RealNumType site_specific_mutation_rate = getMutationMatrixEntry(parent_state, child_state, genome_pos); + if(parent_state != child_state) { + RealNumType prob = likelihood * branch_length_to_obs * site_specific_mutation_rate; + weight_vector[parent_state] += prob; + sum += prob; + } else { + RealNumType prob = likelihood * (1 - branch_length_to_obs * site_specific_mutation_rate); + weight_vector[parent_state] += prob; + sum += prob; + } + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states_, sum); + return weight_vector; +} + +std::vector ModelDNARateVariation::getRelativeProbabilityOfChildOStatesForRegion( + const SeqRegion* seqC_region, + StateType parent_state, + RealNumType branch_length_to_obs, + PositionType genome_pos) +{ + assert(seqC_region->type == TYPE_O); + std::vector weight_vector(num_states_); + RealNumType sum = 0.0; + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType likelihood = seqC_region->getLH(child_state); + RealNumType site_specific_mutation_rate = getMutationMatrixEntry(parent_state, child_state, genome_pos); + if(parent_state != child_state) { + RealNumType prob = likelihood * branch_length_to_obs * site_specific_mutation_rate; + weight_vector[child_state] += prob; + sum += prob; + } else { + RealNumType prob = likelihood * (1 - branch_length_to_obs * site_specific_mutation_rate); + weight_vector[child_state] += prob; + sum += prob; + } + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states_, sum); + return weight_vector; +} + +std::vector ModelDNARateVariation::getRelativeProbabilityOfParentOChildOStatesForRegion( + const cmaple::SeqRegion* seqP_region, + const cmaple::SeqRegion* seqC_region, + RealNumType branch_length_to_obs, + PositionType genome_pos) +{ + assert(seqP_region->type == TYPE_O && seqC_region->type == TYPE_O); + std::vector weight_vector(num_states_ * num_states_); + RealNumType sum = 0.0; + for(StateType parent_state = 0; parent_state < num_states_; parent_state++) { + RealNumType parent_likelihood = seqP_region->getLH(parent_state); + for(StateType child_state = 0; child_state < num_states_; child_state++) { + RealNumType child_likelihood = seqC_region->getLH(child_state); + RealNumType site_specific_mutation_rate = getMutationMatrixEntry(parent_state, child_state, genome_pos); + if(parent_state != child_state) { + RealNumType prob = parent_likelihood * child_likelihood * branch_length_to_obs * site_specific_mutation_rate; + weight_vector[row_index[parent_state] + child_state] += prob; + sum += prob; + } else { + RealNumType prob = parent_likelihood * child_likelihood * (1 - branch_length_to_obs * site_specific_mutation_rate); + weight_vector[row_index[parent_state] + child_state] += prob; + sum += prob; + } + } + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states_*num_states_, sum); + return weight_vector; } \ No newline at end of file diff --git a/model/model_dna_rate_variation.h b/model/model_dna_rate_variation.h index 1be8bbed..80e753a5 100644 --- a/model/model_dna_rate_variation.h +++ b/model/model_dna_rate_variation.h @@ -12,8 +12,14 @@ class Tree; /** Class of DNA evolutionary models with rate variation */ class ModelDNARateVariation : public ModelDNA { public: - ModelDNARateVariation( const cmaple::ModelBase::SubModel sub_model, PositionType _genome_size, - bool _use_site_rates, cmaple::RealNumType _wt_pseudocount, std::string _rates_filename); + ModelDNARateVariation( + const cmaple::ModelBase::SubModel sub_model, + PositionType _genome_size, + bool _use_site_rates, + cmaple::RealNumType _wt_pseudocount, + std::string _rates_filename, + int _max_num_EM_steps); + virtual ~ModelDNARateVariation(); void estimateRates(cmaple::Tree* tree); @@ -79,7 +85,7 @@ class ModelDNARateVariation : public ModelDNA { private: - void updateCountsAndWaitingTimesAcrossRoot( PositionType start, PositionType end, + void updateCountsAndWaitingTimesAcrossRoot( PositionType genome_pos, StateType parent_state, StateType child_state, RealNumType dist_to_root, RealNumType dist_to_observed, RealNumType* waiting_times, RealNumType* counts, @@ -87,6 +93,19 @@ class ModelDNARateVariation : public ModelDNA { void readRatesFile(); + std::vector getRelativeProbabilityOfParentOStatesForRegion( const cmaple::SeqRegion* seqP_region, + StateType child_state, + RealNumType branch_length_to_obs, + PositionType genome_pos); + std::vector getRelativeProbabilityOfChildOStatesForRegion( const cmaple::SeqRegion* seqC_region, + StateType parent_state, + RealNumType branch_length_to_obs, + PositionType genome_pos); + std::vector getRelativeProbabilityOfParentOChildOStatesForRegion(const cmaple::SeqRegion* seqP_region, + const cmaple::SeqRegion* seqC_region, + RealNumType branch_length_to_obs, + PositionType genome_pos); + cmaple::PositionType genome_size; cmaple::RealNumType* mutation_matrices = nullptr; @@ -96,12 +115,13 @@ class ModelDNARateVariation : public ModelDNA { cmaple::RealNumType* freqj_transposedijs = nullptr; cmaple::RealNumType* rates = nullptr; uint16_t mat_size; - bool use_site_rates = false; + bool scalar_rate_model = false; bool rates_estimated = false; cmaple::RealNumType waiting_time_pseudocount; std::string rates_filename; + int max_num_EM_steps; }; } \ No newline at end of file diff --git a/tree/tree.cpp b/tree/tree.cpp index cb126d21..35c0806d 100644 --- a/tree/tree.cpp +++ b/tree/tree.cpp @@ -196,7 +196,8 @@ std::string cmaple::Tree::exportNewick(const TreeType tree_type, } std::string cmaple::Tree::exportNexus(const TreeType tree_type, - const bool show_branch_supports) { + const bool show_branch_supports, + const bool show_mutations) { assert(aln); assert(model); @@ -208,10 +209,10 @@ std::string cmaple::Tree::exportNexus(const TreeType tree_type, // output the tree according to its type switch (tree_type) { case BIN_TREE: - return exportNexus(true, show_branch_supports_checked); + return exportNexus(true, show_branch_supports_checked, show_mutations); // break; case MUL_TREE: - return exportNexus(false, show_branch_supports_checked); + return exportNexus(false, show_branch_supports_checked, show_mutations); // break; case UNKNOWN_TREE: default: @@ -1030,11 +1031,9 @@ void cmaple::Tree::applySPRTemplate( template void cmaple::Tree::doRateEstimationTemplate(std::ostream& out_stream) { - if(params->rate_variation || params->site_specific_rates) { + if(params->rate_variation || params->site_specific_rate_matrix) { ModelDNARateVariation* rvModel = (ModelDNARateVariation*) model; rvModel->estimateRates(this); - //rvModel->setAllMatricesToDefault(); - computeCumulativeRate(); } } @@ -1126,6 +1125,13 @@ void cmaple::Tree::optimizeTreeTopology(const TreeSearchType tree_search_type, // traverse the tree from root to try improvements on the entire tree RealNumType improvement = improveEntireTree(tree_search_type, short_range_search); + + if(params->estimate_rates_during_SPR && + (params->rate_variation || params->site_specific_rate_matrix)) + { + ModelDNARateVariation* rvModel = (ModelDNARateVariation*) model; + rvModel->estimateRates(this); + } // if only compute SPRTA (~ tree search type = FAST), stop searching further, one round is enough if (tree_search_type == FAST_TREE_SEARCH) @@ -1154,6 +1160,13 @@ void cmaple::Tree::optimizeTreeTopology(const TreeSearchType tree_search_type, << "Tree log likelihood: " << computeLh() << std::endl; } + if(params->estimate_rates_during_SPR && + (params->rate_variation || params->site_specific_rate_matrix)) + { + ModelDNARateVariation* rvModel = (ModelDNARateVariation*) model; + rvModel->estimateRates(this); + } + // stop trying if the improvement is so small if (improvement < params->thresh_entire_tree_improvement) { break; @@ -1418,7 +1431,8 @@ std::string cmaple::Tree::exportNodeString(const bool is_newick_format, const bool binary, const NumSeqsType node_vec_index, const bool print_internal_id, - const bool show_branch_supports) { + const bool show_branch_supports, + const bool show_mutations) { string internal_name = ""; PhyloNode& node = nodes[node_vec_index]; string output = "("; @@ -1499,6 +1513,15 @@ std::string cmaple::Tree::exportNodeString(const bool is_newick_format, } } } + + if(show_mutations) + { + std::string mutation_string = getMutationStringForNode(node); + if(mutation_string.size() > 0) + { + annotation_vec.push_back("mutationsInf={" + mutation_string + "}"); + } + } if (sh_alrt_str.length() > 0) { @@ -1564,11 +1587,11 @@ std::string cmaple::Tree::exportNodeString(const bool is_newick_format, output += exportNodeString(is_newick_format, binary, node.getNeighborIndex(RIGHT).getVectorIndex(), - print_internal_id, show_branch_supports); + print_internal_id, show_branch_supports, show_mutations); output += ","; output += exportNodeString(is_newick_format, binary, node.getNeighborIndex(LEFT).getVectorIndex(), - print_internal_id, show_branch_supports); + print_internal_id, show_branch_supports, show_mutations); } // output SH-aLRT in newick string @@ -1590,6 +1613,296 @@ std::string cmaple::Tree::exportNodeString(const bool is_newick_format, return output; } +std::string cmaple::Tree::getMutationStringAcrossRoot( + RealNumType dist_to_root, + RealNumType dist_to_observed, + StateType parent_state, + StateType child_state, + PositionType pos, + RealNumType weight, + SeqRegion::SeqType seq_type +) +{ + assert(parent_state != child_state); + std::string mutation_string = ""; + RealNumType p_root_is_state_parent = model->root_freqs[parent_state] * model->getMutationMatrixEntry(parent_state, child_state, pos) * dist_to_root; + RealNumType p_root_is_state_child = model->root_freqs[child_state] * model->getMutationMatrixEntry(child_state, parent_state, pos) * dist_to_observed; + RealNumType relative_root_is_state_parent = p_root_is_state_parent / (p_root_is_state_parent + p_root_is_state_child); + RealNumType parent_support = weight * relative_root_is_state_parent; + RealNumType child_support = weight * (1 - relative_root_is_state_parent); + + if(parent_support >= min_mutation_support) { + mutation_string += aln->convertState2Char(parent_state, seq_type) + + std::to_string(pos+1) + + aln->convertState2Char(child_state, seq_type) + + ":" + std::to_string(parent_support) + ","; + } + if(child_support >= min_mutation_support) { + mutation_string += aln->convertState2Char(child_state, seq_type) + + std::to_string(pos+1) + + aln->convertState2Char(parent_state, seq_type) + + ":" + std::to_string(child_support) + ","; + } + return mutation_string; +} + +std::string cmaple::Tree::getMutationStringForNode(cmaple::PhyloNode& node) +{ + + std::string mutation_string = ""; + RealNumType blength = node.getUpperLength(); + if(blength <= 0.) { + return mutation_string; + } + + PositionType genome_size = aln->ref_seq.size(); + SeqRegion::SeqType seq_type = aln->getSeqType(); + StateType num_states = aln->num_states; + + Index parent_index = node.getNeighborIndex(TOP); + PhyloNode& parent_node = nodes[parent_index.getVectorIndex()]; + const std::unique_ptr& parent_regions = parent_node.getPartialLh(parent_index.getMiniIndex()); + // Note: getPartialLh is not const so cannot const cmaple::PhyloNode& node + const std::unique_ptr& child_regions = node.getPartialLh(TOP); + + PositionType pos = 0; + const SeqRegions& seqP_regions = *parent_regions; + const SeqRegions& seqC_regions = *child_regions; + size_t iseq1 = 0; + size_t iseq2 = 0; + + while(pos < genome_size) + { + PositionType end_pos = 0; + SeqRegions::getNextSharedSegment(pos, seqP_regions, seqC_regions, iseq1, iseq2, end_pos); + const auto* seqP_region = &seqP_regions[iseq1]; + const auto* seqC_region = &seqC_regions[iseq2]; + + // distance to last observation or root if last observation was across the root. + RealNumType branch_length_to_observation = blength; + if(seqP_region->plength_observation2node > 0 && seqP_region->plength_observation2root < 0) { + branch_length_to_observation += seqP_region->plength_observation2node; + } + if(seqC_region->plength_observation2node > 0 && seqC_region->plength_observation2root < 0) { + branch_length_to_observation += seqC_region->plength_observation2node; + } + + RealNumType blength_weight = blength / branch_length_to_observation; + if(seqP_region->type != seqC_region->type && + seqP_region->type <= TYPE_R && + seqC_region->type <= TYPE_R) + { + StateType stateA = seqP_region->type; + StateType stateB = seqC_region->type; + if(seqP_region->type == TYPE_R) + { + stateA = aln->ref_seq[static_cast::size_type>(pos)]; + } + if(seqC_region->type == TYPE_R) + { + stateB = aln->ref_seq[static_cast::size_type>(pos)]; + } + if(seqP_region->plength_observation2root < 0) + { + if(blength_weight >= min_mutation_support) { + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos+1) + + aln->convertState2Char(stateB, seq_type) + + ":" + std::to_string(blength_weight) + ","; + } + } else { + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + mutation_string += getMutationStringAcrossRoot(dist_to_root, dist_to_observed, stateA, stateB, pos, blength_weight, seq_type); + } + } + else if(seqP_region->type <= TYPE_R && seqC_region->type == TYPE_O) + { + StateType stateA = seqP_region->type; + if(seqP_region->type == TYPE_R) + { + stateA = aln->ref_seq[static_cast::size_type>(pos)]; + } + // Calculate a weight vector giving the relative probabilities of observing + // each state at the O node. + std::vector weight_vector(num_states); + RealNumType sum = 0.0; + for(StateType stateB = 0; stateB < num_states; stateB++) + { + RealNumType likelihoodB = seqC_region->getLH(stateB); + RealNumType prob = 0.; + if(stateB != stateA) + { + prob = likelihoodB * branch_length_to_observation + * model->getMutationMatrixEntry(stateA, stateB, pos); + } else { + prob = likelihoodB * (1 - branch_length_to_observation + * model->getMutationMatrixEntry(stateB, stateB, pos)); + } + weight_vector[stateB] = prob; + sum += prob; + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states, sum); + + // write out mutations + if(seqP_region->plength_observation2root < 0) { + for(StateType stateB = 0; stateB < num_states; stateB++) + { + if (stateB != stateA) + { + RealNumType mutation_support = weight_vector[stateB] * blength_weight; + if(mutation_support >= min_mutation_support) + { + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos+1) + + aln->convertState2Char(stateB, seq_type) + ":" + + std::to_string(mutation_support) + "," ; + } + } + } + } else { + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType stateB = 0; stateB < num_states; stateB++) + { + if(stateA != stateB) { + RealNumType relative_stateB = weight_vector[stateB]; + mutation_string += getMutationStringAcrossRoot(dist_to_root, dist_to_observed, stateA, stateB, pos, relative_stateB * blength_weight, seq_type); + } + } + } + } + else if(seqP_region->type == TYPE_O && seqC_region->type <= TYPE_R) + { + StateType stateB = seqC_region->type; + if(seqC_region->type == TYPE_R) + { + stateB = aln->ref_seq[static_cast::size_type>(pos)]; + } + // Calculate a weight vector giving the relative probabilities of observing + // each state at the O node. + std::vector weight_vector(num_states); + RealNumType sum = 0.0; + for(StateType stateA = 0; stateA < num_states; stateA++) + { + RealNumType likelihoodA = seqP_region->getLH(stateA); + RealNumType prob = 0; + if(stateA != stateB) + { + prob = likelihoodA * branch_length_to_observation + * model->getMutationMatrixEntry(stateA, stateB, pos); + } else { + prob = likelihoodA * (1 - branch_length_to_observation + * model->getMutationMatrixEntry(stateA, stateA, pos)); + } + weight_vector[stateA] = prob; + sum += prob; + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states, sum); + + // write out mutations + if(seqP_region->plength_observation2root < 0) + { + for(StateType stateA = 0; stateA < num_states; stateA++) + { + if (stateA != stateB) + { + RealNumType mutation_support = weight_vector[stateA] * blength_weight; + if(mutation_support >= min_mutation_support) + { + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos + 1) + + aln->convertState2Char(stateB, seq_type) + ":" + + std::to_string(mutation_support) + "," ; + } + } + } + } else { + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType stateA = 0; stateA < num_states; stateA++) + { + if(stateA != stateB) { + RealNumType relative_stateA = weight_vector[stateA]; + mutation_string += getMutationStringAcrossRoot(dist_to_root, dist_to_observed, stateA, stateB, pos, relative_stateA * blength_weight, seq_type); + } + } + } + } + else if(seqP_region->type == TYPE_O && seqC_region->type == TYPE_O) + { + // Calculate a weight vector giving the relative probabilities of observing + // each state at each of the O nodes. + std::vector weight_vector(num_states * num_states); + RealNumType sum = 0.0; + for(StateType stateA = 0; stateA < num_states; stateA++) + { + RealNumType likelihoodA = seqP_region->getLH(stateA); + for(StateType stateB = 0; stateB < num_states; stateB++) + { + RealNumType likelihoodB = seqC_region->getLH(stateB); + RealNumType prob = 0.; + if(stateA != stateB) { + prob = likelihoodA * likelihoodB * branch_length_to_observation + * model->getMutationMatrixEntry(stateA, stateB, pos); + } else { + prob = likelihoodA * likelihoodB * (1 - branch_length_to_observation + * model->getMutationMatrixEntry(stateA, stateA, pos)); + } + weight_vector[model->row_index[stateA] + stateB] = prob; + sum += prob; + } + } + // Normalise weight vector + normalize_arr(weight_vector.data(), num_states * num_states, sum); + + // write out mutations + if(seqP_region->plength_observation2root < 0) { + for(StateType stateA = 0; stateA < num_states; stateA++) + { + for(StateType stateB = 0; stateB < num_states; stateB++) + { + if (stateA != stateB) + { + RealNumType mutation_support = weight_vector[model->row_index[stateA] + stateB] * blength_weight; + if(mutation_support >= min_mutation_support) + { + mutation_string += aln->convertState2Char(stateA, seq_type) + + std::to_string(pos + 1) + + aln->convertState2Char(stateB, seq_type) + ":" + + std::to_string(mutation_support) + "," ; + } + } + } + } + } else { + RealNumType dist_to_root = seqP_region->plength_observation2root + blength; + RealNumType dist_to_observed = seqP_region->plength_observation2node; + for(StateType stateA = 0; stateA < num_states; stateA++) + { + for(StateType stateB = 0; stateB < num_states; stateB++) + { + if(stateA != stateB) { + RealNumType relative_stateAB = weight_vector[model->row_index[stateA] + stateB]; + mutation_string += getMutationStringAcrossRoot(dist_to_root, dist_to_observed, stateA, stateB, pos, relative_stateAB * blength_weight, seq_type); + } + } + } + } + } + pos = end_pos + 1; + } + + if(mutation_string.size() > 0) + { + // remove trailing comma + mutation_string = mutation_string.substr(0, mutation_string.size()-1); + } + return mutation_string; +} + std::string cmaple::Tree::exportNewick(const bool binary, const bool print_internal_id, const bool show_branch_supports) { @@ -1611,7 +1924,8 @@ std::string cmaple::Tree::exportNewick(const bool binary, } std::string cmaple::Tree::exportNexus(const bool binary, - const bool show_branch_supports) { + const bool show_branch_supports, + const bool show_mutations) { assert(annotations.size() == nodes.size()); // make sure tree is not empty @@ -1690,7 +2004,7 @@ std::string cmaple::Tree::exportNexus(const bool binary, return pre_output + convertIntToString(seq_names.size()) + mid_output_1 + list_leaf_names + mid_output_2 - + exportNodeString(false, binary, root_vector_index, true, show_branch_supports) + + exportNodeString(false, binary, root_vector_index, true, show_branch_supports, show_mutations) + ";" + post_output; } @@ -6753,7 +7067,6 @@ void calculateSampleCost_R_O(const SeqRegion& seq1_region, RealNumType& total_factor, const ModelBase* model) { PositionType pos = seq2_region.position; - assert(seq1_region.position == seq2_region.position); if (seq1_region.plength_observation2root >= 0) { RealNumType total_blength = seq1_region.plength_observation2root + blength; diff --git a/tree/tree.h b/tree/tree.h index 09401434..be5d8e5c 100644 --- a/tree/tree.h +++ b/tree/tree.h @@ -293,6 +293,12 @@ class Tree { */ std::unique_ptr& getPartialLhAtNode(const cmaple::Index index); + /** + Compute cumulative rate of the ref genome + @throw std::logic\_error if the reference genome is empty + */ + void computeCumulativeRate(); + // ----------------- END OF PUBLIC APIs ------------------------------------ // // @@ -446,12 +452,14 @@ class Tree { * (bifurcating tree), MUL_TREE (multifurcating tree) * @param[in] show_branch_supports TRUE to output the branch supports (aLRT-SH * values) + * @param[in] show_mutations TRUE to output estimated mutations along tree * @return A tree string in NEXUS format * @throw std::invalid\_argument if any of the following situations occur. * - tree\_type is unknown */ std::string exportNexus(const TreeType tree_type = BIN_TREE, - const bool show_branch_supports = true); + const bool show_branch_supports = true, + const bool show_mutations = false); /** Export a TSV file that contains useful information from SPRTA @@ -461,6 +469,26 @@ class Tree { /*! \endcond */ private: + /** + * Get mutation string for MATs + */ + std::string getMutationStringForNode(cmaple::PhyloNode& node); + + /** + * Minimum support for writing a mutation to MAT + */ + RealNumType min_mutation_support = 0.01; + /** + * Helper function when last observation of state goes across the root. + */ + std::string getMutationStringAcrossRoot( + RealNumType dist_to_root, + RealNumType dist_to_observed, + StateType parent_state, + StateType child_state, + PositionType pos, + RealNumType weight, + SeqRegion::SeqType seq_type ); /** Pointer to LoadTree method */ @@ -621,12 +649,6 @@ class Tree { Model* model, std::unique_ptr&& params); - /** - Compute cumulative rate of the ref genome - @throw std::logic\_error if the reference genome is empty - */ - void computeCumulativeRate(); - /*! Optimize the tree topology @throw std::logic\_error if unexpected values/behaviors found during the operations @@ -1698,7 +1720,8 @@ bool isDiffFromOrigPlacement( const bool binary, const cmaple::NumSeqsType node_vec_index, const bool print_internal_id, - const bool show_branch_supports); + const bool show_branch_supports, + const bool show_mutations = false); /** Export string of an alternative branch (for SPRTA) @@ -1799,7 +1822,8 @@ bool isDiffFromOrigPlacement( Export tree std::string in NEXUS format */ std::string exportNexus(const bool binary, - const bool show_branch_supports); + const bool show_branch_supports, + const bool show_mutations); /** Traverse the tree to export TSV content diff --git a/utils/tools.cpp b/utils/tools.cpp index 65c334be..b044de03 100644 --- a/utils/tools.cpp +++ b/utils/tools.cpp @@ -614,6 +614,7 @@ cmaple::Params::Params() { make_consistent = false; print_internal_ids = false; output_NEXUS = false; + output_MAT = false; ignore_input_annotations = false; allow_rerooting = true; compute_SPRTA = false; @@ -623,7 +624,9 @@ cmaple::Params::Params() { min_support_alt_branches = 0.01; thresh_loglh_optimal_diff_fac = 1.0; rate_variation = false; - site_specific_rates = false; + site_specific_rate_matrix = false; + rate_variation_max_num_EM_steps = 20; + estimate_rates_during_SPR = false; wt_pseudocount = 0.1; rates_filename = ""; @@ -822,6 +825,13 @@ void cmaple::parseArg(int argc, char* argv[], Params& params) { continue; } + if (strcmp(argv[cnt], "--estimate-MAT") == 0 || + strcmp(argv[cnt], "-estimate-MAT") == 0) { + + params.output_MAT = true; + + continue; + } if (strcmp(argv[cnt], "--out-internal") == 0 || strcmp(argv[cnt], "-out-int") == 0) { @@ -1276,8 +1286,10 @@ void cmaple::parseArg(int argc, char* argv[], Params& params) { continue; } if (strcmp(argv[cnt], "--site-specific-rates") == 0 || + strcmp(argv[cnt], "--site-specific-rate-matrix") == 0 || + strcmp(argv[cnt], "--site-specific-matrix") == 0 || strcmp(argv[cnt], "-ssr") == 0) { - params.site_specific_rates = true; + params.site_specific_rate_matrix = true; continue; } @@ -1294,6 +1306,24 @@ void cmaple::parseArg(int argc, char* argv[], Params& params) { continue; } + if (strcmp(argv[cnt], "--rv-max-EM-steps") == 0) { + cnt++; + if (cnt >= argc) { + outError("Use --rv-max-EM-steps "); + } + try { + params.rate_variation_max_num_EM_steps = convert_int(argv[cnt]); + } catch (std::invalid_argument e) { + outError(e.what()); + } + continue; + } + + if (strcmp(argv[cnt], "--estimate-rates-during-SPR") == 0) { + params.estimate_rates_during_SPR = true; + continue; + } + if (strcmp(argv[cnt], "--rates-filename") == 0) { cnt++; if (cnt >= argc) { @@ -1370,8 +1400,8 @@ void cmaple::parseArg(int argc, char* argv[], Params& params) { "if SPRTA is not computed. Please use " "`--sprta` if you want to compute SPRTA."); } - if(params.rate_variation && params.site_specific_rates) { - outError("Unable to use rate variation and site-specific rate matrices.\n" + if(params.rate_variation && params.site_specific_rate_matrix) { + outError("Unable to use rate-variation and site-specific rate matrices.\n" "Please choose either:\n\t \"--rate-variation\" for a rate multiplier at each genomic site, or \n" "\t\"--site-specific-rates\" for an independent rate matrix at each genomic site."); } @@ -1467,6 +1497,9 @@ void cmaple::usage_cmaple() { << " --mean-subs Specify the mean #substitutions per site" << endl << " that CMAPLE is effective. Default: 0.02." << endl + << " --estimate-MAT Write a mutation-annotated tree (MAT) to " << endl + << " nexus file." + << endl << " --seed Set a seed number for random generators." << endl << " -v Set the verbose mode " @@ -1500,6 +1533,17 @@ void cmaple::usage_cmaple() { << endl << " alternative SPRs." << endl + << "RATE VARIATION MODELS:" << endl + << " --rate-variation Use a model of rate variation where each site " << endl + << " has an independent scalar rate multiplier." << endl + << " --site-specific-rate-matrix Use a model of rate variation where each site " << endl + << " has an independent rate matrix." << endl + << " --estimate-rates-during-SPR Re-estimate rates after every SPR tree traversal " << endl + << " (default: only after initial tree construction)." << endl + << " --waiting-time-pseudocount Set the waiting-time pseudocount (default: 1)." << endl + << " --rv-max-EM-steps . Maximum number of steps to attempt for EM " << endl + << " convergence when estimating rates with " << endl + << " --site-specific-rate-matrix (default: 20)." << endl << endl; exit(0); diff --git a/utils/tools.h b/utils/tools.h index 2cb98ff8..42f8e2de 100644 --- a/utils/tools.h +++ b/utils/tools.h @@ -665,7 +665,20 @@ class Params { /** * TRUE to allow an independent rate matrix for each genomic site. */ - bool site_specific_rates; + bool site_specific_rate_matrix; + + /** + * Maximum number of EM iterations for estimating rates. + */ + int rate_variation_max_num_EM_steps; + + /** + * When to perform rate estimation for rate variation and + * site-specific matrix models. + * If true then estimate rates after each SPR tree traversal. + * Otherwise only estimate rates after initial sample placement. + */ + bool estimate_rates_during_SPR; /** * Name of file containing rates for each genomic site. @@ -689,6 +702,11 @@ class Params { */ bool output_NEXUS; + /** + * TRUE to also output MAT in nexus format + */ + bool output_MAT; + /** * TRUE to compute the SPRTA branch supports */