Skip to content

Steady state kalman filtering#236

Draft
AdrienCorenflos wants to merge 3 commits into
mainfrom
steady-state
Draft

Steady state kalman filtering#236
AdrienCorenflos wants to merge 3 commits into
mainfrom
steady-state

Conversation

@AdrienCorenflos
Copy link
Copy Markdown
Contributor

No description provided.

@AdrienCorenflos AdrienCorenflos requested review from Sahel13 and SamDuffield and removed request for Sahel13 May 13, 2026 14:39
@AdrienCorenflos AdrienCorenflos added documentation Improvements or additions to documentation enhancement New feature or request cuthbertlib Atomic backend for cuthbert labels May 13, 2026
@AdrienCorenflos AdrienCorenflos linked an issue May 13, 2026 that may be closed by this pull request
@AdrienCorenflos
Copy link
Copy Markdown
Contributor Author

The computational improvements exist by are not as good as I expected. I am wondering if perhaps I missed some simple improvement, or maybe it's a feature of testing with a very high dim state which triggers an expensive QR decomposition in the prediction anyway. Not sure, I would appreciate a second look.

get_init_params: GetInitParams,
get_dynamics_params: GetDynamicsParams,
get_observation_params: GetObservationParams,
steady_state_params: SteadyStateFilterParams | None = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of the extra arg could we have get_observation_params possibly return SteadyStateFilterParams if using steady state?

@SamDuffield
Copy link
Copy Markdown
Contributor

SamDuffield commented May 15, 2026

The computational improvements exist by are not as good as I expected. I am wondering if perhaps I missed some simple improvement, or maybe it's a feature of testing with a very high dim state which triggers an expensive QR decomposition in the prediction anyway. Not sure, I would appreciate a second look.

Seems to me that this PR reduces the number of tria calls from 4 or 5 (4 in the example) to 3 per step? On my laptop we get a 1.16x/1.52x speedup (which is more than in your doc). Seems some discrepancy across devices. But I don't have a gut feel if we should expect more than that

Copy link
Copy Markdown
Collaborator

@Sahel13 Sahel13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't handle missing observation dimensions correctly. The following test fails when (and only when) y has a nan in it:

@pytest.mark.parametrize("seed", [0, 42, 99])
def test_update_steady_state_matches_qr_path_with_missing_observations(seed):
    """Steady-state update must preserve the standard NaN-observation behavior."""
    nx, ny = 4, 3
    F, _, chol_Q, H, d, chol_R, *_ = _make_lgssm(seed, nx, ny)
    ss = compute_steady_state_filter_params(
        jnp.array(F),
        jnp.array(chol_Q),
        jnp.array(H),
        jnp.array(chol_R),
    )

    rng = np.random.default_rng(seed + 10)
    m = jnp.array(rng.standard_normal(nx))
    y = jnp.array(rng.standard_normal(ny)).at[1].set(jnp.nan)

    (m_qr, chol_P_qr), ell_qr = update(
        m,
        jnp.array(chol_Q),
        jnp.array(H),
        jnp.array(d),
        jnp.array(chol_R),
        y,
    )
    (m_ss, chol_P_ss), ell_ss = update(
        m,
        jnp.array(chol_Q),
        jnp.array(H),
        jnp.array(d),
        jnp.array(chol_R),
        y,
        steady_state_params=ss,
    )

    chex.assert_trees_all_close(m_ss, m_qr, atol=1e-10)
    chex.assert_trees_all_close(
        chol_P_ss @ chol_P_ss.T,
        chol_P_qr @ chol_P_qr.T,
        atol=1e-10,
    )
    chex.assert_trees_all_close(ell_ss, ell_qr, atol=1e-10)

:func:`update` skips the expensive per-step QR decomposition and reuses the
constant ``A``, ``U``, and ``Z`` blocks.

Fields:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Fields:
Attributes:

Comment on lines +237 to +240
F: Array,
chol_Q: Array,
H: Array,
chol_R: Array,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these can have type ArrayLike

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cuthbertlib Atomic backend for cuthbert documentation Improvements or additions to documentation enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement the steady-state KF

3 participants