Find k-nearest neighbors using a k-d tree in JAX!
This is an implementation of two GPU-friendly tree algorithms [1, 2] using only JAX primitives. The core build_tree, query_neighbors, and count_neighbors operations are compatible with JIT and automatic differentiation. They are reasonably fast when vectorized on GPU/TPU, but will be slower than SciPy's KDTree on CPU. For small problems where a pairwise distance matrix fits in memory, check whether brute force is faster (see jaxkd.extras).
If query speed is your performance bottleneck and you only use Nvidia GPUs, the jaxkd-cuda extension can be installed as an optional dependency (see below). The intention is to match the pure-JAX behavior and integrate seamlessly via the cuda=True argument. However, this library is still does not offer the maximum possible performance. Consider binding the original cudaKDTree library to JAX or check out the new jz-tree code which uses a different tree structure. The advantage of the pure-JAX jaxkd is that it is portable and easy to use, with the ability to scale up to larger problems without the complexity of integrating non-JAX libraries. Try it out!
import jax
import jaxkd as jk
kp, kq = jax.random.split(jax.random.key(83))
points = jax.random.normal(kp, shape=(100_000, 3))
queries = jax.random.normal(kq, shape=(10_000, 3))
tree = jk.build_tree(points)
counts = jk.count_neighbors(tree, queries, r=0.1)
neighbors, distances = jk.query_neighbors(tree, queries, k=10)There is also a one-step build_and_query for convenience, and all these functions accept cuda=True to use the CUDA extension if it is installed.
Additional helpful functionality can be found in jaxkd.extras.
query_neighbors_pairwiseandcount_neighbors_pairwisefor brute-force neighbor searchesk_meansfor clustering using k-means++ initialization, thanks to @NeilGirdhar for contributions
Suggestions and contributions for other extras are always welcome!
To install, use pip. The only dependency is jax.
python -m pip install jaxkd
Or with the CUDA extension. extension will require CMake and NVCC installed on your system.
python -m pip install jaxkd[cuda]
Or just grab tree.py.
If you use jaxkd in your research, please cite:
@software{jaxkd,
author = {Dodge, Benjamin},
title = {jaxkd: k-d trees in pure JAX},
url = {https://github.com/dodgebc/jaxkd},
year = {2025}
}