Add conditional clade topology sampling and gradient estimators#110
Open
christiaanjs wants to merge 24 commits into
Open
Add conditional clade topology sampling and gradient estimators#110christiaanjs wants to merge 24 commits into
christiaanjs wants to merge 24 commits into
Conversation
Implements a conditional clade / subsplit Bayesian network (SBN) distribution over rooted tree topologies, plus a comparison of gradient estimators for learning its parameters through discrete topology samples. treeflow/conditional_clade/: - clade.py: bitset-based clade and subsplit representations with canonical ordering, enumeration of a clade's subsplits, and binary-vector views that support an embedding-based parametrisation. - support.py: enumeration of the conditional clade support, flat indexing of (parent clade, subsplit) pairs for a single logit vector, conversion to/from TreeFlow's parent_indices topology encoding, exhaustive topology enumeration, and binary-vector feature matrices. - distribution.py: a differentiable distribution over topologies with sampling, log-probability, exact enumeration, exact KL and clade-visitation DP for small taxon sets; conditional probabilities via a per-parent segmented softmax. - estimators.py: score function (REINFORCE), leave-one-out baseline, VIMCO, straight-through Gumbel-Softmax, and the "1-0 probability gradient" sampler, plus a relaxed recursive sampler giving a differentiable cost. experiments/conditional_clade_gradient_estimators.ipynb: fits a variational CCD to a target CCD by SGD, comparing the estimators against the exact enumerated gradient (variance, bias and KL convergence). Tests in test/conditional_clade/ cover the representation, enumeration, parent_indices round-tripping, normalisation, sampling vs exact pmf, exact KL, estimator unbiasedness against the exact gradient, and optimisation convergence. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
The eager ConditionalCladeDistribution samples with Python recursion, NumPy RNG and .numpy() calls, so it cannot run inside a tf.function. Add a fully tensor-based, graph-compatible implementation and wrap it as a tfp Distribution. treeflow/conditional_clade/tensor_ops.py: graph-compatible operations built on tf.while_loop / tf.cond: - sample_parent_indices: samples a topology via an explicit two-state depth-first traversal, assigning internal node indices on the way out so they come in post-order (matching TreeFlow's parent_indices convention); per-clade lookups use dense length-2**n tensors keyed by clade bitset. - parent_indices_to_child_indices / child_indices_to_preorder: derive the remaining topology arrays with segment ops and a traversal loop. - topology_log_prob: computes node clades, reads off each subsplit and maps it to its flat index via a (parent, canonical child1) hash table; differentiable in the conditional log-probs. treeflow/conditional_clade/tree_distribution.py: ConditionalCladeTreeDistribution, a tfp Distribution whose samples are TensorflowTreeTopology objects. sample() and log_prob() run inside tf.function; it composes the eager distribution for the exact enumeration/KL utilities and shares its logits. Tests verify the tensor child_indices/preorder match the NumPy reference, graph log_prob matches the eager distribution and is normalised, sampling matches the exact pmf under tf.function, and a REINFORCE loop driven entirely in graph mode reduces KL to a target. The notebook gains a section demonstrating the graph-mode distribution. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…ogy transforms
Compiled TensorFlow custom ops mirroring the graph-mode tensor_ops reference,
exposed through the existing treeflow/acceleration/native infrastructure.
cc/conditional_clade_op.cc:
- ConditionalCladeSample: samples parent_indices (one topology per seed row)
from per-subsplit logits via recursive expansion, assigning internal node ids
in post-order. The data-dependent recursion that needs an explicit two-state
tf.while_loop in the TF reference is plain C++ recursion here.
- ConditionalCladeLogProb (+ Grad): log-probability of a topology with an
analytic scatter-add gradient w.r.t. the conditional log-probabilities.
- ParentIndicesToChildIndices / ChildIndicesToPreorder: the topology index
transforms.
treeflow/acceleration/native/conditional_clade.py: Python wrappers and gradient
registration, matching the node_height_ratio op's structure.
ConditionalCladeTreeDistribution gains use_native ("auto" by default): sample and
log_prob route through the native ops when the library is loadable, else the
pure-TensorFlow path. Both run inside tf.function.
Wired into build.sh / build.py (build_conditional_clade) and the native package
exports; README documents the new ops. Tests (marked native, auto-skipped if the
op cannot be built) check the sampler against the exact pmf, child_indices /
preorder against the NumPy reference, log_prob and its gradient against the eager
distribution and autodiff, and that the native and TensorFlow distribution paths
agree.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
Add treeflow/conditional_clade/traversal_estimators.py: graph-op helpers that compute gradient-estimator quantities from a *pre-sampled* topology traversal (the chosen flat subsplit indices of a batch of sampled trees), so every estimator shares one fast, vectorised code path instead of per-node Python. - traversal_log_prob: exact log q(T) (the score-function path for REINFORCE / leave-one-out / VIMCO). - straight_through_traversal_log_prob / straight_through_traversal_cost: the straight-through (optionally Gumbel-Softmax) relaxation whose forward value is the exact log-prob but whose gradient flows through the per-clade softmax, built from a per-decision masked softmax over the flat subsplit space. This replaces the recursive sample_relaxed_cost (a host-device sync per node) in the training path: ~10x faster (≈58 vs ≈705 ms/step at 8 taxa), runs inside tf.function, and gives identical forward values. ConditionalCladeDistribution gains straight_through_log_prob_from_flat_indices and routes log_prob_from_flat_indices through the shared helper. Notebook: the straight-through estimator now uses the vectorised cost; VIMCO is given a fair, dedicated section showing it ascends the K-sample bound L_K (which it drives to ≈0) rather than reverse KL, with a note on the noise floor and the larger-K-weakens-the-proposal-gradient effect. Tests in test_traversal_estimators.py: straight-through forward equals the exact log-prob, the temperature-1 gradient matches the expected-conditional-log-prob reference, the cost matches the recursive reference, and a vectorised-cost training run reduces KL inside tf.function. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
Section 6 now demonstrates VIMCO's leave-one-out control variate cuts the variance of its own bound's (L_K) gradient by ~100x versus plain REINFORCE on the same bound -- the meaningful, fair variance check (comparing to the reverse-KL gradient would be apples-to-oranges). The summary adds a concrete "which estimator to use" recommendation: RLOO as the unbiased default for reverse KL, straight-through when a little bias buys speed, VIMCO/RLOO-on-the-ELBO for the real (unnormalised posterior) inference problem, and Rao-Blackwellised / RELAX / Gumbel-Rao as the next variance-reduction candidates. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
estimators.py: rao_blackwellized_surrogate (+ _cost_to_go DP). At each internal node the categorical over subsplits is summed analytically using an exact cost-to-go V(clade) computed by dynamic programming, so only the sampled tree structure carries Monte-Carlo noise. Unbiased, lower variance than RLOO (verified: cosine 0.999 to the exact gradient, ~2.6x lower variance at 5 taxa). This is tractable only because a CCD target makes the reverse-KL cost decompose additively over subsplits. Notebook: Rao-Blackwellized joins the head-to-head comparison, plus an aside answering when these estimators transfer to real phylogenetic inference: - RB needs a decomposable cost, so it does not transfer to the (global, non-decomposable) phylogenetic likelihood; only the entropy/prior part Rao-Blackwellises cheaply. - VIMCO can be Rao-Blackwellized only partially (its bound is a log-sum-exp of importance weights, not additive over nodes). - straight-through/Gumbel gives no usable gradient for a non-decomposable likelihood (nothing for the relaxed one-hot to dot with; the hard tree's likelihood is constant in theta), besides being biased -- hence VBPI uses multi-sample score-function/VIMCO bounds for the topology. Tests: RB is unbiased against the exact gradient, has lower variance than RLOO, and reduces KL during training. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
The native ConditionalCladeSample op now also emits flat_indices [B, n-1] (the chosen flat subsplit index per internal node, in expansion order -- the order is irrelevant downstream since the estimators sum over the n-1 decisions). tensor_ops gains parent_indices_to_flat_indices for the pure-TF fallback (Route A), factored out of topology_log_prob. ConditionalCladeTreeDistribution.sample_flat_indices(n, seed) is the graph-mode counterpart of the eager NumPy sample_flat_index_batch: it returns an [n, taxon_count-1] int32 tensor entirely from tensor ops (native when available, else the tf.while_loop sampler), so a whole training step -- sampling included -- can run inside tf.function. Verified both paths match the exact pmf and that sample + log_prob compose in a single tf.function. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…h gather) treeflow/conditional_clade/relaxed_likelihood.py: - straight_through_gather(values, selection): the production primitive. Forward is a plain tf.gather (the hard child selection -- no dense one-hot materialised); a tf.custom_gradient backward sends gradients back *as if* the selection had been applied by a one-hot @ matmul. Validated to match autodiff of the dense product exactly (forward and gradient). - relaxed_phylogenetic_likelihood(child_selection, ..., gather=True/False): the Felsenstein combine with the child gather routed through either the efficient straight-through gather (gather=True) or an explicit dense one-hot matmul (gather=False). The dense path exists only to validate the efficient one against autodiff; verified the two agree on value and gradient, and that the relaxed combine equals a direct Felsenstein computation. - child_selection_from_topology: builds the one-hot selection from an ordinary topology (the "switched-off" case, reproducing the standard likelihood). This is a separate, opt-in path; the production LeafCTMC integer-gather likelihood is untouched and stays the default for fixed-topology inference. Routing the selection gradient to the clade categorical (and the alternative-subsplit-partials / materialisation question) is the next step. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
straight_through_gather computes d_selection[o,k] = <d_out[o], values[k]> independently per candidate, and the upstream d_out depends only on the realised gathered row. So a sampled (non-exhaustive) candidate subset still runs and gives gradients that are exact for each candidate in the subset (omitted candidates get no gradient); the only approximation is sampled-softmax normalisation of the selection. Verified with a test (subset gradient == full gradient restricted to the subset, forward unchanged, degenerate single-candidate sets run) and documented on the primitive. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…model clade_likelihood.clade_straight_through_log_likelihood runs Felsenstein pruning over a sampled topology while selecting each internal node's children with a straight-through softmax over the clade model's subsplit probabilities, so d log L / d (clade conditional log-probs) flows back to the clade logits. The forward pass is the exact likelihood of the sampled tree. The candidate alternative subsplits (whose contrast carries the informative gradient) are pluggable: candidate_subsplits_fn selects them (exhaustive for an enumerable support, or the realised one plus a sample), and alternative_partial_fn supplies non-realised child partials (default: sample a subtree from the clade model). Because straight_through_gather differentiates each candidate independently, a sampled candidate set degrades to a sampled-softmax approximation while the routing stays exact. Validated: forward == direct Felsenstein; gradient reaches the clade logits; gather routing == dense one-hot multiply; the forward is unchanged under a sampled candidate set and the gradient still flows. A single fixed transition matrix isolates the topology gradient (per-edge branch lengths are an orthogonal next step). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…hood clade_straight_through_log_likelihood and sampled_subtree_partial_fn now accept `transition` as either a fixed matrix (equal branch lengths, as before) or a callable clade -> [state, state] giving the transition matrix on the edge above each child clade -- the hook for joint branch-length optimisation (e.g. heights from a node-height ratio transform). Verified the forward equals a direct Felsenstein computation with per-edge matrices; existing fixed-matrix tests unchanged. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…nt branches experiments/conditional_clade_phylo_estimators.ipynb fits a CCD topology and branch lengths jointly to a simulated likelihood-based posterior (Jukes-Cantor, 5 taxa). Branch lengths are a root height plus a node-height ratio per split (indexed by clade), mapped to heights by the existing node-height ratio transform (pathwise gradient); the topology gradient is compared across estimators: - straight-through via clade_straight_through_log_likelihood (the likelihood gradient reaches the clade model), - score function + leave-one-out (RLOO), - VIMCO (multi-sample bound over topologies). Each is scored by the probability q places on the true topology and the total-variation distance to the exact enumerated topology posterior. Smoke-tested end to end; straight-through drives q(true tree) from ~0.02 to 0.40 in 30 steps (uniform 1/105) on the simulated data. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…book Add a step-by-step derivation (section 4.1) of the relaxed / straight-through phylogenetic likelihood: Felsenstein pruning, child gather as a one-hot multiply, the relaxed selection, the straight-through gather primitive and its forward/backward, routing the gradient into the clade logits via the straight-through softmax contrast, and why that contrast needs each subsplit's combined child partials (so an independent realized-node gather flips the gradient sign while the sampled-subtree route stays correctly signed). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
Marginalise the topology by Felsenstein sum-product over the full subsplit DAG: Pi[c] = sum_s w_s (P_X . Pi[X]) (x) (P_Y . Pi[Y]), giving the per-site tree-marginal likelihood sum_s log E_q[L_s(T)] exactly, in DAG-size time rather than over all (2n-3)!! trees. This is a deterministic, exact-autodiff topology gradient -- the tractable analogue of the straight-through likelihood for an enumerable DAG. Two passes over the static DAG, both tf.function-compatible: - relaxed_partials_sequential: unrolled child-before-parent loop (simplest, no masking, arbitrary leading dims); - relaxed_partials_vectorized: level-based batched einsum + segment-sum scatter (fewer, larger ops; the shape for a native kernel). Tests check the sequential pass equals a brute-force sum over every enumerated tree in value and exact gradient (n=3,4,5), vectorized == sequential (value and gradient, incl. per-clade transitions), and graph-mode tracing. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
New section 6 marginalises the topology by the subsplit-DAG sum-product (generalized_pruning), an exact, deterministic, zero-variance topology gradient. The cell verifies the gradient is identical across calls and trains it jointly with per-edge branch lengths. Framed honestly: the DAG sum-product computes the per-site tree-marginal sum_s log E_q[L_s], a Jensen upper bound on the ELBO term E_q[log L] (the log cannot pass through the sum-product), so it optimises a different objective and empirically settles short of the exact posterior -- below even the stochastic straight-through estimator. The lesson: "exact, zero-variance" describes the gradient, not the objective. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
The DAG sum-product weights need not be the exact marginal conditional probs: add interchangeable weight samplers (exact / Gumbel-softmax / straight-through / Gumbel-softmax straight-through) selectable via weight_fn on relaxed_log_likelihood_from_distribution. Key property (tested): with HARD one-hot-per-clade weights the sum-product collapses onto a single tree, so the objective becomes the proper across-sites likelihood sum_s log L_s(T) = log L(T) -- no Jensen gap. A Gumbel-max draw per clade is an exact CCD tree sample and supplies every clade's partial at once, so a straight-through weighting turns generalized pruning into a Gumbel-softmax / straight-through topology estimator on the correct joint objective. Soft (Gumbel-)softmax weights interpolate between the per-site marginal and the single-tree objective via temperature. Tests: hard weights == direct Felsenstein on the selected tree (Gumbel-max and argmax); one-hot-per-segment forward; exact_weights == default; sampled-weight gradients finite and nonzero; low temperature hardens the relaxation. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
Section 6 now drops the decision-weight samplers into one generalized-pruning DAG and compares the resulting objectives: exact marginal (deterministic, per-site tree-marginal), Gumbel-softmax, and straight-through. It checks the marginal gradient is zero-variance, and shows that hardening the per-clade decision (straight-through) collapses the DAG onto a single tree -- recovering the proper across-sites objective and lifting q(true) from ~0.04 to ~0.3, on par with the straight-through likelihood. Markdown and summary updated to frame objective bias vs gradient bias as the knob the weighting controls. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
… metrics Section 7 makes the score-function / VIMCO comparison fair and graph-mode: - A hard one-hot weight vector is a single sampled tree, so the generalized-pruning DAG likelihood (relaxed_log_likelihood) gives that tree's across-sites log- likelihood as a tensor op. Feeding the batch's flat subsplit indices as a tensor input lets the whole score+LOO / VIMCO step run inside @tf.function, reusing TreeFlow's tensor pruning instead of the eager Python Felsenstein loop. (bito native isn't built here, so the DAG sum-product is the in-graph likelihood.) - All graph-mode estimators (score+LOO, VIMCO, GP straight-through) use the same per-edge branch model as generalized pruning, against one shared per-edge exact posterior, for a fair comparison. - Adds importance-weighted diagnostics: population ESS (1/sum p^2/q, proposal quality) and a finite-K self-normalised IS posterior estimate -- VIMCO-type methods can show modest q(true) yet a near-exact IS-corrected posterior. - Adds a GP straight-through variant that prices the concrete sampled tree with the section-4 ratio-transform branch model. Also fixes a gradient bug: the per-edge transition matrices must be built inside the GradientTape, otherwise gp_branch_raw receives no gradient and the branch lengths silently never train. Answers the Gumbel-softmax question in the section-7 markdown: soft GS weights put a soft mixture into the forward pass, which marginalises the topology per site (the Jensen-biased objective), so GS-soft tracks the marginal and lands low at any temperature; only the hard straight-through forward gives the across-sites objective. Not a misuse -- the wrong variant for this objective. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Implements a conditional clade / subsplit Bayesian network (SBN) distribution
over rooted tree topologies, plus a comparison of gradient estimators for
learning its parameters through discrete topology samples.
treeflow/conditional_clade/:
ordering, enumeration of a clade's subsplits, and binary-vector views that
support an embedding-based parametrisation.
(parent clade, subsplit) pairs for a single logit vector, conversion to/from
TreeFlow's parent_indices topology encoding, exhaustive topology enumeration,
and binary-vector feature matrices.
log-probability, exact enumeration, exact KL and clade-visitation DP for small
taxon sets; conditional probabilities via a per-parent segmented softmax.
straight-through Gumbel-Softmax, and the "1-0 probability gradient" sampler,
plus a relaxed recursive sampler giving a differentiable cost.
experiments/conditional_clade_gradient_estimators.ipynb: fits a variational CCD
to a target CCD by SGD, comparing the estimators against the exact enumerated
gradient (variance, bias and KL convergence).
Tests in test/conditional_clade/ cover the representation, enumeration,
parent_indices round-tripping, normalisation, sampling vs exact pmf, exact KL,
estimator unbiasedness against the exact gradient, and optimisation convergence.
Co-Authored-By: Claude Opus 4.8 noreply@anthropic.com
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa