Skip to content

Add conditional clade topology sampling and gradient estimators#110

Open
christiaanjs wants to merge 24 commits into
masterfrom
claude/conditional-clade-sampling-ijnl3u
Open

Add conditional clade topology sampling and gradient estimators#110
christiaanjs wants to merge 24 commits into
masterfrom
claude/conditional-clade-sampling-ijnl3u

Conversation

@christiaanjs

Copy link
Copy Markdown
Owner

Implements a conditional clade / subsplit Bayesian network (SBN) distribution
over rooted tree topologies, plus a comparison of gradient estimators for
learning its parameters through discrete topology samples.

treeflow/conditional_clade/:

  • clade.py: bitset-based clade and subsplit representations with canonical
    ordering, enumeration of a clade's subsplits, and binary-vector views that
    support an embedding-based parametrisation.
  • support.py: enumeration of the conditional clade support, flat indexing of
    (parent clade, subsplit) pairs for a single logit vector, conversion to/from
    TreeFlow's parent_indices topology encoding, exhaustive topology enumeration,
    and binary-vector feature matrices.
  • distribution.py: a differentiable distribution over topologies with sampling,
    log-probability, exact enumeration, exact KL and clade-visitation DP for small
    taxon sets; conditional probabilities via a per-parent segmented softmax.
  • estimators.py: score function (REINFORCE), leave-one-out baseline, VIMCO,
    straight-through Gumbel-Softmax, and the "1-0 probability gradient" sampler,
    plus a relaxed recursive sampler giving a differentiable cost.

experiments/conditional_clade_gradient_estimators.ipynb: fits a variational CCD
to a target CCD by SGD, comparing the estimators against the exact enumerated
gradient (variance, bias and KL convergence).

Tests in test/conditional_clade/ cover the representation, enumeration,
parent_indices round-tripping, normalisation, sampling vs exact pmf, exact KL,
estimator unbiasedness against the exact gradient, and optimisation convergence.

Co-Authored-By: Claude Opus 4.8 noreply@anthropic.com
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa

claude and others added 24 commits June 23, 2026 07:57
Implements a conditional clade / subsplit Bayesian network (SBN) distribution
over rooted tree topologies, plus a comparison of gradient estimators for
learning its parameters through discrete topology samples.

treeflow/conditional_clade/:
- clade.py: bitset-based clade and subsplit representations with canonical
  ordering, enumeration of a clade's subsplits, and binary-vector views that
  support an embedding-based parametrisation.
- support.py: enumeration of the conditional clade support, flat indexing of
  (parent clade, subsplit) pairs for a single logit vector, conversion to/from
  TreeFlow's parent_indices topology encoding, exhaustive topology enumeration,
  and binary-vector feature matrices.
- distribution.py: a differentiable distribution over topologies with sampling,
  log-probability, exact enumeration, exact KL and clade-visitation DP for small
  taxon sets; conditional probabilities via a per-parent segmented softmax.
- estimators.py: score function (REINFORCE), leave-one-out baseline, VIMCO,
  straight-through Gumbel-Softmax, and the "1-0 probability gradient" sampler,
  plus a relaxed recursive sampler giving a differentiable cost.

experiments/conditional_clade_gradient_estimators.ipynb: fits a variational CCD
to a target CCD by SGD, comparing the estimators against the exact enumerated
gradient (variance, bias and KL convergence).

Tests in test/conditional_clade/ cover the representation, enumeration,
parent_indices round-tripping, normalisation, sampling vs exact pmf, exact KL,
estimator unbiasedness against the exact gradient, and optimisation convergence.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
The eager ConditionalCladeDistribution samples with Python recursion, NumPy RNG
and .numpy() calls, so it cannot run inside a tf.function. Add a fully
tensor-based, graph-compatible implementation and wrap it as a tfp Distribution.

treeflow/conditional_clade/tensor_ops.py: graph-compatible operations built on
tf.while_loop / tf.cond:
- sample_parent_indices: samples a topology via an explicit two-state depth-first
  traversal, assigning internal node indices on the way out so they come in
  post-order (matching TreeFlow's parent_indices convention); per-clade lookups
  use dense length-2**n tensors keyed by clade bitset.
- parent_indices_to_child_indices / child_indices_to_preorder: derive the
  remaining topology arrays with segment ops and a traversal loop.
- topology_log_prob: computes node clades, reads off each subsplit and maps it to
  its flat index via a (parent, canonical child1) hash table; differentiable in
  the conditional log-probs.

treeflow/conditional_clade/tree_distribution.py: ConditionalCladeTreeDistribution,
a tfp Distribution whose samples are TensorflowTreeTopology objects. sample() and
log_prob() run inside tf.function; it composes the eager distribution for the
exact enumeration/KL utilities and shares its logits.

Tests verify the tensor child_indices/preorder match the NumPy reference, graph
log_prob matches the eager distribution and is normalised, sampling matches the
exact pmf under tf.function, and a REINFORCE loop driven entirely in graph mode
reduces KL to a target. The notebook gains a section demonstrating the graph-mode
distribution.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…ogy transforms

Compiled TensorFlow custom ops mirroring the graph-mode tensor_ops reference,
exposed through the existing treeflow/acceleration/native infrastructure.

cc/conditional_clade_op.cc:
- ConditionalCladeSample: samples parent_indices (one topology per seed row)
  from per-subsplit logits via recursive expansion, assigning internal node ids
  in post-order. The data-dependent recursion that needs an explicit two-state
  tf.while_loop in the TF reference is plain C++ recursion here.
- ConditionalCladeLogProb (+ Grad): log-probability of a topology with an
  analytic scatter-add gradient w.r.t. the conditional log-probabilities.
- ParentIndicesToChildIndices / ChildIndicesToPreorder: the topology index
  transforms.

treeflow/acceleration/native/conditional_clade.py: Python wrappers and gradient
registration, matching the node_height_ratio op's structure.

ConditionalCladeTreeDistribution gains use_native ("auto" by default): sample and
log_prob route through the native ops when the library is loadable, else the
pure-TensorFlow path. Both run inside tf.function.

Wired into build.sh / build.py (build_conditional_clade) and the native package
exports; README documents the new ops. Tests (marked native, auto-skipped if the
op cannot be built) check the sampler against the exact pmf, child_indices /
preorder against the NumPy reference, log_prob and its gradient against the eager
distribution and autodiff, and that the native and TensorFlow distribution paths
agree.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
Add treeflow/conditional_clade/traversal_estimators.py: graph-op helpers that
compute gradient-estimator quantities from a *pre-sampled* topology traversal
(the chosen flat subsplit indices of a batch of sampled trees), so every
estimator shares one fast, vectorised code path instead of per-node Python.

- traversal_log_prob: exact log q(T) (the score-function path for REINFORCE /
  leave-one-out / VIMCO).
- straight_through_traversal_log_prob / straight_through_traversal_cost: the
  straight-through (optionally Gumbel-Softmax) relaxation whose forward value is
  the exact log-prob but whose gradient flows through the per-clade softmax,
  built from a per-decision masked softmax over the flat subsplit space.

This replaces the recursive sample_relaxed_cost (a host-device sync per node) in
the training path: ~10x faster (≈58 vs ≈705 ms/step at 8 taxa), runs inside
tf.function, and gives identical forward values. ConditionalCladeDistribution
gains straight_through_log_prob_from_flat_indices and routes
log_prob_from_flat_indices through the shared helper.

Notebook: the straight-through estimator now uses the vectorised cost; VIMCO is
given a fair, dedicated section showing it ascends the K-sample bound L_K (which
it drives to ≈0) rather than reverse KL, with a note on the noise floor and the
larger-K-weakens-the-proposal-gradient effect.

Tests in test_traversal_estimators.py: straight-through forward equals the exact
log-prob, the temperature-1 gradient matches the expected-conditional-log-prob
reference, the cost matches the recursive reference, and a vectorised-cost
training run reduces KL inside tf.function.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
Section 6 now demonstrates VIMCO's leave-one-out control variate cuts the variance
of its own bound's (L_K) gradient by ~100x versus plain REINFORCE on the same
bound -- the meaningful, fair variance check (comparing to the reverse-KL gradient
would be apples-to-oranges). The summary adds a concrete "which estimator to use"
recommendation: RLOO as the unbiased default for reverse KL, straight-through when
a little bias buys speed, VIMCO/RLOO-on-the-ELBO for the real (unnormalised
posterior) inference problem, and Rao-Blackwellised / RELAX / Gumbel-Rao as the
next variance-reduction candidates.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
estimators.py: rao_blackwellized_surrogate (+ _cost_to_go DP). At each internal
node the categorical over subsplits is summed analytically using an exact
cost-to-go V(clade) computed by dynamic programming, so only the sampled tree
structure carries Monte-Carlo noise. Unbiased, lower variance than RLOO (verified:
cosine 0.999 to the exact gradient, ~2.6x lower variance at 5 taxa). This is
tractable only because a CCD target makes the reverse-KL cost decompose additively
over subsplits.

Notebook: Rao-Blackwellized joins the head-to-head comparison, plus an aside
answering when these estimators transfer to real phylogenetic inference:
- RB needs a decomposable cost, so it does not transfer to the (global,
  non-decomposable) phylogenetic likelihood; only the entropy/prior part
  Rao-Blackwellises cheaply.
- VIMCO can be Rao-Blackwellized only partially (its bound is a log-sum-exp of
  importance weights, not additive over nodes).
- straight-through/Gumbel gives no usable gradient for a non-decomposable
  likelihood (nothing for the relaxed one-hot to dot with; the hard tree's
  likelihood is constant in theta), besides being biased -- hence VBPI uses
  multi-sample score-function/VIMCO bounds for the topology.

Tests: RB is unbiased against the exact gradient, has lower variance than RLOO,
and reduces KL during training.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
The native ConditionalCladeSample op now also emits flat_indices [B, n-1] (the
chosen flat subsplit index per internal node, in expansion order -- the order is
irrelevant downstream since the estimators sum over the n-1 decisions). tensor_ops
gains parent_indices_to_flat_indices for the pure-TF fallback (Route A), factored
out of topology_log_prob.

ConditionalCladeTreeDistribution.sample_flat_indices(n, seed) is the graph-mode
counterpart of the eager NumPy sample_flat_index_batch: it returns an
[n, taxon_count-1] int32 tensor entirely from tensor ops (native when available,
else the tf.while_loop sampler), so a whole training step -- sampling included --
can run inside tf.function. Verified both paths match the exact pmf and that
sample + log_prob compose in a single tf.function.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…h gather)

treeflow/conditional_clade/relaxed_likelihood.py:

- straight_through_gather(values, selection): the production primitive. Forward is
  a plain tf.gather (the hard child selection -- no dense one-hot materialised);
  a tf.custom_gradient backward sends gradients back *as if* the selection had
  been applied by a one-hot @ matmul. Validated to match autodiff of the dense
  product exactly (forward and gradient).

- relaxed_phylogenetic_likelihood(child_selection, ..., gather=True/False): the
  Felsenstein combine with the child gather routed through either the efficient
  straight-through gather (gather=True) or an explicit dense one-hot matmul
  (gather=False). The dense path exists only to validate the efficient one against
  autodiff; verified the two agree on value and gradient, and that the relaxed
  combine equals a direct Felsenstein computation.

- child_selection_from_topology: builds the one-hot selection from an ordinary
  topology (the "switched-off" case, reproducing the standard likelihood).

This is a separate, opt-in path; the production LeafCTMC integer-gather likelihood
is untouched and stays the default for fixed-topology inference. Routing the
selection gradient to the clade categorical (and the alternative-subsplit-partials
/ materialisation question) is the next step.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
straight_through_gather computes d_selection[o,k] = <d_out[o], values[k]>
independently per candidate, and the upstream d_out depends only on the realised
gathered row. So a sampled (non-exhaustive) candidate subset still runs and gives
gradients that are exact for each candidate in the subset (omitted candidates get
no gradient); the only approximation is sampled-softmax normalisation of the
selection. Verified with a test (subset gradient == full gradient restricted to
the subset, forward unchanged, degenerate single-candidate sets run) and
documented on the primitive.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…model

clade_likelihood.clade_straight_through_log_likelihood runs Felsenstein pruning
over a sampled topology while selecting each internal node's children with a
straight-through softmax over the clade model's subsplit probabilities, so
d log L / d (clade conditional log-probs) flows back to the clade logits. The
forward pass is the exact likelihood of the sampled tree.

The candidate alternative subsplits (whose contrast carries the informative
gradient) are pluggable: candidate_subsplits_fn selects them (exhaustive for an
enumerable support, or the realised one plus a sample), and alternative_partial_fn
supplies non-realised child partials (default: sample a subtree from the clade
model). Because straight_through_gather differentiates each candidate
independently, a sampled candidate set degrades to a sampled-softmax approximation
while the routing stays exact.

Validated: forward == direct Felsenstein; gradient reaches the clade logits;
gather routing == dense one-hot multiply; the forward is unchanged under a sampled
candidate set and the gradient still flows. A single fixed transition matrix
isolates the topology gradient (per-edge branch lengths are an orthogonal next
step).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…hood

clade_straight_through_log_likelihood and sampled_subtree_partial_fn now accept
`transition` as either a fixed matrix (equal branch lengths, as before) or a
callable clade -> [state, state] giving the transition matrix on the edge above
each child clade -- the hook for joint branch-length optimisation (e.g. heights
from a node-height ratio transform). Verified the forward equals a direct
Felsenstein computation with per-edge matrices; existing fixed-matrix tests
unchanged.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…nt branches

experiments/conditional_clade_phylo_estimators.ipynb fits a CCD topology and
branch lengths jointly to a simulated likelihood-based posterior (Jukes-Cantor,
5 taxa). Branch lengths are a root height plus a node-height ratio per split
(indexed by clade), mapped to heights by the existing node-height ratio transform
(pathwise gradient); the topology gradient is compared across estimators:

- straight-through via clade_straight_through_log_likelihood (the likelihood
  gradient reaches the clade model),
- score function + leave-one-out (RLOO),
- VIMCO (multi-sample bound over topologies).

Each is scored by the probability q places on the true topology and the
total-variation distance to the exact enumerated topology posterior. Smoke-tested
end to end; straight-through drives q(true tree) from ~0.02 to 0.40 in 30 steps
(uniform 1/105) on the simulated data.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
…book

Add a step-by-step derivation (section 4.1) of the relaxed / straight-through
phylogenetic likelihood: Felsenstein pruning, child gather as a one-hot
multiply, the relaxed selection, the straight-through gather primitive and its
forward/backward, routing the gradient into the clade logits via the
straight-through softmax contrast, and why that contrast needs each subsplit's
combined child partials (so an independent realized-node gather flips the
gradient sign while the sampled-subtree route stays correctly signed).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
Marginalise the topology by Felsenstein sum-product over the full subsplit DAG:
Pi[c] = sum_s w_s (P_X . Pi[X]) (x) (P_Y . Pi[Y]), giving the per-site
tree-marginal likelihood sum_s log E_q[L_s(T)] exactly, in DAG-size time rather
than over all (2n-3)!! trees. This is a deterministic, exact-autodiff topology
gradient -- the tractable analogue of the straight-through likelihood for an
enumerable DAG.

Two passes over the static DAG, both tf.function-compatible:
- relaxed_partials_sequential: unrolled child-before-parent loop (simplest, no
  masking, arbitrary leading dims);
- relaxed_partials_vectorized: level-based batched einsum + segment-sum scatter
  (fewer, larger ops; the shape for a native kernel).

Tests check the sequential pass equals a brute-force sum over every enumerated
tree in value and exact gradient (n=3,4,5), vectorized == sequential (value and
gradient, incl. per-clade transitions), and graph-mode tracing.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
New section 6 marginalises the topology by the subsplit-DAG sum-product
(generalized_pruning), an exact, deterministic, zero-variance topology gradient.
The cell verifies the gradient is identical across calls and trains it jointly
with per-edge branch lengths.

Framed honestly: the DAG sum-product computes the per-site tree-marginal
sum_s log E_q[L_s], a Jensen upper bound on the ELBO term E_q[log L] (the log
cannot pass through the sum-product), so it optimises a different objective and
empirically settles short of the exact posterior -- below even the stochastic
straight-through estimator. The lesson: "exact, zero-variance" describes the
gradient, not the objective.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
The DAG sum-product weights need not be the exact marginal conditional probs:
add interchangeable weight samplers (exact / Gumbel-softmax / straight-through /
Gumbel-softmax straight-through) selectable via weight_fn on
relaxed_log_likelihood_from_distribution.

Key property (tested): with HARD one-hot-per-clade weights the sum-product
collapses onto a single tree, so the objective becomes the proper across-sites
likelihood sum_s log L_s(T) = log L(T) -- no Jensen gap. A Gumbel-max draw per
clade is an exact CCD tree sample and supplies every clade's partial at once, so
a straight-through weighting turns generalized pruning into a
Gumbel-softmax / straight-through topology estimator on the correct joint
objective. Soft (Gumbel-)softmax weights interpolate between the per-site
marginal and the single-tree objective via temperature.

Tests: hard weights == direct Felsenstein on the selected tree (Gumbel-max and
argmax); one-hot-per-segment forward; exact_weights == default; sampled-weight
gradients finite and nonzero; low temperature hardens the relaxation.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
Section 6 now drops the decision-weight samplers into one generalized-pruning DAG
and compares the resulting objectives: exact marginal (deterministic, per-site
tree-marginal), Gumbel-softmax, and straight-through. It checks the marginal
gradient is zero-variance, and shows that hardening the per-clade decision
(straight-through) collapses the DAG onto a single tree -- recovering the proper
across-sites objective and lifting q(true) from ~0.04 to ~0.3, on par with the
straight-through likelihood. Markdown and summary updated to frame objective bias
vs gradient bias as the knob the weighting controls.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
… metrics

Section 7 makes the score-function / VIMCO comparison fair and graph-mode:
- A hard one-hot weight vector is a single sampled tree, so the generalized-pruning
  DAG likelihood (relaxed_log_likelihood) gives that tree's across-sites log-
  likelihood as a tensor op. Feeding the batch's flat subsplit indices as a tensor
  input lets the whole score+LOO / VIMCO step run inside @tf.function, reusing
  TreeFlow's tensor pruning instead of the eager Python Felsenstein loop. (bito
  native isn't built here, so the DAG sum-product is the in-graph likelihood.)
- All graph-mode estimators (score+LOO, VIMCO, GP straight-through) use the same
  per-edge branch model as generalized pruning, against one shared per-edge exact
  posterior, for a fair comparison.
- Adds importance-weighted diagnostics: population ESS (1/sum p^2/q, proposal
  quality) and a finite-K self-normalised IS posterior estimate -- VIMCO-type
  methods can show modest q(true) yet a near-exact IS-corrected posterior.
- Adds a GP straight-through variant that prices the concrete sampled tree with the
  section-4 ratio-transform branch model.

Also fixes a gradient bug: the per-edge transition matrices must be built inside the
GradientTape, otherwise gp_branch_raw receives no gradient and the branch lengths
silently never train.

Answers the Gumbel-softmax question in the section-7 markdown: soft GS weights put a
soft mixture into the forward pass, which marginalises the topology per site (the
Jensen-biased objective), so GS-soft tracks the marginal and lands low at any
temperature; only the hard straight-through forward gives the across-sites
objective. Not a misuse -- the wrong variant for this objective.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012eCSDgvCNka4pWJmSNG1Wa
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants