Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ Specifies where to read the ancestral allele for each variant position.
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `path` | string | (required) | Path to VCZ containing ancestral alleles |
| `field` | string | (required) | Array name in the store (e.g. `"variant_AA"`) |
| `field` | string | — | Array name in the store (e.g. `"variant_AA"`). Required unless `is_reference` is set. |
| `is_reference` | bool | `false` | Use the REF allele (`variant_allele[:, 0]`) as the ancestral state. Useful for simulations. `field` must not be set when this is `true`. |


## `[[ancestors]]`
Expand Down
5 changes: 3 additions & 2 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ vcf2zarr convert mydata.vcf.gz mydata.vcz
Each site used for inference requires a known **ancestral allele**. If your VCF
has an `AA` INFO field, `vcf2zarr` stores it as `variant_AA` in the `.vcz`
store and you can reference it directly in the config. Alternatively, ancestral
alleles can come from a separate VCZ store. See the
{ref}`config reference <sec_config_reference>` for details.
alleles can come from a separate VCZ store. For simulated data where the REF
allele is the ancestral allele, set `is_reference = true` instead of specifying
a field. See the {ref}`config reference <sec_config_reference>` for details.


## Writing the config
Expand Down
5 changes: 5 additions & 0 deletions example_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ path = "data/1kgp_chr20.vcz"
path = "data/homo_sapiens-chr20.vcz"
field = "variant_AA"

# For simulated data where REF is the ancestral allele, use:
# [ancestral_state]
# path = "data/simulated.vcz"
# is_reference = true


# ============================================================================
# Ancestors
Expand Down
19 changes: 19 additions & 0 deletions tests/test_arg_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest
import tskit

from tsinfer import arg_ops


class TestAddVestigialRoot:
def test_non_discrete_genome(self):
tables = tskit.TableCollection(sequence_length=1.5)
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
ts = tables.tree_sequence()
with pytest.raises(ValueError, match="discrete genome"):
arg_ops.add_vestigial_root(ts)

def test_empty_tree_sequence(self):
tables = tskit.TableCollection(sequence_length=1)
ts = tables.tree_sequence()
with pytest.raises(ValueError, match="Emtpy trees"):
arg_ops.add_vestigial_root(ts)
15 changes: 8 additions & 7 deletions tests/test_genotype_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@
# ---------------------------------------------------------------------------


def _anc_state(store):
return config.AncestralState(path=store, field="variant_ancestral_allele")


def _run_pipeline(sample_store):
"""Build config, run full pipeline, return output tree sequence."""
"""Build config, run full pipeline, return output tree sequence.

Uses ``is_reference=True`` so the REF allele (variant_allele[:, 0])
is treated as the ancestral allele — the natural choice when input
data originates from a tree sequence.
"""
src = config.Source(path=sample_store, name="test")
anc_src = config.Source(path=None, name="ancestors", sample_time="sample_time")
cfg = config.Config(
Expand All @@ -62,7 +63,7 @@ def _run_pipeline(sample_store):
output="output.trees",
),
post_process=config.PostProcessConfig(),
ancestral_state=_anc_state(sample_store),
ancestral_state=config.AncestralState(path=sample_store, is_reference=True),
)
return pipeline.run(cfg)

Expand Down Expand Up @@ -103,7 +104,7 @@ def _run_pipeline_with_augment(sample_store, augment_store=None, ann_store=None)
),
post_process=config.PostProcessConfig(),
augment_sites=config.AugmentSitesConfig(sources=["augment"]),
ancestral_state=_anc_state(ann_store),
ancestral_state=config.AncestralState(path=ann_store, is_reference=True),
)
return pipeline.run(cfg)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_lshmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import tskit

import _tsinfer
from tsinfer import matching
from tsinfer import arg_ops, matching


@dataclasses.dataclass
Expand Down Expand Up @@ -78,7 +78,7 @@ class MatcherIndexes:
def __init__(self, in_tables, *, vestigial_root=True, num_alleles=None):
ts = in_tables.tree_sequence()
if vestigial_root:
ts = matching.add_vestigial_root(ts)
ts = arg_ops.add_vestigial_root(ts)
tables = ts.dump_tables()

self.sequence_length = tables.sequence_length
Expand Down
21 changes: 0 additions & 21 deletions tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from __future__ import annotations

import numpy as np
import pytest
import tskit

from tsinfer import grouping, matching, vcz
Expand Down Expand Up @@ -761,26 +760,6 @@ def test_metadata_survives_multiple_cycles(self):
assert ts2.metadata["sequence_intervals"] == [[10, 51]]


# ---------------------------------------------------------------------------
# TestAddVestigialRoot
# ---------------------------------------------------------------------------


class TestAddVestigialRoot:
def test_non_discrete_genome(self):
tables = tskit.TableCollection(sequence_length=1.5)
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
ts = tables.tree_sequence()
with pytest.raises(ValueError, match="discrete genome"):
matching.add_vestigial_root(ts)

def test_empty_tree_sequence(self):
tables = tskit.TableCollection(sequence_length=1)
ts = tables.tree_sequence()
with pytest.raises(ValueError, match="Emtpy trees"):
matching.add_vestigial_root(ts)


# ---------------------------------------------------------------------------
# TestAncestorMatcherWrapper
# ---------------------------------------------------------------------------
Expand Down
29 changes: 29 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,35 @@ def test_hand_constructed(self):
assert out_ts.num_sites == 2


# ---------------------------------------------------------------------------
# TestIsReferenceConfig
# ---------------------------------------------------------------------------


class TestIsReferenceConfig:
def test_is_reference_with_field_raises(self):
"""is_reference=True with field set raises ValueError."""
with pytest.raises(ValueError, match="field must not be set"):
config.AncestralState(path="x.vcz", field="variant_AA", is_reference=True)

def test_neither_field_nor_is_reference_raises(self):
"""Neither field nor is_reference raises ValueError."""
with pytest.raises(ValueError, match="requires either"):
config.AncestralState(path="x.vcz")

def test_valid_field(self):
"""field without is_reference is valid."""
a = config.AncestralState(path="x.vcz", field="variant_AA")
assert a.field == "variant_AA"
assert a.is_reference is None

def test_valid_is_reference(self):
"""is_reference=True without field is valid."""
a = config.AncestralState(path="x.vcz", is_reference=True)
assert a.is_reference is True
assert a.field is None


# ---------------------------------------------------------------------------
# TestNodeMetadata
# ---------------------------------------------------------------------------
Expand Down
12 changes: 6 additions & 6 deletions tests/test_python_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import tskit

import _tsinfer
from tsinfer import matching
from tsinfer import arg_ops

IS_WINDOWS = sys.platform == "win32"

Expand Down Expand Up @@ -155,7 +155,7 @@ def make_matcher_indexes_and_matcher(num_samples=4):
tables.sites.add_row(position=1, ancestral_state="A")
tables.mutations.add_row(site=0, node=1, derived_state="T")
ts = tables.tree_sequence()
ts = matching.add_vestigial_root(ts)
ts = arg_ops.add_vestigial_root(ts)
ll_tables = _tsinfer.LightweightTableCollection(ts.sequence_length)
ll_tables.fromdict(ts.dump_tables().asdict())
mi = _tsinfer.MatcherIndexes(ll_tables)
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_find_path_match_impossible(self):
tables.sites.add_row(position=1, ancestral_state="A")
# No mutations: all nodes carry allele 0
ts = tables.tree_sequence()
ts = matching.add_vestigial_root(ts)
ts = arg_ops.add_vestigial_root(ts)
ll_tables = _tsinfer.LightweightTableCollection(ts.sequence_length)
ll_tables.fromdict(ts.dump_tables().asdict())
mi = _tsinfer.MatcherIndexes(ll_tables)
Expand Down Expand Up @@ -299,7 +299,7 @@ def test_get_traceback_bad_site(self):
class TestMatcherIndexes:
def test_single_tree(self):
ts = tskit.Tree.generate_balanced(4).tree_sequence
ts = matching.add_vestigial_root(ts)
ts = arg_ops.add_vestigial_root(ts)
tables = ts.dump_tables()
ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length)
ll_tables.fromdict(tables.asdict())
Expand All @@ -323,7 +323,7 @@ def test_num_alleles(self):

def test_print_state(self, tmpdir):
ts = tskit.Tree.generate_balanced(4).tree_sequence
ts = matching.add_vestigial_root(ts)
ts = arg_ops.add_vestigial_root(ts)
tables = ts.dump_tables()
ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length)
ll_tables.fromdict(tables.asdict())
Expand All @@ -341,7 +341,7 @@ def test_print_state(self, tmpdir):

def test_print_state_bad_file(self):
ts = tskit.Tree.generate_balanced(4).tree_sequence
ts = matching.add_vestigial_root(ts)
ts = arg_ops.add_vestigial_root(ts)
tables = ts.dump_tables()
ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length)
ll_tables.fromdict(tables.asdict())
Expand Down
61 changes: 61 additions & 0 deletions tsinfer/arg_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# Copyright (C) 2018-2026 University of Oxford
#
# This file is part of tsinfer.
#
# tsinfer is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tsinfer is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with tsinfer. If not, see <http://www.gnu.org/licenses/>.
#
"""
Operations on ARG (Ancestral Recombination Graph) topology.
"""

import logging

logger = logging.getLogger(__name__)


def add_vestigial_root(ts):
"""
Adds the nodes and edges required by tsinfer to the specified tree sequence
and returns it.
"""
if not ts.discrete_genome:
raise ValueError("Only discrete genome coords supported")
if ts.num_nodes == 0:
raise ValueError("Emtpy trees not supported")

base_tables = ts.dump_tables()
tables = base_tables.copy()
tables.nodes.clear()
t = max(ts.nodes_time)
tables.nodes.add_row(time=t + 1)
num_additonal_nodes = 1
tables.mutations.node += num_additonal_nodes
tables.edges.child += num_additonal_nodes
tables.edges.parent += num_additonal_nodes
for node in base_tables.nodes:
tables.nodes.append(node)
if ts.num_edges > 0:
for tree in ts.trees():
# if tree.num_roots > 1:
# print(ts.draw_text())
root = tree.root + num_additonal_nodes
tables.edges.add_row(
tree.interval.left, tree.interval.right, parent=0, child=root
)
tables.edges.squash()
# FIXME probably don't need to sort here most of the time, or at least
# we can just sort almost the end of the table.
tables.sort()
return tables.tree_sequence()
38 changes: 28 additions & 10 deletions tsinfer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,28 @@ def __post_init__(self):

@dataclasses.dataclass
class AncestralState:
"""Specifies where to read the ancestral allele for each variant position."""
"""Specifies where to read the ancestral allele for each variant position.

Exactly one of *field* or *is_reference* must be set.

If *is_reference* is ``True``, the REF allele (``variant_allele[:, 0]``)
from the store at *path* is used as the ancestral allele. Otherwise
*field* names the array to read.
"""

path: str | pathlib.Path
field: str
field: str | None = None
is_reference: bool | None = None

def __post_init__(self):
if self.is_reference is None and self.field is None:
raise ValueError(
"[ancestral_state] requires either 'field' or 'is_reference = true'"
)
if self.is_reference is True and self.field is not None:
raise ValueError(
"[ancestral_state] field must not be set when is_reference is true"
)


@dataclasses.dataclass
Expand Down Expand Up @@ -356,7 +374,7 @@ def from_toml(cls, path: str | pathlib.Path) -> Config:
"sample_time",
}

_KNOWN_ANCESTRAL_STATE_KEYS = {"path", "field"}
_KNOWN_ANCESTRAL_STATE_KEYS = {"path", "field", "is_reference"}

_KNOWN_ANCESTORS_KEYS = {
"name",
Expand Down Expand Up @@ -429,13 +447,13 @@ def _parse_ancestral_state(raw: dict) -> AncestralState:
if entry is None:
raise ValueError("Config must contain an [ancestral_state] section")
_check_unknown_keys("ancestral_state", entry, _KNOWN_ANCESTRAL_STATE_KEYS)
try:
return AncestralState(
path=_resolve_path(entry["path"]),
field=entry["field"],
)
except KeyError as e:
raise ValueError(f"[ancestral_state] missing required key: {e}") from e
if "path" not in entry:
raise ValueError("[ancestral_state] missing required key: 'path'")
return AncestralState(
path=_resolve_path(entry["path"]),
field=entry.get("field"),
is_reference=entry.get("is_reference"),
)


def _parse_one_ancestor(entry: dict) -> AncestorsConfig:
Expand Down
Loading
Loading