diff --git a/c/tskit/core.c b/c/tskit/core.c index 5e5f828943..103bf041a8 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -853,7 +853,7 @@ tsk_blkalloc_free(tsk_blkalloc_t *self) } } -/* Mirrors the semantics of numpy's searchsorted function. Uses binary +/* Mirrors the semantics of numpy's searchsorted function (side='left'). Uses binary * search to find the index of the closest value in the array. */ tsk_size_t tsk_search_sorted(const double *restrict array, tsk_size_t size, double value) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index f00bb83d28..e1b4ff6efb 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3461,6 +3461,319 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl return ret; } +// TODO: breaks at 0,1/0,1 -- span is into negative, no need to 1/2 stat. +// TODO: support is bound or as written in test? +static double +integrate_stat_over_bin( + const interval_t i1, const interval_t i2, double bl, double br, double stat) +{ + interval_t support = { i2.left - i1.right, i2.right - i1.left }; + double r2_len = fmin(i1.right - i1.left, i2.right - i2.left); + // Size of the center region is determined by the larger of the two + // intervals. It is zero if they are equal (triangle). + double r2_l_bound = fmin(i2.left - i1.left, i2.right - i1.right); + double r2_r_bound = support.right - r2_len; + // left and right values for each of the 3 regions to integrate over + // variable names are: r{region}_{left|right} + double r1_l = fmin(fmax(bl, support.left), r2_l_bound); + double r1_r = fmax(fmin(br, r2_l_bound), support.left); + double r2_l = fmin(fmax(bl, r2_l_bound), r2_r_bound); + double r2_r = fmax(fmin(br, r2_r_bound), r2_l_bound); + double r3_l = fmin(fmax(bl, r2_r_bound), support.right); + double r3_r = fmax(fmin(br, support.right), r2_r_bound); + double i1_span = i1.right - i1.left; + double i2_span = i2.right - i2.left; + + return stat / (i1_span * i2_span) + * (-1. / 2 * (r1_l - r1_r) * (2 * i1.right - 2 * i2.left + r1_l + r1_r) + + (r2_r - r2_l) * r2_len + + 1. / 2 * (r3_l - r3_r) * (2 * i1.left - 2 * i2.right + r3_l + r3_r)); +} + +static int +tsk_treeseq_two_locus_branch_decay_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + sample_count_stat_params_t *f_params, const double *bins, tsk_size_t num_bins, + const double *restrict breakpoints, double *result, tsk_size_t *bincount) +{ + int ret = 0; + const tsk_size_t num_nodes = self->tables->nodes.num_rows; + interval_t bounds, ivl_l, ivl_r; + iter_state l_state, r_state; + double *result_tmp = NULL, *result_row; + tsk_bitset_t node_samples, sample_sets_bits; + tsk_size_t i, j, k, bin_l, bin_r, *bincount_row; + + tsk_memset(&sample_sets_bits, 0, sizeof(sample_sets_bits)); + tsk_memset(&node_samples, 0, sizeof(node_samples)); + tsk_memset(&l_state, 0, sizeof(l_state)); + tsk_memset(&r_state, 0, sizeof(r_state)); + result_tmp = tsk_malloc(result_dim * sizeof(*result_tmp)); + if (result_tmp == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + ret = iter_state_init(&l_state, self, state_dim); + if (ret != 0) { + goto out; + } + ret = iter_state_init(&r_state, self, state_dim); + if (ret != 0) { + goto out; + } + ret = sample_sets_to_bitset( + self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits); + if (ret != 0) { + goto out; + } + ret = get_node_samples(self, state_dim, &sample_sets_bits, &node_samples); + if (ret != 0) { + goto out; + } + iter_state_clear(&l_state, state_dim, num_nodes, &node_samples); + for (i = 0; i < self->num_trees; i++) { + tsk_memset(result_tmp, 0, result_dim * sizeof(*result_tmp)); + iter_state_clear(&r_state, state_dim, num_nodes, &node_samples); + ret = advance_collect_edges(&l_state, (tsk_id_t) i); + if (ret != 0) { + goto out; + } + ivl_l = l_state.tree.tree_pos.interval; + ret = compute_two_tree_branch_stat( + self, &r_state, &l_state, f, f_params, result_dim, state_dim, result_tmp); + if (ret != 0) { + goto out; + } + for (j = i; j < self->num_trees; j++) { + ivl_r = (interval_t){ breakpoints[j], breakpoints[j + 1] }; + bounds = (interval_t){ fmax(bins[0], ivl_r.left - ivl_l.right), + fmin(bins[num_bins - 1], ivl_r.right - ivl_l.left) }; + if (bounds.left > bins[num_bins - 1] || bounds.right < bins[0]) { + continue; + } + bin_l = tsk_search_sorted(bins + 1, num_bins - 1, bounds.left); + bin_r = tsk_search_sorted(bins + 1, num_bins - 1, bounds.right); + if (bin_l + 1 <= bin_r && bins[bin_l + 1] == bounds.left) { + bin_l += 1; + } + ret = advance_collect_edges(&r_state, (tsk_id_t) j); + if (ret != 0) { + goto out; + } + ret = compute_two_tree_branch_stat(self, &l_state, &r_state, f, f_params, + result_dim, state_dim, result_tmp); + if (ret != 0) { + goto out; + } + do { + result_row = GET_2D_ROW(result, result_dim, bin_l); + bincount_row = GET_2D_ROW(bincount, result_dim, bin_l); + for (k = 0; k < result_dim; k++) { + double val = integrate_stat_over_bin( + ivl_l, ivl_r, bins[bin_l], bins[bin_l + 1], result_tmp[k]); + if (tsk_isnan(val)) { + continue; + } + result_row[k] += val; + bincount_row[k] += 1; + } + bin_l++; + } while (bin_l <= bin_r); + } + } +out: + tsk_safe_free(result_tmp); + iter_state_free(&l_state); + iter_state_free(&r_state); + tsk_bitset_free(&node_samples); + tsk_bitset_free(&sample_sets_bits); + return ret; +} + +static int +tsk_treeseq_two_locus_site_decay_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + sample_count_stat_params_t *f_params, norm_func_t *norm_f, const double *bins, + tsk_size_t num_bins, const double *restrict sites_position, tsk_flags_t options, + double *result, tsk_size_t *bincount) +{ + int ret = 0; + tsk_bitset_t allele_samples, allele_sample_sets; + bool polarised = options & TSK_STAT_POLARISED; + tsk_id_t *sites; + tsk_size_t i, j, k, bin, n_sites, *bincount_row; + double dist, *result_row, *result_tmp = NULL; + const tsk_size_t num_samples = self->num_samples; + tsk_size_t *num_alleles = NULL, *site_offsets = NULL, *allele_counts = NULL; + tsk_size_t max_ss_size = 0, max_alleles = 0, n_alleles = 0; + two_locus_work_t work; + + tsk_memset(&allele_samples, 0, sizeof(allele_samples)); + n_sites = self->tables->sites.num_rows; + sites = tsk_malloc(n_sites * sizeof(*sites)); + if (sites == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + for (i = 0; i < n_sites; i++) { + sites[i] = (tsk_id_t) i; + } + // depends on n_sites + num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles)); + site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets)); + result_tmp = tsk_malloc(result_dim * sizeof(*result_tmp)); + if (num_alleles == NULL || site_offsets == NULL || result_tmp == NULL + || bincount == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + for (i = 0; i < n_sites; i++) { + site_offsets[i] = n_alleles * num_sample_sets; + n_alleles += self->site_mutations_length[sites[i]] + 1; + max_alleles = TSK_MAX(self->site_mutations_length[sites[i]], max_alleles); + } + max_alleles++; // add 1 for the ancestral allele + // depends on n_alleles + ret = tsk_bitset_init(&allele_samples, num_samples, n_alleles); + if (ret != 0) { + goto out; + } + for (i = 0; i < num_sample_sets; i++) { + max_ss_size = TSK_MAX(sample_set_sizes[i], max_ss_size); + } + // depend on n_alleles and max_ss_size + ret = tsk_bitset_init(&allele_sample_sets, max_ss_size, n_alleles * num_sample_sets); + if (ret != 0) { + goto out; + } + allele_counts = tsk_calloc(n_alleles * num_sample_sets, sizeof(*allele_counts)); + if (allele_counts == NULL) { + ret = tsk_trace_error(TSK_ERR_NO_MEMORY); + goto out; + } + // depends on max_ss_size and max_alleles + ret = two_locus_work_init(max_alleles, max_ss_size, result_dim, state_dim, &work); + if (ret != 0) { + goto out; + } + // we track the number of alleles to account for backmutations + ret = get_mutation_samples(self, sites, n_sites, num_alleles, &allele_samples); + if (ret != 0) { + goto out; + } + get_mutation_sample_sets(&allele_samples, num_sample_sets, sample_set_sizes, + sample_sets, self->sample_index_map, &allele_sample_sets, allele_counts); + for (i = 0; i < n_sites; i++) { + for (j = i + 1; j < n_sites; j++) { + dist = sites_position[j] - sites_position[i]; + if (dist >= bins[num_bins - 1]) { // right open + break; + } + if (dist < bins[0]) { // left closed + continue; + } + bin = tsk_search_sorted(bins, num_bins, dist); + bin = bins[bin] > dist ? bin - 1 : bin; // left closed intervals + result_row = GET_2D_ROW(result, result_dim, bin); + bincount_row = GET_2D_ROW(bincount, result_dim, bin); + if (num_alleles[i] == 2 && num_alleles[j] == 2) { + // both sites are biallelic + ret = compute_general_two_site_stat_result(&allele_sample_sets, + allele_counts, site_offsets[i], site_offsets[j], state_dim, + result_dim, f, f_params, &work, result_tmp); + } else { + // at least one site is multiallelic + ret = compute_general_normed_two_site_stat_result(&allele_sample_sets, + allele_counts, site_offsets[i], site_offsets[j], num_alleles[i], + num_alleles[j], state_dim, result_dim, f, f_params, norm_f, + polarised, &work, result_tmp); + } + if (ret != 0) { + goto out; + } + for (k = 0; k < result_dim; k++) { + if (tsk_isnan(result_tmp[k])) { + continue; + } + result_row[k] += result_tmp[k]; + bincount_row[k] += 1; + } + tsk_memset(result_tmp, 0, sizeof(*result_tmp) * result_dim); + } + } +out: + tsk_safe_free(sites); + tsk_safe_free(result_tmp); + tsk_safe_free(num_alleles); + tsk_safe_free(site_offsets); + tsk_safe_free(allele_counts); + two_locus_work_free(&work); + tsk_bitset_free(&allele_samples); + tsk_bitset_free(&allele_sample_sets); + return ret; +} + +// In two_locus_decay_stat, we specify positions. These can be site positions or tree +// breakpoints. We pass them in at this level so that we can convert their overall +// positions using a recombination map if we'd like. +static int +tsk_treeseq_two_locus_decay_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, + norm_func_t *norm_f, const double *bins, tsk_size_t num_bins, + const double *positions, tsk_flags_t options, double *result, tsk_size_t *bincount) +{ + int ret = 0; + bool stat_site = !!(options & TSK_STAT_SITE); + bool stat_branch = !!(options & TSK_STAT_BRANCH); + tsk_size_t state_dim = num_sample_sets; + sample_count_stat_params_t f_params = { .sample_sets = sample_sets, + .num_sample_sets = num_sample_sets, + .sample_set_sizes = sample_set_sizes, + .set_indexes = set_indexes }; + + // We do not support two-locus node stats + if (!!(options & TSK_STAT_NODE)) { + ret = tsk_trace_error(TSK_ERR_UNSUPPORTED_STAT_MODE); + goto out; + } + // If no mode is specified, we default to site mode + if (!(stat_site || stat_branch)) { + stat_site = true; + } + // It's an error to specify more than one mode + if (stat_site + stat_branch > 1) { + ret = tsk_trace_error(TSK_ERR_MULTIPLE_STAT_MODES); + goto out; + } + ret = tsk_treeseq_check_sample_sets( + self, num_sample_sets, sample_set_sizes, sample_sets); + if (ret != 0) { + goto out; + } + if (result_dim < 1) { + ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS); + goto out; + } + if (stat_site) { + ret = tsk_treeseq_two_locus_site_decay_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, bins, + num_bins, positions, options, result, bincount); + } else if (stat_branch) { + ret = tsk_treeseq_two_locus_branch_decay_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, bins, num_bins, + positions, result, bincount); + goto out; + } else { + ret = TSK_ERR_UNSUPPORTED_STAT_MODE; + goto out; + } +out: + return ret; +} + /*********************************** * Allele frequency spectrum ***********************************/ @@ -4303,13 +4616,25 @@ tsk_treeseq_D(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result) { - options |= TSK_STAT_POLARISED; // TODO: allow user to pick? + options |= TSK_STAT_POLARISED; return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_summary_func, norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, result); } +int +tsk_treeseq_D_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + options |= TSK_STAT_POLARISED; // TODO: allow user to pick? + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D_summary_func, norm_total_weighted, bins, + num_bins, positions, options, result, bincount); +} + static int D2_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4348,6 +4673,17 @@ tsk_treeseq_D2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, result); } +int +tsk_treeseq_D2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D2_summary_func, norm_total_weighted, bins, + num_bins, positions, options, result, bincount); +} + static int r2_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4387,6 +4723,17 @@ tsk_treeseq_r2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, row_sites, row_positions, num_cols, col_sites, col_positions, options, result); } +int +tsk_treeseq_r2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, r2_summary_func, norm_hap_weighted, bins, + num_bins, positions, options, result, bincount); +} + static int D_prime_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4424,13 +4771,25 @@ tsk_treeseq_D_prime(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result) { - options |= TSK_STAT_POLARISED; // TODO: allow user to pick? + options |= TSK_STAT_POLARISED; return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, D_prime_summary_func, norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, result); } +int +tsk_treeseq_D_prime_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + options |= TSK_STAT_POLARISED; + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D_prime_summary_func, norm_total_weighted, + bins, num_bins, positions, options, result, bincount); +} + static int r_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4465,13 +4824,25 @@ tsk_treeseq_r(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result) { - options |= TSK_STAT_POLARISED; // TODO: allow user to pick? + options |= TSK_STAT_POLARISED; return tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, num_sample_sets, NULL, r_summary_func, norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, col_positions, options, result); } +int +tsk_treeseq_r_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + options |= TSK_STAT_POLARISED; + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, r_summary_func, norm_total_weighted, bins, + num_bins, positions, options, result, bincount); +} + static int Dz_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4511,6 +4882,17 @@ tsk_treeseq_Dz(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, result); } +int +tsk_treeseq_Dz_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, Dz_summary_func, norm_total_weighted, bins, + num_bins, positions, options, result, bincount); +} + static int pi2_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4547,6 +4929,17 @@ tsk_treeseq_pi2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, result); } +int +tsk_treeseq_pi2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, pi2_summary_func, norm_total_weighted, bins, + num_bins, positions, options, result, bincount); +} + static int D2_unbiased_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4584,6 +4977,17 @@ tsk_treeseq_D2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, col_positions, options, result); } +int +tsk_treeseq_D2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, D2_unbiased_summary_func, + norm_total_weighted, bins, num_bins, positions, options, result, bincount); +} + static int Dz_unbiased_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4622,6 +5026,17 @@ tsk_treeseq_Dz_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, col_positions, options, result); } +int +tsk_treeseq_Dz_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, Dz_unbiased_summary_func, + norm_total_weighted, bins, num_bins, positions, options, result, bincount); +} + static int pi2_unbiased_summary_func(tsk_size_t state_dim, const double *state, tsk_size_t TSK_UNUSED(result_dim), double *result, void *params) @@ -4660,6 +5075,17 @@ tsk_treeseq_pi2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, col_positions, options, result); } +int +tsk_treeseq_pi2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + return tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_sample_sets, NULL, pi2_unbiased_summary_func, + norm_total_weighted, bins, num_bins, positions, options, result, bincount); +} + /*********************************** * Two way stats ***********************************/ @@ -5036,6 +5462,25 @@ tsk_treeseq_D2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, return ret; } +int +tsk_treeseq_D2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, D2_ij_summary_func, + norm_total_weighted, bins, num_bins, positions, options, result, bincount); +out: + return ret; +} + static int D2_ij_unbiased_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, tsk_size_t result_dim, double *result, void *params) @@ -5110,6 +5555,25 @@ tsk_treeseq_D2_ij_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets return ret; } +int +tsk_treeseq_D2_ij_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, D2_ij_unbiased_summary_func, + norm_total_weighted, bins, num_bins, positions, options, result, bincount); +out: + return ret; +} + static int r2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, tsk_size_t result_dim, double *result, void *params) @@ -5170,6 +5634,25 @@ tsk_treeseq_r2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, return ret; } +int +tsk_treeseq_r2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_decay_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, r2_ij_summary_func, + norm_hap_weighted_ij, bins, num_bins, positions, options, result, bincount); +out: + return ret; +} + /*********************************** * Three way stats ***********************************/ diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 21495edbf7..7bf6ff3fc2 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -118,12 +118,14 @@ typedef struct { tsk_table_collection_t *tables; } tsk_treeseq_t; +typedef struct { + double left; + double right; +} interval_t; + typedef struct { tsk_id_t index; - struct { - double left; - double right; - } interval; + interval_t interval; struct { tsk_id_t start; tsk_id_t stop; @@ -1135,6 +1137,74 @@ typedef int k_way_two_locus_count_stat_method(const tsk_treeseq_t *self, const double *row_positions, tsk_size_t num_cols, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result); +typedef int two_locus_decay_stat_method(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, const double *bins, tsk_size_t num_bins, + const double *positions, tsk_flags_t options, double *result, tsk_size_t *bincount); + +typedef int k_way_two_locus_decay_stat_method(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, const double *bins, tsk_size_t num_bins, + const double *positions, tsk_flags_t options, double *result, tsk_size_t *bincount); + +int tsk_treeseq_D_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); +int tsk_treeseq_D2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); +int tsk_treeseq_r2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); +int tsk_treeseq_D_prime_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); +int tsk_treeseq_r_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); +int tsk_treeseq_Dz_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); +int tsk_treeseq_pi2_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); +int tsk_treeseq_D2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); +int tsk_treeseq_Dz_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); +int tsk_treeseq_pi2_unbiased_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); + +int tsk_treeseq_D2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); +int tsk_treeseq_D2_ij_unbiased_decay(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, const double *bins, tsk_size_t num_bins, + const double *positions, tsk_flags_t options, double *result, tsk_size_t *bincount); +int tsk_treeseq_r2_ij_decay(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, const double *bins, + tsk_size_t num_bins, const double *positions, tsk_flags_t options, double *result, + tsk_size_t *bincount); + /* Two way sample set stats */ int tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 78cb9f7c8e..b436cc00dc 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -990,6 +990,22 @@ bool_array_converter(PyObject *py_obj, PyArrayObject **array_out) return array_converter(NPY_BOOL, py_obj, array_out); } +static int +float64_array_converter(PyObject *py_obj, PyArrayObject **array_out) +{ + return array_converter(NPY_FLOAT64, py_obj, array_out); +} + +static int +optional_float64_array_converter(PyObject *py_obj, PyArrayObject **array_out) +{ + if (py_obj != Py_None) { + return array_converter(NPY_FLOAT64, py_obj, array_out); + } + *array_out = (PyArrayObject *) Py_None; + return 1; +} + /* Note: it doesn't seem to be possible to cast pointers to the actual * table functions to this type because the first argument must be a * void *, so the simplest option is to put in a small shim that @@ -8105,6 +8121,311 @@ TreeSequence_r2_ij_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) return TreeSequence_k_way_ld_matrix(self, args, kwds, 2, tsk_treeseq_r2_ij); } +static int +parse_decay_positions(const tsk_treeseq_t *ts, PyArrayObject *positions_array, + tsk_flags_t options, const double **out) +{ + bool stat_site, stat_branch; + tsk_size_t positions_len; + stat_site = !!(options & TSK_STAT_SITE); + stat_branch = !!(options & TSK_STAT_BRANCH); + + if (!(stat_site || stat_branch)) { + return 0; // mode validation happens later, out as NULL will be fine + } + + if ((PyObject *) positions_array == Py_None) { + *out = stat_site ? ts->tables->sites.position : ts->breakpoints; + } else { + if (PyArray_NDIM(positions_array) != 1) { + PyErr_Format(PyExc_ValueError, "positions must be a 1d array."); + return 1; + } + positions_len = PyArray_DIM(positions_array, 0); + if (stat_site && (ts->tables->sites.num_rows != positions_len)) { + PyErr_Format(PyExc_ValueError, + "site positions must contain one element per site " + "(want a length %lu array).", + ts->tables->sites.num_rows); + return 1; + } else if (stat_branch && (ts->num_trees + 1 != positions_len)) { + PyErr_Format(PyExc_ValueError, + "site positions must contain one element per tree breakpoint" + "(want a length %lu array).", + ts->num_trees + 1); + return 1; + } + *out = PyArray_DATA(positions_array); + } + return 0; +} + +static PyObject * +TreeSequence_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, + two_locus_decay_stat_method *method) +{ + PyObject *ret = NULL; + static char *kwlist[] + = { "sample_set_sizes", "sample_sets", "bins", "positions", "mode", NULL }; + PyObject *sample_sets = NULL; + PyObject *sample_set_sizes = NULL; + PyArrayObject *sample_sets_array = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + PyArrayObject *bins = NULL; + PyArrayObject *positions_array = NULL; + PyArrayObject *result_stat_matrix = NULL; + PyArrayObject *result_bincount_matrix = NULL; + npy_intp num_bins; + npy_intp result_dim[2]; + tsk_size_t num_sample_sets; + const double *positions = NULL; + char *mode = NULL; + tsk_flags_t options = 0; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO&|O&s", kwlist, &sample_set_sizes, + &sample_sets, &float64_array_converter, &bins, + &optional_float64_array_converter, &positions_array, &mode)) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; + } + if (parse_decay_positions(self->tree_sequence, positions_array, options, &positions) + != 0) { + goto out; + } + num_bins = PyArray_DIM(bins, 0); + result_dim[0] = num_bins - 1; + result_dim[1] = num_sample_sets; + result_stat_matrix = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_FLOAT64, 0); + if (result_stat_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + result_bincount_matrix + = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_UINT64, 0); + if (result_bincount_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + // clang-format off + Py_BEGIN_ALLOW_THREADS + err = method(self->tree_sequence, num_sample_sets, + PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), + PyArray_DATA(bins), num_bins, positions, options, PyArray_DATA(result_stat_matrix), PyArray_DATA(result_bincount_matrix)); + Py_END_ALLOW_THREADS + // clang-format on + if (err != 0) + { + handle_library_error(err); + goto out; + } + // Return a tuple of matrixes: (stat, count) + ret = PyTuple_New(2); + PyTuple_SET_ITEM(ret, 0, (PyObject *) result_stat_matrix); + PyTuple_SET_ITEM(ret, 1, (PyObject *) result_bincount_matrix); + result_stat_matrix = NULL; + result_bincount_matrix = NULL; +out: + Py_XDECREF(bins); + Py_XDECREF(positions_array); + Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(result_stat_matrix); + Py_XDECREF(result_bincount_matrix); + return ret; +} + +static PyObject * +TreeSequence_D_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_D_decay); +} + +static PyObject * +TreeSequence_D2_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_D2_decay); +} + +static PyObject * +TreeSequence_r2_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_r2_decay); +} + +static PyObject * +TreeSequence_D_prime_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_D_prime_decay); +} + +static PyObject * +TreeSequence_r_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_r_decay); +} + +static PyObject * +TreeSequence_Dz_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_Dz_decay); +} + +static PyObject * +TreeSequence_pi2_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_pi2_decay); +} + +static PyObject * +TreeSequence_pi2_unbiased_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_pi2_unbiased_decay); +} + +static PyObject * +TreeSequence_D2_unbiased_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_D2_unbiased_decay); +} + +static PyObject * +TreeSequence_Dz_unbiased_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_ld_decay(self, args, kwds, tsk_treeseq_Dz_unbiased_decay); +} + +static PyObject * +TreeSequence_k_way_ld_decay(TreeSequence *self, PyObject *args, PyObject *kwds, + npy_intp tuple_size, k_way_two_locus_decay_stat_method *method) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "indexes", "bins", + "positions", "mode", NULL }; + PyObject *sample_sets = NULL; + PyObject *sample_set_sizes = NULL; + PyObject *indexes = NULL; + PyArrayObject *bins = NULL; + PyArrayObject *positions_array = NULL; + PyArrayObject *indexes_array = NULL; + PyArrayObject *result_stat_matrix = NULL; + PyArrayObject *result_bincount_matrix = NULL; + PyArrayObject *sample_sets_array = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + npy_intp num_bins; + npy_intp *shape; + npy_intp result_dim[2]; + tsk_size_t num_sample_sets; + tsk_size_t num_set_index_tuples; + const double *positions = NULL; + char *mode = NULL; + tsk_flags_t options = 0; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOO&|O&s", kwlist, &sample_set_sizes, + &sample_sets, &indexes, &float64_array_converter, &bins, + &optional_float64_array_converter, &positions_array, &mode)) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; + } + if (parse_decay_positions(self->tree_sequence, positions_array, options, &positions) + != 0) { + goto out; + } + indexes_array = (PyArrayObject *) PyArray_FROMANY( + indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY); + if (indexes_array == NULL) { + goto out; + } + shape = PyArray_DIMS(indexes_array); + if (shape[0] < 1 || shape[1] != tuple_size) { + PyErr_Format( + PyExc_ValueError, "indexes must be a k x %d array.", (int) tuple_size); + goto out; + } + num_set_index_tuples = shape[0]; + num_bins = PyArray_DIM(bins, 0); + result_dim[0] = num_bins - 1; + result_dim[1] = num_set_index_tuples; + result_stat_matrix = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_FLOAT64, 0); + if (result_stat_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + result_bincount_matrix + = (PyArrayObject *) PyArray_ZEROS(2, result_dim, NPY_UINT64, 0); + if (result_bincount_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + // clang-format off + Py_BEGIN_ALLOW_THREADS + err = method(self->tree_sequence, num_sample_sets, + PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), num_set_index_tuples, + PyArray_DATA(indexes_array), PyArray_DATA(bins), num_bins, positions, options, PyArray_DATA(result_stat_matrix), PyArray_DATA(result_bincount_matrix)); + Py_END_ALLOW_THREADS + // clang-format on + if (err != 0) + { + handle_library_error(err); + goto out; + } + // Return a tuple of matrixes: (stat, count) + ret = PyTuple_New(2); + PyTuple_SET_ITEM(ret, 0, (PyObject *) result_stat_matrix); + PyTuple_SET_ITEM(ret, 1, (PyObject *) result_bincount_matrix); + result_stat_matrix = NULL; + result_bincount_matrix = NULL; +out: + Py_XDECREF(bins); + Py_XDECREF(indexes_array); + Py_XDECREF(positions_array); + Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(result_stat_matrix); + Py_XDECREF(result_bincount_matrix); + return ret; +} + +static PyObject * +TreeSequence_D2_ij_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_decay(self, args, kwds, 2, tsk_treeseq_D2_ij_decay); +} + +static PyObject * +TreeSequence_D2_ij_unbiased_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_decay( + self, args, kwds, 2, tsk_treeseq_D2_ij_unbiased_decay); +} + +static PyObject * +TreeSequence_r2_ij_decay(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_decay(self, args, kwds, 2, tsk_treeseq_r2_ij_decay); +} + static PyObject * TreeSequence_get_num_mutations(TreeSequence *self) { @@ -8800,6 +9121,46 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_pi2_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the pi2 matrix." }, + { .ml_name = "D_decay", + .ml_meth = (PyCFunction) TreeSequence_D_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the D decay curve." }, + { .ml_name = "D2_decay", + .ml_meth = (PyCFunction) TreeSequence_D2_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the D2 decay curve." }, + { .ml_name = "r2_decay", + .ml_meth = (PyCFunction) TreeSequence_r2_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the r2 decay curve." }, + { .ml_name = "D_prime_decay", + .ml_meth = (PyCFunction) TreeSequence_D_prime_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the D_prime decay curve." }, + { .ml_name = "r_decay", + .ml_meth = (PyCFunction) TreeSequence_r_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the r decay curve." }, + { .ml_name = "Dz_decay", + .ml_meth = (PyCFunction) TreeSequence_Dz_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the Dz decay curve." }, + { .ml_name = "pi2_decay", + .ml_meth = (PyCFunction) TreeSequence_pi2_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the pi2 decay curve." }, + { .ml_name = "D2_unbiased_decay", + .ml_meth = (PyCFunction) TreeSequence_D2_unbiased_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased D2 decay curve." }, + { .ml_name = "Dz_unbiased_decay", + .ml_meth = (PyCFunction) TreeSequence_Dz_unbiased_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased Dz decay curve." }, + { .ml_name = "pi2_unbiased_decay", + .ml_meth = (PyCFunction) TreeSequence_pi2_unbiased_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the unbiased pi2 decay curve." }, { .ml_name = "D2_unbiased_matrix", .ml_meth = (PyCFunction) TreeSequence_D2_unbiased_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, @@ -8824,6 +9185,18 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_r2_ij_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the two-way r^2 matrix." }, + { .ml_name = "r2_ij_decay", + .ml_meth = (PyCFunction) TreeSequence_r2_ij_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way r2 decay curve." }, + { .ml_name = "D2_ij_decay", + .ml_meth = (PyCFunction) TreeSequence_D2_ij_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way D2 decay curve." }, + { .ml_name = "D2_ij_unbiased_decay", + .ml_meth = (PyCFunction) TreeSequence_D2_ij_unbiased_decay, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way unbiased D2 decay curve." }, { NULL } /* Sentinel */ }; @@ -10535,8 +10908,8 @@ static PyMethodDef Tree_methods[] = { { .ml_name = "map_mutations", .ml_meth = (PyCFunction) Tree_map_mutations, .ml_flags = METH_VARARGS | METH_KEYWORDS, - .ml_doc - = "Returns a parsimonious state reconstruction for the specified genotypes." }, + .ml_doc = "Returns a parsimonious state reconstruction for the specified " + "genotypes." }, { .ml_name = "equals", .ml_meth = (PyCFunction) Tree_equals, .ml_flags = METH_VARARGS, @@ -11966,9 +12339,8 @@ PyInit__tskit(void) return NULL; } Py_INCREF(&LsHmmType); - PyModule_AddObject(module, "LsHmm", (PyObject *) &LsHmmType); - - /* IdentitySegments type */ + PyModule_AddObject( + module, "LsHmm", (PyObject *) &LsHmmType); /* IdentitySegments type */ if (PyType_Ready(&IdentitySegmentsType) < 0) { return NULL; } diff --git a/python/tests/test_ld_decay.py b/python/tests/test_ld_decay.py new file mode 100644 index 0000000000..3f0d5747f2 --- /dev/null +++ b/python/tests/test_ld_decay.py @@ -0,0 +1,280 @@ +import contextlib +from itertools import combinations_with_replacement +from itertools import product + +import demes +import msprime +import numpy as np +import pytest + +from tskit import Interval + + +@contextlib.contextmanager +def suppress_overflow_div0_warning(): + with np.errstate(over="ignore", invalid="ignore", divide="ignore"): + yield + + +def expand_dims(arr): + """ + Expand the dimensions of the provided array (arrays). This helps to control + the output dimensions of the ld matrix (ie if there's 2 dimensions to + indexes or sample_sets, we'll get a 3D ld matrix back. This will not be + necessary in the C implementation because dimension dropping happens in the + python layer. + """ + arr = np.asarray(arr) + if arr.ndim == 1: + return np.expand_dims(arr, axis=0) + try: + arr = [np.asarray(a) for a in arr] + except Exception as e: + raise ValueError("Must be a list of 1D array-like") from e + for a in arr: + if a.ndim != 1: + raise ValueError("Must be a list of 1D arrays") + return arr + + +def check_bins(bins, seq_len): + try: + bins = np.asarray(bins) + except Exception as e: + raise ValueError("Bins must be coercible to a 1D array") from e + if bins.ndim != 1: + raise ValueError("Bins must be a 1D array") + if not np.all(bins[:-1] <= bins[1:]): + raise ValueError("Bins must be sorted") + if bins[-1] > seq_len: + raise ValueError(f"Last bin is out of bounds, must be <= L: {bins[-1]}") + if len(bins) < 2: + raise ValueError(f"Must have at least 2 bins, got {len(bins)}") + if (bins < 0).any(): + raise ValueError("Bins must be greater than 0") + return bins + + +def construct_ld_matrix(ts, stat, sample_sets, indexes): + """ + Produce an ld matrix with the same error characteristics as the C version + Create an LD matrix by starting at the diagonal of each row. This ensures + that we accumulate error in the same way as we would in the C version. If + we produce an LD matrix starting from tree 0 at each row, we accumulate a + different (likely more) amount of error. + """ + bp = ts.breakpoints(as_array=True)[:-1] + k = len(sample_sets) if indexes is None else len(indexes) + out = np.zeros((k, ts.num_trees, ts.num_trees)) + for i, b in enumerate(bp): + out[0:k, i, i:] = ts.ld_matrix( + sample_sets=sample_sets, + indexes=indexes, + mode="branch", + stat=stat, + positions=[[b], bp[i:]], + )[:, 0, :] # result is for one row + return out + + +def integrate_stat_over_bin(bin, i1, i2, stat): + bl, br = bin + # Integration support + l_support = i2.left - i1.right + r_support = i2.right - i1.left + # length of the middle region + r2_len = min(i1.right - i1.left, i2.right - i2.left) + # bounds of the middle region + r2_l_bound = min(i2.left - i1.left, i2.right - i1.right) + r2_r_bound = r_support - r2_len + + r1_l = min(max(bl, l_support), r2_l_bound) + r1_r = max(min(br, r2_l_bound), l_support) + r2_l = min(max(bl, r2_l_bound), r2_r_bound) + r2_r = max(min(br, r2_r_bound), r2_l_bound) + r3_l = min(max(bl, r2_r_bound), r_support) + r3_r = max(min(br, r_support), r2_r_bound) + return ( + stat + / (i1.span * i2.span) + * ( + -1 / 2 * (r1_l - r1_r) * (2 * i1.right - 2 * i2.left + r1_l + r1_r) + + (r2_r - r2_l) * r2_len + + 1 / 2 * (r3_l - r3_r) * (2 * i1.left - 2 * i2.right + r3_l + r3_r) + ) + ) + + +def isect(l1, r1, l2, r2): + "left closed, right open, left is ivl and right is query" + return max(l1, l2) < min(r1, r2) + + +def get_tree_pair_bounds(ivl_l, ivl_r, bins): + return Interval( + max(bins[0], ivl_r.left - ivl_l.right), + min(bins[-1], ivl_r.right - ivl_l.left), + ) + + +def ld_decay_branch(ts, bins, stat, sample_sets, indexes): + ld = construct_ld_matrix(ts, stat, sample_sets, indexes) + dims = (len(indexes or sample_sets), len(bins) - 1) + result = np.zeros(dims, dtype=float) + bincount = np.zeros(dims, dtype=int) + bp = ts.breakpoints(as_array=True) + bin_ivls = np.fromiter(zip(bins[:-1], bins[1:]), np.dtype((float, 2))) + for i, j in combinations_with_replacement(range(ts.num_trees), 2): # upper tri+diag + ivl_l = Interval(bp[i], bp[i + 1]) + ivl_r = Interval(bp[j], bp[j + 1]) + bounds = get_tree_pair_bounds(ivl_l, ivl_r, bins) + for k in range(dims[0]): + result[k] += np.apply_along_axis( + integrate_stat_over_bin, 1, bin_ivls, ivl_l, ivl_r, ld[k, i, j] + ) + bincount[k] += np.fromiter((isect(*bounds, *b) for b in bin_ivls), int) + if dims[0] == 1: # drop dims if first dim is length 1 + return result.reshape(dims[1:]), bincount.reshape(dims[1:]) + return result, bincount + + +def ld_decay_site(ts, bins, stat, sample_sets, indexes): + ld = ts.ld_matrix(stat=stat, sample_sets=sample_sets, indexes=indexes) + dims = (len(indexes or sample_sets), len(bins) - 1) + result = np.zeros(dims, dtype=float) + bincount = np.zeros(dims, dtype=int) + site_pos = ts.sites_position + for i in range(ts.num_sites): + for j in range(i + 1, ts.num_sites): # upper tri (-diag) + dist = site_pos[j] - site_pos[i] + if dist > bins[-1]: + break + bin = np.searchsorted(bins[1:], dist, side="left") + for k in range(dims[0]): + s = ld[k, i, j] + if np.isnan(s): + continue + result[k, bin] += s + bincount[k, bin] += 1 + if dims[0] == 1: # drop dims if first dim is length 1 + return result.reshape(dims[1:]), bincount.reshape(dims[1:]) + return result, bincount + + +def ld_decay( + ts, + bins, + stat="r2", + sample_sets=None, + indexes=None, + mode="site", + return_counts=False, +): + bins = check_bins(bins, ts.sequence_length) + sample_sets = expand_dims(sample_sets or [ts.samples()]) + if indexes is not None: + indexes = expand_dims(indexes) + match mode: + case "site": + result, count = ld_decay_site(ts, bins, stat, sample_sets, indexes) + case "branch": + result, count = ld_decay_branch(ts, bins, stat, sample_sets, indexes) + case _: + raise ValueError(f"Unknown Stats Mode: {mode}") + + if return_counts: + return result, count + with suppress_overflow_div0_warning(): + return result / count + + +ONE_WAY_STATS = [ + "r", + "r2", + "D", + "D2", + "D_prime", + "pi2", + "Dz", + "D2_unbiased", + "Dz_unbiased", + "pi2_unbiased", +] + +TWO_WAY_STATS = ["r2", "D2", "D2_unbiased"] + +TS = msprime.sim_mutations( + msprime.sim_ancestry( + samples=100, + sequence_length=1e5, + recombination_rate=1e-8, + demography=msprime.Demography.from_demes( + demes.loads(""" + time_units: generations + demes: + - name: A + epochs: + - {start_size: 5000, end_time: 1000} + - {start_size: 1000, end_time: 400} + - {start_size: 5000, end_time: 0} + """) + ), + random_seed=23, + ), + rate=1e-7, + random_seed=23, +) + + +@pytest.mark.parametrize("stat,mode", product(ONE_WAY_STATS, ["site", "branch"])) +def test_ld_decay(stat, mode): + bins = np.logspace(0, np.log10(TS.sequence_length), num=35) + bins[0] = 0 + decay, counts = ld_decay(TS, bins, stat=stat, mode=mode, return_counts=True) + c = TS.ld_decay(bins, stat=stat, mode=mode) + with suppress_overflow_div0_warning(): + np.testing.assert_array_equal(decay / counts, c) + # Verify that the sum of all LD in our bins is equal to the sum of the LD + # matrix entries from which they originated. + if mode == "branch": + tu = np.triu( + construct_ld_matrix( + TS, sample_sets=expand_dims(TS.samples()), indexes=None, stat=stat + ).squeeze() + ) + dmask = np.diag_indices_from(tu) + tu[dmask] = tu[dmask] / 2 # we take half the density on the diagonal + np.testing.assert_allclose(np.nansum(decay), np.nansum(tu)) + # all but r2 D2 Dz are within 1 ulp, likely due to numerical precision + np.testing.assert_array_almost_equal_nulp( + np.nansum(decay), np.nansum(tu), nulp=2 + ) + elif mode == "site": + tu = TS.ld_matrix(stat=stat)[np.triu_indices(TS.num_sites, k=1)] + np.testing.assert_allclose(decay.sum(), np.nansum(tu)) + + +@pytest.mark.parametrize("stat,mode", product(ONE_WAY_STATS, ["site", "branch"])) +def test_ld_decay_sample_sets(stat, mode): + bins = np.logspace(0, np.log10(TS.sequence_length), num=35) + bins[0] = 0 + sample_sets = [TS.samples(), TS.samples(), TS.samples()] + decay = TS.ld_decay(bins, sample_sets=sample_sets, stat=stat, mode=mode) + np.testing.assert_array_equal(decay[0], decay[1]) + np.testing.assert_array_equal(decay[1], decay[2]) + + +@pytest.mark.slow +@pytest.mark.parametrize("stat,mode", product(TWO_WAY_STATS, ["site", "branch"])) +def test_two_way_ld_decay(stat, mode): + bins = np.logspace(0, np.log10(TS.sequence_length), num=35) + np.testing.assert_array_almost_equal( + ld_decay(TS, bins, stat=stat, mode=mode), + TS.ld_decay(bins, stat=stat, mode=mode), + ) + ss = [TS.samples()] * 3 + indexes = [(0, 0), (0, 1), (1, 1)] + np.testing.assert_array_almost_equal( + ld_decay(TS, bins, stat=stat, mode=mode, sample_sets=ss, indexes=indexes), + TS.ld_decay(bins, stat=stat, mode=mode, sample_sets=ss, indexes=indexes), + ) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 9904b60a98..09aa89cf2d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -23,6 +23,7 @@ """ Module responsible for managing trees and tree sequences. """ + from __future__ import annotations import base64 @@ -696,8 +697,7 @@ def __init__( options = 0 if sample_counts is not None: warnings.warn( - "The sample_counts option is not supported since 0.2.4 " - "and is ignored", + "The sample_counts option is not supported since 0.2.4 and is ignored", RuntimeWarning, stacklevel=4, ) @@ -6889,7 +6889,7 @@ def to_macs(self): bytes_genotypes[:] = lookup[variant.genotypes] genotypes = bytes_genotypes.tobytes().decode() output.append( - f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t" f"{genotypes}" + f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t{genotypes}" ) return "\n".join(output) + "\n" @@ -8331,6 +8331,119 @@ def __k_way_sample_set_stat( stat = stat[()] return stat + def _try_drop_dimension(self, sample_sets): + # First try to convert to a 1D numpy array. If we succeed, then we strip off + # the corresponding dimension from the output. + drop_dimension = False + try: + sample_sets = np.array(sample_sets, dtype=np.uint64) + except ValueError: + pass + else: + # If we've successfully converted sample_sets to a 1D numpy array + # of integers then drop the dimension + if len(sample_sets.shape) == 1: + sample_sets = [sample_sets] + drop_dimension = True + return sample_sets, drop_dimension + + def __two_locus_sample_set_decay_stat( + self, + ll_method, + sample_sets, + bins, + return_counts, + ratemap=None, + mode=None, + ): + if sample_sets is None: + sample_sets = self.samples() + + sample_sets, drop_dimension = self._try_drop_dimension(sample_sets) + sample_set_sizes = np.array( + [len(sample_set) for sample_set in sample_sets], dtype=np.uint32 + ) + if np.any(sample_set_sizes == 0): + raise ValueError("Sample sets must contain at least one element") + flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + positions = None + if ratemap is not None: + # rate in cM + if mode is None or mode == "site": + positions = ratemap.get_cumulative_mass(self.sites_position) * 100 + elif mode == "branch": + positions = ( + ratemap.get_cumulative_mass(self.breakpoints(as_array=True)) * 100 + ) + result, counts = ll_method(sample_set_sizes, flattened, bins, positions, mode) + if drop_dimension: + result = result.reshape(result.shape[0]) + counts = counts.reshape(counts.shape[0]) + else: + # Orient the data so that the first dimension is the sample set. + result = result.swapaxes(0, 1) + counts = counts.swapaxes(0, 1) + if return_counts: + return result, counts + with np.errstate(divide="ignore", invalid="ignore"): + return result / counts + + def __k_way_two_locus_sample_set_decay_stat( + self, + ll_method, + k, + sample_sets, + bins, + return_counts, + indexes=None, + ratemap=None, + mode=None, + ): + sample_set_sizes = np.array( + [len(sample_set) for sample_set in sample_sets], dtype=np.uint32 + ) + if np.any(sample_set_sizes == 0): + raise ValueError("Sample sets must contain at least one element") + flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + drop_dimension = False + indexes = util.safe_np_int_cast(indexes, np.int32) + if len(indexes.shape) == 1: + indexes = indexes.reshape((1, indexes.shape[0])) + drop_dimension = True + if len(indexes.shape) != 2 or indexes.shape[1] != k: + raise ValueError( + "Indexes must be convertable to a 2D numpy array with {} " + "columns".format(k) + ) + positions = None + if ratemap is not None: + # rate in cM + if mode is None or mode == "site": + positions = ratemap.get_cumulative_mass(self.sites_position) * 100 + elif mode == "branch": + positions = ( + ratemap.get_cumulative_mass(self.breakpoints(as_array=True)) * 100 + ) + result, counts = ll_method( + sample_set_sizes, + flattened, + indexes, + bins, + positions, + mode, + ) + if drop_dimension: + result = result.reshape(result.shape[0]) + counts = counts.reshape(counts.shape[0]) + else: + # Orient the data so that the first dimension is the sample set. + result = result.swapaxes(0, 1) + counts = counts.swapaxes(0, 1) + if return_counts: + return result, counts + with np.errstate(divide="ignore", invalid="ignore"): + return result / counts + def __k_way_weighted_stat( self, ll_method, @@ -9281,9 +9394,9 @@ def pca( if time_windows is None: tree_sequence_low, tree_sequence_high = None, self else: - assert ( - time_windows[0] < time_windows[1] - ), "The second argument should be larger." + assert time_windows[0] < time_windows[1], ( + "The second argument should be larger." + ) tree_sequence_low, tree_sequence_high = ( self.decapitate(time_windows[0]), self.decapitate(time_windows[1]), @@ -9351,9 +9464,9 @@ def _rand_pow_range_finder( """ Algorithm 9 in https://arxiv.org/pdf/2002.01387 """ - assert ( - num_vectors >= rank > 0 - ), "num_vectors should not be smaller than rank" + assert num_vectors >= rank > 0, ( + "num_vectors should not be smaller than rank" + ) for _ in range(depth): Q = np.linalg.qr(Q)[0] Q = operator(Q) @@ -10880,6 +10993,60 @@ def ld_matrix( stat_func, sample_sets, sites=sites, positions=positions, mode=mode ) + def ld_decay( + self, + bins, + sample_sets=None, + mode="site", + stat="r2", + indexes=None, + ratemap=None, + return_counts=False, + ): + one_way_stats = { + "D": self._ll_tree_sequence.D_decay, + "D2": self._ll_tree_sequence.D2_decay, + "r2": self._ll_tree_sequence.r2_decay, + "D_prime": self._ll_tree_sequence.D_prime_decay, + "r": self._ll_tree_sequence.r_decay, + "Dz": self._ll_tree_sequence.Dz_decay, + "pi2": self._ll_tree_sequence.pi2_decay, + "Dz_unbiased": self._ll_tree_sequence.Dz_unbiased_decay, + "D2_unbiased": self._ll_tree_sequence.D2_unbiased_decay, + "pi2_unbiased": self._ll_tree_sequence.pi2_unbiased_decay, + } + two_way_stats = { + "D2": self._ll_tree_sequence.D2_ij_decay, + "D2_unbiased": self._ll_tree_sequence.D2_ij_unbiased_decay, + "r2": self._ll_tree_sequence.r2_ij_decay, + } + stats = one_way_stats if indexes is None else two_way_stats + try: + stat_func = stats[stat] + except KeyError: + raise ValueError( + f"Unknown two-locus statistic '{stat}', we support: {list(stats.keys())}" + ) + if indexes is not None: + return self.__k_way_two_locus_sample_set_decay_stat( + stat_func, + 2, + sample_sets, + bins, + return_counts, + indexes=indexes, + ratemap=ratemap, + mode=mode, + ) + return self.__two_locus_sample_set_decay_stat( + stat_func, + sample_sets, + bins, + return_counts, + ratemap=ratemap, + mode=mode, + ) + def sample_nodes_by_ploidy(self, ploidy): """ Returns an 2D array of node IDs, where each row has length `ploidy`.