Skip to content

dodgebc/jaxkd

Repository files navigation

JAX k-D

Find k-nearest neighbors using a k-d tree in JAX!

This is an implementation of two GPU-friendly tree algorithms [1, 2] using only JAX primitives. The core build_tree, query_neighbors, and count_neighbors operations are compatible with JIT and automatic differentiation. They are reasonably fast when vectorized on GPU/TPU, but will be slower than SciPy's KDTree on CPU. For small problems where a pairwise distance matrix fits in memory, check whether brute force is faster (see jaxkd.extras).

If query speed is your performance bottleneck and you only use Nvidia GPUs, the jaxkd-cuda extension can be installed as an optional dependency (see below). The intention is to match the pure-JAX behavior and integrate seamlessly via the cuda=True argument. However, this library is still does not offer the maximum possible performance. Consider binding the original cudaKDTree library to JAX or check out the new jz-tree code which uses a different tree structure. The advantage of the pure-JAX jaxkd is that it is portable and easy to use, with the ability to scale up to larger problems without the complexity of integrating non-JAX libraries. Try it out!

Open In Colab

Usage

import jax
import jaxkd as jk

kp, kq = jax.random.split(jax.random.key(83))
points = jax.random.normal(kp, shape=(100_000, 3))
queries = jax.random.normal(kq, shape=(10_000, 3))

tree = jk.build_tree(points)
counts = jk.count_neighbors(tree, queries, r=0.1)
neighbors, distances = jk.query_neighbors(tree, queries, k=10)

There is also a one-step build_and_query for convenience, and all these functions accept cuda=True to use the CUDA extension if it is installed.

Additional helpful functionality can be found in jaxkd.extras.

  • query_neighbors_pairwise and count_neighbors_pairwise for brute-force neighbor searches
  • k_means for clustering using k-means++ initialization, thanks to @NeilGirdhar for contributions

Suggestions and contributions for other extras are always welcome!

Installation

To install, use pip. The only dependency is jax.

python -m pip install jaxkd

Or with the CUDA extension. extension will require CMake and NVCC installed on your system.

python -m pip install jaxkd[cuda]

Or just grab tree.py.

Citation

If you use jaxkd in your research, please cite:

@software{jaxkd,
  author = {Dodge, Benjamin},
  title = {jaxkd: k-d trees in pure JAX},
  url = {https://github.com/dodgebc/jaxkd},
  year = {2025}
}

About

Minimal JAX implementation of k-nearest neighbors using a k-d tree.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors