From 988d1ae340fddec13c0844ebd533c70539f1eb84 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 31 Mar 2026 10:35:15 +0100 Subject: [PATCH 1/2] Move add_vestigial_root to new arg_ops module Create tsinfer/arg_ops.py for ARG topology operations and move add_vestigial_root from matching.py. Create tests/test_arg_ops.py with the corresponding tests. Update all call sites in matching.py, test_python_c.py, and test_lshmm.py. --- tests/test_arg_ops.py | 19 +++++++++++++ tests/test_lshmm.py | 4 +-- tests/test_matching.py | 21 --------------- tests/test_python_c.py | 12 ++++----- tsinfer/arg_ops.py | 61 ++++++++++++++++++++++++++++++++++++++++++ tsinfer/matching.py | 40 ++------------------------- 6 files changed, 90 insertions(+), 67 deletions(-) create mode 100644 tests/test_arg_ops.py create mode 100644 tsinfer/arg_ops.py diff --git a/tests/test_arg_ops.py b/tests/test_arg_ops.py new file mode 100644 index 00000000..904ff407 --- /dev/null +++ b/tests/test_arg_ops.py @@ -0,0 +1,19 @@ +import pytest +import tskit + +from tsinfer import arg_ops + + +class TestAddVestigialRoot: + def test_non_discrete_genome(self): + tables = tskit.TableCollection(sequence_length=1.5) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + ts = tables.tree_sequence() + with pytest.raises(ValueError, match="discrete genome"): + arg_ops.add_vestigial_root(ts) + + def test_empty_tree_sequence(self): + tables = tskit.TableCollection(sequence_length=1) + ts = tables.tree_sequence() + with pytest.raises(ValueError, match="Emtpy trees"): + arg_ops.add_vestigial_root(ts) diff --git a/tests/test_lshmm.py b/tests/test_lshmm.py index 02bbb146..24b09147 100644 --- a/tests/test_lshmm.py +++ b/tests/test_lshmm.py @@ -14,7 +14,7 @@ import tskit import _tsinfer -from tsinfer import matching +from tsinfer import arg_ops, matching @dataclasses.dataclass @@ -78,7 +78,7 @@ class MatcherIndexes: def __init__(self, in_tables, *, vestigial_root=True, num_alleles=None): ts = in_tables.tree_sequence() if vestigial_root: - ts = matching.add_vestigial_root(ts) + ts = arg_ops.add_vestigial_root(ts) tables = ts.dump_tables() self.sequence_length = tables.sequence_length diff --git a/tests/test_matching.py b/tests/test_matching.py index 1028604a..041993a0 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -23,7 +23,6 @@ from __future__ import annotations import numpy as np -import pytest import tskit from tsinfer import grouping, matching, vcz @@ -761,26 +760,6 @@ def test_metadata_survives_multiple_cycles(self): assert ts2.metadata["sequence_intervals"] == [[10, 51]] -# --------------------------------------------------------------------------- -# TestAddVestigialRoot -# --------------------------------------------------------------------------- - - -class TestAddVestigialRoot: - def test_non_discrete_genome(self): - tables = tskit.TableCollection(sequence_length=1.5) - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - ts = tables.tree_sequence() - with pytest.raises(ValueError, match="discrete genome"): - matching.add_vestigial_root(ts) - - def test_empty_tree_sequence(self): - tables = tskit.TableCollection(sequence_length=1) - ts = tables.tree_sequence() - with pytest.raises(ValueError, match="Emtpy trees"): - matching.add_vestigial_root(ts) - - # --------------------------------------------------------------------------- # TestAncestorMatcherWrapper # --------------------------------------------------------------------------- diff --git a/tests/test_python_c.py b/tests/test_python_c.py index b4237334..d118ddc5 100644 --- a/tests/test_python_c.py +++ b/tests/test_python_c.py @@ -28,7 +28,7 @@ import tskit import _tsinfer -from tsinfer import matching +from tsinfer import arg_ops IS_WINDOWS = sys.platform == "win32" @@ -155,7 +155,7 @@ def make_matcher_indexes_and_matcher(num_samples=4): tables.sites.add_row(position=1, ancestral_state="A") tables.mutations.add_row(site=0, node=1, derived_state="T") ts = tables.tree_sequence() - ts = matching.add_vestigial_root(ts) + ts = arg_ops.add_vestigial_root(ts) ll_tables = _tsinfer.LightweightTableCollection(ts.sequence_length) ll_tables.fromdict(ts.dump_tables().asdict()) mi = _tsinfer.MatcherIndexes(ll_tables) @@ -267,7 +267,7 @@ def test_find_path_match_impossible(self): tables.sites.add_row(position=1, ancestral_state="A") # No mutations: all nodes carry allele 0 ts = tables.tree_sequence() - ts = matching.add_vestigial_root(ts) + ts = arg_ops.add_vestigial_root(ts) ll_tables = _tsinfer.LightweightTableCollection(ts.sequence_length) ll_tables.fromdict(ts.dump_tables().asdict()) mi = _tsinfer.MatcherIndexes(ll_tables) @@ -299,7 +299,7 @@ def test_get_traceback_bad_site(self): class TestMatcherIndexes: def test_single_tree(self): ts = tskit.Tree.generate_balanced(4).tree_sequence - ts = matching.add_vestigial_root(ts) + ts = arg_ops.add_vestigial_root(ts) tables = ts.dump_tables() ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) ll_tables.fromdict(tables.asdict()) @@ -323,7 +323,7 @@ def test_num_alleles(self): def test_print_state(self, tmpdir): ts = tskit.Tree.generate_balanced(4).tree_sequence - ts = matching.add_vestigial_root(ts) + ts = arg_ops.add_vestigial_root(ts) tables = ts.dump_tables() ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) ll_tables.fromdict(tables.asdict()) @@ -341,7 +341,7 @@ def test_print_state(self, tmpdir): def test_print_state_bad_file(self): ts = tskit.Tree.generate_balanced(4).tree_sequence - ts = matching.add_vestigial_root(ts) + ts = arg_ops.add_vestigial_root(ts) tables = ts.dump_tables() ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) ll_tables.fromdict(tables.asdict()) diff --git a/tsinfer/arg_ops.py b/tsinfer/arg_ops.py new file mode 100644 index 00000000..05e337bc --- /dev/null +++ b/tsinfer/arg_ops.py @@ -0,0 +1,61 @@ +# +# Copyright (C) 2018-2026 University of Oxford +# +# This file is part of tsinfer. +# +# tsinfer is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# tsinfer is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with tsinfer. If not, see . +# +""" +Operations on ARG (Ancestral Recombination Graph) topology. +""" + +import logging + +logger = logging.getLogger(__name__) + + +def add_vestigial_root(ts): + """ + Adds the nodes and edges required by tsinfer to the specified tree sequence + and returns it. + """ + if not ts.discrete_genome: + raise ValueError("Only discrete genome coords supported") + if ts.num_nodes == 0: + raise ValueError("Emtpy trees not supported") + + base_tables = ts.dump_tables() + tables = base_tables.copy() + tables.nodes.clear() + t = max(ts.nodes_time) + tables.nodes.add_row(time=t + 1) + num_additonal_nodes = 1 + tables.mutations.node += num_additonal_nodes + tables.edges.child += num_additonal_nodes + tables.edges.parent += num_additonal_nodes + for node in base_tables.nodes: + tables.nodes.append(node) + if ts.num_edges > 0: + for tree in ts.trees(): + # if tree.num_roots > 1: + # print(ts.draw_text()) + root = tree.root + num_additonal_nodes + tables.edges.add_row( + tree.interval.left, tree.interval.right, parent=0, child=root + ) + tables.edges.squash() + # FIXME probably don't need to sort here most of the time, or at least + # we can just sort almost the end of the table. + tables.sort() + return tables.tree_sequence() diff --git a/tsinfer/matching.py b/tsinfer/matching.py index 7845e750..9c002216 100644 --- a/tsinfer/matching.py +++ b/tsinfer/matching.py @@ -33,7 +33,7 @@ import _tsinfer -from . import grouping, vcz +from . import arg_ops, grouping, vcz logger = logging.getLogger(__name__) @@ -380,48 +380,12 @@ def extend_ts( return result_ts -def add_vestigial_root(ts): - """ - Adds the nodes and edges required by tsinfer to the specified tree sequence - and returns it. - """ - if not ts.discrete_genome: - raise ValueError("Only discrete genome coords supported") - if ts.num_nodes == 0: - raise ValueError("Emtpy trees not supported") - - base_tables = ts.dump_tables() - tables = base_tables.copy() - tables.nodes.clear() - t = max(ts.nodes_time) - tables.nodes.add_row(time=t + 1) - num_additonal_nodes = 1 - tables.mutations.node += num_additonal_nodes - tables.edges.child += num_additonal_nodes - tables.edges.parent += num_additonal_nodes - for node in base_tables.nodes: - tables.nodes.append(node) - if ts.num_edges > 0: - for tree in ts.trees(): - # if tree.num_roots > 1: - # print(ts.draw_text()) - root = tree.root + num_additonal_nodes - tables.edges.add_row( - tree.interval.left, tree.interval.right, parent=0, child=root - ) - tables.edges.squash() - # FIXME probably don't need to sort here most of the time, or at least we - # can just sort almost the end of the table. - tables.sort() - return tables.tree_sequence() - - class MatcherIndexes(_tsinfer.MatcherIndexes): """Wrapper around the C MatcherIndexes, built from a tree sequence.""" def __init__(self, ts, *, vestigial_root=True, num_alleles=None): if vestigial_root: - ts = add_vestigial_root(ts) + ts = arg_ops.add_vestigial_root(ts) tables = ts.dump_tables() ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) ll_tables.fromdict(tables.asdict()) From 1452bb0d638bf3b054ad139f98f1d4d684c1f09d Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 31 Mar 2026 11:45:46 +0100 Subject: [PATCH 2/2] Add is_reference option to AncestralState config When is_reference is true, the REF allele (variant_allele[:, 0]) from the VCZ store is used as the ancestral allele instead of reading a named field. Useful for simulated data where REF equals the ancestral allele. Exactly one of field or is_reference must be set; both default to None. The genotype roundtrip tests now use is_reference=True throughout, since all test inputs originate from tree sequences where REF is ancestral. --- docs/config.md | 3 ++- docs/quickstart.md | 5 +++-- example_config.toml | 5 +++++ tests/test_genotype_roundtrip.py | 15 +++++++------ tests/test_pipeline.py | 29 ++++++++++++++++++++++++ tsinfer/config.py | 38 +++++++++++++++++++++++--------- tsinfer/vcz.py | 5 ++++- 7 files changed, 79 insertions(+), 21 deletions(-) diff --git a/docs/config.md b/docs/config.md index a272798c..265a0577 100644 --- a/docs/config.md +++ b/docs/config.md @@ -33,7 +33,8 @@ Specifies where to read the ancestral allele for each variant position. | Field | Type | Default | Description | |-------|------|---------|-------------| | `path` | string | (required) | Path to VCZ containing ancestral alleles | -| `field` | string | (required) | Array name in the store (e.g. `"variant_AA"`) | +| `field` | string | — | Array name in the store (e.g. `"variant_AA"`). Required unless `is_reference` is set. | +| `is_reference` | bool | `false` | Use the REF allele (`variant_allele[:, 0]`) as the ancestral state. Useful for simulations. `field` must not be set when this is `true`. | ## `[[ancestors]]` diff --git a/docs/quickstart.md b/docs/quickstart.md index 4cddd38e..13eb6df2 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -28,8 +28,9 @@ vcf2zarr convert mydata.vcf.gz mydata.vcz Each site used for inference requires a known **ancestral allele**. If your VCF has an `AA` INFO field, `vcf2zarr` stores it as `variant_AA` in the `.vcz` store and you can reference it directly in the config. Alternatively, ancestral -alleles can come from a separate VCZ store. See the -{ref}`config reference ` for details. +alleles can come from a separate VCZ store. For simulated data where the REF +allele is the ancestral allele, set `is_reference = true` instead of specifying +a field. See the {ref}`config reference ` for details. ## Writing the config diff --git a/example_config.toml b/example_config.toml index d197cedd..04026a95 100644 --- a/example_config.toml +++ b/example_config.toml @@ -59,6 +59,11 @@ path = "data/1kgp_chr20.vcz" path = "data/homo_sapiens-chr20.vcz" field = "variant_AA" +# For simulated data where REF is the ancestral allele, use: +# [ancestral_state] +# path = "data/simulated.vcz" +# is_reference = true + # ============================================================================ # Ancestors diff --git a/tests/test_genotype_roundtrip.py b/tests/test_genotype_roundtrip.py index 5642666a..8ae75e2a 100644 --- a/tests/test_genotype_roundtrip.py +++ b/tests/test_genotype_roundtrip.py @@ -39,12 +39,13 @@ # --------------------------------------------------------------------------- -def _anc_state(store): - return config.AncestralState(path=store, field="variant_ancestral_allele") - - def _run_pipeline(sample_store): - """Build config, run full pipeline, return output tree sequence.""" + """Build config, run full pipeline, return output tree sequence. + + Uses ``is_reference=True`` so the REF allele (variant_allele[:, 0]) + is treated as the ancestral allele — the natural choice when input + data originates from a tree sequence. + """ src = config.Source(path=sample_store, name="test") anc_src = config.Source(path=None, name="ancestors", sample_time="sample_time") cfg = config.Config( @@ -62,7 +63,7 @@ def _run_pipeline(sample_store): output="output.trees", ), post_process=config.PostProcessConfig(), - ancestral_state=_anc_state(sample_store), + ancestral_state=config.AncestralState(path=sample_store, is_reference=True), ) return pipeline.run(cfg) @@ -103,7 +104,7 @@ def _run_pipeline_with_augment(sample_store, augment_store=None, ann_store=None) ), post_process=config.PostProcessConfig(), augment_sites=config.AugmentSitesConfig(sources=["augment"]), - ancestral_state=_anc_state(ann_store), + ancestral_state=config.AncestralState(path=ann_store, is_reference=True), ) return pipeline.run(cfg) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 33c29de7..854774e0 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -355,6 +355,35 @@ def test_hand_constructed(self): assert out_ts.num_sites == 2 +# --------------------------------------------------------------------------- +# TestIsReferenceConfig +# --------------------------------------------------------------------------- + + +class TestIsReferenceConfig: + def test_is_reference_with_field_raises(self): + """is_reference=True with field set raises ValueError.""" + with pytest.raises(ValueError, match="field must not be set"): + config.AncestralState(path="x.vcz", field="variant_AA", is_reference=True) + + def test_neither_field_nor_is_reference_raises(self): + """Neither field nor is_reference raises ValueError.""" + with pytest.raises(ValueError, match="requires either"): + config.AncestralState(path="x.vcz") + + def test_valid_field(self): + """field without is_reference is valid.""" + a = config.AncestralState(path="x.vcz", field="variant_AA") + assert a.field == "variant_AA" + assert a.is_reference is None + + def test_valid_is_reference(self): + """is_reference=True without field is valid.""" + a = config.AncestralState(path="x.vcz", is_reference=True) + assert a.is_reference is True + assert a.field is None + + # --------------------------------------------------------------------------- # TestNodeMetadata # --------------------------------------------------------------------------- diff --git a/tsinfer/config.py b/tsinfer/config.py index f6b6b990..d7b48898 100644 --- a/tsinfer/config.py +++ b/tsinfer/config.py @@ -113,10 +113,28 @@ def __post_init__(self): @dataclasses.dataclass class AncestralState: - """Specifies where to read the ancestral allele for each variant position.""" + """Specifies where to read the ancestral allele for each variant position. + + Exactly one of *field* or *is_reference* must be set. + + If *is_reference* is ``True``, the REF allele (``variant_allele[:, 0]``) + from the store at *path* is used as the ancestral allele. Otherwise + *field* names the array to read. + """ path: str | pathlib.Path - field: str + field: str | None = None + is_reference: bool | None = None + + def __post_init__(self): + if self.is_reference is None and self.field is None: + raise ValueError( + "[ancestral_state] requires either 'field' or 'is_reference = true'" + ) + if self.is_reference is True and self.field is not None: + raise ValueError( + "[ancestral_state] field must not be set when is_reference is true" + ) @dataclasses.dataclass @@ -356,7 +374,7 @@ def from_toml(cls, path: str | pathlib.Path) -> Config: "sample_time", } -_KNOWN_ANCESTRAL_STATE_KEYS = {"path", "field"} +_KNOWN_ANCESTRAL_STATE_KEYS = {"path", "field", "is_reference"} _KNOWN_ANCESTORS_KEYS = { "name", @@ -429,13 +447,13 @@ def _parse_ancestral_state(raw: dict) -> AncestralState: if entry is None: raise ValueError("Config must contain an [ancestral_state] section") _check_unknown_keys("ancestral_state", entry, _KNOWN_ANCESTRAL_STATE_KEYS) - try: - return AncestralState( - path=_resolve_path(entry["path"]), - field=entry["field"], - ) - except KeyError as e: - raise ValueError(f"[ancestral_state] missing required key: {e}") from e + if "path" not in entry: + raise ValueError("[ancestral_state] missing required key: 'path'") + return AncestralState( + path=_resolve_path(entry["path"]), + field=entry.get("field"), + is_reference=entry.get("is_reference"), + ) def _parse_one_ancestor(entry: dict) -> AncestorsConfig: diff --git a/tsinfer/vcz.py b/tsinfer/vcz.py index 1d2b8e1d..350304ee 100644 --- a/tsinfer/vcz.py +++ b/tsinfer/vcz.py @@ -2368,7 +2368,10 @@ def __init__( # --- Ancestral state --- ann_store = open_store(ancestral_state.path) ann_positions = np.asarray(ann_store["variant_position"][:], dtype=np.int32) - ann_values = np.asarray(ann_store[ancestral_state.field][:]) + if ancestral_state.is_reference: + ann_values = np.asarray(ann_store["variant_allele"][:, 0]) + else: + ann_values = np.asarray(ann_store[ancestral_state.field][:]) # --- Unified site set (all numpy) --- valid_per_source = []