From 953b5e2b000e034b94c2d740072dd1e5c211fd2e Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Fri, 13 Mar 2026 01:33:18 +0000 Subject: [PATCH 01/51] Trigger lint on merge queue --- .github/workflows/lint.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 92e47e5b39..67fd8e1042 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -2,6 +2,7 @@ name: Lint on: pull_request: + merge_group: jobs: Lint: From 17bacdaa6b240b69fb15434a0470f17848f861b0 Mon Sep 17 00:00:00 2001 From: peter Date: Wed, 11 Mar 2026 20:13:18 -0700 Subject: [PATCH 02/51] remove consts from tsk_json_struct_metadata_get_blob (closes #3425) --- c/tests/test_core.c | 16 ++++++++-------- c/tskit/core.c | 17 ++++++++--------- c/tskit/core.h | 5 ++--- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/c/tests/test_core.c b/c/tests/test_core.c index dad4a81bf2..9d8d9f5a7d 100644 --- a/c/tests/test_core.c +++ b/c/tests/test_core.c @@ -101,9 +101,9 @@ test_json_struct_metadata_get_blob(void) { int ret; char metadata[128]; - const char *json; + char *json; tsk_size_t json_buffer_length; - const char *blob; + char *blob; tsk_size_t blob_length; uint8_t *bytes; tsk_size_t metadata_length; @@ -111,9 +111,9 @@ test_json_struct_metadata_get_blob(void) size_t json_length; size_t payload_length; size_t total_length; - const char json_payload[] = "{\"a\":1}"; - const uint8_t binary_payload[] = { 0x01, 0x02, 0x03, 0x04 }; - const uint8_t empty_payload[] = { 0 }; + char json_payload[] = "{\"a\":1}"; + uint8_t binary_payload[] = { 0x01, 0x02, 0x03, 0x04 }; + uint8_t empty_payload[] = { 0 }; bytes = (uint8_t *) metadata; header_length = 4 + 1 + 8 + 8; @@ -135,7 +135,7 @@ test_json_struct_metadata_get_blob(void) ret = tsk_json_struct_metadata_get_blob( metadata, metadata_length, &json, &json_buffer_length, &blob, &blob_length); CU_ASSERT_EQUAL(ret, 0); - CU_ASSERT_PTR_EQUAL(json, (const char *) bytes + header_length); + CU_ASSERT_PTR_EQUAL(json, (char *) bytes + header_length); CU_ASSERT_EQUAL(json_buffer_length, (tsk_size_t) json_length); if (json_length > 0) { CU_ASSERT_EQUAL(memcmp(json, json_payload, json_length), 0); @@ -152,7 +152,7 @@ test_json_struct_metadata_get_blob(void) ret = tsk_json_struct_metadata_get_blob( metadata, metadata_length, &json, &json_buffer_length, &blob, &blob_length); CU_ASSERT_EQUAL(ret, 0); - CU_ASSERT_PTR_EQUAL(json, (const char *) bytes + header_length); + CU_ASSERT_PTR_EQUAL(json, (char *) bytes + header_length); CU_ASSERT_EQUAL(json_buffer_length, (tsk_size_t) json_length); CU_ASSERT_EQUAL(blob_length, (tsk_size_t) payload_length); CU_ASSERT_PTR_EQUAL(blob, bytes + header_length + json_length); @@ -168,7 +168,7 @@ test_json_struct_metadata_get_blob(void) ret = tsk_json_struct_metadata_get_blob( metadata, metadata_length, &json, &json_buffer_length, &blob, &blob_length); CU_ASSERT_EQUAL(ret, 0); - CU_ASSERT_PTR_EQUAL(json, (const char *) bytes + header_length); + CU_ASSERT_PTR_EQUAL(json, (char *) bytes + header_length); CU_ASSERT_EQUAL(json_buffer_length, (tsk_size_t) json_length); CU_ASSERT_EQUAL(blob_length, (tsk_size_t) payload_length); CU_ASSERT_PTR_EQUAL(blob, bytes + header_length + json_length); diff --git a/c/tskit/core.c b/c/tskit/core.c index 574424c144..66f8cbd0ef 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -142,9 +142,8 @@ tsk_generate_uuid(char *dest, int TSK_UNUSED(flags)) } int -tsk_json_struct_metadata_get_blob(const char *metadata, tsk_size_t metadata_length, - const char **json, tsk_size_t *json_length, const char **blob, - tsk_size_t *blob_length) +tsk_json_struct_metadata_get_blob(char *metadata, tsk_size_t metadata_length, + char **json, tsk_size_t *json_length, char **blob, tsk_size_t *blob_length) { int ret; uint8_t version; @@ -152,16 +151,16 @@ tsk_json_struct_metadata_get_blob(const char *metadata, tsk_size_t metadata_leng uint64_t binary_length_u64; uint64_t header_and_json_length; uint64_t total_length; - const uint8_t *bytes; - const char *blob_start; - const char *json_start; + uint8_t *bytes; + char *blob_start; + char *json_start; if (metadata == NULL || json == NULL || json_length == NULL || blob == NULL || blob_length == NULL) { ret = tsk_trace_error(TSK_ERR_BAD_PARAM_VALUE); goto out; } - bytes = (const uint8_t *) metadata; + bytes = (uint8_t *) metadata; if (metadata_length < TSK_JSON_BINARY_HEADER_SIZE) { ret = tsk_trace_error(TSK_ERR_JSON_STRUCT_METADATA_TRUNCATED); goto out; @@ -191,8 +190,8 @@ tsk_json_struct_metadata_get_blob(const char *metadata, tsk_size_t metadata_leng ret = tsk_trace_error(TSK_ERR_JSON_STRUCT_METADATA_TRUNCATED); goto out; } - json_start = (const char *) bytes + TSK_JSON_BINARY_HEADER_SIZE; - blob_start = (const char *) bytes + TSK_JSON_BINARY_HEADER_SIZE + json_length_u64; + json_start = (char *) bytes + TSK_JSON_BINARY_HEADER_SIZE; + blob_start = (char *) bytes + TSK_JSON_BINARY_HEADER_SIZE + json_length_u64; *json = json_start; *json_length = (tsk_size_t) json_length_u64; *blob = blob_start; diff --git a/c/tskit/core.h b/c/tskit/core.h index 9fe0643975..2964e3d8f1 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -1153,9 +1153,8 @@ the original metadata buffer is alive. @param[out] blob_length On success, set to the payload length in bytes. @return Return 0 on success or a negative value on failure. */ -int tsk_json_struct_metadata_get_blob(const char *metadata, tsk_size_t metadata_length, - const char **json, tsk_size_t *json_length, const char **blob, - tsk_size_t *blob_length); +int tsk_json_struct_metadata_get_blob(char *metadata, tsk_size_t metadata_length, + char **json, tsk_size_t *json_length, char **blob, tsk_size_t *blob_length); /* TODO most of these can probably be macros so they compile out as no-ops. * Lets do the 64 bit tsk_size_t switch first though. */ From 8d4ab2dce9fa4d4065f17b6efbb337c6a9a78c01 Mon Sep 17 00:00:00 2001 From: peter Date: Wed, 11 Mar 2026 20:38:31 -0700 Subject: [PATCH 03/51] edits to dev docs --- docs/development.md | 47 +++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/docs/development.md b/docs/development.md index b38cc292ce..7a203654c1 100644 --- a/docs/development.md +++ b/docs/development.md @@ -98,14 +98,14 @@ the development workflows of all tskit-dev packages are organised around using uv, and therefore we strongly recommend using it. Uv is straightforward to install, and not invasive (existing Python installations can be completely isolated if you don't use features like ``uv tool`` etc which update your -$HOME/.local/bin). Uv manages an isolated local environment per project +``$HOME/.local/bin``). Uv manages an isolated local environment per project and allows us to deterministically pin package versions and easily switch between Python versions, so that CI environments can be replicated exactly locally. The packages needed for development are specified as dependency groups in ``python/pyproject.toml`` and managed with [uv](https://docs.astral.sh/uv/). -Install all development dependencies using: +Install all development dependencies by running, from the `python/` directory: ```bash $ uv sync @@ -129,7 +129,8 @@ To get a local git development environment, please follow these steps: ```bash $ git clone git@github.com:YOUR_GITHUB_USERNAME/tskit.git ``` -- Install the {ref}`sec_development_workflow_prek` pre-commit hook: +- Install the {ref}`sec_development_workflow_prek` pre-commit hook + (again from the ``python/`` subdirectory): ```bash $ uv run prek install ``` @@ -201,7 +202,7 @@ skip to {ref}`sec_development_workflow_anothers_commit`. to document any breaking changes separately in a "breaking changes" section. 8. Push your changes to your topic branch and either open the PR or, if you - opened a draft PR above change it to a non-draft PR by clicking "Ready to + already opened a draft PR change it to a non-draft PR by clicking "Ready to Review". 9. The tskit community will review the code, asking you to make changes where appropriate. @@ -258,10 +259,10 @@ subdirectory. To test out changes to the *code*, you can change to the `python/` subdirectory, and run `make` to compile the C code. -If you then execute `python` from this subdirectory (and only this one!), +If you then execute python commands from this subdirectory (and only this one!), it will use the modified version of the package. -(For instance, you might want to -open an interactive `python` shell from the `python/` subdirectory, +(For instance, you might want to open an interactive python shell by running +`uv run python` in the `python/` subdirectory, or running `uv run pytest` from this subdirectory.) After you're done, you should do: @@ -272,7 +273,8 @@ $ git checkout main to get your repository back to the "main" branch of development. If the pull request is changed and you want to do the same thing again, -then first *delete* your local copy (by doing `git branch -d my_pr_copy`) +then to avoid conflicts with any changes you might have made, +first *delete* your local copy (by doing `git branch -d my_pr_copy`) and repeat the steps again. @@ -285,13 +287,7 @@ On each commit a [prek](https://prek.j178.dev) hook will run checks for code style (see the {ref}`sec_development_python_style` section for details) and other common problems. -To install the hook: - -```bash -$ uv run prek install -``` - -To run checks manually without committing: +To run checks manually without committing, from the `python/` subdirectory: ```bash $ uv run prek --all-files @@ -467,6 +463,9 @@ See :ref:`sec_development_documentation_cross_referencing` for details. The :meth:`.TreeSequence.trees` method returns an iterator. ```` +Some errors may occur because of out-of-date cached results, +which can be cleared by running `make clean`. + (sec_development_python)= @@ -544,6 +543,10 @@ To run a specific test case in this class (say, `test_copy`) use: $ uv run pytest tests/test_tables.py::TestNodeTable::test_copy ``` +In general, you can copy-paste the string describing a failed test from the +output of pytest to re-run just that test (including specific parametrized +arguments present as `[args]`). + You can also run tests with a keyword expression search. For example this will run all tests that have `TestNodeTable` but not `copy` in their name: @@ -793,6 +796,13 @@ this test name as a command line argument, e.g.: $ ./build/test_tables test_node_table ``` +After making sure tests pass, you should next run the tests through valgrind, +to check for memory leaks, for instance: + +```bash +$ valgrind ./build/test_tables test_node_table +``` + While 100% test coverage is not feasible for C code, we aim to cover all code that can be reached. (Some classes of error such as malloc failures and IO errors are difficult to simulate in C.) Code coverage statistics are @@ -1029,7 +1039,7 @@ Continuous integration is handled by [GitHub Actions](https://help.github.com/en tskit uses shared workflows defined in the [tskit-dev/.github](https://github.com/tskit-dev/.github) repository: -- **lint** — runs prek against all files +- **lint** — runs ruff and clang (using prek) against all files - **python-tests** — runs the pytest suite with coverage on Linux, macOS and Windows - **python-c-tests** — builds the C extension with coverage and runs low-level tests - **c-tests** — runs C unit tests under gcc, clang, and valgrind @@ -1050,6 +1060,9 @@ tskit codebase. Note that this guide covers the most complex case of adding a new function to both the C and Python APIs. +0. Draft a docstring for your function, that describes exactly what the function + takes as arguments and what it returns under what conditions. Update this + docstring as you go along and make modifications. 1. Write your function in Python: in `python/tests/` find the test module that pertains to the functionality you wish to add. For instance, the kc_distance metric was added to @@ -1085,7 +1098,7 @@ the C and Python APIs. the example of other tests, you might need to only add a single line of code here. In this case, the tests are well factored so that we can easily compare the results from both the Python and C versions. -9. Write a docstring for your function in the Python API: for instance, the kc_distance +9. Finalize your docstring and insert it into the Python API: for instance, the kc_distance docstring is in [tskit/python/tskit/trees.py](https://github.com/tskit-dev/tskit/blob/main/python/tskit/trees.py). Ensure that your docstring renders correctly by building the documentation From 2a02e7f825512305928182e0b05b5e5fa9a2f5d9 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 12 Mar 2026 09:29:49 -0700 Subject: [PATCH 04/51] no detached head; closes #3431 --- docs/development.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/development.md b/docs/development.md index 7a203654c1..358bee9777 100644 --- a/docs/development.md +++ b/docs/development.md @@ -175,8 +175,7 @@ skip to {ref}`sec_development_workflow_anothers_commit`. is to follow this recipe: ```bash $ git fetch upstream - $ git checkout upstream/main - $ git checkout -b topic_branch_name + $ git checkout -b topic_branch_name upstream/main ``` 4. Write your code following the outline in {ref}`sec_development_best_practices`. From 924c2f0205bd248725c6f506d41baa63ad55f91c Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 12 Mar 2026 09:48:33 -0700 Subject: [PATCH 05/51] remove prompts from bash commands for copyability; closes #3430 --- docs/development.md | 92 ++++++++++++++++++++++---------------------- docs/export.md | 6 +-- docs/installation.md | 4 +- docs/provenance.md | 2 +- 4 files changed, 52 insertions(+), 52 deletions(-) diff --git a/docs/development.md b/docs/development.md index 358bee9777..887fb428b0 100644 --- a/docs/development.md +++ b/docs/development.md @@ -89,7 +89,7 @@ is required for building the C API documentation. On Debian/Ubuntu we can install these with: ```bash -$ sudo apt install build-essential doxygen +sudo apt install build-essential doxygen ``` All Python development is managed using [uv](https://docs.astral.sh/uv/). @@ -108,7 +108,7 @@ in ``python/pyproject.toml`` and managed with [uv](https://docs.astral.sh/uv/). Install all development dependencies by running, from the `python/` directory: ```bash -$ uv sync +uv sync ``` The lock file lives at `python/uv.lock` and must be kept up to date. Run @@ -127,12 +127,12 @@ To get a local git development environment, please follow these steps: - Make a fork of the tskit repo on [GitHub](http://github.com/tskit-dev/tskit) - Clone your fork into a local directory: ```bash - $ git clone git@github.com:YOUR_GITHUB_USERNAME/tskit.git + git clone git@github.com:YOUR_GITHUB_USERNAME/tskit.git ``` - Install the {ref}`sec_development_workflow_prek` pre-commit hook (again from the ``python/`` subdirectory): ```bash - $ uv run prek install + uv run prek install ``` See the {ref}`sec_development_workflow_git` section for detailed information @@ -168,14 +168,14 @@ skip to {ref}`sec_development_workflow_anothers_commit`. [upstream remote]( https://help.github.com/articles/configuring-a-remote-for-a-fork/): ```bash - $ git remote add upstream https://github.com/tskit-dev/tskit.git + git remote add upstream https://github.com/tskit-dev/tskit.git ``` 3. Create a "topic branch" to work on. One reliable way to do it is to follow this recipe: ```bash - $ git fetch upstream - $ git checkout -b topic_branch_name upstream/main + git fetch upstream + git checkout -b topic_branch_name upstream/main ``` 4. Write your code following the outline in {ref}`sec_development_best_practices`. @@ -235,7 +235,7 @@ Then, continuing from above: 3. Fetch the pull request, and store it as a local branch. For instance, to name the local branch `my_pr_copy`: ```bash - $ git fetch upstream pull/854/head:my_pr_copy + git fetch upstream pull/854/head:my_pr_copy ``` You should probably call the branch something more descriptive, though. (Also note that you might need to put `origin` instead @@ -244,7 +244,7 @@ Then, continuing from above: 4. Check out the pull request's local branch: ```bash - $ git checkout my_pr_copy + git checkout my_pr_copy ``` Now, your repository will be in exactly the same state as @@ -267,7 +267,7 @@ or running `uv run pytest` from this subdirectory.) After you're done, you should do: ```bash -$ git checkout main +git checkout main ``` to get your repository back to the "main" branch of development. @@ -289,7 +289,7 @@ and other common problems. To run checks manually without committing, from the `python/` subdirectory: ```bash -$ uv run prek --all-files +uv run prek --all-files ``` If local results differ from CI, run `uv run prek cache clean` to clear the cache. @@ -526,20 +526,20 @@ The tests are defined in the `tests` directory, and run using If you want to run the tests in a particular module (say, `test_tables.py`), use: ```bash -$ uv run pytest tests/test_tables.py +uv run pytest tests/test_tables.py ``` To run all the tests in a particular class in this module (say, `TestNodeTable`) use: ```bash -$ uv run pytest tests/test_tables.py::TestNodeTable +uv run pytest tests/test_tables.py::TestNodeTable ``` To run a specific test case in this class (say, `test_copy`) use: ```bash -$ uv run pytest tests/test_tables.py::TestNodeTable::test_copy +uv run pytest tests/test_tables.py::TestNodeTable::test_copy ``` In general, you can copy-paste the string describing a failed test from the @@ -550,7 +550,7 @@ You can also run tests with a keyword expression search. For example this will run all tests that have `TestNodeTable` but not `copy` in their name: ```bash -$ uv run pytest -k "TestNodeTable and not copy" +uv run pytest -k "TestNodeTable and not copy" ``` When developing your own tests, it is much quicker to run the specific tests @@ -560,41 +560,41 @@ suite each time. To run all of the tests, we can use: ```bash -$ uv run pytest +uv run pytest ``` By default the tests are run on 4 cores, if you have more you can specify: ```bash -$ uv run pytest -n8 +uv run pytest -n8 ``` A few of the tests take most of the time, we can skip the slow tests to get the test run under 20 seconds on an modern workstation: ```bash -$ uv run pytest --skip-slow +uv run pytest --skip-slow ``` If you have an agent running the tests in a sandboxed environment, you may need to skip tests thsat require network access or FIFOs: ```bash -$ uv run pytest --skip-network +uv run pytest --skip-network ``` If you have a lot of failing tests it can be useful to have a shorter summary of the failing lines: ```bash -$ uv run pytest --tb=line +uv run pytest --tb=line ``` If you need to see the output of tests (e.g. `print` statements) then you need to use these flags to run a single thread and capture output: ```bash -$ uv run pytest -n0 -vs +uv run pytest -n0 -vs ``` All new code must have high test coverage, which will be checked as part of the @@ -644,7 +644,7 @@ However, if you really need to be on the bleeding edge, you can use the following command to install: ```bash -$ python3 -m pip install git+https://github.com/tskit-dev/tskit.git#subdirectory=python +python3 -m pip install git+https://github.com/tskit-dev/tskit.git#subdirectory=python ``` (Because the Python package is not defined in the project root directory, using pip to @@ -689,13 +689,13 @@ to automatically format code. On Debian/Ubuntu, install the system dependencies with: ```bash -$ sudo apt install libcunit1-dev ninja-build +sudo apt install libcunit1-dev ninja-build ``` Install meson using uv: ```bash -$ uv tool install meson +uv tool install meson ``` An exact version of clang-format is required because formatting rules @@ -718,7 +718,7 @@ with a custom configuration. This is checked as part of the {ref}`prek checks `. To manually format all files run: ```bash -$ uv run prek --all-files +uv run prek --all-files ``` If you are doing this in the ``c`` directory, use @@ -730,7 +730,7 @@ prek searching for configuration within subdirectories. To avoid this, tell prek where to find its config explicitly: ```bash -$ uv run prek --all-files -c prek.toml +uv run prek --all-files -c prek.toml ``` @@ -743,8 +743,8 @@ is defined in `meson.build`. To set up the initial build directory, run ```bash -$ cd c -$ meson setup build +cd c +meson setup build ``` To setup a debug build add `--buildtype=debug` to the above command. This will set the `TSK_TRACE_ERRORS` @@ -753,7 +753,7 @@ flag, which will print error messages to `stderr` when errors occur which is use To compile the code run ```bash -$ ninja -C build +ninja -C build ``` All the tests and other artefacts are in the build directory. Individual test @@ -761,7 +761,7 @@ suites can be run, via (e.g.) `./build/test_trees`. To run all of the tests, run ```bash -$ ninja -C build test +ninja -C build test ``` For vim users, the [mesonic](https://www.vim.org/scripts/script.php?script_id=5378) plugin @@ -792,14 +792,14 @@ To just run a specific test on its own, provide this test name as a command line argument, e.g.: ```bash -$ ./build/test_tables test_node_table +./build/test_tables test_node_table ``` After making sure tests pass, you should next run the tests through valgrind, to check for memory leaks, for instance: ```bash -$ valgrind ./build/test_tables test_node_table +valgrind ./build/test_tables test_node_table ``` While 100% test coverage is not feasible for C code, we aim to cover all code @@ -814,20 +814,20 @@ To generate and view coverage reports for the C tests locally: Compile with coverage enabled: ```bash - $ cd c - $ meson build -D b_coverage=true - $ ninja -C build + cd c + meson build -D b_coverage=true + ninja -C build ``` Run the tests: ```bash - $ ninja -C build test + ninja -C build test ``` Generate coverage data: ```bash - $ cd build - $ find ../tskit/*.c -type f -printf "%f\n" | xargs -i gcov -pb libtskit.a.p/tskit_{}.gcno ../tskit/{} + cd build + find ../tskit/*.c -type f -printf "%f\n" | xargs -i gcov -pb libtskit.a.p/tskit_{}.gcno ../tskit/{} ``` The generated `.gcov` files can then be viewed directly with `cat filename.c.gcov`. @@ -835,10 +835,10 @@ Lines prefixed with `#####` were never executed, lines with numbers show executi `lcov` can be used to create browsable HTML coverage reports: ```bash - $ sudo apt-get install lcov # if needed - $ lcov --capture --directory build-gcc --output-file coverage.info - $ genhtml coverage.info --output-directory coverage_html - $ firefox coverage_html/index.html + sudo apt-get install lcov # if needed + lcov --capture --directory build-gcc --output-file coverage.info + genhtml coverage.info --output-directory coverage_html + firefox coverage_html/index.html ``` ### Coding conventions @@ -972,20 +972,20 @@ module and how it is built from source. The module is built automatically by The simplest way to do this is to run `make` in the `python` directory: ```bash -$ make +make ``` If `make` is not available, you can run the same command manually: ```bash -$ uv run python setup.py build_ext --inplace +uv run python setup.py build_ext --inplace ``` It is sometimes useful to specify compiler flags when building the low level module. For example, to make a debug build you can use: ```bash -$ CFLAGS='-Wall -O0 -g' make +CFLAGS='-Wall -O0 -g' make ``` If you need to track down a segfault etc, running some code through gdb can @@ -993,7 +993,7 @@ be very useful. For example, to run a particular test case, we can do: ```bash -$ gdb python +gdb python (gdb) run -m pytest tests/test_python_c.py diff --git a/docs/export.md b/docs/export.md index 6dddb73a7c..f5f039e0d1 100644 --- a/docs/export.md +++ b/docs/export.md @@ -41,7 +41,7 @@ If we have a tree sequence file the convenient way to convert to VCF: :::{code-block} bash -$ tskit vcf example.trees > example.vcf +tskit vcf example.trees > example.vcf ::: See the {ref}`sec_export_vcf_compression` section for information @@ -137,14 +137,14 @@ The simplest way to compress the VCF output is to use the and pipe the output to `bgzip`: :::{code-block} bash -$ tskit vcf example.trees | bgzip -c > example.vcf.gz +tskit vcf example.trees | bgzip -c > example.vcf.gz ::: A general way to convert VCF data to various formats is to pipe the text produced by ``tskit`` into ``bcftools`` using the command line interface: :::{code-block} bash -$ tskit vcf example.trees | bcftools view -O b > example.bcf +tskit vcf example.trees | bcftools view -O b > example.bcf ::: If you need more control over the form of the output (or want to work diff --git a/docs/installation.md b/docs/installation.md index 6da52e12a8..ef557c1d02 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -46,7 +46,7 @@ Packages for recent version of Python are available for Linux, OSX and Windows. using: ```bash -$ conda install -c conda-forge tskit +conda install -c conda-forge tskit ``` ### Quick Start @@ -75,7 +75,7 @@ may result in code that is (slightly) faster on your specific hardware. installations. Installation is straightforward: ```bash -$ python3 -m pip install tskit +python3 -m pip install tskit ``` (sec_installation_development_versions)= diff --git a/docs/provenance.md b/docs/provenance.md index d07da38b11..330ad859d4 100644 --- a/docs/provenance.md +++ b/docs/provenance.md @@ -251,4 +251,4 @@ should validate the output JSON against this schema. ```{eval-rst} .. literalinclude:: ../python/tskit/provenance.schema.json :language: json -``` \ No newline at end of file +``` From 1656dc4273872123c3ffc2be45098ed55845ad70 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 12 Mar 2026 09:56:46 -0700 Subject: [PATCH 06/51] further clarify uv subdir; closes #3429 --- docs/development.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/development.md b/docs/development.md index 887fb428b0..f1dabe2cf3 100644 --- a/docs/development.md +++ b/docs/development.md @@ -105,12 +105,16 @@ locally. The packages needed for development are specified as dependency groups in ``python/pyproject.toml`` and managed with [uv](https://docs.astral.sh/uv/). -Install all development dependencies by running, from the `python/` directory: +Install all development dependencies by running: ```bash +cd python uv sync ``` +Since `uv` operates from the `python/` subdirectory, +**all `uv` commands below must be run from within that subdirectory**; +otherwise errors like "No such file or directory" will occur. The lock file lives at `python/uv.lock` and must be kept up to date. Run `uv lock` after any change to the dependencies in `python/pyproject.toml`. From 01de52358ea073bd663cbda1107e9d9cea7113b4 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 12 Mar 2026 10:07:46 -0700 Subject: [PATCH 07/51] installation on macos; closes #3432 --- docs/development.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/development.md b/docs/development.md index f1dabe2cf3..f5ab90c793 100644 --- a/docs/development.md +++ b/docs/development.md @@ -92,6 +92,11 @@ On Debian/Ubuntu we can install these with: sudo apt install build-essential doxygen ``` +On macOS, either `brew install doxygen` or +`sudo port install doxygen` should get doxygen. +You'll also need a "essential build" tools: +a compiler (`gcc`) and a few other things (e.g., `make`). + All Python development is managed using [uv](https://docs.astral.sh/uv/). It is not strictly necessary to use uv in order to make small changes, but the development workflows of all tskit-dev packages are organised around @@ -696,7 +701,10 @@ On Debian/Ubuntu, install the system dependencies with: sudo apt install libcunit1-dev ninja-build ``` -Install meson using uv: +On macOS, you can run `brew install cunit ninja` +or `sudo port install cunit ninja`. + +You can install meson using uv: ```bash uv tool install meson From 0071e2d31498af8f0beea3b3c6a502201e13df23 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 12 Mar 2026 12:16:26 -0700 Subject: [PATCH 08/51] note about debug build; closes #3433 --- docs/development.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/development.md b/docs/development.md index f5ab90c793..f5cf47d2e1 100644 --- a/docs/development.md +++ b/docs/development.md @@ -97,9 +97,12 @@ On macOS, either `brew install doxygen` or You'll also need a "essential build" tools: a compiler (`gcc`) and a few other things (e.g., `make`). -All Python development is managed using [uv](https://docs.astral.sh/uv/). +All Python development is managed using [uv](https://docs.astral.sh/uv/), +which takes the place of virtual/conda environments. It is not strictly necessary to use uv in order to make small changes, but -the development workflows of all tskit-dev packages are organised around +if you don't use it, you'll need to figure out how to install python +dependencies on your own, +and the development workflows of all tskit-dev packages are organised around using uv, and therefore we strongly recommend using it. Uv is straightforward to install, and not invasive (existing Python installations can be completely isolated if you don't use features like ``uv tool`` etc which update your @@ -759,7 +762,9 @@ cd c meson setup build ``` -To setup a debug build add `--buildtype=debug` to the above command. This will set the `TSK_TRACE_ERRORS` +To setup a debug build add `--buildtype=debug` to the above command. +(Re-running the command with this argument will have the desired effect.) +This will set the `TSK_TRACE_ERRORS` flag, which will print error messages to `stderr` when errors occur which is useful for debugging. To compile the code run From afaf3b926f3472fdfb036dee21d539a65f5a3d41 Mon Sep 17 00:00:00 2001 From: peter Date: Mon, 16 Mar 2026 08:15:06 -0700 Subject: [PATCH 09/51] uv run gdb in dev docs --- docs/development.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/development.md b/docs/development.md index f5cf47d2e1..87c92024f5 100644 --- a/docs/development.md +++ b/docs/development.md @@ -1010,10 +1010,7 @@ be very useful. For example, to run a particular test case, we can do: ```bash -gdb python -(gdb) run -m pytest tests/test_python_c.py - - +uv run gdb python (gdb) run -m pytest -vs tests/test_tables.py::TestNodeTable::test_copy Starting program: /usr/bin/python3 run -m pytest tests/test_tables.py::TestNodeTable::test_copy [Thread debugging using libthread_db enabled] From 2f26dc6f0d033cdf7d6fb1adfbab96468dde8831 Mon Sep 17 00:00:00 2001 From: Aaron Ragsdale Date: Thu, 5 Mar 2026 11:05:46 -0600 Subject: [PATCH 10/51] Add documentation for ts.ld_matrix() --- docs/stats.md | 390 +++++++++++++++++++++++++++++++++++++++++- python/CHANGELOG.rst | 6 + python/tskit/trees.py | 98 ++++++++++- 3 files changed, 487 insertions(+), 7 deletions(-) diff --git a/docs/stats.md b/docs/stats.md index 87a2ed5f83..fe042246b9 100644 --- a/docs/stats.md +++ b/docs/stats.md @@ -684,15 +684,395 @@ and boolean expressions (e.g., {math}`(x > 0)`) are interpreted as 0/1. and {math}`v_j` is the covariance of the trait with the j-th covariate. +(sec_stats_multi_site)= + ## Multi site statistics -:::{todo} -Document statistics that use information about correlation between sites, such as -LdCalculator (and perhaps reference {ref}`sec_identity`). Note that if we have a general -framework which has the same calling conventions as the single site stats, -we can rework the sections above. +(sec_stats_two_locus)= + +### Two-locus statistics + +The {meth}`~TreeSequence.ld_matrix` method provides an interface to +a collection of two-locus statistics with predefined summary functions (see +{ref}`sec_stats_two_locus_summary_functions`). +The LD matrix method differs from other +statistics methods in that it provides a unified API with an argument to +specify different two-locus summaries of the data. It otherwise behaves +similarly to most other functions with respect to `sample_sets` and `indexes`. + +Two-locus statistics can be computed using two {ref}`modes `, +either `site` or `branch`, and these should be interpreted in the same way as +these modes in the single-site statistics. That is, the `site` mode computes LD +over observed alleles at pairs of sites, while the `branch` model computes +expected LD conditioned on pairs of trees. + +(sec_stats_two_locus_site)= + +#### Site mode + +The `"site"` mode computes two-locus statistics summarized over alleles between +all pairs of specified sites. The default behavior, leaving `sites` +unspecified, will compute a matrix for all pairs of sites, with one row and +column for each site in the tree sequence (i.e., an {math}`n \times n` matrix +where {math}`n` is the number of sites in the tree sequence). We can also +restrict the output to a subset of sites, either by specifying a single vector +of site indexes for both rows and columns or a pair of vectors for the row +sites and column sites separately. + +The following computes a matrix of the {math}`r^2` measure of linkage +disequilibrium (LD) computed pairwise between the first 4 sites in the tree +sequence among all samples. The `sites` must be given as a list of lists, and +with a single list of sites specified, we obtain a symmetric square matrix. + +```{code-cell} ipython3 +ld = ts.ld_matrix(sites=[[0, 1, 2, 3]]) +print(ld) +``` + +If a list of two lists of site indexes is provided, these specify the row and +column sites. For instance, here we specify 2 rows and 3 columns, which +computes a subset of the matrix shown above. + +```{code-cell} ipython3 +ld = ts.ld_matrix(sites=[[1, 2], [1, 2, 3]]) +print(ld) +``` + +#### Computational details + +Because we allow for two-locus statistics to be computed for multi-allelic +data, we need to be able to combine statistical results from each pair of +alleles into one summary for a pair of sites. This does not affect biallelic +data (and so this section can be skipped on first reading). +We use two implementations for +combining results from multiple alleles: `hap_weighted` and `total_weighted`. +These are statistic-specific and not chosen by the user, with choices motivated +by [Zhao (2007)](https://doi.org/10.1017/S0016672307008634). + +Briefly, consider a pair of sites with {math}`n` alleles at the first locus and +{math}`m` alleles at the second. (Whether this includes the ancestral allele +depends on whether the statistic is polarised.) Write {math}`f_{ij}` as the +statistic computed for focal alleles {math}`A_i` and {math}`B_j`. Then the +weighting schemes are defined as: + +- `hap_weighted`: {math}`\sum_{i=1}^{n}\sum_{j=1}^{m}p(A_{i}B_{j})f_{ij}`, + where {math}`p(A_{i}B_{j})` is the frequency of haplotype {math}`A_{i}B_{j}`. + This method was first introduced in [Karlin + (1981)](https://doi.org/10.1111/j.1469-1809.1981.tb00308.x) and reviewed in + [Zhao (2007)](https://doi.org/10.1017/S0016672307008634). + +- `total_weighted`: {math}`\frac{1}{n m}\sum_{i=1}^{n}\sum_{j=1}^{m}f_{ij}`. + This method assigns equal weight to each of the possible pairs of focal + alleles at the two sites, taking the arithmetic mean of statistics over + focal haplotypes. + +Out of all of the available summary functions, only {math}`r^2` uses +`hap_weighted` normalisation, with the remainder using uniform weighting +(`total_weighted`). + +Within this framework, statistics may be either polarised or unpolarised. For +statistics that are polarised, we compute statistic values for pairs of derived +alleles. (For this purpose, the "derived" alleles at a site are all alleles +except that stored as the ``ancestral_state`` for the site.) Unpolarised +statistics compute statistics over all pairs of alleles, derived and ancestral. +In either case, the result is averaged over these values, using one of the +weighting scheme (described below for each statistics). The option for +polarisation is not exposed to the user, and we list which statistics are +polarised below. + +(sec_stats_two_locus_branch)= + +#### Branch mode + +The `"branch"` mode computes expected two-locus statistics between pairs of +trees, conditioned on the marginal topologies and branch lengths of those +trees. The trees for which we compute statistics are specified by positions, +and for a pair of positions we consider all possible haplotypes that could be +generated by a single mutation occurring on each of the two trees. + +For two trees, one with {math}`n` branches and the other with {math}`m` +branches, there are {math}`nm` possible pairs of branches that may carry the +pair of mutations. For each pair, we compute the two-locus statistic, and then +sum these values weighted by the product of the two branch lengths. Given that +the two mutations occur, this accounts for the relative probability that the +two mutations fall on any pair of branches. + +In other words, imagine we place two mutations uniformly, one on each tree, and +then compute the statistic. The branch mode computes the expected value of the +statistic over this process, multiplied by the product of the total branch +lengths of each tree. This weighting accounts for mutational opportunity, so that +the sum of the branch-mode statistic over all positions in a genomic region, +multiplied by a mutation rate, is equal to the expected sum of the two-locus site +statistic over all mutations falling in that region under an infinite-sites model. + +The time complexity of this method is quadratic in the number of samples, +due to the pairwise comparisons of branches from each pair of trees. +By default, this method computes +a symmetric matrix for all pairs of trees, with rows and columns representing +each tree in the tree sequence. Similar to the site method, we can restrict the +output to a subset of trees, either by specifying a vector of positions or +a pair of vectors for row and column positions separately. To select a specific +tree, the specified positions must land in the tree span (`[start, end)`). + +In the following, we compute a matrix of expected {math}`r^2` within and +between the first 4 trees in the tree sequence. The tree breakpoints are +a convenient way to specify those first four trees. + +```{code-cell} ipython3 +ld = ts.ld_matrix( + mode="branch", + positions=[ts.breakpoints(as_array=True)[0:4]] +) +print(ld) +``` + +We note that these values are quite large: as described above, the statistic is +scaled by the product of the total branch lengths of each pair of trees. To +compute the expected {math}`r^2` value for a pair of mutations that each land +uniformly on the pair of trees, we can divide by the product of the total +branch lengths: + +```{code-cell} ipython3 +total_branch_lengths = [tree.total_branch_length for tree in ts.trees()] +prod_branch_lengths = np.outer(total_branch_lengths, total_branch_lengths) +print(ld / prod_branch_lengths[0:4, 0:4]) +``` + +To compute the average {math}`r^2` for a uniformly chosen pair of mutations, we also +weight by tree span: + +```{code-cell} ipython3 +tree_spans = np.array([t.span for t in ts.trees()]) +total_opportunity = np.sum(tree_spans * total_branch_lengths) +all_ld = ts.ld_matrix(mode="branch") +mean_ld = np.sum(all_ld * np.outer(tree_spans, tree_spans)) / total_opportunity ** 2 +print("mean infinite-sites LD:", mean_ld) +``` + +As with the `"site"` mode above, we can specify the row and column trees +separately. + +```{code-cell} ipython3 +breakpoints = ts.breakpoints(as_array=True) +ld = ts.ld_matrix( + mode="branch", + positions=[breakpoints[[0]], breakpoints[0:4]] +) +print(ld) +``` + +(sec_stats_two_locus_sample_sets)= + +#### Sample Sets + +Without specifying `sample_sets` or `indexes`, the `ld_matrix()` method +computes statistics over a single sample set that includes all samples in the +tree sequence. The API allows for the specification of a subset or multiple +subsets of samples, so that a separate LD matrix can be computed for each. If +`sample_sets` is specified as a single list of samples, then a single LD matrix +is returned. A list of lists of samples will return a 3D array containing an LD +matrix for each list of samples. + +Some LD statistics can be computed between sample sets (two-way statistics are +specified below), in which case `indexes` must be specified that reference the +indexes of the `sample_sets`, which must be a list of lists of sample nodes. +This results in an LD matrix computed for each list of indexes. The statistics +are selected in the same way (with the `stat` argument), and these are limited +to a handful of statistics (see +{ref}`sec_stats_two_locus_summary_functions_two_way`). The dimension-dropping +rules for the result follow the rest of the tskit stats API in that a single +list or tuple will produce a single two-dimensional matrix, while a list of +these will produce a three-dimensional array, with the first dimension of +length equal to the length of the list. + +For example, to compute the {math}`r^2` LD matrix over a subset of samples in +the tree sequence (such as sample nodes 0 through 7), we would specify the +samples as follows: + +```{code-cell} ipython3 +ts = msprime.sim_ancestry( + 20, + population_size=10000, + sequence_length=1000, + recombination_rate=2e-8, + random_seed=12) +ts = msprime.sim_mutations(ts, rate=2e-8, random_seed=12) + +ld = ts.ld_matrix(mode="site", sample_sets=range(8)) +print(ld) +``` + +We would get the following dimensions with the specified +`sample_sets` and `indexes` arguments. + +``` +# one-way +ts.ld_matrix(sample_sets=None) # -> 2 dimensions +ts.ld_matrix(sample_sets=[0, 1, 2, 3]) # -> 2 dimensions +ts.ld_matrix(sample_sets=[[0, 1, 2, 3]]) # -> 3 dimensions +# two-way +ts.ld_matrix(sample_sets=[[0, 1, 2, 3], [4, 5, 6, 7]], indexes=(0, 1)) # -> 2 dimensions +ts.ld_matrix(sample_sets=[[0, 1, 2, 3], [4, 5, 6, 7]], indexes=[(0, 1)]) # -> 3 dimensions +``` + +#### Why are there `nan` values in the LD matrix? + +For some statistics, it is possible to observe `nan` entries in the LD matrix, +which can be surprising and may numerically impact downstream analyses. A `nan` +entry occurs if the denominator of a ratio statistic (including {math}`r` and +{math}`r^2`) is zero, indicating that one or both of the alleles in the pair is +fixed or absent in the given sample set(s). This can happen for +a number of reasons: + +- Some mutation models allow for reversible mutations, so a back mutation at + a site can result in a single allele despite multiple mutations in the + history of the sample. +- LD is computed for a subsample of individuals, and some sites are not + variable among the sample nodes in the subsample. +- A mutation exists above the root of the local tree, so that all samples carry + the mutation, and one or more sites are not variable. + +The `branch` mode will also return `nan` values for ratio statistics if there +are branches in either tree on which a mutation would not result in +a polymorphism within a sample set. + +:::{warning} +This means there are two common situations in which many or all LD values will be `nan`. +These are: + +1. A branch-mode ratio statistic computed on less than the full set of samples + will always be `nan`, since part of the trees are ancestral to none of the samples. +2. A site-mode ratio statistic will be `nan` at any sites at which there are alleles found + in the entire set of samples that are not seen in the provided sample set. + +This behavior **may change in the future**, +because possibly more natural behavior not currently implemented +would be to ignore the branches/alleles not ancestral +to any of the provided samples. ::: +(sec_stats_two_locus_sample_one_way_stats)= + +#### One-way Statistics + +One-way statistics are summaries of two loci in a single sample set, using +a triple of haplotype counts {math}`\{n_{AB}, n_{Ab}, n_{aB}\}` and the size of +the sample set {math}`n`, where the capitalized and lowercase letters in our +notation represent alternate alleles. + +(sec_stats_two_locus_sample_two_way_stats)= + +#### Two-way Statistics + +Two-way statistics are summaries of haplotype counts between two sample sets, +which operate on the three haplotype counts (as in one-way stats, above) +computed from each sample set, indexed by `(i, j)`. These statistics take on +a different meaning from their one-way counterparts. For instance `stat="D2"` +over a pair of sample sets computes {math}`D_i D_j`, which is the product of +the covariance measure of LD within each sample set and is related to the +covariance of {math}`D` between sample sets. + +Only a subset of our summary functions are two-way statistics (see +{ref}`sec_two_locus_summary_functions_two_way`). Note that the unbiased two-way +statistics expect non-overlapping sample sets (see [Ragsdale and Gravel +(2020)](https://doi.org/10.1093/molbev/msz265)), and we do not make any +assertions about the sample sets and assume that `i` and `j` represent disjoint +sets of samples (see also the note in {meth}`~TreeSequence.divergence`). + +(sec_stats_two_locus_summary_functions)= + +#### Summary Functions + +(sec_stats_two_locus_summary_functions_one_way)= + +##### One-way + +The two-locus summary functions all take haplotype counts and sample set size +as input. Each of our summary functions has the signature +{math}`f(n_{AB}, n_{Ab}, n_{aB}, n)`, converting to haplotype frequencies +{math}`\{p_{AB}, p_{Ab}, p_{aB}\}` by dividing by {math}`n`. Below, +{math}`n_{ab} = n - n_{AB} - n_{Ab} - n_{aB}`, {math}`n_A = n_{AB} + n_{Ab}` +and {math}`n_B = n_{AB} + n_{aB}`, with frequencies {math}`p` found by dividing +by {math}`n`. + +Our convention is to use {math}`A,B` to denote derived alleles, and {math}`a,b` +ancestral alleles (or other alleles, if the site is multi-allelic). For +polarised statistics, we average statistics over all non-ancestral alleles. For +unpolarised statistics, the labeling is arbitrary as we average over all +alleles (derived and ancestral). + +`D` +: {math}`f(n_{AB}, n_{Ab}, n_{aB}, n) = p_{AB}p_{ab} - p_{Ab}p_{aB} \, (=p_{AB} - p_A p_B)` + + This statistic is polarised, as the unpolarised result, which averages over + allele labelings, is zero. Uses the `total` weighting method. + +`D_prime` +: {math}`f(n_{AB}, n_{Ab}, n_{aB}, n) = \frac{D}{D_{\max}}`, + + where {math}`D_{\max} = \begin{cases} + \min\{p_A (1-p_B), p_B (1-p_B)\} & \textrm{if }D>=0 \\ + \min\{p_A p_B, (1-p_B) (1-p_B)\} & \textrm{if }D<0 + \end{cases}` + + and {math}`D` is defined above. Polarised, `total` weighted. + +`D2` +: {math}`f(n_{AB}, n_{Ab}, n_{aB}, n) = D^2` + + and {math}`D` is defined above. Unpolarised, `total` weighted. + +`Dz` +: {math}`f(n_{AB}, n_{Ab}, n_{aB}, n) = D (1 - 2 p_A) (1 - 2 p_B)`, + + where {math}`D` is defined above. Unpolarised, `total` weighted. + +`pi2` +: {math}`f(n_{AB}, n_{Ab}, n_{aB}, n) = p_A (1-p_A) p_B (1-p_B)` + + Unpolarised, `total` weighted. + +`r` +: {math}`f(n_{AB}, n_{Ab}, n_{aB}, n) = \frac{D}{\sqrt{p_A (1-p_A) p_B (1-p_B)}}`, + + where {math}`D` is defined above. Polarised, `total` weighted. + +`r2` +: {math}`f(n_{AB}, n_{Ab}, n_{aB}, n) = \frac{D^{2}}{p_A (1-p_A) p_B (1-p_B))}`, + + where {math}`D` is defined above. Unpolarised, `haplotype` weighted. + +Unbiased two-locus statistics from the Hill-Robertson (1968) system are +computed from haplotype counts. Definitions of these unbiased estimators can +be found in [Ragsdale and Gravel +(2020)](https://doi.org/10.1093/molbev/msz265). They require at least 4 samples +to be valid and are specified as `stat="D2_unbiased"`, `"Dz_unbiased"`, or +`"pi2_unbiased"`. + +(sec_two_locus_summary_functions_two_way)= + +(sec_stats_two_locus_summary_functions_two_way)= + +##### Two-way + +Two-way statistics are indexed by sample sets {math}`i, j` and compute values +using haplotype counts within pairs of sample sets. + +`D2` +: {math}`f(n_{AB}, n_{Ab}, n_{aB}, n) = D_i D_j`, + + where {math}`D_i` denotes {math}`D` computed within sample set {math}`i`, + and {math}`D` is defined above. Unpolarised, `total` weighted. + +`r2` +: {math}`f(n_{AB}, n_{Ab}, n_{aB}, n) = r_i r_j`, + + where {math}`r_i` denotes {math}`r` computed within sample set {math}`i`, + and {math}`r` is defined above. Unpolarised, `haplotype` weighted. + +And `D2_unbiased`, which can be found in [Ragsdale and Gravel +(2020)](https://doi.org/10.1093/molbev/msz265). + (sec_stats_notes)= diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 916412cc81..bc9b6217da 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -7,6 +7,12 @@ In development - Add ``json+struct`` metadata codec that allows storing binary data using a struct schema alongside JSON metadata. (:user:`benjeffery`, :pr:`3306`) +**Features** + +- Add ``TreeSequence.ld_matrix`` stats method and documentation, for computing + two-locus statistics in site and branch mode. + (:user:`lkirk`, :user:`apragsdale`, :pr:`3416`) + -------------------- [1.0.2] - 2026-03-06 -------------------- diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 45d2da59e0..a370daf17f 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10930,12 +10930,106 @@ def impute_unknown_mutations_time( def ld_matrix( self, sample_sets=None, - sites=None, - positions=None, mode="site", stat="r2", + sites=None, + positions=None, indexes=None, ): + r""" + + Returns a matrix of the specified two-locus statistic (default + :math:`r^2`) computed from sample allelic states or branch lengths. + The resulting linkage disequilibrium (LD) matrix represents either the + two-locus statistic as computed between all pairs of specified + ``sites`` (``"site"`` mode, producing a + ``len(sites)``-by-``len(sites)`` sized matrix), or as computed from the + branch structures at marginal trees between pairs of trees at all + specified ``positions`` (``"branch"`` mode, producing a + ``len(positions)``-by-``len(positions)`` sized matrix). + + The sites considered for ``"site"`` mode defaults to all sites (which may + result in a very large matrix!), but can be restricted using + the ``sites`` argument. Sites must be passed as a list of lists, + specifying the ``[row_sites, col_sites]``, resulting in a + rectangular matrix, or by specifying a single list of ``[sites]``, in + which a square matrix will be produced (see + :ref:`sec_stats_two_locus_site` for examples). Here, ``sites``, + ``row_sites``, and ``col_sites`` are each lists of site indexes. + + Similarly, in the ``"branch"`` mode, the ``positions`` argument specifies + genomic coordinates at which the expectation for the two-locus statistic + is computed, given the local tree structure. + (See :ref:`sec_stats_two_locus_branch` for explanation of in what sense + this is an expectation.) This defaults to computing + the LD for each pair of distinct trees (this is equivalent to passing in + the leftmost coordinates of each tree's span, since intervals are closed on + the left and open on the right). Similar to the site mode, a nested list + of row and column positions can be specified separately (resulting in a + rectangular matrix) or a single list of a specified positions results + in a square matrix (see :ref:`sec_stats_two_locus_branch` for + examples). Like ``sites``, the ``positions`` must be specified as a list + of lists. + + Some LD statistics are defined for both within a single set of samples + and for two sample sets. If the ``indexes`` argument is specified, then + ``indexes`` specifies the indexes of the sample sets in the + ``sample_sets`` list between which to compute LD. For instance, this + results in a 3D array whose ``[k,:,:]``-th slice contains LD values + between ``sample_sets[i]`` and ``sample_sets[j]``, where ``(i, j)`` is + the ``k``-th element of ``indexes``. + + For more on how the ``indexes`` and ``sample_sets`` interact with the + output dimensions, see the :ref:`sec_stats_two_locus_sample_sets` + section. Statistics are defined in the + :ref:`sec_stats_two_locus_summary_functions_two_way` section. + + **Available Stats** (use ``Stat Name`` in the ``stat`` keyword + argument). Statistics marked as "multi sample set" allow + (but do not require) computation from two sample sets + via the ``indexes`` argument. + + ======================= ========== ================ ============== + Stat Polarised Multi Sample Set Stat Name + ======================= ========== ================ ============== + :math:`r^2` n y "r2" + :math:`r` y n "r" + :math:`D^2` n y "D2" + :math:`D` y n "D" + :math:`D'` y n "D_prime" + :math:`D_z` n n "Dz" + :math:`\pi_2` n n "pi2" + :math:`\widehat{D^2}` n y "D2_unbiased" + :math:`\widehat{D_z}` n n "Dz_unbiased" + :math:`\widehat{\pi_2}` n n "pi2_unbiased" + ======================= ========== ================ ============== + + :param list sample_sets: A list, or a list of lists of sample node IDs, + specifying the groups of nodes to compute the statistic with. Defaults + to all samples. + :param str mode: A string giving the "type" of the statistic to be + computed. Defaults to "site", can be "site" or "branch". + :param str stat: A string giving the selected two-locus statistic to + compute. Defaults to "r2". + :param list sites: A list of lists of sites over which to compute an + LD matrix. Can be specified as a list of lists to control the row + and column sites. Only available in "site" mode. Specify as + ``[row_sites, col_sites]`` or ``[all_sites]``. + Defaults to all sites. + :param list positions: A list of lists of genomic positions where + expected LD is computed based on tree topologies and branch + lengths. Only applicable in "branch" mode. Specify as a list of + two lists to control the row and column positions, as + ``[row_positions, col_positions]``, or ``[all_positions]``. + Defaults to the leftmost coordinates of all trees and computes + LD between all pairs of trees. + :param list indexes: A list of 2-tuples or a single 2-tuple, specifying + the indexes of two sample sets over which to compute a two-way LD + statistic. Only :math:`r^2`, :math:`D^2`, and :math:`\widehat{D^2}` + are implemented for two-way statistics. + :return: A 2D or 3D array of LD matrices. + :rtype: numpy.ndarray + """ one_way_stats = { "D": self._ll_tree_sequence.D_matrix, "D2": self._ll_tree_sequence.D2_matrix, From cc36ef731f8951c4f8560090f82fb96fa0f793e1 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 26 Oct 2025 18:57:49 -0500 Subject: [PATCH 11/51] initial stab at a general matrix (no normalisation) --- c/tskit/trees.c | 66 ++++++++------ c/tskit/trees.h | 7 ++ python/_tskitmodule.c | 202 ++++++++++++++++++++++++++++++++++++++++++ python/tskit/trees.py | 65 +++++++++++--- 4 files changed, 299 insertions(+), 41 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 1aa06e5b03..f7805857ca 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2411,8 +2411,8 @@ static int compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, tsk_size_t num_a_alleles, tsk_size_t num_b_alleles, tsk_size_t state_dim, - tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, - norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result) + tsk_size_t result_dim, general_stat_func_t *f, void *f_params, norm_func_t *norm_f, + bool polarised, two_locus_work_t *restrict work, double *result) { int ret = 0; // Sample sets and b sites are rows, a sites are columns @@ -2463,9 +2463,8 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, static int compute_general_two_site_stat_result(const tsk_bitset_t *state, const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, - tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, - double *result) + tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, void *f_params, + two_locus_work_t *restrict work, double *result) { int ret = 0; tsk_size_t k; @@ -2653,9 +2652,8 @@ static int tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows, - const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + void *f_params, norm_func_t *norm_f, tsk_size_t n_rows, const tsk_id_t *row_sites, + tsk_size_t n_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result) { int ret = 0; tsk_bitset_t allele_samples, allele_sample_sets; @@ -3089,9 +3087,8 @@ advance_collect_edges(iter_state *s, tsk_id_t index) static int compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim, - tsk_size_t result_dim, int sign, general_stat_func_t *f, - sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, - double *result) + tsk_size_t result_dim, int sign, general_stat_func_t *f, void *f_params, + two_locus_work_t *restrict work, double *result) { int ret = 0; double a_len, b_len; @@ -3141,8 +3138,8 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, static int compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, - iter_state *r_state, general_stat_func_t *f, sample_count_stat_params_t *f_params, - tsk_size_t result_dim, tsk_size_t state_dim, double *result) + iter_state *r_state, general_stat_func_t *f, void *f_params, tsk_size_t result_dim, + tsk_size_t state_dim, double *result) { int ret = 0; tsk_id_t e, c, ec, p, *updated_nodes = NULL; @@ -3243,9 +3240,9 @@ static int tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f), - tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols, - const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result) + void *f_params, norm_func_t *TSK_UNUSED(norm_f), tsk_size_t n_rows, + const double *row_positions, tsk_size_t n_cols, const double *col_positions, + tsk_flags_t TSK_UNUSED(options), double *result) { int ret = 0; int r, c; @@ -3385,10 +3382,10 @@ check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_s } int -tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, - norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, +tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result) { @@ -3398,10 +3395,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); tsk_size_t state_dim = num_sample_sets; - sample_count_stat_params_t f_params = { .sample_sets = sample_sets, - .num_sample_sets = num_sample_sets, - .sample_set_sizes = sample_set_sizes, - .set_indexes = set_indexes }; // We do not support two-locus node stats if (!!(options & TSK_STAT_NODE)) { @@ -3441,7 +3434,7 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl goto out; } ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); } else if (stat_branch) { ret = check_positions( @@ -3455,13 +3448,30 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl goto out; } ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_positions, out_cols, col_positions, options, result); } out: return ret; } +int +tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, + norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result) +{ + sample_count_stat_params_t f_params = { .sample_sets = sample_sets, + .num_sample_sets = num_sample_sets, + .sample_set_sizes = sample_set_sizes, + .set_indexes = set_indexes }; + return tsk_treeseq_two_locus_count_general_stat(self, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + row_sites, row_positions, out_cols, col_sites, col_positions, options, result); +} + /*********************************** * Allele frequency spectrum ***********************************/ @@ -8697,8 +8707,8 @@ update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A, for (k = offsets[b]; k < offsets[b + 1]; k++) { u = A[j]; v = A[k]; - /* Only increment the upper triangle to (hopefully) improve memory - * access patterns */ + /* Only increment the upper triangle to (hopefully) improve + * memory access patterns */ if (u > v) { u = A[k]; v = A[j]; diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 84480ed96e..acc15c9aac 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1120,6 +1120,13 @@ typedef int general_sample_stat_method(const tsk_treeseq_t *self, const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result); + typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 0e0c1c5ed5..afde032847 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7946,6 +7946,203 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) return array; } +typedef struct { + PyArrayObject *sample_set_sizes; + PyObject *callable; +} two_locus_general_stat_params; + +static int +general_two_locus_count_stat_func( + tsk_size_t K, const double *X, tsk_size_t M, double *Y, void *params) +{ + int ret = TSK_PYTHON_CALLBACK_ERROR; + two_locus_general_stat_params *tl_params = params; + PyObject *callable = tl_params->callable; + PyArrayObject *sample_set_sizes = tl_params->sample_set_sizes; + PyObject *arglist = NULL; + PyObject *result = NULL; + PyArrayObject *X_array = NULL; + PyArrayObject *Y_array = NULL; + npy_intp X_dims[2] = { K, 3 }; + // Convert "n" to a column array + PyArray_Dims n_dims = { (npy_intp[2]){ PyArray_DIMS(sample_set_sizes)[0], 1 }, 2 }; + npy_intp *Y_dims; + + // Create a read only view of X as a numpy array + X_array = (PyArrayObject *) PyArray_SimpleNewFromData( + 2, X_dims, NPY_FLOAT64, (void *) X); + if (X_array == NULL) { + goto out; + } + sample_set_sizes + = (PyArrayObject *) PyArray_Newshape(sample_set_sizes, &n_dims, NPY_CORDER); + + PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); + arglist = Py_BuildValue("OO", X_array, sample_set_sizes); + if (arglist == NULL) { + goto out; + } + result = PyObject_CallObject(callable, arglist); + if (result == NULL) { + goto out; + } + Y_array = (PyArrayObject *) PyArray_FromAny( + result, PyArray_DescrFromType(NPY_FLOAT64), 0, 0, NPY_ARRAY_IN_ARRAY, NULL); + if (Y_array == NULL) { + goto out; + } + if (PyArray_NDIM(Y_array) != 1) { + PyErr_Format(PyExc_ValueError, + "Array returned by general_stat callback is %d dimensional; " + "must be 1D", + (int) PyArray_NDIM(Y_array)); + goto out; + } + Y_dims = PyArray_DIMS(Y_array); + if (Y_dims[0] != (npy_intp) M) { + PyErr_Format(PyExc_ValueError, + "Array returned by general_stat callback is of length %d; " + "must be %d", + Y_dims[0], M); + goto out; + } + /* Copy the contents of the return Y array into Y */ + memcpy(Y, PyArray_DATA(Y_array), M * sizeof(*Y)); + ret = 0; +out: + Py_XDECREF(X_array); + Py_XDECREF(arglist); + Py_XDECREF(result); + Py_XDECREF(Y_array); + return ret; +} + +static PyObject * +TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "summary_func", + "output_dim", "polarised", "row_sites", "col_sites", "row_positions", + "column_positions", "mode", NULL }; + two_locus_general_stat_params *params; + PyObject *summary_func = NULL; + unsigned int output_dim; + PyObject *sample_set_sizes = NULL; + PyObject *sample_sets = NULL; + PyObject *row_sites = NULL; + PyObject *col_sites = NULL; + PyObject *row_positions = NULL; + PyObject *col_positions = NULL; + char *mode = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + PyArrayObject *sample_sets_array = NULL; + PyArrayObject *row_sites_array = NULL; + PyArrayObject *col_sites_array = NULL; + PyArrayObject *row_positions_array = NULL; + PyArrayObject *col_positions_array = NULL; + PyArrayObject *result_matrix = NULL; + tsk_id_t *row_sites_parsed = NULL; + tsk_id_t *col_sites_parsed = NULL; + double *row_positions_parsed = NULL; + double *col_positions_parsed = NULL; + npy_intp result_dim[3] = { 0, 0, 0 }; + tsk_size_t num_sample_sets; + tsk_flags_t options = 0; + int polarised = 0; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|s", kwlist, + &sample_set_sizes, &sample_sets, &summary_func, &output_dim, &polarised, + &row_sites, &col_sites, &row_positions, &col_positions, &mode)) { + Py_XINCREF(summary_func); + goto out; + } + Py_INCREF(summary_func); + if (!PyCallable_Check(summary_func)) { + PyErr_SetString(PyExc_TypeError, "summary_func must be callable"); + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (polarised) { + options |= TSK_STAT_POLARISED; + } + + if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; + } + PyArray_CLEARFLAGS(sample_set_sizes_array, NPY_ARRAY_WRITEABLE); + + if (options & TSK_STAT_SITE) { + if (row_positions != Py_None || col_positions != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify positions in site mode"); + goto out; + } + row_sites_array = parse_sites(self, row_sites, &(result_dim[0])); + col_sites_array = parse_sites(self, col_sites, &(result_dim[1])); + if (row_sites_array == NULL || col_sites_array == NULL) { + goto out; + } + row_sites_parsed = PyArray_DATA(row_sites_array); + col_sites_parsed = PyArray_DATA(col_sites_array); + } else if (options & TSK_STAT_BRANCH) { + if (row_sites != Py_None || col_sites != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify sites in branch mode"); + goto out; + } + row_positions_array = parse_positions(self, row_positions, &(result_dim[0])); + col_positions_array = parse_positions(self, col_positions, &(result_dim[1])); + if (col_positions_array == NULL || row_positions_array == NULL) { + goto out; + } + row_positions_parsed = PyArray_DATA(row_positions_array); + col_positions_parsed = PyArray_DATA(col_positions_array); + } + + result_dim[2] = num_sample_sets; + result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); + if (result_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + + params = &(two_locus_general_stat_params){ + .sample_set_sizes = sample_set_sizes_array, + .callable = summary_func, + }; + // TODO: deal with null norm func, need general stat. + err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, + PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), + output_dim, general_two_locus_count_stat_func, params, NULL, result_dim[0], + row_sites_parsed, row_positions_parsed, result_dim[1], col_sites_parsed, + col_positions_parsed, options, PyArray_DATA(result_matrix)); + + if (err == TSK_PYTHON_CALLBACK_ERROR) { + goto out; + } else if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_matrix; + result_matrix = NULL; +out: + Py_XDECREF(summary_func); + Py_XDECREF(row_sites_array); + Py_XDECREF(col_sites_array); + Py_XDECREF(row_positions_array); + Py_XDECREF(col_positions_array); + Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(result_matrix); + return ret; +} + static PyObject * TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, two_locus_count_stat_method *method) @@ -8831,6 +9028,11 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_general_stat, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Runs the general stats algorithm for a given summary function." }, + { .ml_name = "two_locus_count_stat", + .ml_meth = (PyCFunction) TreeSequence_two_locus_count_stat, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc + = "Runs the general two locus stats algorithm for a given summary function." }, { .ml_name = "diversity", .ml_meth = (PyCFunction) TreeSequence_diversity, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index a370daf17f..909fb57e93 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8249,19 +8249,7 @@ def parse_positions(self, positions): ) return row_positions, col_positions - def __two_locus_sample_set_stat( - self, - ll_method, - sample_sets, - sites=None, - positions=None, - mode=None, - ): - if sample_sets is None: - sample_sets = self.samples() - row_sites, col_sites = self.parse_sites(sites) - row_positions, col_positions = self.parse_positions(positions) - + def __convert_sample_sets(self, sample_sets): # First try to convert to a 1D numpy array. If we succeed, then we strip off # the corresponding dimension from the output. drop_dimension = False @@ -8283,7 +8271,23 @@ def __two_locus_sample_set_stat( raise ValueError("Sample sets must contain at least one element") flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + return drop_dimension, flattened, sample_set_sizes + def __two_locus_sample_set_stat( + self, + ll_method, + sample_sets, + sites=None, + positions=None, + mode=None, + ): + if sample_sets is None: + sample_sets = self.samples() + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + drop_dimension, flattened, sample_set_sizes = self.__convert_sample_sets( + sample_sets + ) result = ll_method( sample_set_sizes, flattened, @@ -10927,6 +10931,41 @@ def impute_unknown_mutations_time( mutations_time[unknown] = self.nodes_time[self.mutations_node[unknown]] return mutations_time + def two_locus_count_stat( + self, + sample_sets, + f, + result_dim, + polarised=False, + sites=None, + positions=None, + mode="site", + ): + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + drop_dimension, flattened, sample_set_sizes = self.__convert_sample_sets( + sample_sets + ) + result = self._ll_tree_sequence.two_locus_count_stat( + sample_set_sizes, + sample_sets, + f, + result_dim, + polarised, + row_sites, + col_sites, + row_positions, + col_positions, + mode, + ) + if drop_dimension: + result = result.reshape(result.shape[:2]) + else: + # Orient the data so that the first dimension is the sample set. + # With this orientation, we get one LD matrix per sample set. + result = result.swapaxes(0, 2).swapaxes(1, 2) + return result + def ld_matrix( self, sample_sets=None, From 1c086a279bfd73deed192ca7b66b322e0297dbf7 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 26 Oct 2025 18:59:39 -0500 Subject: [PATCH 12/51] added dimension dropping, but I think transposing is better -- we don't have to add a dimension at the end for scalar operations --- python/_tskitmodule.c | 37 +++++++++++++++++++++++-------------- python/tskit/trees.py | 4 +++- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index afde032847..ed50577ee8 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7947,6 +7947,7 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) } typedef struct { + bool drop_dimensions; PyArrayObject *sample_set_sizes; PyObject *callable; } two_locus_general_stat_params; @@ -7956,29 +7957,33 @@ general_two_locus_count_stat_func( tsk_size_t K, const double *X, tsk_size_t M, double *Y, void *params) { int ret = TSK_PYTHON_CALLBACK_ERROR; - two_locus_general_stat_params *tl_params = params; - PyObject *callable = tl_params->callable; - PyArrayObject *sample_set_sizes = tl_params->sample_set_sizes; PyObject *arglist = NULL; PyObject *result = NULL; PyArrayObject *X_array = NULL; PyArrayObject *Y_array = NULL; - npy_intp X_dims[2] = { K, 3 }; - // Convert "n" to a column array - PyArray_Dims n_dims = { (npy_intp[2]){ PyArray_DIMS(sample_set_sizes)[0], 1 }, 2 }; + two_locus_general_stat_params *tl_params = params; + PyObject *callable = tl_params->callable; + PyArrayObject *ss_sizes = tl_params->sample_set_sizes; + bool drop = (K == 1 && tl_params->drop_dimensions); + // Convert "n" to a column array -- reshape(-1, K) or a scalar if K=1 and drop=True + PyArray_Dims ss_sizes_dims = (drop ? (PyArray_Dims){ (npy_intp[1]){ 1 }, 0 } + : (PyArray_Dims){ (npy_intp[2]){ K, 1 }, 2 }); + int X_ndims = drop ? 1 : 2; + npy_intp *X_dims = drop ? (npy_intp[1]){ 3 } : (npy_intp[2]){ K, 3 }; npy_intp *Y_dims; // Create a read only view of X as a numpy array X_array = (PyArrayObject *) PyArray_SimpleNewFromData( - 2, X_dims, NPY_FLOAT64, (void *) X); + X_ndims, X_dims, NPY_FLOAT64, (void *) X); if (X_array == NULL) { goto out; } - sample_set_sizes - = (PyArrayObject *) PyArray_Newshape(sample_set_sizes, &n_dims, NPY_CORDER); - + ss_sizes = (PyArrayObject *) PyArray_Newshape(ss_sizes, &ss_sizes_dims, NPY_CORDER); + if (ss_sizes == NULL) { + goto out; + } PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); - arglist = Py_BuildValue("OO", X_array, sample_set_sizes); + arglist = Py_BuildValue("OO", X_array, ss_sizes); if (arglist == NULL) { goto out; } @@ -8014,6 +8019,7 @@ general_two_locus_count_stat_func( Py_XDECREF(arglist); Py_XDECREF(result); Py_XDECREF(Y_array); + Py_XDECREF(ss_sizes); return ret; } @@ -8023,7 +8029,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * PyObject *ret = NULL; static char *kwlist[] = { "sample_set_sizes", "sample_sets", "summary_func", "output_dim", "polarised", "row_sites", "col_sites", "row_positions", - "column_positions", "mode", NULL }; + "column_positions", "mode", "drop_dimensions", NULL }; two_locus_general_stat_params *params; PyObject *summary_func = NULL; unsigned int output_dim; @@ -8048,15 +8054,17 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * npy_intp result_dim[3] = { 0, 0, 0 }; tsk_size_t num_sample_sets; tsk_flags_t options = 0; + int drop_dimensions = 0; int polarised = 0; int err; if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|s", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|si", kwlist, &sample_set_sizes, &sample_sets, &summary_func, &output_dim, &polarised, - &row_sites, &col_sites, &row_positions, &col_positions, &mode)) { + &row_sites, &col_sites, &row_positions, &col_positions, &mode, + &drop_dimensions)) { Py_XINCREF(summary_func); goto out; } @@ -8115,6 +8123,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * params = &(two_locus_general_stat_params){ .sample_set_sizes = sample_set_sizes_array, .callable = summary_func, + .drop_dimensions = drop_dimensions, }; // TODO: deal with null norm func, need general stat. err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 909fb57e93..655281ef6d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10940,6 +10940,7 @@ def two_locus_count_stat( sites=None, positions=None, mode="site", + drop_dimensions=True, ): row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) @@ -10948,7 +10949,7 @@ def two_locus_count_stat( ) result = self._ll_tree_sequence.two_locus_count_stat( sample_set_sizes, - sample_sets, + flattened, f, result_dim, polarised, @@ -10957,6 +10958,7 @@ def two_locus_count_stat( row_positions, col_positions, mode, + drop_dimensions, ) if drop_dimension: result = result.reshape(result.shape[:2]) From 1e30d0d8708b028dc43b0e09f1d6b5b5e7651486 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 12:31:28 -0600 Subject: [PATCH 13/51] finalize and add tests for single and multipop --- c/tskit/trees.c | 4 +- python/_tskitmodule.c | 153 +++++++++++++++----- python/tests/test_ld_matrix.py | 255 ++++++++++++++++++++++++++++++++- python/tskit/trees.py | 33 ++--- 4 files changed, 387 insertions(+), 58 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index f7805857ca..a8f6e168f9 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -8707,8 +8707,8 @@ update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A, for (k = offsets[b]; k < offsets[b + 1]; k++) { u = A[j]; v = A[k]; - /* Only increment the upper triangle to (hopefully) improve - * memory access patterns */ + /* Only increment the upper triangle to (hopefully) improve memory + * access patterns */ if (u > v) { u = A[k]; v = A[j]; diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index ed50577ee8..c6bca34ca9 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7947,47 +7947,123 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) } typedef struct { - bool drop_dimensions; PyArrayObject *sample_set_sizes; - PyObject *callable; + PyObject *summary_func; + PyObject *norm_func; } two_locus_general_stat_params; static int -general_two_locus_count_stat_func( - tsk_size_t K, const double *X, tsk_size_t M, double *Y, void *params) +general_two_locus_norm_func(tsk_size_t result_dim, const double *X, tsk_size_t n_a, + tsk_size_t n_b, double *Y, void *params) { int ret = TSK_PYTHON_CALLBACK_ERROR; PyObject *arglist = NULL; PyObject *result = NULL; + PyArrayObject *n_a_scalar = NULL; + PyArrayObject *n_b_scalar = NULL; PyArrayObject *X_array = NULL; PyArrayObject *Y_array = NULL; two_locus_general_stat_params *tl_params = params; - PyObject *callable = tl_params->callable; + PyObject *summary_func = tl_params->norm_func; PyArrayObject *ss_sizes = tl_params->sample_set_sizes; - bool drop = (K == 1 && tl_params->drop_dimensions); - // Convert "n" to a column array -- reshape(-1, K) or a scalar if K=1 and drop=True - PyArray_Dims ss_sizes_dims = (drop ? (PyArray_Dims){ (npy_intp[1]){ 1 }, 0 } - : (PyArray_Dims){ (npy_intp[2]){ K, 1 }, 2 }); - int X_ndims = drop ? 1 : 2; - npy_intp *X_dims = drop ? (npy_intp[1]){ 3 } : (npy_intp[2]){ K, 3 }; - npy_intp *Y_dims; + npy_intp X_dims[2] = { result_dim, 3 }; // Create a read only view of X as a numpy array X_array = (PyArrayObject *) PyArray_SimpleNewFromData( - X_ndims, X_dims, NPY_FLOAT64, (void *) X); + 2, X_dims, NPY_FLOAT64, (void *) X); if (X_array == NULL) { goto out; } - ss_sizes = (PyArrayObject *) PyArray_Newshape(ss_sizes, &ss_sizes_dims, NPY_CORDER); - if (ss_sizes == NULL) { + PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); + // Transpose into column arrays, so that we can easily decompose the results + X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + if (X_array == NULL) { + goto out; + } + n_a_scalar + = (PyArrayObject *) PyArray_Scalar(&n_a, PyArray_DescrFromType(NPY_INT64), NULL); + if (n_a_scalar == NULL) { + goto out; + } + n_b_scalar + = (PyArrayObject *) PyArray_Scalar(&n_b, PyArray_DescrFromType(NPY_INT64), NULL); + if (n_b_scalar == NULL) { + goto out; + } + arglist = Py_BuildValue("OOOO", X_array, ss_sizes, n_a_scalar, n_b_scalar); + if (arglist == NULL) { + goto out; + } + result = PyObject_CallObject(summary_func, arglist); + if (result == NULL) { + goto out; + } + Y_array = (PyArrayObject *) PyArray_FromAny( + result, PyArray_DescrFromType(NPY_FLOAT64), 0, 0, NPY_ARRAY_IN_ARRAY, NULL); + if (Y_array == NULL) { + goto out; + } + if (PyArray_NDIM(Y_array) != 1) { + PyErr_Format(PyExc_ValueError, + "Array returned by norm function callback is %d dimensional; " + "must be 1D", + (int) PyArray_NDIM(Y_array)); + goto out; + } + if (PyArray_DIM(Y_array, 0) != (npy_intp) result_dim) { + PyErr_Format(PyExc_ValueError, + "Array returned by norm function callback is of length %d; must be %d", + PyArray_DIM(Y_array, 0), result_dim); + goto out; + } + /* Copy the contents of the return Y array into Y */ + memcpy(Y, PyArray_DATA(Y_array), result_dim * sizeof(*Y)); + ret = 0; +out: + Py_XDECREF(X_array); + Py_XDECREF(arglist); + Py_XDECREF(result); + Py_XDECREF(Y_array); + Py_XDECREF(n_a_scalar); + Py_XDECREF(n_b_scalar); + return ret; +} + +static int +general_two_locus_count_stat_func( + tsk_size_t K, const double *X, tsk_size_t result_dim, double *Y, void *params) +{ + int ret = TSK_PYTHON_CALLBACK_ERROR; + PyObject *arglist = NULL; + PyObject *result = NULL; + PyArrayObject *X_array = NULL; + PyArrayObject *Y_array = NULL; + two_locus_general_stat_params *tl_params = params; + PyObject *summary_func = tl_params->summary_func; + PyArrayObject *ss_sizes = tl_params->sample_set_sizes; + npy_intp X_dims[2] = { K, 3 }; + + // Create a read only view of X as a numpy array + X_array = (PyArrayObject *) PyArray_SimpleNewFromData( + 2, X_dims, NPY_FLOAT64, (void *) X); + if (X_array == NULL) { goto out; } PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); + // Transpose into column arrays, so that we can easily decompose the results + // For example: pAB, pAb, paB = X / n + // which works with K>1. In addition, the data is not reordered, meaning + // that the data is still oriented where samples are rows, meaning that + // we'll preserve data locality in ops over samples. + X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + if (X_array == NULL) { + goto out; + } arglist = Py_BuildValue("OO", X_array, ss_sizes); if (arglist == NULL) { goto out; } - result = PyObject_CallObject(callable, arglist); + result = PyObject_CallObject(summary_func, arglist); if (result == NULL) { goto out; } @@ -7998,28 +8074,25 @@ general_two_locus_count_stat_func( } if (PyArray_NDIM(Y_array) != 1) { PyErr_Format(PyExc_ValueError, - "Array returned by general_stat callback is %d dimensional; " + "Array returned by summary function callback is %d dimensional; " "must be 1D", (int) PyArray_NDIM(Y_array)); goto out; } - Y_dims = PyArray_DIMS(Y_array); - if (Y_dims[0] != (npy_intp) M) { + if (PyArray_DIM(Y_array, 0) != (npy_intp) result_dim) { PyErr_Format(PyExc_ValueError, - "Array returned by general_stat callback is of length %d; " - "must be %d", - Y_dims[0], M); + "Array returned by summary function callback is of length %d; must be %d", + PyArray_DIM(Y_array, 0), result_dim); goto out; } /* Copy the contents of the return Y array into Y */ - memcpy(Y, PyArray_DATA(Y_array), M * sizeof(*Y)); + memcpy(Y, PyArray_DATA(Y_array), result_dim * sizeof(*Y)); ret = 0; out: Py_XDECREF(X_array); Py_XDECREF(arglist); Py_XDECREF(result); Py_XDECREF(Y_array); - Py_XDECREF(ss_sizes); return ret; } @@ -8028,10 +8101,11 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * { PyObject *ret = NULL; static char *kwlist[] = { "sample_set_sizes", "sample_sets", "summary_func", - "output_dim", "polarised", "row_sites", "col_sites", "row_positions", - "column_positions", "mode", "drop_dimensions", NULL }; + "norm_func", "output_dim", "polarised", "row_sites", "col_sites", + "row_positions", "column_positions", "mode", NULL }; two_locus_general_stat_params *params; PyObject *summary_func = NULL; + PyObject *norm_func = NULL; unsigned int output_dim; PyObject *sample_set_sizes = NULL; PyObject *sample_sets = NULL; @@ -8054,25 +8128,29 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * npy_intp result_dim[3] = { 0, 0, 0 }; tsk_size_t num_sample_sets; tsk_flags_t options = 0; - int drop_dimensions = 0; int polarised = 0; int err; if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|si", kwlist, - &sample_set_sizes, &sample_sets, &summary_func, &output_dim, &polarised, - &row_sites, &col_sites, &row_positions, &col_positions, &mode, - &drop_dimensions)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOIiOOOO|s", kwlist, + &sample_set_sizes, &sample_sets, &summary_func, &norm_func, &output_dim, + &polarised, &row_sites, &col_sites, &row_positions, &col_positions, &mode)) { Py_XINCREF(summary_func); + Py_XINCREF(norm_func); goto out; } Py_INCREF(summary_func); + Py_INCREF(norm_func); if (!PyCallable_Check(summary_func)) { PyErr_SetString(PyExc_TypeError, "summary_func must be callable"); goto out; } + if (!PyCallable_Check(norm_func)) { + PyErr_SetString(PyExc_TypeError, "norm_func must be callable"); + goto out; + } if (parse_stats_mode(mode, &options) != 0) { goto out; } @@ -8113,7 +8191,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * col_positions_parsed = PyArray_DATA(col_positions_array); } - result_dim[2] = num_sample_sets; + result_dim[2] = output_dim; result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); if (result_matrix == NULL) { PyErr_NoMemory(); @@ -8122,15 +8200,16 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * params = &(two_locus_general_stat_params){ .sample_set_sizes = sample_set_sizes_array, - .callable = summary_func, - .drop_dimensions = drop_dimensions, + .summary_func = summary_func, + .norm_func = norm_func, }; // TODO: deal with null norm func, need general stat. err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), - output_dim, general_two_locus_count_stat_func, params, NULL, result_dim[0], - row_sites_parsed, row_positions_parsed, result_dim[1], col_sites_parsed, - col_positions_parsed, options, PyArray_DATA(result_matrix)); + output_dim, general_two_locus_count_stat_func, params, + general_two_locus_norm_func, result_dim[0], row_sites_parsed, + row_positions_parsed, result_dim[1], col_sites_parsed, col_positions_parsed, + options, PyArray_DATA(result_matrix)); if (err == TSK_PYTHON_CALLBACK_ERROR) { goto out; diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 4d6e47ddcc..e784a9628b 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -22,7 +22,6 @@ """ Test cases for two-locus statistics """ - import contextlib import io from collections.abc import Callable, Generator @@ -2398,3 +2397,257 @@ def test_multipopulation_r2_varying_unequal_set_sizes(genotypes, sample_sets, ex norm_hap_weighted_ij(1, state, max(a) + 1, max(b) + 1, norm[i, j], params) np.testing.assert_allclose((result * norm).sum(), expected) + + +class GeneralStatFuncs: + """ + functions take X, n as parameters where + + X: shape=(3, #ss) + sample sets + count AB [[ ] + count Ab [ ] + count aB [ ]] + + n: shape=(#ss, ) + [ ] + """ + + @staticmethod + def D(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + @staticmethod + def D2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return (pAB - (pA * pB)) ** 2 + + @staticmethod + def r2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = pA * pB * (1 - pA) * (1 - pB) + with suppress_overflow_div0_warning(): + return D**2 / denom + + @staticmethod + def r(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = pA * pB * (1 - pA) * (1 - pB) + with suppress_overflow_div0_warning(): + return D / np.sqrt(denom) + + @staticmethod + def D_prime(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = np.vstack( + [ + np.min([pA * (1 - pB), (1 - pA) * pB], axis=0), + np.min([pA * pB, (1 - pA) * (1 - pB)], axis=0), + ] + ) + with suppress_overflow_div0_warning(): + return D / denom[(D < 0).astype(int), range(len(D))] + + @staticmethod + def Dz(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + return D * (1 - 2 * pA) * (1 - 2 * pB) + + @staticmethod + def pi2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pA * (1 - pA) * pB * (1 - pB) + + @staticmethod + def D2_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((aB**2) * (Ab - 1) * Ab) + + ((ab - 1) * ab * (AB - 1) * AB) + - (aB * Ab * (Ab + (2 * ab * AB) - 1)) + ) + + @staticmethod + def Dz_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + (((AB * ab) - (Ab * aB)) * (aB + ab - AB - Ab) * (Ab + ab - AB - aB)) + - ((AB * ab) * (AB + ab - Ab - aB - 2)) + - ((Ab * aB) * (Ab + aB - AB - ab - 2)) + ) + + @staticmethod + def pi2_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((AB + Ab) * (aB + ab) * (AB + aB) * (Ab + ab)) + - ((AB * ab) * (AB + ab + (3 * Ab) + (3 * aB) - 1)) + - ((Ab * aB) * (Ab + aB + (3 * AB) + (3 * ab) - 1)) + ) + + @staticmethod + def r2_ij(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = np.prod(pAB - (pA * pB)) + denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB))) + with suppress_overflow_div0_warning(): + return np.expand_dims(D / denom, axis=0) + + @staticmethod + def D2_ij(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + return np.expand_dims(np.prod(D), axis=0) + + @staticmethod + def D2_ij_unbiased(X, n): + """ + NB: the two sample sets must be disjoint + we have no way for testing equality + """ + AB, Ab, aB = X + ab = n - X.sum(0) + return np.expand_dims( + (Ab[0] * aB[0] - AB[0] * ab[0]) + * (Ab[1] * aB[1] - AB[1] * ab[1]) + / n[0] + / (n[0] - 1) + / n[1] + / (n[1] - 1), + axis=0, + ) + + +@pytest.mark.parametrize( + "ts,stat", + [ + ( + ts := [ + p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" + ][0].values[0], + "D", + ), + (ts, "D2"), + (ts, "r2"), + (ts, "r"), + (ts, "D_prime"), + (ts, "Dz"), + (ts, "pi2"), + (ts, "D2_unbiased"), + (ts, "Dz_unbiased"), + (ts, "pi2_unbiased"), + ], +) +def test_general_two_locus_site_stat(ts, stat): + sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) + ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) + np.testing.assert_equal(ldg, ld) + + +@pytest.mark.parametrize( + "ts,stat", + [ + ( + ts := [ + p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" + ][0].values[0], + "r2_ij", + ), + (ts, "D2_ij"), + (ts, "D2_ij_unbiased"), + ], +) +def test_general_two_locus_two_way_site_stat(ts, stat): + sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1) + ld = ts.ld_matrix( + sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=(0, 1) + ) + np.testing.assert_allclose(ldg, ld) + + +@pytest.mark.parametrize( + "stat", + [ + "D", + "D2", + "r2", + "r", + "D_prime", + "Dz", + "pi2", + "D2_unbiased", + "Dz_unbiased", + "pi2_unbiased", + ], +) +def test_general_one_way_two_locus_stat_multiallelic(stat): + (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + func = getattr(GeneralStatFuncs, stat) + if stat == "r2": + result = ts.two_locus_count_stat( + [ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n + ) + elif stat in {"D", "r", "D_prime"}: + result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) + else: + # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + result = ts.two_locus_count_stat([ts.samples()], func, 1) + np.testing.assert_allclose(ts.ld_matrix(stat=stat), result) + + +@pytest.mark.parametrize( + "stat", + [ + "r2_ij", + "D2_ij", + "D2_ij_unbiased", + ], +) +def test_general_two_way_two_locus_stat_multiallelic(stat): + (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + func = getattr(GeneralStatFuncs, stat) + if stat == "r2_ij": + result = ts.two_locus_count_stat( + [ts.samples(), ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n + ) + elif stat in {"D", "r", "D_prime"}: + result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) + else: + # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) + np.testing.assert_allclose( + ts.ld_matrix( + stat=stat.replace("_ij", ""), + indexes=(0, 1), + sample_sets=[ts.samples(), ts.samples()], + ), + result, + ) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 655281ef6d..f380324cce 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -696,7 +696,8 @@ def __init__( options = 0 if sample_counts is not None: warnings.warn( - "The sample_counts option is not supported since 0.2.4 and is ignored", + "The sample_counts option is not supported since 0.2.4 " + "and is ignored", RuntimeWarning, stacklevel=4, ) @@ -6945,7 +6946,7 @@ def to_macs(self): bytes_genotypes[:] = lookup[variant.genotypes] genotypes = bytes_genotypes.tobytes().decode() output.append( - f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t{genotypes}" + f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t" f"{genotypes}" ) return "\n".join(output) + "\n" @@ -9391,9 +9392,9 @@ def pca( if time_windows is None: tree_sequence_low, tree_sequence_high = None, self else: - assert time_windows[0] < time_windows[1], ( - "The second argument should be larger." - ) + assert ( + time_windows[0] < time_windows[1] + ), "The second argument should be larger." tree_sequence_low, tree_sequence_high = ( self.decapitate(time_windows[0]), self.decapitate(time_windows[1]), @@ -10936,21 +10937,20 @@ def two_locus_count_stat( sample_sets, f, result_dim, + norm_f=lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0), polarised=False, sites=None, positions=None, mode="site", - drop_dimensions=True, ): row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) - drop_dimension, flattened, sample_set_sizes = self.__convert_sample_sets( - sample_sets - ) + _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) result = self._ll_tree_sequence.two_locus_count_stat( sample_set_sizes, - flattened, + sample_sets, f, + norm_f, result_dim, polarised, row_sites, @@ -10958,15 +10958,12 @@ def two_locus_count_stat( row_positions, col_positions, mode, - drop_dimensions, ) - if drop_dimension: - result = result.reshape(result.shape[:2]) - else: - # Orient the data so that the first dimension is the sample set. - # With this orientation, we get one LD matrix per sample set. - result = result.swapaxes(0, 2).swapaxes(1, 2) - return result + if result_dim == 1: # drop dimension + return result.reshape(result.shape[:2]) + # Orient the data so that the first dimension is the sample set so that + # we get one LD matrix per sample set. + return result.swapaxes(0, 2).swapaxes(1, 2) def ld_matrix( self, From 4b948a0d8dcb810ee58493c314d4687902aee7ae Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:16:12 -0600 Subject: [PATCH 14/51] turns out, the general norm function needs to know the state_dims --- c/tskit/trees.c | 17 ++++++++++------- c/tskit/trees.h | 4 ++-- python/_tskitmodule.c | 6 +++--- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index a8f6e168f9..cccf56a8be 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2298,8 +2298,9 @@ get_allele_samples(const tsk_site_t *site, tsk_size_t site_offset, } static int -norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, - tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +norm_hap_weighted(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights, + tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), + double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; const double *weight_row; @@ -2315,8 +2316,9 @@ norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, } static int -norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, - tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +norm_hap_weighted_ij(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights, + tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), + double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; const double *weight_row; @@ -2341,8 +2343,9 @@ norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, } static int -norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights), - tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) +norm_total_weighted(tsk_size_t TSK_UNUSED(state_dim), + const double *TSK_UNUSED(hap_weights), tsk_size_t result_dim, tsk_size_t n_a, + tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) { tsk_size_t k; double norm = 1 / (double) (n_a * n_b); @@ -2445,7 +2448,7 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, if (ret != 0) { goto out; } - ret = norm_f(result_dim, weights, num_a_alleles - is_polarised, + ret = norm_f(state_dim, weights, result_dim, num_a_alleles - is_polarised, num_b_alleles - is_polarised, norm, f_params); if (ret != 0) { goto out; diff --git a/c/tskit/trees.h b/c/tskit/trees.h index acc15c9aac..2bf1a26cc9 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1036,8 +1036,8 @@ int tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t K, const doub tsk_size_t M, general_stat_func_t *f, void *f_params, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); -typedef int norm_func_t(tsk_size_t result_dim, const double *hap_weights, tsk_size_t n_a, - tsk_size_t n_b, double *result, void *params); +typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights, + tsk_size_t result_dim, tsk_size_t n_a, tsk_size_t n_b, double *result, void *params); int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index c6bca34ca9..3ad4229186 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7953,8 +7953,8 @@ typedef struct { } two_locus_general_stat_params; static int -general_two_locus_norm_func(tsk_size_t result_dim, const double *X, tsk_size_t n_a, - tsk_size_t n_b, double *Y, void *params) +general_two_locus_norm_func(tsk_size_t K, const double *X, tsk_size_t result_dim, + tsk_size_t n_a, tsk_size_t n_b, double *Y, void *params) { int ret = TSK_PYTHON_CALLBACK_ERROR; PyObject *arglist = NULL; @@ -7966,7 +7966,7 @@ general_two_locus_norm_func(tsk_size_t result_dim, const double *X, tsk_size_t n two_locus_general_stat_params *tl_params = params; PyObject *summary_func = tl_params->norm_func; PyArrayObject *ss_sizes = tl_params->sample_set_sizes; - npy_intp X_dims[2] = { result_dim, 3 }; + npy_intp X_dims[2] = { K, 3 }; // Create a read only view of X as a numpy array X_array = (PyArrayObject *) PyArray_SimpleNewFromData( From 7e8264aacb92ef2f9cef4f6c353aec8159f1d77f Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:18:40 -0600 Subject: [PATCH 15/51] fix up a bit of naming in general test funcs, remove unneeded branch, fix norm func for r2_ij --- python/tests/test_ld_matrix.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index e784a9628b..6ba04ceb41 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2512,18 +2512,17 @@ def r2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB - D = np.prod(pAB - (pA * pB)) + D2_ij = np.prod(pAB - (pA * pB)) denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB))) with suppress_overflow_div0_warning(): - return np.expand_dims(D / denom, axis=0) + return np.expand_dims(D2_ij / denom, axis=0) @staticmethod def D2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB - D = pAB - (pA * pB) - return np.expand_dims(np.prod(D), axis=0) + return np.expand_dims(np.prod(pAB - (pA * pB)), axis=0) @staticmethod def D2_ij_unbiased(X, n): @@ -2635,11 +2634,8 @@ def test_general_two_way_two_locus_stat_multiallelic(stat): (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": - result = ts.two_locus_count_stat( - [ts.samples(), ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n - ) - elif stat in {"D", "r", "D_prime"}: - result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) + norm_f = lambda X, n, nA, nB: np.expand_dims(X[0].sum() / n.sum(), axis=0) + result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) From 7ba6f745ce0dc1a43f720142b24bf8709fbf20d3 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:21:12 -0600 Subject: [PATCH 16/51] flake8 does not like assigning lambdas to variables --- python/tests/test_ld_matrix.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 6ba04ceb41..e230bce601 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2634,7 +2634,8 @@ def test_general_two_way_two_locus_stat_multiallelic(stat): (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": - norm_f = lambda X, n, nA, nB: np.expand_dims(X[0].sum() / n.sum(), axis=0) + def norm_f(X, n, nA, nB): + return np.expand_dims(X[0].sum() / n.sum(), axis=0) result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) From 25d46330c2e59ad77bdb2fa7ce6196be90efb6cc Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:23:53 -0600 Subject: [PATCH 17/51] and black doesn't like that --- python/tests/test_ld_matrix.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index e230bce601..524287e9be 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2634,8 +2634,10 @@ def test_general_two_way_two_locus_stat_multiallelic(stat): (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": + def norm_f(X, n, nA, nB): return np.expand_dims(X[0].sum() / n.sum(), axis=0) + result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) From 49a42789df5be1d5ea1133fac75c89862e0deba1 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sat, 6 Dec 2025 18:29:12 -0600 Subject: [PATCH 18/51] do not test equality, this was useful on my local machine but is problematic in practice --- python/tests/test_ld_matrix.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 524287e9be..b1b2cd29b6 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -22,6 +22,7 @@ """ Test cases for two-locus statistics """ + import contextlib import io from collections.abc import Callable, Generator @@ -2567,7 +2568,7 @@ def test_general_two_locus_site_stat(ts, stat): sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) - np.testing.assert_equal(ldg, ld) + np.testing.assert_allclose(ldg, ld) @pytest.mark.parametrize( From c75e20aa04381061d84019d0b1b50a4f385391e4 Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 9 Mar 2026 16:33:19 -0500 Subject: [PATCH 19/51] lowlevel tests --- python/tests/test_python_c.py | 165 ++++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 15f9967f3f..80861f83d5 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -1987,6 +1987,171 @@ def test_ld_matrix_multipop(self, stat_method_name): with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE"): stat_method(ss_sizes, ss, indexes, col_sites, row_sites, None, None, "node") + def test_two_locus_count_stat(self): + ts = self.get_example_tree_sequence(10) + ss = ts.get_samples() # sample sets + ss_sizes = np.array([len(ss)], dtype=np.uint32) + row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) + col_sites = row_sites + row_pos = ts.get_breakpoints()[:-1] + col_pos = row_pos + row_sites_list = list(range(ts.get_num_sites())) + col_sites_list = row_sites_list + row_pos_list = list(map(float, ts.get_breakpoints()[:-1])) + col_pos_list = row_pos_list + + def stat_func(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + def norm_func(X, n, nA, nB): + return np.expand_dims(X[0].sum() / n.sum(), axis=0) + + method = ts.two_locus_count_stat + + site_args = row_sites, col_sites, None, None, "site" + branch_args = None, None, row_pos, col_pos, "branch" + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args) + assert a.shape == (10, 10, 1) + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_args) + assert a.shape == (2, 2, 1) + site_list_args = row_sites_list, col_sites_list, None, None, "site" + branch_list_args = None, None, row_pos_list, col_pos_list, "branch" + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_list_args) + assert a.shape == (10, 10, 1) + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args) + assert a.shape == (2, 2, 1) + # CPython API errors + with pytest.raises(ValueError, match="Sum of sample_set_sizes"): + bad_ss = np.array([], dtype=np.int32) + method(ss_sizes, bad_ss, stat_func, norm_func, 1, True, *site_args) + with pytest.raises(TypeError, match="cast array data"): + bad_ss = np.array(ts.get_samples(), dtype=np.uint32) + method(ss_sizes, bad_ss, stat_func, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="Unrecognised stats mode"): + bad_args = row_sites, col_sites, None, None, "bla" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_args) + with pytest.raises(TypeError, match="at most"): + method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args, "extraarg") + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0.1, 0.2, 2.0] + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError): + bad_pos = [{}, 0.1, 0.2, 2.0] + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0, 3, 2] + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError): + bad_pos = [{}, 0, 3, 2] + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(ValueError, match="Cannot specify positions in site mode"): + bad_site_args = None, None, row_pos, col_pos, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError, match="Cannot specify sites in branch mode"): + bad_branch_args = row_sites, col_sites, None, None, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError, match="summary_func must be callable"): + method(ss_sizes, ss, "uncallable", norm_func, 1, True, *site_args) + with pytest.raises(TypeError, match="norm_func must be callable"): + method(ss_sizes, ss, stat_func, "uncallable", 1, True, *site_args) + with pytest.raises(ValueError, match="summary function.*must be 1D"): + method(ss_sizes, ss, lambda a, b: 1, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="length 2; must be 1"): + method(ss_sizes, ss, lambda a, b: [1, 2], norm_func, 1, True, *site_args) + # TODO: Cannot test without multiallelic sites + # with pytest.raises(ValueError, match="summary function.*must be 1D"): + # method(ss_sizes, ss, stat_func, lambda a, b, c, d: 1, 1, True, *site_args) + # with pytest.raises(ValueError, match="length 2; must be 1"): + # method(ss_sizes, ss, stat_func, lambda a, b, c, d: [1, 2], 1, True, *site_args) + # C API errors + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS"): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS"): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError): From 58f36476ac3203f026169ca25676dc82ea930920 Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 9 Mar 2026 16:47:41 -0500 Subject: [PATCH 20/51] relax diff requirements (macos failure) --- python/tests/test_ld_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index b1b2cd29b6..28c4ad3ff5 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2568,7 +2568,7 @@ def test_general_two_locus_site_stat(ts, stat): sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) - np.testing.assert_allclose(ldg, ld) + np.testing.assert_array_almost_equal(ldg, ld) @pytest.mark.parametrize( From 3f5bbbb3235c427fab289538266843e63de1ef3a Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 9 Mar 2026 17:08:05 -0500 Subject: [PATCH 21/51] relax diff requirements (macos failure) -- previous commit fixed one --- python/tests/test_ld_matrix.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 28c4ad3ff5..953f542f30 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2590,7 +2590,7 @@ def test_general_two_locus_two_way_site_stat(ts, stat): ld = ts.ld_matrix( sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=(0, 1) ) - np.testing.assert_allclose(ldg, ld) + np.testing.assert_array_almost_equal(ldg, ld) @pytest.mark.parametrize( @@ -2620,7 +2620,7 @@ def test_general_one_way_two_locus_stat_multiallelic(stat): else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) result = ts.two_locus_count_stat([ts.samples()], func, 1) - np.testing.assert_allclose(ts.ld_matrix(stat=stat), result) + np.testing.assert_array_almost_equal(ts.ld_matrix(stat=stat), result) @pytest.mark.parametrize( @@ -2643,7 +2643,7 @@ def norm_f(X, n, nA, nB): else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) - np.testing.assert_allclose( + np.testing.assert_array_almost_equal( ts.ld_matrix( stat=stat.replace("_ij", ""), indexes=(0, 1), From c89341d09338b7bbec64421a9b43a412439b4ad0 Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 9 Mar 2026 17:09:14 -0500 Subject: [PATCH 22/51] new formatting tools, fix lint --- python/_tskitmodule.c | 2 +- python/tests/test_python_c.py | 4 +++- python/tskit/trees.py | 11 +++++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 3ad4229186..09f29e6de3 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -8198,7 +8198,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * goto out; } - params = &(two_locus_general_stat_params){ + params = &(two_locus_general_stat_params) { .sample_set_sizes = sample_set_sizes_array, .summary_func = summary_func, .norm_func = norm_func, diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 80861f83d5..7181964509 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2101,7 +2101,9 @@ def norm_func(X, n, nA, nB): # with pytest.raises(ValueError, match="summary function.*must be 1D"): # method(ss_sizes, ss, stat_func, lambda a, b, c, d: 1, 1, True, *site_args) # with pytest.raises(ValueError, match="length 2; must be 1"): - # method(ss_sizes, ss, stat_func, lambda a, b, c, d: [1, 2], 1, True, *site_args) + # method( + # ss_sizes, ss, stat_func, lambda a, b, c, d: [1, 2], 1, True, *site_args + # ) # C API errors with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): bad_sites = np.array([1, 0, 2], dtype=np.int32) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index f380324cce..7e9f20df5b 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -696,8 +696,7 @@ def __init__( options = 0 if sample_counts is not None: warnings.warn( - "The sample_counts option is not supported since 0.2.4 " - "and is ignored", + "The sample_counts option is not supported since 0.2.4 and is ignored", RuntimeWarning, stacklevel=4, ) @@ -6946,7 +6945,7 @@ def to_macs(self): bytes_genotypes[:] = lookup[variant.genotypes] genotypes = bytes_genotypes.tobytes().decode() output.append( - f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t" f"{genotypes}" + f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t{genotypes}" ) return "\n".join(output) + "\n" @@ -9392,9 +9391,9 @@ def pca( if time_windows is None: tree_sequence_low, tree_sequence_high = None, self else: - assert ( - time_windows[0] < time_windows[1] - ), "The second argument should be larger." + assert time_windows[0] < time_windows[1], ( + "The second argument should be larger." + ) tree_sequence_low, tree_sequence_high = ( self.decapitate(time_windows[0]), self.decapitate(time_windows[1]), From 23eed87cedc89a3828e30139e2eb94c9a217dc3c Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 10 Mar 2026 10:16:26 -0500 Subject: [PATCH 23/51] remove TODOs, old comment and tested elsewhere --- python/_tskitmodule.c | 1 - python/tests/test_python_c.py | 7 ------- 2 files changed, 8 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 09f29e6de3..c772270de6 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -8203,7 +8203,6 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * .summary_func = summary_func, .norm_func = norm_func, }; - // TODO: deal with null norm func, need general stat. err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), output_dim, general_two_locus_count_stat_func, params, diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 7181964509..625d8f9bcb 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2097,13 +2097,6 @@ def norm_func(X, n, nA, nB): method(ss_sizes, ss, lambda a, b: 1, norm_func, 1, True, *site_args) with pytest.raises(ValueError, match="length 2; must be 1"): method(ss_sizes, ss, lambda a, b: [1, 2], norm_func, 1, True, *site_args) - # TODO: Cannot test without multiallelic sites - # with pytest.raises(ValueError, match="summary function.*must be 1D"): - # method(ss_sizes, ss, stat_func, lambda a, b, c, d: 1, 1, True, *site_args) - # with pytest.raises(ValueError, match="length 2; must be 1"): - # method( - # ss_sizes, ss, stat_func, lambda a, b, c, d: [1, 2], 1, True, *site_args - # ) # C API errors with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): bad_sites = np.array([1, 0, 2], dtype=np.int32) From b412eceeabd5da02d278d24e1d06a21f9c2eddc7 Mon Sep 17 00:00:00 2001 From: peter Date: Sun, 15 Mar 2026 08:05:10 -0700 Subject: [PATCH 24/51] make testing more clear --- python/tests/test_ld_matrix.py | 84 ++++++++++++++++++---------------- python/tests/tsutil.py | 42 +++++++++++------ 2 files changed, 73 insertions(+), 53 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 953f542f30..2160a360c0 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2063,38 +2063,36 @@ def compute_branch_stat( ts for ts in get_example_tree_sequences() if ts.id - not in { - "no_samples", - "empty_ts", - # We must skip these cases so that tests run in a reasonable - # amount of time. To get more complete testing, these filters - # can be commented out. (runtime ~1hr) - "gap_0", - "gap_0.1", - "gap_0.5", - "gap_0.75", - "n=2_m=32_rho=0", - "n=10_m=1_rho=0", - "n=10_m=1_rho=0.1", - "n=10_m=2_rho=0", - "n=10_m=2_rho=0.1", - "n=10_m=32_rho=0", - "n=10_m=32_rho=0.1", - "n=10_m=32_rho=0.5", + in { + # We run only these cases so that tests run in a reasonable + # amount of time. All examples takes ~1hr. + "decapitate_recomb", + "gap_at_end", + "all_nodes_samples", + "internal_nodes_samples", + "mixed_internal_leaf_samples", + "bottleneck_n=3_mutated", + "bottleneck_n=10_mutated", + "rev_node_order", + "empty_tree", + "n=3_m=2_rho=0.5", + "n=3_m=32_rho=0", + "n=3_m=32_rho=0.1", + "n=2_m=1_rho=0", + "n=2_m=1_rho=0.1", + "n=2_m=1_rho=0.5", + "n=2_m=2_rho=0", + "n=2_m=2_rho=0.1", + "n=2_m=2_rho=0.5", + "n=2_m=32_rho=0.1", + "n=2_m=32_rho=0.5", + "n=3_m=1_rho=0", + "n=3_m=1_rho=0.5", + "n=3_m=2_rho=0", + "n=10_m=1_rho=0.5", + "n=10_m=2_rho=0.5", # we keep one n=100 case to ensure bit arrays are working - "n=100_m=1_rho=0.1", - "n=100_m=1_rho=0.5", - "n=100_m=2_rho=0", - "n=100_m=2_rho=0.1", - "n=100_m=2_rho=0.5", - "n=100_m=32_rho=0", - "n=100_m=32_rho=0.1", - "n=100_m=32_rho=0.5", - "all_fields", - "back_mutations", - "multichar", - "multichar_no_metadata", - "bottleneck_n=100_mutated", + "n=100_m=1_rho=0", } ], ) @@ -2548,9 +2546,13 @@ def D2_ij_unbiased(X, n): "ts,stat", [ ( - ts := [ - p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" - ][0].values[0], + ts := tsutil.get_sim_example( + sample_size=100, + sequence_length=32, + recombination_rate=0.5, + mutation_rate=0.1, + seed=123, + ), "D", ), (ts, "D2"), @@ -2575,9 +2577,13 @@ def test_general_two_locus_site_stat(ts, stat): "ts,stat", [ ( - ts := [ - p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" - ][0].values[0], + ts := tsutil.get_sim_example( + sample_size=100, + sequence_length=32, + recombination_rate=0.5, + mutation_rate=0.1, + seed=123, + ), "r2_ij", ), (ts, "D2_ij"), @@ -2609,7 +2615,7 @@ def test_general_two_locus_two_way_site_stat(ts, stat): ], ) def test_general_one_way_two_locus_stat_multiallelic(stat): - (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + ts = tsutil.all_fields_ts() func = getattr(GeneralStatFuncs, stat) if stat == "r2": result = ts.two_locus_count_stat( @@ -2632,7 +2638,7 @@ def test_general_one_way_two_locus_stat_multiallelic(stat): ], ) def test_general_two_way_two_locus_stat_multiallelic(stat): - (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + ts = tsutil.all_fields_ts() func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 0037e06391..48ff72d0d7 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -2463,6 +2463,28 @@ def get_back_mutation_examples(): yield insert_branch_mutations(ts) +@functools.lru_cache +def get_sim_example( + sample_size, sequence_length, recombination_rate, mutation_rate, seed +): + recomb_map = msprime.RecombinationMap.uniform_map( + sequence_length, recombination_rate + ) + ts = msprime.simulate( + recombination_map=recomb_map, + mutation_rate=mutation_rate, + random_seed=seed, + population_configurations=[ + msprime.PopulationConfiguration(sample_size), + msprime.PopulationConfiguration(0), + ], + migration_matrix=[[0, 1], [1, 0]], + ) + ts = insert_random_ploidy_individuals(ts, 4, seed=seed) + ts = add_random_metadata(ts, seed=seed) + return ts + + def make_example_tree_sequences(custom_max=None): yield from get_decapitated_examples(custom_max=custom_max) yield from get_gap_examples(custom_max=custom_max) @@ -2475,22 +2497,14 @@ def make_example_tree_sequences(custom_max=None): for n in n_list: for m in [1, 2, 32]: for rho in [0, 0.1, 0.5]: - recomb_map = msprime.RecombinationMap.uniform_map(m, rho, num_loci=m) - ts = msprime.simulate( - recombination_map=recomb_map, + ts = get_sim_example( + sample_size=n, + sequence_length=m, + recombination_rate=rho, mutation_rate=0.1, - random_seed=seed, - population_configurations=[ - msprime.PopulationConfiguration(n), - msprime.PopulationConfiguration(0), - ], - migration_matrix=[[0, 1], [1, 0]], - ) - ts = insert_random_ploidy_individuals(ts, 4, seed=seed) - yield ( - f"n={n}_m={m}_rho={rho}", - add_random_metadata(ts, seed=seed), + seed=seed, ) + yield (f"n={n}_m={m}_rho={rho}", ts) seed += 1 for name, ts in get_bottleneck_examples(custom_max=custom_max): yield ( From 3bc699df8801603fdd2f0a85062155002719d3a4 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 15 Mar 2026 21:45:28 -0500 Subject: [PATCH 25/51] preserve native dimensions instead of expanding at the end --- python/tests/test_ld_matrix.py | 43 ++++++++++++++++------------------ python/tskit/trees.py | 2 +- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 2160a360c0..e7133af0b9 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2511,34 +2511,30 @@ def r2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB - D2_ij = np.prod(pAB - (pA * pB)) - denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB))) + D2_ij = np.prod(pAB - (pA * pB), keepdims=True) + denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB)), keepdims=True) with suppress_overflow_div0_warning(): - return np.expand_dims(D2_ij / denom, axis=0) + return D2_ij / denom @staticmethod def D2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB - return np.expand_dims(np.prod(pAB - (pA * pB)), axis=0) + return np.prod(pAB - (pA * pB), keepdims=True) @staticmethod def D2_ij_unbiased(X, n): - """ - NB: the two sample sets must be disjoint - we have no way for testing equality - """ + """NB: We use double brackets here to preserve the output shape of (1,)""" AB, Ab, aB = X ab = n - X.sum(0) - return np.expand_dims( - (Ab[0] * aB[0] - AB[0] * ab[0]) - * (Ab[1] * aB[1] - AB[1] * ab[1]) - / n[0] - / (n[0] - 1) - / n[1] - / (n[1] - 1), - axis=0, + return ( + (Ab[[0]] * aB[[0]] - AB[[0]] * ab[[0]]) + * (Ab[[1]] * aB[[1]] - AB[[1]] * ab[[1]]) + / n[[0]] + / (n[[0]] - 1) + / n[[1]] + / (n[[1]] - 1) ) @@ -2624,7 +2620,7 @@ def test_general_one_way_two_locus_stat_multiallelic(stat): elif stat in {"D", "r", "D_prime"}: result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) else: - # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + # default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]` result = ts.two_locus_count_stat([ts.samples()], func, 1) np.testing.assert_array_almost_equal(ts.ld_matrix(stat=stat), result) @@ -2641,13 +2637,14 @@ def test_general_two_way_two_locus_stat_multiallelic(stat): ts = tsutil.all_fields_ts() func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": - - def norm_f(X, n, nA, nB): - return np.expand_dims(X[0].sum() / n.sum(), axis=0) - - result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) + result = ts.two_locus_count_stat( + [ts.samples(), ts.samples()], + func, + 1, + lambda X, n, nA, nB: X[0].sum(keepdims=True) / n.sum(), + ) else: - # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + # default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]` result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) np.testing.assert_array_almost_equal( ts.ld_matrix( diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 7e9f20df5b..09b6224c55 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10936,7 +10936,7 @@ def two_locus_count_stat( sample_sets, f, result_dim, - norm_f=lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0), + norm_f=lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,], polarised=False, sites=None, positions=None, From f172c3ed0d5b0f7ba51b39397dfbc27a366a786d Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 16 Mar 2026 17:33:00 -0500 Subject: [PATCH 26/51] Update tests according to Peters's feedback *Python tests* Overhaul python testing of the general stat functions. Remove the dependence on the example tree sequences, opting instead to simulate a couple of examples directly. Use these simulated trees in test fixtures, scoped at the module level. This streamlines the test parameterization a lot. Use the single stat site names from the summary function definitions. *CPython tests* Add a multiallelic tree sequence to test normalisation function validation and errors. Remove one more occurrence of `np.expand_dims`. *trees.c* Remove the unnecessary branch in tsk_treeseq_two_locus_count_general_stat, improving the code coverage. *trees.py* Default normalisation function can be None, applying default at runtime. Simplifies calling code and is more in line with the rest of the API. --- c/tskit/trees.c | 29 +++--- python/tests/test_ld_matrix.py | 163 +++++++++++++-------------------- python/tests/test_python_c.py | 93 +++++++++++++++++-- python/tskit/trees.py | 4 +- 4 files changed, 165 insertions(+), 124 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index cccf56a8be..0f8a10c182 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3439,21 +3439,22 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets, sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); - } else if (stat_branch) { - ret = check_positions( - row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); - if (ret != 0) { - goto out; - } - ret = check_positions( - col_positions, out_cols, tsk_treeseq_get_sequence_length(self)); - if (ret != 0) { - goto out; - } - ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, - row_positions, out_cols, col_positions, options, result); + goto out; + } + tsk_bug_assert(stat_branch); + ret = check_positions( + row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = check_positions( + col_positions, out_cols, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; } + ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, + row_positions, out_cols, col_positions, options, result); out: return ret; } diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index e7133af0b9..bb459ee7ca 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2538,55 +2538,49 @@ def D2_ij_unbiased(X, n): ) -@pytest.mark.parametrize( - "ts,stat", - [ - ( - ts := tsutil.get_sim_example( - sample_size=100, - sequence_length=32, - recombination_rate=0.5, - mutation_rate=0.1, - seed=123, - ), - "D", +@pytest.fixture(scope="module") +def ts_100_samp_with_sites_fixture(): + ts = tsutil.get_sim_example( + sample_size=100, + sequence_length=32, + recombination_rate=0.5, + mutation_rate=0.1, + seed=123, + ) + assert ts.num_sites > 0, "sites are required" + assert ts.num_samples == 100, "100 samples are required" + return ts + + +@pytest.fixture(scope="module") +def ts_multiallelic_fixture(): + ts = msprime.sim_mutations( + msprime.sim_ancestry( + 2, recombination_rate=0.1, sequence_length=100, random_seed=123 ), - (ts, "D2"), - (ts, "r2"), - (ts, "r"), - (ts, "D_prime"), - (ts, "Dz"), - (ts, "pi2"), - (ts, "D2_unbiased"), - (ts, "Dz_unbiased"), - (ts, "pi2_unbiased"), - ], -) -def test_general_two_locus_site_stat(ts, stat): + rate=0.1, + random_seed=123, + ) + # Need at least 4 samples to test unbiased statistics + assert ts.num_samples >= 4, "At least 4 samples required" + assert max({len(s.mutations) for s in ts.sites()}) > 2, ( + "At least one multiallelic site required" + ) + return ts + + +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture): + ts = ts_100_samp_with_sites_fixture sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) np.testing.assert_array_almost_equal(ldg, ld) -@pytest.mark.parametrize( - "ts,stat", - [ - ( - ts := tsutil.get_sim_example( - sample_size=100, - sequence_length=32, - recombination_rate=0.5, - mutation_rate=0.1, - seed=123, - ), - "r2_ij", - ), - (ts, "D2_ij"), - (ts, "D2_ij_unbiased"), - ], -) -def test_general_two_locus_two_way_site_stat(ts, stat): +@pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) +def test_general_two_locus_two_way_site_stat(stat, ts_100_samp_with_sites_fixture): + ts = ts_100_samp_with_sites_fixture sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1) ld = ts.ld_matrix( @@ -2595,62 +2589,31 @@ def test_general_two_locus_two_way_site_stat(ts, stat): np.testing.assert_array_almost_equal(ldg, ld) -@pytest.mark.parametrize( - "stat", - [ - "D", - "D2", - "r2", - "r", - "D_prime", - "Dz", - "pi2", - "D2_unbiased", - "Dz_unbiased", - "pi2_unbiased", - ], -) -def test_general_one_way_two_locus_stat_multiallelic(stat): - ts = tsutil.all_fields_ts() - func = getattr(GeneralStatFuncs, stat) - if stat == "r2": - result = ts.two_locus_count_stat( - [ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n - ) - elif stat in {"D", "r", "D_prime"}: - result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) - else: - # default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]` - result = ts.two_locus_count_stat([ts.samples()], func, 1) - np.testing.assert_array_almost_equal(ts.ld_matrix(stat=stat), result) - - -@pytest.mark.parametrize( - "stat", - [ - "r2_ij", - "D2_ij", - "D2_ij_unbiased", - ], -) -def test_general_two_way_two_locus_stat_multiallelic(stat): - ts = tsutil.all_fields_ts() - func = getattr(GeneralStatFuncs, stat) - if stat == "r2_ij": - result = ts.two_locus_count_stat( - [ts.samples(), ts.samples()], - func, - 1, - lambda X, n, nA, nB: X[0].sum(keepdims=True) / n.sum(), - ) - else: - # default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]` - result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) - np.testing.assert_array_almost_equal( - ts.ld_matrix( - stat=stat.replace("_ij", ""), - indexes=(0, 1), - sample_sets=[ts.samples(), ts.samples()], - ), - result, +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_one_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixture): + ts = ts_multiallelic_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_func = (lambda X, n, nA, nB: X[0] / n) if stat == "r2" else None + polarised = POLARIZATION[SUMMARY_FUNCS[stat]] + ldg = ts.two_locus_count_stat( + [ts.samples()], general_func, 1, norm_f=norm_func, polarised=polarised + ) + ld = ts.ld_matrix(stat=stat) + np.testing.assert_array_almost_equal(ld, ldg) + + +@pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) +def test_general_two_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixture): + ts = ts_multiallelic_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_func = ( + (lambda X, n, nA, nB: X[0].sum(keepdims=True) / n.sum()) + if stat == "r2_ij" + else None + ) + sample_sets = [ts.samples(), ts.samples()] + ldg = ts.two_locus_count_stat(sample_sets, general_func, 1, norm_f=norm_func) + ld = ts.ld_matrix( + stat=stat.replace("_ij", ""), indexes=(0, 1), sample_sets=sample_sets ) + np.testing.assert_array_almost_equal(ld, ldg) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 625d8f9bcb..d9a5f79be7 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -138,6 +138,23 @@ def get_example_migration_tree_sequence(self): ) return ts.ll_tree_sequence + def get_example_tree_sequence_multiallelic(self, sample_size=10): + ts = msprime.sim_mutations( + msprime.sim_ancestry( + sample_size, + recombination_rate=0.1, + sequence_length=100, + ploidy=1, + random_seed=123, + ), + rate=0.1, + random_seed=123, + ) + assert max({len(s.mutations) for s in ts.sites()}) > 2, ( + "At least one multiallelic site required" + ) + return ts.ll_tree_sequence + def verify_iterator(self, iterator): """ Checks that the specified non-empty iterator implements the @@ -1989,6 +2006,12 @@ def test_ld_matrix_multipop(self, stat_method_name): def test_two_locus_count_stat(self): ts = self.get_example_tree_sequence(10) + # Multiallelic test case to test norm function + ts_multi = self.get_example_tree_sequence_multiallelic() + assert (ts.get_samples() == ts_multi.get_samples()).all(), ( + "biallelic and multiallelic test case are expected " + "to have the same sample nodes" + ) ss = ts.get_samples() # sample sets ss_sizes = np.array([len(ss)], dtype=np.uint32) row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) @@ -2007,10 +2030,9 @@ def stat_func(X, n): return pAB - (pA * pB) def norm_func(X, n, nA, nB): - return np.expand_dims(X[0].sum() / n.sum(), axis=0) - - method = ts.two_locus_count_stat + return X[0].sum(keepdims=True) / n.sum() + method = ts.two_locus_count_stat # most tests on biallelic site_args = row_sites, col_sites, None, None, "site" branch_args = None, None, row_pos, col_pos, "branch" a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args) @@ -2019,10 +2041,20 @@ def norm_func(X, n, nA, nB): assert a.shape == (2, 2, 1) site_list_args = row_sites_list, col_sites_list, None, None, "site" branch_list_args = None, None, row_pos_list, col_pos_list, "branch" + + # happy path a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_list_args) - assert a.shape == (10, 10, 1) + assert a.shape == (10, 10, 1) # ts has 10 sites a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args) - assert a.shape == (2, 2, 1) + assert a.shape == (2, 2, 1) # ts has 2 trees + a = ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "site" + ) + assert a.shape == (56, 56, 1) # ts has 56 sites + a = ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "branch" + ) + assert a.shape == (48, 48, 1) # ts has 48 trees # CPython API errors with pytest.raises(ValueError, match="Sum of sample_set_sizes"): bad_ss = np.array([], dtype=np.int32) @@ -2094,10 +2126,55 @@ def norm_func(X, n, nA, nB): with pytest.raises(TypeError, match="norm_func must be callable"): method(ss_sizes, ss, stat_func, "uncallable", 1, True, *site_args) with pytest.raises(ValueError, match="summary function.*must be 1D"): - method(ss_sizes, ss, lambda a, b: 1, norm_func, 1, True, *site_args) - with pytest.raises(ValueError, match="length 2; must be 1"): - method(ss_sizes, ss, lambda a, b: [1, 2], norm_func, 1, True, *site_args) + method(ss_sizes, ss, lambda *_: 1, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="summary function.*length 2; must be 1"): + method(ss_sizes, ss, lambda *_: [1, 2], norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="could not convert string to float"): + method(ss_sizes, ss, lambda *_: ["nonfloat"], norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="norm function.*must be 1D"): + ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, lambda *_: 1, 1, True, *site_args + ) + with pytest.raises( + TypeError, match="takes 1 positional argument but 2 were given" + ): + ts_multi.two_locus_count_stat( + ss_sizes, ss, lambda _: 1, norm_func, 1, True, *site_args + ) + with pytest.raises(ValueError, match="norm function.*length 2; must be 1"): + ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, lambda *_: [1, 2], 1, True, *site_args + ) + with pytest.raises( + TypeError, match="takes 1 positional argument but 4 were given" + ): + ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, lambda _: [1, 2], 1, True, *site_args + ) + with pytest.raises(ValueError, match="could not convert string to float"): + ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, lambda *_: ["nonfloat"], 1, True, *site_args + ) + # Exceptions within stat_func and norm_func are correctly raised. + for exception in [ValueError, TypeError]: + + def stat_func_except(*_): + raise exception("test") + + def norm_func_except(*_): + raise exception("test") + + with pytest.raises(exception, match="test"): + method( + ss_sizes, ss, stat_func_except, norm_func, 1, True, *site_list_args + ) + with pytest.raises(exception, match="test"): + ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, norm_func_except, 1, True, *site_list_args + ) # C API errors + with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_RESULT_DIMS"): + method(ss_sizes, ss, stat_func, norm_func, 0, True, *site_list_args) with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): bad_sites = np.array([1, 0, 2], dtype=np.int32) bad_site_args = bad_sites, col_sites, None, None, "site" diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 09b6224c55..31e309f56e 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10936,7 +10936,7 @@ def two_locus_count_stat( sample_sets, f, result_dim, - norm_f=lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,], + norm_f=None, polarised=False, sites=None, positions=None, @@ -10949,7 +10949,7 @@ def two_locus_count_stat( sample_set_sizes, sample_sets, f, - norm_f, + norm_f or (lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]), result_dim, polarised, row_sites, From eaa6c111ce68b47adff8ed4cb2e29e228fcb8ed6 Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 16 Mar 2026 17:52:04 -0500 Subject: [PATCH 27/51] msprime produces different trees on macos (same seed) --- python/tests/test_python_c.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index d9a5f79be7..ccd73ac3d4 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2050,7 +2050,12 @@ def norm_func(X, n, nA, nB): a = ts_multi.two_locus_count_stat( ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "site" ) - assert a.shape == (56, 56, 1) # ts has 56 sites + import platform + + if platform.system() == "Darwin": + assert a.shape == (54, 54, 1) # ts has 54 sites on macos? + else: + assert a.shape == (56, 56, 1) # ts has 56 sites a = ts_multi.two_locus_count_stat( ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "branch" ) @@ -2159,10 +2164,10 @@ def norm_func(X, n, nA, nB): for exception in [ValueError, TypeError]: def stat_func_except(*_): - raise exception("test") + raise exception("test") # noqa: B023 def norm_func_except(*_): - raise exception("test") + raise exception("test") # noqa: B023 with pytest.raises(exception, match="test"): method( From 478681b62286c21a37cde449440c9f1d97ab73aa Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 13:07:12 -0500 Subject: [PATCH 28/51] Clean up python C tests Use the number of sites and trees reported by the tree sequence instead of hard coded values. This has the benefit of being more readable, communicating intent (review comment from Peter). Split the multiallelic and biallelic test cases, they're getting messy. Now I can explicitly assert that the norm_func is not run for biallelic sites and for branch stats. Also gets rid of awkward assertions about sample sets. --- python/tests/test_python_c.py | 128 ++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 60 deletions(-) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index ccd73ac3d4..9f8822b2f6 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2005,13 +2005,8 @@ def test_ld_matrix_multipop(self, stat_method_name): stat_method(ss_sizes, ss, indexes, col_sites, row_sites, None, None, "node") def test_two_locus_count_stat(self): + """Test two_locus_count_stat on biallelic data (no norm function)""" ts = self.get_example_tree_sequence(10) - # Multiallelic test case to test norm function - ts_multi = self.get_example_tree_sequence_multiallelic() - assert (ts.get_samples() == ts_multi.get_samples()).all(), ( - "biallelic and multiallelic test case are expected " - "to have the same sample nodes" - ) ss = ts.get_samples() # sample sets ss_sizes = np.array([len(ss)], dtype=np.uint32) row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) @@ -2029,37 +2024,33 @@ def stat_func(X, n): pB = paB + pAB return pAB - (pA * pB) - def norm_func(X, n, nA, nB): - return X[0].sum(keepdims=True) / n.sum() + def norm_func(*_): + raise Exception # norm function will not be used - method = ts.two_locus_count_stat # most tests on biallelic + method = ts.two_locus_count_stat site_args = row_sites, col_sites, None, None, "site" branch_args = None, None, row_pos, col_pos, "branch" + # happy path a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args) - assert a.shape == (10, 10, 1) + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_args) - assert a.shape == (2, 2, 1) + assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1) + # happy path - sample sets as lists are also valid site_list_args = row_sites_list, col_sites_list, None, None, "site" branch_list_args = None, None, row_pos_list, col_pos_list, "branch" - - # happy path a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_list_args) - assert a.shape == (10, 10, 1) # ts has 10 sites + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args) - assert a.shape == (2, 2, 1) # ts has 2 trees - a = ts_multi.two_locus_count_stat( + assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1) + # happy path - default array filling + a = method( ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "site" ) - import platform - - if platform.system() == "Darwin": - assert a.shape == (54, 54, 1) # ts has 54 sites on macos? - else: - assert a.shape == (56, 56, 1) # ts has 56 sites - a = ts_multi.two_locus_count_stat( + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) + a = method( ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "branch" ) - assert a.shape == (48, 48, 1) # ts has 48 trees + assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1) # CPython API errors with pytest.raises(ValueError, match="Sum of sample_set_sizes"): bad_ss = np.array([], dtype=np.int32) @@ -2136,50 +2127,17 @@ def norm_func(X, n, nA, nB): method(ss_sizes, ss, lambda *_: [1, 2], norm_func, 1, True, *site_args) with pytest.raises(ValueError, match="could not convert string to float"): method(ss_sizes, ss, lambda *_: ["nonfloat"], norm_func, 1, True, *site_args) - with pytest.raises(ValueError, match="norm function.*must be 1D"): - ts_multi.two_locus_count_stat( - ss_sizes, ss, stat_func, lambda *_: 1, 1, True, *site_args - ) - with pytest.raises( - TypeError, match="takes 1 positional argument but 2 were given" - ): - ts_multi.two_locus_count_stat( - ss_sizes, ss, lambda _: 1, norm_func, 1, True, *site_args - ) - with pytest.raises(ValueError, match="norm function.*length 2; must be 1"): - ts_multi.two_locus_count_stat( - ss_sizes, ss, stat_func, lambda *_: [1, 2], 1, True, *site_args - ) - with pytest.raises( - TypeError, match="takes 1 positional argument but 4 were given" - ): - ts_multi.two_locus_count_stat( - ss_sizes, ss, stat_func, lambda _: [1, 2], 1, True, *site_args - ) - with pytest.raises(ValueError, match="could not convert string to float"): - ts_multi.two_locus_count_stat( - ss_sizes, ss, stat_func, lambda *_: ["nonfloat"], 1, True, *site_args - ) - # Exceptions within stat_func and norm_func are correctly raised. + # Exceptions within stat_func are correctly raised. for exception in [ValueError, TypeError]: def stat_func_except(*_): raise exception("test") # noqa: B023 - def norm_func_except(*_): - raise exception("test") # noqa: B023 - - with pytest.raises(exception, match="test"): - method( - ss_sizes, ss, stat_func_except, norm_func, 1, True, *site_list_args - ) with pytest.raises(exception, match="test"): - ts_multi.two_locus_count_stat( - ss_sizes, ss, stat_func, norm_func_except, 1, True, *site_list_args - ) + method(ss_sizes, ss, stat_func_except, norm_func, 1, True, *site_args) # C API errors with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_RESULT_DIMS"): - method(ss_sizes, ss, stat_func, norm_func, 0, True, *site_list_args) + method(ss_sizes, ss, stat_func, norm_func, 0, True, *site_args) with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): bad_sites = np.array([1, 0, 2], dtype=np.int32) bad_site_args = bad_sites, col_sites, None, None, "site" @@ -2229,6 +2187,56 @@ def norm_func_except(*_): bad_branch_args = None, None, row_pos, bad_pos, "branch" method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + def test_two_locus_count_stat_multialleliic(self): + """ + Test two_locus_count_stat on multiallelic sites to test the behavior of + the norm function. + """ + ts = self.get_example_tree_sequence_multiallelic() + + def stat_func(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + def norm_func(X, n, nA, nB): + return X[0].sum(keepdims=True) / n.sum() + + ss = ts.get_samples() # sample sets + ss_sizes = np.array([len(ss)], dtype=np.uint32) + row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) + col_sites = row_sites + method = ts.two_locus_count_stat + site_args = row_sites, col_sites, None, None, "site" + + # happy path + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args) + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) + # CPython API errors + with pytest.raises(ValueError, match="norm function.*must be 1D"): + method(ss_sizes, ss, stat_func, lambda *_: 1, 1, True, *site_args) + with pytest.raises( + TypeError, match="takes 1 positional argument but 2 were given" + ): + method(ss_sizes, ss, lambda _: 1, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="norm function.*length 2; must be 1"): + method(ss_sizes, ss, stat_func, lambda *_: [1, 2], 1, True, *site_args) + with pytest.raises( + TypeError, match="takes 1 positional argument but 4 were given" + ): + method(ss_sizes, ss, stat_func, lambda _: [1, 2], 1, True, *site_args) + with pytest.raises(ValueError, match="could not convert string to float"): + method(ss_sizes, ss, stat_func, lambda *_: ["nonfloat"], 1, True, *site_args) + # Exceptions within stat_func are correctly raised. + for exception in [ValueError, TypeError]: + + def norm_func_except(*_): + raise exception("test") # noqa: B023 + + with pytest.raises(exception, match="test"): + method(ss_sizes, ss, stat_func, norm_func_except, 1, True, *site_args) + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError): From 3b61f67fbab3826db1fcb04d29eb6ebf608f8ce9 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 17:42:26 -0500 Subject: [PATCH 29/51] Add/refine tests, draft docstring Clean up dimension handling around summary functions and normalisation. There is a slight speed advantage (according to a microbenchmark) and a huge readability advantage to simply returning [value]. I keep all computations specifying `keepdims`, but remove list indexing (i.e. `AB[[0]]`) in favor of returning a list with a single scalar. It turns out that vectorised numpy functions are actually slower in some cases because the data we're operating on is so small. Finally, fix the default normalisation function so that it works both on one-way and two-way statistics. Users will still need to specify `hap_norm` when appropriate (and a special case of `hap_norm` for two-way stats). Per Peter's comment, I investigated dimension dropping and indeed, general stats don't drop dimensions so I removed the dimension dropping code. However, we return a matrix of `(m, m, k)` and we want `(k, m, m)`, so `np.moveaxis` is still needed. Added tests: * Multiallelic multi sample-set. This tests operations on two sample sets for multiallelic data (which excercises the norm function with multiple sample sets). This test highlighted the slight changes needed to the default normalisation function. * Multi outputs. This test mimics a two-way stat called on multiple indexes. It shows and tests the ability to compute multiple statistics from the same haplotype counts matrix (which is especially useful with the explosion of possible summary functions in three-way, four-way stats). In our biallelic test case, I also assert that the normalisation function is never called and add a note about polarisation. Finally, I add a draft docstring, but to complete this I think that the two-locus docs are required. Also, I'd like to add some general documentation. --- python/tests/test_ld_matrix.py | 123 ++++++++++++++++++++++++++------- python/tskit/trees.py | 76 ++++++++++++++++++-- 2 files changed, 170 insertions(+), 29 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index bb459ee7ca..411149aa35 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2400,16 +2400,19 @@ def test_multipopulation_r2_varying_unequal_set_sizes(genotypes, sample_sets, ex class GeneralStatFuncs: """ - functions take X, n as parameters where + Summary functions take X, n as parameters where X is a matrix of haplotype + counts per sample set and n is a vector of sample set sizes. X has shape (3, k) + and n has shape (k, ), where k is the number of sample sets. The rows of X + contain haplotype counts for AB, Ab, aB (capitalized == derived). - X: shape=(3, #ss) + X: shape=(3, k) sample sets - count AB [[ ] - count Ab [ ] - count aB [ ]] + count AB [[ #ss1, #ss2, ... ] + count Ab [ #ss1, #ss2, ... ] + count aB [ #ss1, #ss2, ... ]] - n: shape=(#ss, ) - [ ] + n: shape=(k, ) + [ #ss1, #ss2, ... ] """ @staticmethod @@ -2480,37 +2483,39 @@ def pi2(X, n): def D2_unbiased(X, n): AB, Ab, aB = X ab = n - X.sum(0) - return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + return ( ((aB**2) * (Ab - 1) * Ab) + ((ab - 1) * ab * (AB - 1) * AB) - (aB * Ab * (Ab + (2 * ab * AB) - 1)) - ) + ) / (n * (n - 1) * (n - 2) * (n - 3)) @staticmethod def Dz_unbiased(X, n): AB, Ab, aB = X ab = n - X.sum(0) - return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + return ( (((AB * ab) - (Ab * aB)) * (aB + ab - AB - Ab) * (Ab + ab - AB - aB)) - ((AB * ab) * (AB + ab - Ab - aB - 2)) - ((Ab * aB) * (Ab + aB - AB - ab - 2)) - ) + ) / (n * (n - 1) * (n - 2) * (n - 3)) @staticmethod def pi2_unbiased(X, n): AB, Ab, aB = X ab = n - X.sum(0) - return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + return ( ((AB + Ab) * (aB + ab) * (AB + aB) * (Ab + ab)) - ((AB * ab) * (AB + ab + (3 * Ab) + (3 * aB) - 1)) - ((Ab * aB) * (Ab + aB + (3 * AB) + (3 * ab) - 1)) - ) + ) / (n * (n - 1) * (n - 2) * (n - 3)) + # Two-way statistics have the _ij suffix. @staticmethod def r2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB + # keepdims preserves the output shape of (1, ) D2_ij = np.prod(pAB - (pA * pB), keepdims=True) denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB)), keepdims=True) with suppress_overflow_div0_warning(): @@ -2525,17 +2530,37 @@ def D2_ij(X, n): @staticmethod def D2_ij_unbiased(X, n): - """NB: We use double brackets here to preserve the output shape of (1,)""" + """The identity of the sample sets is up to the user.""" AB, Ab, aB = X ab = n - X.sum(0) - return ( - (Ab[[0]] * aB[[0]] - AB[[0]] * ab[[0]]) - * (Ab[[1]] * aB[[1]] - AB[[1]] * ab[[1]]) - / n[[0]] - / (n[[0]] - 1) - / n[[1]] - / (n[[1]] - 1) + return [ + (Ab[0] * aB[0] - AB[0] * ab[0]) + * (Ab[1] * aB[1] - AB[1] * ab[1]) + / (n[0] * (n[0] - 1) * n[1] * (n[1] - 1)) + ] + + @staticmethod + def D2_ii_ij_jj_unbiased(X, n): + """ + Multiple stats can be computed from the same data. The identity of the + sample sets is up to the user. This function assumes two sample sets. + """ + AB, Ab, aB = X + ab = n - X.sum(0) + + # unbiased estimator for equal sample sets + ii, jj = ( + AB * (AB - 1) * ab * (ab - 1) + + Ab * (Ab - 1) * aB * (aB - 1) + - 2 * AB * Ab * aB * ab + ) / (n * (n - 1) * (n - 2) * (n - 3)) + # unbiased estimator for disjoint sample sets + ij = ( + (Ab[0] * aB[0] - AB[0] * ab[0]) + * (Ab[1] * aB[1] - AB[1] * ab[1]) + / (n[0] * (n[0] - 1) * n[1] * (n[1] - 1)) ) + return [ii, ij, jj] @pytest.fixture(scope="module") @@ -2573,7 +2598,17 @@ def ts_multiallelic_fixture(): def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture): ts = ts_100_samp_with_sites_fixture sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] - ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) + + # In addition to not needing a normalisation function, normalisation is also + # not required because these sites are biallelic. + def assert_no_norm_func(*_): + raise Exception( + "Normalisation function should not be called for biallelic sites" + ) + + ldg = ts.two_locus_count_stat( + sample_sets, getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func + ) ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) np.testing.assert_array_almost_equal(ldg, ld) @@ -2584,7 +2619,7 @@ def test_general_two_locus_two_way_site_stat(stat, ts_100_samp_with_sites_fixtur sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1) ld = ts.ld_matrix( - sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=(0, 1) + sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=[(0, 1)] ) np.testing.assert_array_almost_equal(ldg, ld) @@ -2599,7 +2634,24 @@ def test_general_one_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixtu [ts.samples()], general_func, 1, norm_f=norm_func, polarised=polarised ) ld = ts.ld_matrix(stat=stat) - np.testing.assert_array_almost_equal(ld, ldg) + # ld_matrix drops dims, expand for comparison + np.testing.assert_array_almost_equal(ldg, np.expand_dims(ld, 0)) + + +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_one_way_two_locus_stat_multiallelic_multi_sample_set( + stat, ts_multiallelic_fixture +): + ts = ts_multiallelic_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_func = (lambda X, n, nA, nB: X[0] / n) if stat == "r2" else None + polarised = POLARIZATION[SUMMARY_FUNCS[stat]] + sample_sets = [ts.samples(), ts.samples()] + ldg = ts.two_locus_count_stat( + sample_sets, general_func, 2, norm_f=norm_func, polarised=polarised + ) + ld = ts.ld_matrix(stat=stat, sample_sets=sample_sets) + np.testing.assert_array_almost_equal(ldg, ld) @pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) @@ -2616,4 +2668,25 @@ def test_general_two_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixtu ld = ts.ld_matrix( stat=stat.replace("_ij", ""), indexes=(0, 1), sample_sets=sample_sets ) - np.testing.assert_array_almost_equal(ld, ldg) + # ld_matrix drops dims, expand for comparison + np.testing.assert_array_almost_equal(ldg, np.expand_dims(ld, 0)) + + +def test_general_two_locus_multi_outputs(): + ts = msprime.sim_mutations( + msprime.sim_ancestry( + 4, recombination_rate=0.1, sequence_length=100, random_seed=123 + ), + rate=0.1, + random_seed=123, + ) + assert ts.num_samples == 8, "8 samples are required" + assert max({len(s.mutations) for s in ts.sites()}) > 2, ( + "At least one multiallelic site required" + ) + A = ts.samples()[0:4] + B = ts.samples()[4:] + + ldg = ts.two_locus_count_stat([A, B], GeneralStatFuncs.D2_ii_ij_jj_unbiased, 3) + ld = ts.ld_matrix([A, B], stat="D2_unbiased", indexes=[(0, 0), (0, 1), (1, 1)]) + np.testing.assert_array_almost_equal(ldg, ld) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 31e309f56e..408203834d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10942,6 +10942,75 @@ def two_locus_count_stat( positions=None, mode="site", ): + """ + Compute two-locus statistics with a user-defined python function that + operates on haplotype counts. TODO: reference modes in two-locus docs. + On each pair of sites or trees, the summary function is provided with + ``X``, a matrix with shape (3, k) and ``n``, a vector with shape (k,), + where k is the number of sample sets provided. ``X`` is a read-only + matrix whose rows contain haplotype counts per sample set (counts of AB, + Ab, aB) and ``n`` is a vector of sample set sizes. + + .. note:: + Because we are operating on very small matrices/vectors, vectorised + operations are often times slower than operations on scalars. Simply + returning ``[value]`` can be faster than returning + ``value[np.newaxis,]`` or ``np.expand_dims(value, 0)``. + + What follows is an example of computing ``D`` from a tree sequence. Many + more examples can be found in the test suite + ``test_ld_matrix.py::GeneralStatsFuncs``. Let's begin with our summary + function, ``D``. We convert counts to proportions, then compute ``D``, + returning a numpy array with length equal to the number of sample sets. + + .. code-block:: python + def D(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + ``norm_f`` is a normalisation function used to combine all computed + statistics for multiallelic allele pairs (TODO: see two-locus + docs). Biallelic sites do not require any normalisation (in fact, the + normalisation function is never called for biallelic sites). If one of + either site A or site B is multiallelic, then the normalisation function + will be called. The default normalisation function is identical to + ``total_norm`` shown in the example below. ``hap_norm`` is required for + normalising :math:`r^2`. Both of these examples return a numpy array + with length equal to the number of sample sets (for one-way stats). + + .. code-block:: python + def total_norm(X, n, nA, nB): + [1 / (nA * nB)] * result_dim + + def hap_norm(X, n, nA, nB): + X[0] / n + + A simple call (without specifying normalisation) would look like this + + .. code-block::python + ts.two_locus_count_stat([ts.samples()], D, 1, polarised=True) + + :param list sample_sets: A list of lists of Node IDs, specifying the + groups of nodes to compute the statistic with. + :param f: A function that takes two arguments - a two-dimensional array + with shape (3, k) and a one-dimensional array with shape (k, ) where + k is the number of sample sets. + :param int result_dim: The length of ``f`` and ``norm_f``'s return value. + :param norm_f: A function that takes four arguments - the first two are + the same as ``f``, the second two are scalars representing the + number of A and B alleles, respectively. + :param bool polarised: Whether to leave the ancestral state out of + computations: see :ref:`sec_stats` for more details. + :param list sites: TODO: two-locus docs + :param list positions: TODO: two-locus docs + :param str mode: A string giving the "type" of the statistic to be + computed (defaults to "site"). + :return: A ndarray with shape equal to (TODO: reference two-locus docs, + no dimension dropping shape=(k, m, m) where k=num_sample_sets, + m=num_sites or num_trees). + """ row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) @@ -10949,7 +11018,8 @@ def two_locus_count_stat( sample_set_sizes, sample_sets, f, - norm_f or (lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]), + # produce the same number of dims as output dimensions + norm_f or (lambda X, n, nA, nB: [1 / (nA * nB)] * result_dim), result_dim, polarised, row_sites, @@ -10958,11 +11028,9 @@ def two_locus_count_stat( col_positions, mode, ) - if result_dim == 1: # drop dimension - return result.reshape(result.shape[:2]) # Orient the data so that the first dimension is the sample set so that # we get one LD matrix per sample set. - return result.swapaxes(0, 2).swapaxes(1, 2) + return np.moveaxis(result, -1, 0) def ld_matrix( self, From 1b1a58e8f06b62fb41c1b3af738cec38c3e0c241 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 17:57:00 -0500 Subject: [PATCH 30/51] regain test coverage for default sample sets --- python/tests/test_ld_matrix.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 411149aa35..abb7eea3cd 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2594,18 +2594,33 @@ def ts_multiallelic_fixture(): return ts +def assert_no_norm_func(*_): + """Used in biallelic tests""" + raise Exception("Normalisation function should not be called for biallelic sites") + + @pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture): +def test_general_two_locus_site_stat_default_sample_sets( + stat, ts_100_samp_with_sites_fixture +): ts = ts_100_samp_with_sites_fixture - sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] - # In addition to not needing a normalisation function, normalisation is also # not required because these sites are biallelic. - def assert_no_norm_func(*_): - raise Exception( - "Normalisation function should not be called for biallelic sites" - ) + ldg = ts.two_locus_count_stat( + [ts.samples()], getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func + ) + ld = ts.ld_matrix(stat=stat) # use default sample sets + np.testing.assert_array_almost_equal(ldg, ld) + +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_two_locus_site_stat_two_sample_sets( + stat, ts_100_samp_with_sites_fixture +): + ts = ts_100_samp_with_sites_fixture + sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + # In addition to not needing a normalisation function, normalisation is also + # not required because these sites are biallelic. ldg = ts.two_locus_count_stat( sample_sets, getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func ) From 51b358e321f1fa8c40ee71f4bdea888d0e6f52c5 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 18:07:09 -0500 Subject: [PATCH 31/51] Revert "regain test coverage for default sample sets" This reverts commit 6685399f83ef7da4a83f279530b3e00647b11014. --- python/tests/test_ld_matrix.py | 29 +++++++---------------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index abb7eea3cd..411149aa35 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2594,33 +2594,18 @@ def ts_multiallelic_fixture(): return ts -def assert_no_norm_func(*_): - """Used in biallelic tests""" - raise Exception("Normalisation function should not be called for biallelic sites") - - -@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_two_locus_site_stat_default_sample_sets( - stat, ts_100_samp_with_sites_fixture -): - ts = ts_100_samp_with_sites_fixture - # In addition to not needing a normalisation function, normalisation is also - # not required because these sites are biallelic. - ldg = ts.two_locus_count_stat( - [ts.samples()], getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func - ) - ld = ts.ld_matrix(stat=stat) # use default sample sets - np.testing.assert_array_almost_equal(ldg, ld) - - @pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_two_locus_site_stat_two_sample_sets( - stat, ts_100_samp_with_sites_fixture -): +def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture): ts = ts_100_samp_with_sites_fixture sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + # In addition to not needing a normalisation function, normalisation is also # not required because these sites are biallelic. + def assert_no_norm_func(*_): + raise Exception( + "Normalisation function should not be called for biallelic sites" + ) + ldg = ts.two_locus_count_stat( sample_sets, getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func ) From 807bc30cd5ee3cd3e8b7ec64108693ec88a4d000 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 18:33:30 -0500 Subject: [PATCH 32/51] update comment about result dimension --- python/tskit/trees.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 408203834d..d63687b11c 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -11029,7 +11029,7 @@ def hap_norm(X, n, nA, nB): mode, ) # Orient the data so that the first dimension is the sample set so that - # we get one LD matrix per sample set. + # we get one LD matrix per result dimension return np.moveaxis(result, -1, 0) def ld_matrix( From f5931eca0265a85c0e7045c50d7a86c17dd89d80 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 18:39:36 -0500 Subject: [PATCH 33/51] be more explicit about setting the default norm function --- python/tskit/trees.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index d63687b11c..ee792fd84c 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -11014,12 +11014,14 @@ def hap_norm(X, n, nA, nB): row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) + if norm_f is None: + # produce the same number of dims as output dimensions with [val] * dim + norm_f = lambda X, n, nA, nB: [1 / (nA * nB)] * result_dim result = self._ll_tree_sequence.two_locus_count_stat( sample_set_sizes, sample_sets, f, - # produce the same number of dims as output dimensions - norm_f or (lambda X, n, nA, nB: [1 / (nA * nB)] * result_dim), + norm_f, result_dim, polarised, row_sites, From b7236525b88cb5677c7ad7264a908cba8fb88b21 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 18:41:37 -0500 Subject: [PATCH 34/51] linting does not like assigning lambdas to variables --- python/tskit/trees.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index ee792fd84c..81ea4052e6 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -11016,7 +11016,9 @@ def hap_norm(X, n, nA, nB): _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) if norm_f is None: # produce the same number of dims as output dimensions with [val] * dim - norm_f = lambda X, n, nA, nB: [1 / (nA * nB)] * result_dim + def norm_f(X, n, nA, nB): + return [1 / (nA * nB)] * result_dim + result = self._ll_tree_sequence.two_locus_count_stat( sample_set_sizes, sample_sets, From ea88f373f05373911025c80f5a78b17de4fe3fb5 Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 20 Mar 2026 15:03:11 -0500 Subject: [PATCH 35/51] add an else statement to improve readability (review) --- c/tskit/trees.c | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 0f8a10c182..e567ef1acd 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3440,21 +3440,22 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); goto out; + } else { + tsk_bug_assert(stat_branch); + ret = check_positions( + row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = check_positions( + col_positions, out_cols, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, + row_positions, out_cols, col_positions, options, result); } - tsk_bug_assert(stat_branch); - ret = check_positions( - row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); - if (ret != 0) { - goto out; - } - ret = check_positions( - col_positions, out_cols, tsk_treeseq_get_sequence_length(self)); - if (ret != 0) { - goto out; - } - ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, - row_positions, out_cols, col_positions, options, result); out: return ret; } From dfd5736d1448318334cce2e39d6a684d6991340b Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 20 Mar 2026 16:25:34 -0500 Subject: [PATCH 36/51] add a few more tests (review) Add a test for behavior on empty tree sequences (no samples, no edges, no sites). Add a "no sites" fixture. Include branch stat testing. Tune branch stat test runtime by reducing the size of `ts_100_samp_with_sites_fixture`, now named `ts_10_samp_with_sites_fixture`. Add explicit testing for output dimensions and assert that the norm func is not called on trees with only biallelic sites and in branch mode. Add a GeneralStatNormFuncs class to explicitly document possible normalisation functions and in what situations they will be used. Tune size of tree sequence in `test_general_multi_outputs` so that the test runs in a reasonable amount of time in branch mode. --- python/tests/test_ld_matrix.py | 178 +++++++++++++++++++++++++-------- 1 file changed, 136 insertions(+), 42 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 411149aa35..4a071ab00f 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2563,17 +2563,53 @@ def D2_ii_ij_jj_unbiased(X, n): return [ii, ij, jj] +class GeneralStatNormFuncs: + @staticmethod + def hap_norm(X, n, nA, nB): + """Stat from 1 sample set -> 1 result""" + return X[0] / n + + @staticmethod + def k_way_hap_norm(X, n, nA, nB): + """Stat from k sample sets -> 1 result""" + return X[0].sum(keepdims=True) / n.sum() + + @staticmethod + def assert_no_norm_func(*_): + """Normalisation is not required in branch mode and with biallelic sites.""" + raise Exception("Normalisation function should not be called") + + @classmethod + def choose(cls, stat, mode, ts): + """ + Choose norm function based on stat, mode, presence of multiallelic sites + """ + is_multiallelic = max({len(s.mutations) for s in ts.sites()}) > 1 + match (stat, mode, is_multiallelic): + case ("r2", "site", True): + return cls.hap_norm + case ("r2_ij", "site", True): + return cls.k_way_hap_norm + case (_, "branch", _): # branch stats do not need a norm func + return cls.assert_no_norm_func + case (_, _, False): # biallelic sites do not need a norm func + return cls.assert_no_norm_func + case _: # total_norm is default (1 / (nA * nB)). handles multi-way stats + return None + + @pytest.fixture(scope="module") -def ts_100_samp_with_sites_fixture(): +def ts_10_samp_with_sites_fixture(): ts = tsutil.get_sim_example( - sample_size=100, - sequence_length=32, - recombination_rate=0.5, + sample_size=10, + sequence_length=15, + recombination_rate=0.1, mutation_rate=0.1, seed=123, ) assert ts.num_sites > 0, "sites are required" - assert ts.num_samples == 100, "100 samples are required" + assert ts.num_samples == 10 # Samples directly indexed in tests below + assert max({len(s.mutations) for s in ts.sites()}) == 1, "sites must be biallelic" return ts @@ -2588,50 +2624,109 @@ def ts_multiallelic_fixture(): ) # Need at least 4 samples to test unbiased statistics assert ts.num_samples >= 4, "At least 4 samples required" - assert max({len(s.mutations) for s in ts.sites()}) > 2, ( + assert max({len(s.mutations) for s in ts.sites()}) > 1, ( "At least one multiallelic site required" ) return ts -@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture): - ts = ts_100_samp_with_sites_fixture - sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] +@pytest.fixture(scope="module") +def ts_no_sites_fixture(): + ts = msprime.sim_ancestry( + 2, recombination_rate=0.1, sequence_length=100, random_seed=123 + ) + assert ts.num_sites == 0 + return ts + + +@pytest.mark.parametrize("mode", ["site", "branch"]) +@pytest.mark.parametrize( + "ts", + [ts for ts in get_example_tree_sequences() if ts.id in {"no_samples", "empty_ts"}], +) +def test_general_empty_ts(mode, ts): + with pytest.raises(ValueError, match="at least one element"): + ts.two_locus_count_stat([ts.samples()], GeneralStatFuncs.D, 1, mode=mode) + + +def test_general_no_sites(ts_no_sites_fixture): + ts = ts_no_sites_fixture + ldg = ts.two_locus_count_stat([ts.samples()], GeneralStatFuncs.D, 1) + np.testing.assert_array_equal(ldg, np.zeros((1, 0, 0), np.float64)) + + +@pytest.mark.parametrize("mode", ["site", "branch"]) +def test_general_output_dimensions(mode, ts_multiallelic_fixture): + ts = ts_multiallelic_fixture + norm_f = GeneralStatNormFuncs.choose("D", mode, ts) + samples = ts.samples() + expected_dims = dict( + site=(1, ts.num_sites, ts.num_sites), branch=(1, ts.num_trees, ts.num_trees) + )[mode] + result = ts.two_locus_count_stat( + samples, GeneralStatFuncs.D, 1, mode=mode, norm_f=norm_f + ) + assert result.shape == expected_dims + # we expect that dims are the same with `samples` or `[samples]` + result = ts.two_locus_count_stat( + [samples], GeneralStatFuncs.D, 1, mode=mode, norm_f=norm_f + ) + assert result.shape == expected_dims + + expected_dims = dict( + site=(2, ts.num_sites, ts.num_sites), branch=(2, ts.num_trees, ts.num_trees) + )[mode] + result = ts.two_locus_count_stat( + [samples, samples], GeneralStatFuncs.D, 2, mode=mode, norm_f=norm_f + ) + assert result.shape == expected_dims - # In addition to not needing a normalisation function, normalisation is also - # not required because these sites are biallelic. - def assert_no_norm_func(*_): - raise Exception( - "Normalisation function should not be called for biallelic sites" - ) +@pytest.mark.parametrize("mode", ["site", "branch"]) +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_one_way_multi_sample_set(mode, stat, ts_10_samp_with_sites_fixture): + ts = ts_10_samp_with_sites_fixture + norm_f = GeneralStatNormFuncs.choose(stat, mode, ts) + sample_sets = [ts.samples()[0:5], ts.samples()[5:10]] ldg = ts.two_locus_count_stat( - sample_sets, getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func + sample_sets, + getattr(GeneralStatFuncs, stat), + 2, + norm_f=norm_f, + mode=mode, ) - ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) + ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat, mode=mode) np.testing.assert_array_almost_equal(ldg, ld) +@pytest.mark.parametrize("mode", ["site", "branch"]) @pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) -def test_general_two_locus_two_way_site_stat(stat, ts_100_samp_with_sites_fixture): - ts = ts_100_samp_with_sites_fixture - sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] - ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1) +def test_general_two_way(mode, stat, ts_10_samp_with_sites_fixture): + ts = ts_10_samp_with_sites_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_f = GeneralStatNormFuncs.choose(stat, mode, ts) + sample_sets = [ts.samples()[0:5], ts.samples()[5:10]] + ldg = ts.two_locus_count_stat(sample_sets, general_func, 1, norm_f=norm_f, mode=mode) ld = ts.ld_matrix( - sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=[(0, 1)] + sample_sets=sample_sets, + stat=stat.replace("_ij", ""), + indexes=[(0, 1)], + mode=mode, ) np.testing.assert_array_almost_equal(ldg, ld) +# NB: multiallelic testing only needed for sites. branches are biallelic. + + @pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_one_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixture): +def test_general_one_way_multiallelic(stat, ts_multiallelic_fixture): ts = ts_multiallelic_fixture general_func = getattr(GeneralStatFuncs, stat) - norm_func = (lambda X, n, nA, nB: X[0] / n) if stat == "r2" else None + norm_f = GeneralStatNormFuncs.choose(stat, "site", ts) polarised = POLARIZATION[SUMMARY_FUNCS[stat]] ldg = ts.two_locus_count_stat( - [ts.samples()], general_func, 1, norm_f=norm_func, polarised=polarised + [ts.samples()], general_func, 1, norm_f=norm_f, polarised=polarised ) ld = ts.ld_matrix(stat=stat) # ld_matrix drops dims, expand for comparison @@ -2639,32 +2734,26 @@ def test_general_one_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixtu @pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_one_way_two_locus_stat_multiallelic_multi_sample_set( - stat, ts_multiallelic_fixture -): +def test_general_one_way_multiallelic_multi_sample_set(stat, ts_multiallelic_fixture): ts = ts_multiallelic_fixture general_func = getattr(GeneralStatFuncs, stat) - norm_func = (lambda X, n, nA, nB: X[0] / n) if stat == "r2" else None + norm_f = GeneralStatNormFuncs.choose(stat, "site", ts) polarised = POLARIZATION[SUMMARY_FUNCS[stat]] sample_sets = [ts.samples(), ts.samples()] ldg = ts.two_locus_count_stat( - sample_sets, general_func, 2, norm_f=norm_func, polarised=polarised + sample_sets, general_func, 2, norm_f=norm_f, polarised=polarised ) ld = ts.ld_matrix(stat=stat, sample_sets=sample_sets) np.testing.assert_array_almost_equal(ldg, ld) @pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) -def test_general_two_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixture): +def test_general_two_way_multiallelic(stat, ts_multiallelic_fixture): ts = ts_multiallelic_fixture general_func = getattr(GeneralStatFuncs, stat) - norm_func = ( - (lambda X, n, nA, nB: X[0].sum(keepdims=True) / n.sum()) - if stat == "r2_ij" - else None - ) + norm_f = GeneralStatNormFuncs.choose(stat, "site", ts) sample_sets = [ts.samples(), ts.samples()] - ldg = ts.two_locus_count_stat(sample_sets, general_func, 1, norm_f=norm_func) + ldg = ts.two_locus_count_stat(sample_sets, general_func, 1, norm_f=norm_f) ld = ts.ld_matrix( stat=stat.replace("_ij", ""), indexes=(0, 1), sample_sets=sample_sets ) @@ -2672,10 +2761,11 @@ def test_general_two_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixtu np.testing.assert_array_almost_equal(ldg, np.expand_dims(ld, 0)) -def test_general_two_locus_multi_outputs(): +@pytest.mark.parametrize("mode", ["site", "branch"]) +def test_general_multi_outputs(mode): ts = msprime.sim_mutations( msprime.sim_ancestry( - 4, recombination_rate=0.1, sequence_length=100, random_seed=123 + 4, recombination_rate=0.1, sequence_length=35, random_seed=123 ), rate=0.1, random_seed=123, @@ -2687,6 +2777,10 @@ def test_general_two_locus_multi_outputs(): A = ts.samples()[0:4] B = ts.samples()[4:] - ldg = ts.two_locus_count_stat([A, B], GeneralStatFuncs.D2_ii_ij_jj_unbiased, 3) - ld = ts.ld_matrix([A, B], stat="D2_unbiased", indexes=[(0, 0), (0, 1), (1, 1)]) + norm_f = GeneralStatNormFuncs.choose("D2_unbiased", mode, ts) + general_func = GeneralStatFuncs.D2_ii_ij_jj_unbiased + ldg = ts.two_locus_count_stat([A, B], general_func, 3, mode=mode, norm_f=norm_f) + ld = ts.ld_matrix( + [A, B], stat="D2_unbiased", indexes=[(0, 0), (0, 1), (1, 1)], mode=mode + ) np.testing.assert_array_almost_equal(ldg, ld) From c9ad6338f3a87189814788d42ddc5243e9073c67 Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 20 Mar 2026 17:13:01 -0500 Subject: [PATCH 37/51] Add some minimal documentation about the purpose of the two entrypoints --- c/tskit/trees.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index e567ef1acd..515162c157 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3384,6 +3384,7 @@ check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_s return ret; } +/* Called directly by C python interface `two_locus_count_stat` */ int tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, @@ -3439,7 +3440,6 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets, sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); - goto out; } else { tsk_bug_assert(stat_branch); ret = check_positions( @@ -3460,6 +3460,7 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, return ret; } +/* Called by summary functions implemented in C */ int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, From b9384fc723224bb9307283307abba187ca091207 Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 20 Mar 2026 17:13:41 -0500 Subject: [PATCH 38/51] Test explicitly that our internal data is read only --- python/tests/test_python_c.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 9f8822b2f6..3f7546e4c8 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2127,6 +2127,13 @@ def norm_func(*_): method(ss_sizes, ss, lambda *_: [1, 2], norm_func, 1, True, *site_args) with pytest.raises(ValueError, match="could not convert string to float"): method(ss_sizes, ss, lambda *_: ["nonfloat"], norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="assignment destination is read-only"): + + def bad_stat_func(X, n): + X[0] = [1] + return [1] + + method(ss_sizes, ss, bad_stat_func, norm_func, 1, True, *site_args) # Exceptions within stat_func are correctly raised. for exception in [ValueError, TypeError]: From 242734a205cea518d8f6139f54316653a0c59e07 Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 20 Mar 2026 17:20:21 -0500 Subject: [PATCH 39/51] Fix memory leak; more readonly arrays The transpose operation was creating intermediate data that was not being garbage collected, resulting in a rather obvious memory leak. To mitigate this, I opt to wrap the data in a numpy array that is already transposed. The original data is natively laid out with shape (K,3), by creating a numpy array with shape (3,K) and strides (8,8*K), we can avoid an intermediate transpose operation altogether. After leak-checking again, the memory leak is gone. I also add a Py_XDECREF to remove the reference to `norm_func`, though in my leak checking I don't actually see any difference in RSS or heap size. Finally, mark a few more arrays as read-only, since the C functions that accept them as input annotate these arrays as `const`. --- python/_tskitmodule.c | 48 +++++++++++++++++++------------------------ 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index c772270de6..fbb148097f 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7916,6 +7916,7 @@ parse_sites(TreeSequence *self, PyObject *sites, npy_intp *out_dim) if (array == NULL) { goto out; } + PyArray_CLEARFLAGS(array, NPY_ARRAY_WRITEABLE); *out_dim = PyArray_DIM(array, 0); } @@ -7940,6 +7941,7 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) if (array == NULL) { goto out; } + PyArray_CLEARFLAGS(array, NPY_ARRAY_WRITEABLE); *out_dim = PyArray_DIM(array, 0); } out: @@ -7966,17 +7968,13 @@ general_two_locus_norm_func(tsk_size_t K, const double *X, tsk_size_t result_dim two_locus_general_stat_params *tl_params = params; PyObject *summary_func = tl_params->norm_func; PyArrayObject *ss_sizes = tl_params->sample_set_sizes; - npy_intp X_dims[2] = { K, 3 }; + npy_intp X_dims[2] = { 3, K }; + npy_intp X_strides[2] = { sizeof(double), sizeof(double) * 3 }; - // Create a read only view of X as a numpy array - X_array = (PyArrayObject *) PyArray_SimpleNewFromData( - 2, X_dims, NPY_FLOAT64, (void *) X); - if (X_array == NULL) { - goto out; - } - PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); - // Transpose into column arrays, so that we can easily decompose the results - X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + // Create a read only view of X as a numpy array. X is transposed from its + // native memory layout (K, 3) -> (3, K). More detailed comment below. + X_array = (PyArrayObject *) PyArray_New(&PyArray_Type, 2, X_dims, NPY_FLOAT64, + X_strides, (void *) X, -1, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED, NULL); if (X_array == NULL) { goto out; } @@ -8041,21 +8039,17 @@ general_two_locus_count_stat_func( two_locus_general_stat_params *tl_params = params; PyObject *summary_func = tl_params->summary_func; PyArrayObject *ss_sizes = tl_params->sample_set_sizes; - npy_intp X_dims[2] = { K, 3 }; - - // Create a read only view of X as a numpy array - X_array = (PyArrayObject *) PyArray_SimpleNewFromData( - 2, X_dims, NPY_FLOAT64, (void *) X); - if (X_array == NULL) { - goto out; - } - PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); - // Transpose into column arrays, so that we can easily decompose the results - // For example: pAB, pAb, paB = X / n - // which works with K>1. In addition, the data is not reordered, meaning - // that the data is still oriented where samples are rows, meaning that - // we'll preserve data locality in ops over samples. - X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + npy_intp X_dims[2] = { 3, K }; + npy_intp X_strides[2] = { sizeof(double), sizeof(double) * 3 }; + + // Create a transposed, read only view of X as a numpy array. The native + // memory layout of X is (K, 3), we wrap it in a numpy array with dimensions + // (3, K), creating row arrays of haplotype counts so that we can easily + // decompose the results. For example: `pAB, pAb, paB = X / n` which works + // with K>1. Itemsize is -1 because we specify the dtype. NB: we do not set + // NPY_ARRAY_WRITEABLE, so X_array is read only. + X_array = (PyArrayObject *) PyArray_New(&PyArray_Type, 2, X_dims, NPY_FLOAT64, + X_strides, (void *) X, -1, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED, NULL); if (X_array == NULL) { goto out; } @@ -8074,8 +8068,7 @@ general_two_locus_count_stat_func( } if (PyArray_NDIM(Y_array) != 1) { PyErr_Format(PyExc_ValueError, - "Array returned by summary function callback is %d dimensional; " - "must be 1D", + "Array returned by summary function callback is %d dimensional; must be 1D", (int) PyArray_NDIM(Y_array)); goto out; } @@ -8220,6 +8213,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * result_matrix = NULL; out: Py_XDECREF(summary_func); + Py_XDECREF(norm_func); Py_XDECREF(row_sites_array); Py_XDECREF(col_sites_array); Py_XDECREF(row_positions_array); From 17182bde5785b24a8305f890117440c5334b87f0 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sat, 21 Mar 2026 11:11:06 -0500 Subject: [PATCH 40/51] Return on summary function error (bug) `compute_two_tree_branch_stat` did not check the error returned by the summary function (which is the return value of `compute_two_tree_branch_state_update`.). I caught this because the python callback was setting an exception in the summary function and the python runtime was complaining about an exception being set, despite a successful return status. This also means that failing summary functions would (eventually) be caught, but the code would continue to run. The C python tests did not catch this because we would eventually raise the correct exception. --- c/tskit/trees.c | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 515162c157..3be4073971 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3188,8 +3188,11 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + ret = compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, result_dim, -1, f, f_params, &work, result); + if (ret != 0) { + goto out; + } } // Remove samples under nodes from removed edges to parent nodes for (j = 0; j < r_state->n_edges_out; j++) { @@ -3229,8 +3232,11 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + ret = compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, result_dim, +1, f, f_params, &work, result); + if (ret != 0) { + goto out; + } } out: tsk_safe_free(updated_nodes); From 287405804a2b1c0b1859f76380751ae1cca0fc56 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 22 Mar 2026 15:15:29 -0500 Subject: [PATCH 41/51] Incorporate Peter's improvement to the test comments --- python/tests/test_python_c.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 3f7546e4c8..09370a9a78 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2042,7 +2042,7 @@ def norm_func(*_): assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args) assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1) - # happy path - default array filling + # happy path - default values for site and position lists a = method( ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "site" ) From 555b67998d409fdf04bfe96c521598f79780ce67 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 22 Mar 2026 16:27:57 -0500 Subject: [PATCH 42/51] Update docstring (feedback from Peter) Provide more precise requirements for `f` and `norm_f` and give some basic understanding of what these functions are and when normalisation will be required. Attempt to fix syntax errors by adding a newline in code blocks. Clarify output dimensions in the code comments (though this might need to change since I think we'll remove the `np.moveaxis` call at the end. --- python/tskit/trees.py | 67 ++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 81ea4052e6..e8490f5f84 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10945,42 +10945,47 @@ def two_locus_count_stat( """ Compute two-locus statistics with a user-defined python function that operates on haplotype counts. TODO: reference modes in two-locus docs. - On each pair of sites or trees, the summary function is provided with - ``X``, a matrix with shape (3, k) and ``n``, a vector with shape (k,), - where k is the number of sample sets provided. ``X`` is a read-only - matrix whose rows contain haplotype counts per sample set (counts of AB, - Ab, aB) and ``n`` is a vector of sample set sizes. - - .. note:: - Because we are operating on very small matrices/vectors, vectorised - operations are often times slower than operations on scalars. Simply - returning ``[value]`` can be faster than returning - ``value[np.newaxis,]`` or ``np.expand_dims(value, 0)``. - - What follows is an example of computing ``D`` from a tree sequence. Many - more examples can be found in the test suite - ``test_ld_matrix.py::GeneralStatsFuncs``. Let's begin with our summary - function, ``D``. We convert counts to proportions, then compute ``D``, - returning a numpy array with length equal to the number of sample sets. + On each pair of sites or trees, the summary function is called with + haplotype counts for all provided sample sets. The summary function + (``f``) must accept two parameters: ``X``, a matrix with shape (3, k) + and ``n``, a vector with shape (k,), where k is the number of sample + sets provided. ``X`` is a read-only matrix whose rows contain haplotype + counts (AB, Ab, aB) per sample set and ``n`` is a vector of sample set + sizes. ``f`` must return a list of results with length ``result_dim``. + + What follows is an example of computing ``D`` from a tree sequence + (TODO: cite two-locus docs for more details). We convert counts to + proportions, then compute ``D``, returning a numpy array with length + equal to the number of ``result_dim``s. .. code-block:: python + def D(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB return pAB - (pA * pB) - ``norm_f`` is a normalisation function used to combine all computed - statistics for multiallelic allele pairs (TODO: see two-locus - docs). Biallelic sites do not require any normalisation (in fact, the - normalisation function is never called for biallelic sites). If one of - either site A or site B is multiallelic, then the normalisation function - will be called. The default normalisation function is identical to - ``total_norm`` shown in the example below. ``hap_norm`` is required for - normalising :math:`r^2`. Both of these examples return a numpy array - with length equal to the number of sample sets (for one-way stats). + The summary function is called for each pair of sites or trees, + producing results that must be combined when multiallelic sites are + present (``site`` mode only), summary function results must + need to be normalised in order to be aggragated for all pairs of alleles + between both sites. Branch statistics and biallelic sites do not require + any normalisation, ``norm_f`` is only called if one of the two sites + under consideration is multiallelic. TODO: reference two-locus docs for + further information about normalisation. ``norm_f`` is a normalisation + function that must accept four parameters: ``X`` and ``n`` are the same + inputs that ``f`` accepts, along with ``nA`` and ``nB``, which hold the + count of ``A`` alleles and ``B`` alleles. For example, if ``A`` is + biallelic and ``B`` is triallelic, ``nA=2`` and ``nB=3``. ``f`` must + return a list of results with length ``result_dim``. The default + normalisation function is identical to ``total_norm`` shown in the + example below. ``hap_norm`` is required for normalising + :math:`r^2`. Both of these examples return a numpy array with length + equal to the number of ``result_dim``s. .. code-block:: python + def total_norm(X, n, nA, nB): [1 / (nA * nB)] * result_dim @@ -10990,6 +10995,7 @@ def hap_norm(X, n, nA, nB): A simple call (without specifying normalisation) would look like this .. code-block::python + ts.two_locus_count_stat([ts.samples()], D, 1, polarised=True) :param list sample_sets: A list of lists of Node IDs, specifying the @@ -11000,7 +11006,8 @@ def hap_norm(X, n, nA, nB): :param int result_dim: The length of ``f`` and ``norm_f``'s return value. :param norm_f: A function that takes four arguments - the first two are the same as ``f``, the second two are scalars representing the - number of A and B alleles, respectively. + number of A and B alleles, respectively. If ``None``, then defaults + to the "total" normalization described above. :param bool polarised: Whether to leave the ancestral state out of computations: see :ref:`sec_stats` for more details. :param list sites: TODO: two-locus docs @@ -11008,14 +11015,14 @@ def hap_norm(X, n, nA, nB): :param str mode: A string giving the "type" of the statistic to be computed (defaults to "site"). :return: A ndarray with shape equal to (TODO: reference two-locus docs, - no dimension dropping shape=(k, m, m) where k=num_sample_sets, + no dimension dropping shape=(k, m, m) where k=result_dim, m=num_sites or num_trees). """ row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) if norm_f is None: - # produce the same number of dims as output dimensions with [val] * dim + # produce the same number of dims as result dimensions with [val] * dim def norm_f(X, n, nA, nB): return [1 / (nA * nB)] * result_dim @@ -11032,7 +11039,7 @@ def norm_f(X, n, nA, nB): col_positions, mode, ) - # Orient the data so that the first dimension is the sample set so that + # Orient the data so that the first dimension is the result_dim so that # we get one LD matrix per result dimension return np.moveaxis(result, -1, 0) From 17e8d438cad99302d791be9f219856e2cb208fa8 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 22 Mar 2026 17:48:49 -0500 Subject: [PATCH 43/51] turns out the documentation build doesn't like ``result_dim``s changing to ``result_dim`` fixes the docs. --- python/tskit/trees.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index e8490f5f84..0bf81ccf03 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10956,7 +10956,7 @@ def two_locus_count_stat( What follows is an example of computing ``D`` from a tree sequence (TODO: cite two-locus docs for more details). We convert counts to proportions, then compute ``D``, returning a numpy array with length - equal to the number of ``result_dim``s. + equal to the number of ``result_dim``. .. code-block:: python @@ -10982,7 +10982,7 @@ def D(X, n): normalisation function is identical to ``total_norm`` shown in the example below. ``hap_norm`` is required for normalising :math:`r^2`. Both of these examples return a numpy array with length - equal to the number of ``result_dim``s. + equal to the number of ``result_dim``. .. code-block:: python From dd8e8e8f27071be1169a0ff7c0585d52f4e398d9 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 31 Mar 2026 11:11:08 -0500 Subject: [PATCH 44/51] make tsk_treeseq_two_locus_count_stat private --- c/tskit/trees.c | 2 +- c/tskit/trees.h | 8 -------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 3be4073971..60b180bdd8 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3467,7 +3467,7 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, } /* Called by summary functions implemented in C */ -int +static int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 2bf1a26cc9..878746a954 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1039,14 +1039,6 @@ int tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t K, const doub typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights, tsk_size_t result_dim, tsk_size_t n_a, tsk_size_t n_b, double *result, void *params); -int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, - tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, - const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, - general_stat_func_t *f, norm_func_t *norm_f, tsk_size_t out_rows, - const tsk_id_t *row_sites, const double *row_positions, tsk_size_t out_cols, - const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, - double *result); - /* One way weighted stats */ typedef int one_way_weighted_method(const tsk_treeseq_t *self, tsk_size_t num_weights, From f27ee8251d1e8b6d4211135d345880a6b0160974 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 31 Mar 2026 11:19:28 -0500 Subject: [PATCH 45/51] clarify comment a bit --- c/tskit/trees.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 60b180bdd8..780add9afb 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3466,7 +3466,7 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, return ret; } -/* Called by summary functions implemented in C */ +/* Wrapper of `tsk_treeseq_two_locus_count_general_stat` for C summary Functions */ static int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, From 44c5922daed7522efa99ff1e5045f48d975b2554 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 31 Mar 2026 11:19:53 -0500 Subject: [PATCH 46/51] fix one C test that relied on tsk_treeseq_two_locus_count_stat --- c/tests/test_stats.c | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 50a1c84417..f7c51787c8 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -3680,9 +3680,8 @@ test_two_locus_stat_input_errors(void) row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_treeseq_two_locus_count_stat(&ts, num_sample_sets, sample_set_sizes, - sample_sets, 0, NULL, NULL, NULL, num_sites, row_sites, NULL, num_sites, - col_sites, NULL, 0, result); + ret = tsk_treeseq_r2(&ts, num_sample_sets, 0, sample_sets, num_sites, row_sites, + NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_RESULT_DIMS); ret = tsk_treeseq_r2(&ts, 1, sample_set_sizes, sample_sets, num_sites, row_sites, From 9bba0bf4aa9c3da7740177fd320ea044a8a830ba Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 31 Mar 2026 11:34:47 -0500 Subject: [PATCH 47/51] use general stat function in C test It's impossible to test that error mode with the C summary functions. Also, rename `tsk_treeseq_two_locus_count_general_stat` to `tsk_treeseq_two_locus_general_count_stat`. It makes more sense this way. --- c/tests/test_stats.c | 5 +++-- c/tskit/trees.c | 6 +++--- c/tskit/trees.h | 2 +- python/_tskitmodule.c | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index f7c51787c8..249409dbd1 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -3680,8 +3680,9 @@ test_two_locus_stat_input_errors(void) row_sites, NULL, num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_treeseq_r2(&ts, num_sample_sets, 0, sample_sets, num_sites, row_sites, - NULL, num_sites, col_sites, NULL, 0, result); + ret = tsk_treeseq_two_locus_general_count_stat(&ts, num_sample_sets, + sample_set_sizes, sample_sets, 0, NULL, NULL, NULL, num_sites, row_sites, NULL, + num_sites, col_sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_RESULT_DIMS); ret = tsk_treeseq_r2(&ts, 1, sample_set_sizes, sample_sets, num_sites, row_sites, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 780add9afb..7773e97f97 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3392,7 +3392,7 @@ check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_s /* Called directly by C python interface `two_locus_count_stat` */ int -tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, +tsk_treeseq_two_locus_general_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, @@ -3466,7 +3466,7 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, return ret; } -/* Wrapper of `tsk_treeseq_two_locus_count_general_stat` for C summary Functions */ +/* Wrapper of `tsk_treeseq_two_locus_general_count_stat` for C summary Functions */ static int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, @@ -3479,7 +3479,7 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl .num_sample_sets = num_sample_sets, .sample_set_sizes = sample_set_sizes, .set_indexes = set_indexes }; - return tsk_treeseq_two_locus_count_general_stat(self, num_sample_sets, + return tsk_treeseq_two_locus_general_count_stat(self, num_sample_sets, sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, row_sites, row_positions, out_cols, col_sites, col_positions, options, result); } diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 878746a954..c056b1478e 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1112,7 +1112,7 @@ typedef int general_sample_stat_method(const tsk_treeseq_t *self, const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); -int tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, +int tsk_treeseq_two_locus_general_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index fbb148097f..d61776c74d 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -8196,7 +8196,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * .summary_func = summary_func, .norm_func = norm_func, }; - err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, + err = tsk_treeseq_two_locus_general_count_stat(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), output_dim, general_two_locus_count_stat_func, params, general_two_locus_norm_func, result_dim[0], row_sites_parsed, From 88fee030baae48f8220bdc2ddef225ac26b8ddd8 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 31 Mar 2026 11:40:40 -0500 Subject: [PATCH 48/51] clarify comment about norm func Co-authored-by: Peter Ralph --- python/tests/test_ld_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 4a071ab00f..cf69d572a0 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2592,7 +2592,7 @@ def choose(cls, stat, mode, ts): return cls.k_way_hap_norm case (_, "branch", _): # branch stats do not need a norm func return cls.assert_no_norm_func - case (_, _, False): # biallelic sites do not need a norm func + case (_, _, False): # biallelic sites should not use the norm func return cls.assert_no_norm_func case _: # total_norm is default (1 / (nA * nB)). handles multi-way stats return None From 9b581b7a4f2d642e36b2150adf36b68ff39e0912 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 7 Apr 2026 13:44:17 -0500 Subject: [PATCH 49/51] remove superfluous docstrings (comments from peter) --- python/tests/test_ld_matrix.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index cf69d572a0..584e1a4791 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2530,7 +2530,6 @@ def D2_ij(X, n): @staticmethod def D2_ij_unbiased(X, n): - """The identity of the sample sets is up to the user.""" AB, Ab, aB = X ab = n - X.sum(0) return [ @@ -2541,10 +2540,6 @@ def D2_ij_unbiased(X, n): @staticmethod def D2_ii_ij_jj_unbiased(X, n): - """ - Multiple stats can be computed from the same data. The identity of the - sample sets is up to the user. This function assumes two sample sets. - """ AB, Ab, aB = X ab = n - X.sum(0) From b66dd867c2a3012e5436ded0b7e6c94214768f83 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 7 Apr 2026 13:44:48 -0500 Subject: [PATCH 50/51] incorporate Aaron's comments, add links to two-locus docs, a bit of tidying --- python/tskit/trees.py | 85 ++++++++++++++++++++++++++----------------- 1 file changed, 51 insertions(+), 34 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 0bf81ccf03..d5359b4f1a 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10944,19 +10944,24 @@ def two_locus_count_stat( ): """ Compute two-locus statistics with a user-defined python function that - operates on haplotype counts. TODO: reference modes in two-locus docs. - On each pair of sites or trees, the summary function is called with - haplotype counts for all provided sample sets. The summary function - (``f``) must accept two parameters: ``X``, a matrix with shape (3, k) - and ``n``, a vector with shape (k,), where k is the number of sample - sets provided. ``X`` is a read-only matrix whose rows contain haplotype - counts (AB, Ab, aB) per sample set and ``n`` is a vector of sample set - sizes. ``f`` must return a list of results with length ``result_dim``. - - What follows is an example of computing ``D`` from a tree sequence - (TODO: cite two-locus docs for more details). We convert counts to - proportions, then compute ``D``, returning a numpy array with length - equal to the number of ``result_dim``. + operates on haplotype counts. Statistics can be computed in ``site`` + mode (see :ref:`sec_stats_two_locus_site`) or ``branch`` mode (see + :ref:`sec_stats_two_locus_branch`). On each pair of sites or trees, the + summary function is called with haplotype counts for all provided sample + sets. The summary function (``f``) must accept two parameters: ``X``, a + matrix with shape (3, k) and ``n``, a vector with shape (k,), where k is + the number of sample sets provided. ``X`` is a read-only matrix whose + rows contain haplotype counts (AB, Ab, aB) per sample set and ``n`` is a + read-only vector of sample set sizes. ``f`` and ``norm_f`` must return a + list of results with length ``result_dim``. + + What follows is an example of computing ``D`` from a tree sequence. The + result will be equivalent to using ``ts.ld_matrix(stat="D")`` (see + :ref:`sec_stats_two_locus` for usage of the built-in LD matrix + calculation, and :ref:`sec_stats_two_locus_summary_functions` for + available statistics). In the example summary function, we convert + counts to proportions, then compute ``D``, returning a numpy array with + length equal to the number of sample sets. .. code-block:: python @@ -10968,21 +10973,21 @@ def D(X, n): The summary function is called for each pair of sites or trees, producing results that must be combined when multiallelic sites are - present (``site`` mode only), summary function results must - need to be normalised in order to be aggragated for all pairs of alleles - between both sites. Branch statistics and biallelic sites do not require - any normalisation, ``norm_f`` is only called if one of the two sites - under consideration is multiallelic. TODO: reference two-locus docs for - further information about normalisation. ``norm_f`` is a normalisation - function that must accept four parameters: ``X`` and ``n`` are the same - inputs that ``f`` accepts, along with ``nA`` and ``nB``, which hold the - count of ``A`` alleles and ``B`` alleles. For example, if ``A`` is - biallelic and ``B`` is triallelic, ``nA=2`` and ``nB=3``. ``f`` must - return a list of results with length ``result_dim``. The default - normalisation function is identical to ``total_norm`` shown in the - example below. ``hap_norm`` is required for normalising - :math:`r^2`. Both of these examples return a numpy array with length - equal to the number of ``result_dim``. + present (``site`` mode only), so summary function results must need to + be normalised in order to be aggragated for all pairs of alleles between + both sites. Branch statistics and biallelic sites do not require any + normalisation, and ``norm_f`` is only called if one of the two sites + under consideration is multiallelic. See + :ref:`sec_stats_two_locus_computational_details` for further information + about normalisation. ``norm_f`` is a normalisation function that must + accept four parameters: ``X`` and ``n`` are the same inputs that ``f`` + accepts, along with ``nA`` and ``nB``, which hold the count of ``A`` + alleles and ``B`` alleles. For example, if ``A`` is biallelic and ``B`` + is triallelic, ``nA=2`` and ``nB=3``. ``f`` must return a list of + results with length ``result_dim``. The default normalisation function + is identical to ``total_norm`` shown in the example below. ``hap_norm`` + is required for normalising :math:`r^2`. Both of these examples return a + numpy array with length equal to the number of sample sets. .. code-block:: python @@ -10994,7 +10999,7 @@ def hap_norm(X, n, nA, nB): A simple call (without specifying normalisation) would look like this - .. code-block::python + .. code-block:: python ts.two_locus_count_stat([ts.samples()], D, 1, polarised=True) @@ -11010,13 +11015,25 @@ def hap_norm(X, n, nA, nB): to the "total" normalization described above. :param bool polarised: Whether to leave the ancestral state out of computations: see :ref:`sec_stats` for more details. - :param list sites: TODO: two-locus docs - :param list positions: TODO: two-locus docs + :param list sites: A list of lists of sites over which to compute an + LD matrix. Can be specified as a list of lists to control the row + and column sites. Only available in "site" mode. Specify as + ``[row_sites, col_sites]`` or ``[all_sites]``. + Defaults to all sites. More information can be found in the + docstring of :meth:`.ld_matrix` + :param list positions: A list of lists of genomic positions where + expected LD is computed based on tree topologies and branch + lengths. Only applicable in "branch" mode. Specify as a list of + two lists to control the row and column positions, as + ``[row_positions, col_positions]``, or ``[all_positions]``. More + information can be found in the docstring of :meth:`.ld_matrix` + Defaults to the leftmost coordinates of all trees and computes + LD between all pairs of trees. :param str mode: A string giving the "type" of the statistic to be computed (defaults to "site"). - :return: A ndarray with shape equal to (TODO: reference two-locus docs, - no dimension dropping shape=(k, m, m) where k=result_dim, - m=num_sites or num_trees). + :return: A ndarray with shape equal to shape=(k, m, m) where + k=result_dim, m=(num_sites or num_trees, restricted by ``sites`` or + ``positions``). """ row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) From fbcedfe2087b1078eb8eea5f4d9a0d7ac868e92f Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 7 Apr 2026 13:46:41 -0500 Subject: [PATCH 51/51] add computational details section --- docs/stats.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/stats.md b/docs/stats.md index fe042246b9..b0f8b7b55f 100644 --- a/docs/stats.md +++ b/docs/stats.md @@ -738,6 +738,8 @@ ld = ts.ld_matrix(sites=[[1, 2], [1, 2, 3]]) print(ld) ``` +(sec_stats_two_locus_computational_details)= + #### Computational details Because we allow for two-locus statistics to be computed for multi-allelic