Add convergence criterion, full-rank and IAF approximations, ADVI stability notebooks#108
Open
christiaanjs wants to merge 13 commits into
Open
Add convergence criterion, full-rank and IAF approximations, ADVI stability notebooks#108christiaanjs wants to merge 13 commits into
christiaanjs wants to merge 13 commits into
Conversation
Implements a custom TFP ConvergenceCriterion that tracks EWMA of per-step loss decrease normalised by |ELBO|, making the threshold invariant to dataset scale and starting conditions. A min_consecutive parameter (default 10) requires the condition to hold for N consecutive steps, preventing spurious early stopping from transient single-step dips in rel_rate. Also adds: - VIResults.convergence_criterion_state field (backward-compat default None) so criterion state (ewma, rel_rate, consecutive_below) is available in traces - 20 unit tests covering EWMA update, NaN handling, and convergence logic - experiments/advi_stability.ipynb: 5-run YFV ADVI stability study; conclusion notes that clock_rate / root_height variation reflects the mean-field approximation's inability to capture the clock-rate × root-height ridge Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds cells to run NUM_IAF_RUNS independent IAF fits alongside the existing mean-field runs, then compares: - Within-IAF stability (inter_run/post_sd table) - Loss traces: mean-field (dashed) vs IAF (solid) on shared axes - Pooled posterior marginals: MF (blue) vs IAF (orange) histograms - Side-by-side MF vs IAF summary table Conclusion has a placeholder IAF findings section with the key questions to fill in once the cells have been run. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Implements get_fixed_topology_full_rank_approximation: a multivariate Normal in the joint unconstrained parameter space, capable of capturing correlations (e.g. the clock rate x root height ridge that degrades mean-field stability). Key design: - _FullRankAffineBijector stores loc (D,) and raw_scale (D,D) as tf.Variables - lower-triangular extraction via band_part, diagonal positivity via softplus - Composed with the existing split/restructure/event-space bijector chain (same pattern as the IAF approximation) - Initialised to loc=prior-medians (unconstrained), scale_tril ≈ identity Updates advi_stability.ipynb to run 3 full-rank fits alongside the 5 mean-field runs, with loss-trace comparison, within-FR stability summary, and pooled MF vs FR posterior marginal histograms. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Full-rank ELBO ~2700-3700 nats better than mean-field (FR: -6099 to -6123 vs MF: -8845 to -9790), confirming mean-field loses significant posterior mass on the strict-clock model. Root height posterior std expands 2.5x (86 -> 213) and pop_size 1.4x wider under full-rank, reflecting correct representation of the clock rate x root height ridge. Inter-run stability for clock_rate is similar between approximations (0.60 vs 0.62), showing the degeneracy is inherent in the posterior geometry, not an artifact of the approximation family. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Quantify the clock-rate × root-height ridge: Pearson corr = -0.48, log-log slope = -0.75 (partial non-identifiability), quadratic coeff = 0.08 - Test IAF as a more flexible approximation family: all three IAF runs fail (run 1 stalls at ELBO ≈ -14700; runs 2-3 crash with NaN gradients), showing optimisation difficulty is the limiting factor, not expressiveness - Update conclusion with geometry diagnostics and IAF findings - Refresh native_vs_tf_vi_validation.ipynb outputs Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Move IAF cells from advi_stability.ipynb to new advi_iaf.ipynb, which includes self-contained MF/FR reference runs for comparison - advi_stability.ipynb now covers mean-field and full-rank only - iaf.py: add trainable affine base (loc_var + log_scale_var) so the IAF is Normal(init_loc, 1) at network init rather than a random scramble - iaf.py: use DeferredTensor for softplus(log_scale_var) so gradients flow back to log_scale_var during training - iaf.py: use TruncatedNormal(stddev=0.01) kernel init for the autoregressive network so IAF bijectors start near-identity, preventing NaN warm-up losses from extreme initial samples Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
RelativeLossNotDecreasingconvergence criterion — tracks an EWMA of the per-step ELBO decrease normalised by |ELBO|, with amin_consecutivethreshold to avoid spurious early stops from single-step dips in the convergence ratetreeflow/model/approximation/full_rank.py) — multivariate Normal with full lower-triangular covariance in joint unconstrained space (~6k parameters for YFV vs 154 for mean-field); achieves ~2700–3700 nat ELBO improvement over mean-field on YFViaf.py) — trainable affine base (loc_var,log_scale_var) so the IAF starts as a mean-field at network init;DeferredTensorfix solog_scale_varreceives gradients; small kernel init (stddev=0.01) to prevent NaN warm-up losses from extreme initial samples; surrogate warm-up against a fitted mean-field target before ELBO optimisationadvi_stability.ipynb— experiment notebook covering mean-field stability across seeds, full-rank comparison, and posterior geometry analysis of the clock rate × root height non-identifiabilityadvi_iaf.ipynb— self-contained IAF experiment notebook with MF/FR reference runs and the surrogate warm-up strategy; IAF achieves ~−6010 ELBO vs full-rank ~−6110RelativeLossNotDecreasing; phylo likelihood test parametrized over all three unroll modes (unrolled,tensorarray,while_loop)vi/util.py—VIResultsnamedtuple anddefault_vi_trace_fnto capture convergence criterion state in tracesTest plan
pytest test/vi/test_relative_loss_not_decreasing.py— all 20 convergence criterion tests passpytest test/traversal/test_phylo_likelihood.py— parametrized over unroll modesadvi_stability.ipynbend-to-end (MF + FR cells)advi_iaf.ipynbend-to-end (verifies warm-up + ELBO convergence)🤖 Generated with Claude Code