From 556063417e140ed8dff5a81f07fbb0f8b9ec8405 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 19:26:01 +0200 Subject: [PATCH 01/14] add jax to the toml --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index adfcb10..7586916 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,9 @@ output = [ "weasyprint", "tabulate" ] +gpu = [ + "jax" +] [project.urls] Homepage = "https://github.com/CausalInference/pySEQTarget" From 0665436f810dea3960fe9a39551e435b9b89add0 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 19:31:56 +0200 Subject: [PATCH 02/14] clean up some dispatch --- pySEQTarget/analysis/_outcome_fit.py | 31 ++++++++++++++++------------ 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/pySEQTarget/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py index f63e1e3..98ccdb7 100644 --- a/pySEQTarget/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -124,24 +124,29 @@ def _outcome_fit( ) full_formula = f"{outcome} ~ {formula}" - - if getattr(self, "glm_package", "statsmodels") == "glum": - from ..helpers._glum_fit import _fit_glum - - return _fit_glum( - full_formula, - df_pd, - var_weights=df_pd[weight_col] if weighted else None, - ) - + var_weights = df_pd[weight_col] if weighted else None + + match getattr(self, "glm_package", "statsmodels"): + case "glum": + from ..helpers._glum_fit import _fit_glum + return _fit_glum(full_formula, df_pd, var_weights=var_weights) + + case "jax": + from ..helpers._jax_fit import _fit_jax + return _fit_jax( + full_formula, + df_pd, + var_weights=var_weights, + start_params=start_params, + ) + # default glm_kwargs = { "formula": full_formula, "data": df_pd, "family": sm.families.Binomial(), } - - if weighted: - glm_kwargs["var_weights"] = df_pd[weight_col] + if var_weights is not None: + glm_kwargs["var_weights"] = var_weights model = smf.glm(**glm_kwargs) From b5ab7ca52dbbccf5b15cb1dd5c0ba97a090bd4a5 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 19:32:34 +0200 Subject: [PATCH 03/14] glum fit summary2 -> summary. Maybe can be hidden in main class --- pySEQTarget/helpers/_glum_fit.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pySEQTarget/helpers/_glum_fit.py b/pySEQTarget/helpers/_glum_fit.py index f8894b1..b375194 100644 --- a/pySEQTarget/helpers/_glum_fit.py +++ b/pySEQTarget/helpers/_glum_fit.py @@ -12,7 +12,7 @@ class _GlumFit: the codebase (and users) expect: .params (Series), .model.exog_names, .model.data.design_info, .predict(df) / .predict(X_numpy, transform=False), - .bse, .summary(), .summary2(). + .bse, .summary(). Standard errors are derived lazily from the stored design matrix using the GLM asymptotic covariance (X' W X)^-1, which matches statsmodels for the @@ -81,7 +81,7 @@ def _coef_table(self): index=list(self.params.index), ) - def summary2(self): + def summary(self): from statsmodels.iolib.summary2 import Summary info = pd.DataFrame( @@ -101,9 +101,6 @@ def summary2(self): smry.add_df(self._coef_table()) return smry - # statsmodels exposes both; the codebase/practical use either, so alias them. - summary = summary2 - def _fit_glum(formula, data, var_weights=None): """Fit a binomial GLM with glum and return a _GlumFit wrapper.""" From 46359a14cc900427062e26ea3d768db2a65e1eea Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 19:32:54 +0200 Subject: [PATCH 04/14] add jax to opts --- pySEQTarget/SEQopts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 3f12a91..e463af0 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -39,6 +39,7 @@ class SEQopts: :param expand_only: If True, ``SEQuential.expand()`` returns the expanded dataset and skips weighting, modelling, and survival steps :type expand_only: bool + :param glm_package: Backend for fitting logistic (outcome/competing-event) models ["statsmodels", "glum", or "jax"], default "statsmodels". :param followup_class: Boolean to force followup values to be treated as classes :type followup_class: bool :param followup_include: Boolean to force regular followup values into model covariates From 096f4704ef532438ffab8b16b1af7d6482edb984 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 19:34:40 +0200 Subject: [PATCH 05/14] removed explicit typing (should be fine with the sphinx autodoc) --- pySEQTarget/SEQopts.py | 57 +++--------------------------------------- 1 file changed, 3 insertions(+), 54 deletions(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index e463af0..ae0fec9 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -10,114 +10,63 @@ class SEQopts: Parameter builder for ``pySEQTarget.SEQuential`` analysis :param bootstrap_nboot: Number of bootstraps to perform - :type bootstrap_nboot: int :param bootstrap_sample: Subsampling proportion of ID-Trials gathered for each bootstrapping iteration - :type bootstrap_sample: float :param bootstrap_CI: If bootstrapped, confidence interval level - :type bootstrap_CI: float :param bootstrap_CI_method: If bootstrapped, confidence interval method ['SE' or 'percentile'] - :type bootstrap_CI_method: str :param cense_colname: Column name for censoring effect (LTFU, etc.) - :type cense_colname: str :param cense_denominator: Override to specify denominator patsy formula for censoring models; "1" or "" indicate intercept only model - :type cense_denominator: Optional[str] or None :param cense_numerator: Override to specify numerator patsy formula for censoring models - :type cense_numerator: Optional[str] or None :param cense_eligible_colname: Column name to identify which rows are eligible for censoring model fitting - :type cense_eligible_colname: Optional[str] or None :param compevent_colname: Column name specifying a competing event to the outcome - :type compevent_colname: str :param covariates: Override to specify the outcome patsy formula for outcome model fitting - :type covariates: Optional[str] or None :param denominator: Override to specify the outcome patsy formula for denominator model fitting - :type denominator: Optional[str] or None :param excused: Boolean to allow excused conditions when method is censoring - :type excused: bool :param excused_colnames: Column names (at the same length of treatment_level) specifying excused conditions, default ``[]`` - :type excused_colnames: List[str] :param expand_only: If True, ``SEQuential.expand()`` returns the expanded dataset and skips weighting, modelling, and survival steps - :type expand_only: bool :param glm_package: Backend for fitting logistic (outcome/competing-event) models ["statsmodels", "glum", or "jax"], default "statsmodels". :param followup_class: Boolean to force followup values to be treated as classes - :type followup_class: bool :param followup_include: Boolean to force regular followup values into model covariates - :type followup_include: bool :param followup_spline: Boolean to force followup values to be fit to cubic spline - :type followup_spline: bool :param followup_spline_df: Degrees of freedom for the followup cubic spline, default ``4`` - :type followup_spline_df: int :param followup_max: Maximum allowed followup in analysis - :type followup_max: int or None :param followup_min: Minimum allowed followup in analysis - :type followup_min: int :param hazard_estimate: Boolean to create hazard estimates - :type hazard_estimate: bool :param indicator_baseline: How to indicate baseline columns in models - :type indicator_baseline: str :param indicator_squared: How to indicate squared columns in models - :type indicator_squared: str :param km_curves: Boolean to create survival, risk, and incidence (if applicable) estimates - :type km_curves: bool :param ncores: Number of cores to use if running in parallel, default ``max(1, cpu_count() - 1)`` - :type ncores: int :param numerator: Override to specify the outcome patsy formula for numerator models; "1" or "" indicate intercept only model - :type numerator: str :param offload: Boolean to offload intermediate model data to disk - :type offload: bool :param offload_dir: Directory to offload intermediate model data - :type offload_dir: str :param parallel: Boolean to run model fitting in parallel - :type parallel: bool :param plot_colors: List of colors for KM plots, if applicable, default ``["#F8766D", "#00BFC4", "#555555"]`` - :type plot_colors: List[str] :param plot_labels: List of length treat_level to specify treatment labeling, default ``[]`` - :type plot_labels: List[str] :param plot_title: Plot title - :type plot_title: str :param plot_type: Type of plot to show ["risk", "survival" or "incidence" if compevent is specified] - :type plot_type: str :param risk_times: Followup times at which to report risk difference and risk ratio when ``km_curves = True``. Each requested time is snapped to the latest available followup at or before it, and the maximum followup is always included. Defaults to ``None`` (report at the maximum followup only). - :type risk_times: Optional[List[float]] or None :param seed: RNG seed - :type seed: int :param selection_first_trial: Boolean to only use first trial for analysis (similar to non-expanded) - :type selection_first_trial: bool :param selection_sample: Subsampling proportion of ID-trials which did not initiate a treatment - :type selection_sample: float :param selection_random: Boolean to randomly downsample ID-trials which did not initiate a treatment - :type selection_random: bool :param subgroup_colname: Column name for subgroups to share the same weighting but different outcome model fits - :type subgroup_colname: str :param treatment_level: List of eligible treatment levels within treatment_col, default ``[0, 1]`` - :type treatment_level: List[int] :param trial_include: Boolean to force trial values into model covariates - :type trial_include: bool :param visit_colname: Column name specifying visit number - :type visit_colname: str :param weight_eligible_colnames: List of column names of length treatment_level to identify which rows are eligible for weight fitting, default ``[]`` - :type weight_eligible_colnames: List[str] :param weight_fit_method: The fitting method to be used ["newton", "bfgs", "lbfgs", "nm"], default "newton" - :type weight_fit_method: str :param weight_min: Minimum weight - :type weight_min: float :param weight_max: Maximum weight - :type weight_max: float or None :param weight_lag_condition: Boolean to fit weights based on their treatment lag - :type weight_lag_condition: bool :param weight_p99: Boolean to force weight min and max to be 1st and 99th percentile respectively - :type weight_p99: bool :param weight_preexpansion: Boolean to fit weights on preexpanded data - :type weight_preexpansion: bool :param verbose: Boolean to print dataset size summaries and bootstrap information - :type verbose: bool :param weighted: Boolean to weight analysis - :type weighted: bool """ bootstrap_nboot: int = 0 @@ -135,7 +84,7 @@ class SEQopts: excused: bool = False excused_colnames: List[str] = field(default_factory=lambda: []) expand_only: bool = False - glm_package: Literal["statsmodels", "glum"] = "statsmodels" + glm_package: Literal["statsmodels", "glum", "jax"] = "statsmodels" followup_class: bool = False followup_include: bool = True followup_max: int = None @@ -234,8 +183,8 @@ def _validate_choices(self): ) if self.bootstrap_CI_method not in ["se", "percentile"]: raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'") - if self.glm_package not in ["statsmodels", "glum"]: - raise ValueError("glm_package must be 'statsmodels' or 'glum'") + if self.glm_package not in ["statsmodels", "glum", "jax"]: + raise ValueError("glm_package must be 'statsmodels', 'glum', or 'jax'") if self.cox_package not in ["lifelines", "scikit-survival"]: raise ValueError("cox_package must be 'lifelines' or 'scikit-survival'") From 64ca5fc8d2d873c880a8888f4c15175276e52782 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 19:35:02 +0200 Subject: [PATCH 06/14] remove defined typing - let sphinx infer from type hints --- pySEQTarget/SEQuential.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index c46cb61..190e60d 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -29,23 +29,14 @@ class SEQuential: Primary class initializer for SEQuentially nested target trial emulation :param data: Data for analysis - :type data: pl.DataFrame :param id_col: Column name for unique patient IDs - :type id_col: str :param time_col: Column name for observational time points - :type time_col: str :param eligible_col: Column name for analytical eligibility - :type eligible_col: str :param treatment_col: Column name specifying treatment per time_col - :type treatment_col: str :param outcome_col: Column name specifying outcome per time_col - :type outcome_col: str :param time_varying_cols: Time-varying column names as covariates (BMI, Age, etc.) - :type time_varying_cols: Optional[List[str]] or None :param fixed_cols: Fixed column names as covariates (Sex, YOB, etc.) - :type fixed_cols: Optional[List[str]] or None :param method: Method for analysis ['ITT', 'dose-response', or 'censoring'] - :type method: str :param parameters: Parameters to augment analysis, specified with ``pySEQTarget.SEQopts`` :type parameters: Optional[SEQopts] or None """ From 3bc677b1d0a90a72c45597cc371a010a638eb958 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 20:01:59 +0200 Subject: [PATCH 07/14] multinomial logistic via ADAM --- pySEQTarget/helpers/_jax_logistic.py | 98 ++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 pySEQTarget/helpers/_jax_logistic.py diff --git a/pySEQTarget/helpers/_jax_logistic.py b/pySEQTarget/helpers/_jax_logistic.py new file mode 100644 index 0000000..f2315d0 --- /dev/null +++ b/pySEQTarget/helpers/_jax_logistic.py @@ -0,0 +1,98 @@ +from functools import partial + +import jax +import jax.numpy as jnp + + +class MultinomialLogisticRegression: + """ + Multinomial logistic regression in JAX. + """ + + def __init__(self, learning_rate=0.1, num_epochs=2000, eps=1e-7, n_classes=None): + self.learning_rate = learning_rate + self.num_epochs = num_epochs + self.eps = eps + self.n_classes = n_classes + self.params = None + self.loss_history_ = None + self.n_classes_ = None + self.mean_ = None + self.std_ = None + + def _prep(self, X): + X = jnp.asarray(X) + if self.mean_ is not None: + X = (X - self.mean_) / self.std_ + return X + + @staticmethod + def _predict(params, X): + W, b = params + return jax.nn.softmax(jnp.dot(X, W) + b, axis=-1) + + @staticmethod + @jax.jit + def _loss(params, X, Y, w): + """Sample-weighted mean softmax cross-entropy; ``w`` is ``(n_samples,)``.""" + W, b = params + logp = jax.nn.log_softmax(jnp.dot(X, W) + b, axis=-1) + ce = -jnp.sum(Y * logp, axis=-1) + return jnp.sum(w * ce) / jnp.sum(w) + + @staticmethod + @partial(jax.jit, static_argnums=(4, 5)) + def _run(params, X, Y, w, num_epochs, learning_rate): + b1, b2, eps = 0.9, 0.999, 1e-8 + W0, b0 = params + zeros = (jnp.zeros_like(W0), jnp.zeros_like(b0)) + + def step(carry, t): + # ADAM + params, (mW, mb), (vW, vb) = carry + dW, db = jax.grad(MultinomialLogisticRegression._loss)(params, X, Y, w) + mW, vW = b1 * mW + (1 - b1) * dW, b2 * vW + (1 - b2) * dW ** 2 + mb, vb = b1 * mb + (1 - b1) * db, b2 * vb + (1 - b2) * db ** 2 + mWh, vWh = mW / (1 - b1 ** t), vW / (1 - b2 ** t) + mbh, vbh = mb / (1 - b1 ** t), vb / (1 - b2 ** t) + W, b = params + W = W - learning_rate * mWh / (jnp.sqrt(vWh) + eps) + b = b - learning_rate * mbh / (jnp.sqrt(vbh) + eps) + params = (W, b) + loss = MultinomialLogisticRegression._loss(params, X, Y, w) + return (params, (mW, mb), (vW, vb)), loss + + ts = jnp.arange(1, num_epochs + 1, dtype=float) + (params, _, _), loss_history = jax.lax.scan(step, (params, zeros, zeros), ts) + return params, loss_history + + def fit(self, X, y, sample_weight=None, init_params=None): + X = self._prep(X) + y = jnp.asarray(y) + K = self.n_classes if self.n_classes is not None else int(y.max()) + 1 + self.n_classes_ = K + Y = jax.nn.one_hot(y, K) + if sample_weight is None: + w = jnp.ones(X.shape[0]) + else: + w = jnp.asarray(sample_weight, dtype=float) + if init_params is None: + params = (jnp.zeros((X.shape[1], K)), jnp.zeros(K)) + else: + W, b = init_params + W = jnp.asarray(W, dtype=float) + b = jnp.asarray(b, dtype=float) + if W.shape != (X.shape[1], K): + raise ValueError( + f"init_params W has shape {W.shape}, expected ({X.shape[1]}, {K})." + ) + params = (W, b) + self.params, self.loss_history_ = self._run( + params, X, Y, w, self.num_epochs, self.learning_rate + ) + return self + + def predict(self, X): + if self.params is None: + raise RuntimeError("Model is not fitted yet; call `fit` first.") + return self._predict(self.params, self._prep(X)) From 57b9ca884625cacf38f465bff7578c1ce6db9dd7 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 20:14:24 +0200 Subject: [PATCH 08/14] update the glum test --- tests/test_glum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_glum.py b/tests/test_glum.py index 48a0b48..76ba434 100644 --- a/tests/test_glum.py +++ b/tests/test_glum.py @@ -92,7 +92,7 @@ def test_glum_summary_is_printable_and_consistent(): smry = model.summary() assert str(smry) # renders without error - coef_col = model.summary2().tables[1]["Coef."].to_list() + coef_col = model.summary().tables[1]["Coef."].to_list() assert coef_col == approx(list(model.params), rel=1e-9, abs=1e-9) From c5b35ca33b10c716f97e143766daa21f7be5a3b7 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 21:46:42 +0200 Subject: [PATCH 09/14] make a dev optional --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 7586916..c9f9469 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,10 @@ output = [ gpu = [ "jax" ] +dev = [ + "pySEQTarget[output,gpu]", + "pytest", +] [project.urls] Homepage = "https://github.com/CausalInference/pySEQTarget" From cd4d5a72bc5b1992808e595cf26f94802773c3e0 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 21:47:11 +0200 Subject: [PATCH 10/14] clone _glum_fit but with jax --- pySEQTarget/helpers/_jax_fit.py | 180 ++++++++++++++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 pySEQTarget/helpers/_jax_fit.py diff --git a/pySEQTarget/helpers/_jax_fit.py b/pySEQTarget/helpers/_jax_fit.py new file mode 100644 index 0000000..d54fdb3 --- /dev/null +++ b/pySEQTarget/helpers/_jax_fit.py @@ -0,0 +1,180 @@ +import types + +import numpy as np +import pandas as pd +import patsy +import polars as pl + +from ._jax_logistic import MultinomialLogisticRegression + + +class _JaxFit: + def __init__( + self, + formula, + df, + var_weights=None, + learning_rate=0.1, + num_epochs=2000, + start_params=None, + ): + df_pd = df.to_pandas() if isinstance(df, pl.DataFrame) else df + y_mat, X_mat = patsy.dmatrices(formula, df_pd, return_type="dataframe") + + design_info = X_mat.design_info + feature_names = list(X_mat.columns) + self._design_info = design_info + self._feature_names = feature_names + self._X_design = X_mat.values + X_arr = X_mat.drop(columns=["Intercept"], errors="ignore").values + + y_raw = y_mat.values.ravel() + self.classes_ = np.unique(y_raw) + self._n_classes = len(self.classes_) + y_idx = np.searchsorted(self.classes_, y_raw) + + self._sample_weight = None + if var_weights is not None: + self._sample_weight = np.asarray(var_weights, dtype=float) + + jax_model = MultinomialLogisticRegression( + learning_rate=learning_rate, + num_epochs=num_epochs, + n_classes=self._n_classes, + ) + + jax_model.mean_ = X_arr.mean(axis=0) + jax_model.std_ = X_arr.std(axis=0) + jax_model.eps # guard constants + self._jax = jax_model + + init_params = self._warm_init(start_params, X_arr.shape[1]) + jax_model.fit( + X_arr, y_idx, sample_weight=self._sample_weight, init_params=init_params + ) + + # statsmodels 'like' exposure + self.model = types.SimpleNamespace( + exog_names=feature_names, + data=types.SimpleNamespace(design_info=design_info), + ) + self.exog_names = feature_names + self.params = self._build_params() + + def _coef_components(self): + W, b = self._jax.params + W = np.asarray(W) + b = np.asarray(b) + coef = W[:, 1:] - W[:, :1] + intercept = b[1:] - b[0] + mean = np.asarray(self._jax.mean_) + std = np.asarray(self._jax.std_) + intercept = intercept - (coef * (mean / std)[:, None]).sum(axis=0) + coef = coef / std[:, None] + + return intercept, coef + + def _build_params(self): + intercept, coef = self._coef_components() + if self._n_classes == 2: + return pd.Series( + np.concatenate([intercept, coef[:, 0]]), index=self._feature_names + ) + data = np.vstack([intercept[None, :], coef]) + return pd.DataFrame( + data, + index=self._feature_names, + columns=[f"class_{c}" for c in self.classes_[1:]], + ) + + def _warm_init(self, start_params, n_features): + if start_params is None or self._n_classes != 2: + return None + sp_values, sp_names = start_params + if list(sp_names) != self._feature_names: + return None + sp_values = np.asarray(sp_values, dtype=float) + intercept = float(sp_values[0]) + coef = sp_values[1:] + if coef.shape[0] != n_features: + return None + W = np.zeros((n_features, 2)) + b = np.zeros(2) + mean = np.asarray(self._jax.mean_) + std = np.asarray(self._jax.std_) + W[:, 1] = coef * std + b[1] = intercept + float(np.sum(coef * mean)) + + return (W, b) + + def predict(self, data, transform=True): + if transform: + data_pd = data.to_pandas() if isinstance(data, pl.DataFrame) else data + X = patsy.build_design_matrices( + [self._design_info], data_pd, return_type="dataframe" + )[0] + X_arr = X.drop(columns=["Intercept"], errors="ignore").values + else: + X_arr = np.asarray(data)[:, 1:] + probs = np.asarray(self._jax.predict(X_arr), dtype=np.float64) + return probs[:, 1] if self._n_classes == 2 else probs + + def cov_params(self): + if self._n_classes != 2: + raise NotImplementedError( + "Standard errors are only implemented for binary jax fits." + ) + X = self._X_design + mu = np.asarray(self._jax.predict(X[:, 1:]))[:, 1] + w = mu * (1.0 - mu) + if self._sample_weight is not None: + w = w * self._sample_weight + return np.linalg.pinv(X.T @ (w[:, None] * X)) + + @property + def bse(self): + return pd.Series(np.sqrt(np.diag(self.cov_params())), index=self.params.index) + + def _coef_table(self): + from scipy import stats + + coef = self.params.values + se = self.bse.values + with np.errstate(divide="ignore", invalid="ignore"): + z = coef / se + pvals = 2.0 * stats.norm.sf(np.abs(z)) + crit = stats.norm.ppf(0.975) + return pd.DataFrame( + { + "Coef.": coef, + "Std.Err.": se, + "z": z, + "P>|z|": pvals, + "[0.025": coef - crit * se, + "0.975]": coef + crit * se, + }, + index=list(self.params.index), + ) + + def summary(self): + from statsmodels.iolib.summary2 import Summary + + info = pd.DataFrame( + { + " ": [ + "GLM (jax backend)", + "Binomial", + "logit", + str(self._X_design.shape[0]), + ] + }, + index=["Model:", "Family:", "Link:", "No. Observations:"], + ) + smry = Summary() + smry.add_title("Generalized Linear Model Regression Results") + smry.add_df(info, header=False) + smry.add_df(self._coef_table()) + return smry + + +def _fit_jax(formula, data, var_weights=None, start_params=None): + return _JaxFit(formula, data, var_weights=var_weights, start_params=start_params) From e0c7973df95d39e45f5f99b26a35063b5becdb5f Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 21:51:09 +0200 Subject: [PATCH 11/14] tests --- tests/test_jax.py | 118 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 tests/test_jax.py diff --git a/tests/test_jax.py b/tests/test_jax.py new file mode 100644 index 0000000..47faf9b --- /dev/null +++ b/tests/test_jax.py @@ -0,0 +1,118 @@ +import numpy as np +import pandas as pd +import pytest +from pytest import approx + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data +from pySEQTarget.helpers._jax_fit import _JaxFit + + +def _fit(method, glm_package, dataset="SEQdata", **opts): + data = load_data(dataset) + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method=method, + parameters=SEQopts(glm_package=glm_package, **opts), + ) + s.expand() + s.fit() + return s + + +def _outcome_coefs(s): + return list(s.outcome_model[0]["outcome"].params) + + +def test_jax_matches_statsmodels_ITT(): + sm = _outcome_coefs(_fit("ITT", "statsmodels")) + jx = _outcome_coefs(_fit("ITT", "jax")) + assert jx == approx(sm, rel=1e-2, abs=2e-3) + + +def test_jax_matches_statsmodels_censoring_preexpansion(): + opts = dict(weighted=True, weight_preexpansion=True) + sm = _outcome_coefs(_fit("censoring", "statsmodels", **opts)) + jx = _outcome_coefs(_fit("censoring", "jax", **opts)) + assert jx == approx(sm, rel=1e-2, abs=2e-3) + + +def test_jax_matches_statsmodels_censoring_postexpansion(): + opts = dict(weighted=True, weight_preexpansion=False) + sm = _outcome_coefs(_fit("censoring", "statsmodels", **opts)) + jx = _outcome_coefs(_fit("censoring", "jax", **opts)) + assert jx == approx(sm, rel=1e-2, abs=2e-3) + + +def test_jax_standard_errors_match_statsmodels(): + sm_model = _fit("ITT", "statsmodels").outcome_model[0]["outcome"] + jx_model = _fit("ITT", "jax").outcome_model[0]["outcome"] + assert list(jx_model.bse) == approx(list(sm_model.bse), rel=1e-2, abs=1e-3) + + +def test_jax_bootstrap_survival_matches_statsmodels(): + common = dict(bootstrap_nboot=2, seed=1636, km_curves=True) + + def risk_diff(pkg): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(glm_package=pkg, **common), + ) + s.expand() + s.bootstrap() + s.fit() + s.survival() + rd = s.risk_estimates["risk_difference"] + assert rd["RD 95% LCI"].null_count() == 0 + return rd["Risk Difference"].to_list() + + assert risk_diff("jax") == approx(risk_diff("statsmodels"), rel=1e-2, abs=2e-3) + + +def _binary_frame(seed=1, n=2000): + rng = np.random.default_rng(seed) + x1, x2 = rng.normal(size=n), rng.normal(size=n) + eta = -1.0 + 0.8 * x1 - 0.5 * x2 + y = rng.binomial(1, 1 / (1 + np.exp(-eta))) + return pd.DataFrame({"y": y, "x1": x1, "x2": x2}) + + +def test_jax_multinomial_label_detection(): + rng = np.random.default_rng(0) + n = 900 + x = rng.normal(size=n) + df = pd.DataFrame({"y": rng.integers(0, 3, size=n), "x": x}) + + model = _JaxFit("y ~ x", df, num_epochs=300) + assert list(model.classes_) == [0, 1, 2] + + probs = model.predict(df) + assert probs.shape == (n, 3) + assert probs.sum(axis=1) == approx(np.ones(n), abs=1e-5) + + +def test_jax_warm_start_reaches_same_optimum(): + df = _binary_frame() + cold = _JaxFit("y ~ x1 + x2", df) + warm = _JaxFit( + "y ~ x1 + x2", + df, + start_params=(cold.params.values, list(cold.model.exog_names)), + ) + assert list(warm.params) == approx(list(cold.params), rel=1e-3, abs=1e-3) From f2e06581cce514c680a682f287418ed4caf0245d Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 22:05:46 +0200 Subject: [PATCH 12/14] Update pyproject.toml --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c9f9469..41c78d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,8 @@ output = [ "tabulate" ] gpu = [ - "jax" + "jax[cuda12]; sys_platform == 'linux'", + "jax; sys_platform != 'linux'", ] dev = [ "pySEQTarget[output,gpu]", From c0e2d3513292dc3b0022dfaa3cece50278ed9c5d Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 23:10:33 +0200 Subject: [PATCH 13/14] pytree implementation working? + test --- pySEQTarget/helpers/_jax_fit.py | 21 ++++----- pySEQTarget/helpers/_jax_logistic.py | 69 ++++++++++++++-------------- tests/test_jax.py | 2 +- 3 files changed, 45 insertions(+), 47 deletions(-) diff --git a/pySEQTarget/helpers/_jax_fit.py b/pySEQTarget/helpers/_jax_fit.py index d54fdb3..f1dfe55 100644 --- a/pySEQTarget/helpers/_jax_fit.py +++ b/pySEQTarget/helpers/_jax_fit.py @@ -14,8 +14,7 @@ def __init__( formula, df, var_weights=None, - learning_rate=0.1, - num_epochs=2000, + max_iter=25, start_params=None, ): df_pd = df.to_pandas() if isinstance(df, pl.DataFrame) else df @@ -38,8 +37,7 @@ def __init__( self._sample_weight = np.asarray(var_weights, dtype=float) jax_model = MultinomialLogisticRegression( - learning_rate=learning_rate, - num_epochs=num_epochs, + max_iter=max_iter, n_classes=self._n_classes, ) @@ -62,10 +60,8 @@ def __init__( def _coef_components(self): W, b = self._jax.params - W = np.asarray(W) - b = np.asarray(b) - coef = W[:, 1:] - W[:, :1] - intercept = b[1:] - b[0] + coef = np.asarray(W) + intercept = np.asarray(b) mean = np.asarray(self._jax.mean_) std = np.asarray(self._jax.std_) intercept = intercept - (coef * (mean / std)[:, None]).sum(axis=0) @@ -97,12 +93,13 @@ def _warm_init(self, start_params, n_features): coef = sp_values[1:] if coef.shape[0] != n_features: return None - W = np.zeros((n_features, 2)) - b = np.zeros(2) + + W = np.zeros((n_features, 1)) + b = np.zeros(1) mean = np.asarray(self._jax.mean_) std = np.asarray(self._jax.std_) - W[:, 1] = coef * std - b[1] = intercept + float(np.sum(coef * mean)) + W[:, 0] = coef * std + b[0] = intercept + float(np.sum(coef * mean)) return (W, b) diff --git a/pySEQTarget/helpers/_jax_logistic.py b/pySEQTarget/helpers/_jax_logistic.py index f2315d0..ce8ebb6 100644 --- a/pySEQTarget/helpers/_jax_logistic.py +++ b/pySEQTarget/helpers/_jax_logistic.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp +from jax.flatten_util import ravel_pytree class MultinomialLogisticRegression: @@ -9,9 +10,9 @@ class MultinomialLogisticRegression: Multinomial logistic regression in JAX. """ - def __init__(self, learning_rate=0.1, num_epochs=2000, eps=1e-7, n_classes=None): - self.learning_rate = learning_rate - self.num_epochs = num_epochs + def __init__(self, max_iter=25, ridge=1e-8, eps=1e-7, n_classes=None): + self.max_iter = max_iter + self.ridge = ridge # tiny Hessian jitter for numerical stability self.eps = eps self.n_classes = n_classes self.params = None @@ -27,44 +28,44 @@ def _prep(self, X): return X @staticmethod - def _predict(params, X): + def _logits(params, X): W, b = params - return jax.nn.softmax(jnp.dot(X, W) + b, axis=-1) + z = jnp.dot(X, W) + b + return jnp.concatenate([jnp.zeros((z.shape[0], 1)), z], axis=1) + + @staticmethod + def _predict(params, X): + return jax.nn.softmax(MultinomialLogisticRegression._logits(params, X), axis=-1) @staticmethod @jax.jit def _loss(params, X, Y, w): - """Sample-weighted mean softmax cross-entropy; ``w`` is ``(n_samples,)``.""" - W, b = params - logp = jax.nn.log_softmax(jnp.dot(X, W) + b, axis=-1) + logp = jax.nn.log_softmax( + MultinomialLogisticRegression._logits(params, X), axis=-1 + ) ce = -jnp.sum(Y * logp, axis=-1) return jnp.sum(w * ce) / jnp.sum(w) @staticmethod - @partial(jax.jit, static_argnums=(4, 5)) - def _run(params, X, Y, w, num_epochs, learning_rate): - b1, b2, eps = 0.9, 0.999, 1e-8 - W0, b0 = params - zeros = (jnp.zeros_like(W0), jnp.zeros_like(b0)) + @partial(jax.jit, static_argnums=(4,)) + def _run(params, X, Y, w, max_iter, ridge): + flat0, unravel = ravel_pytree(params) + eye = jnp.eye(flat0.shape[0]) - def step(carry, t): - # ADAM - params, (mW, mb), (vW, vb) = carry - dW, db = jax.grad(MultinomialLogisticRegression._loss)(params, X, Y, w) - mW, vW = b1 * mW + (1 - b1) * dW, b2 * vW + (1 - b2) * dW ** 2 - mb, vb = b1 * mb + (1 - b1) * db, b2 * vb + (1 - b2) * db ** 2 - mWh, vWh = mW / (1 - b1 ** t), vW / (1 - b2 ** t) - mbh, vbh = mb / (1 - b1 ** t), vb / (1 - b2 ** t) - W, b = params - W = W - learning_rate * mWh / (jnp.sqrt(vWh) + eps) - b = b - learning_rate * mbh / (jnp.sqrt(vbh) + eps) - params = (W, b) - loss = MultinomialLogisticRegression._loss(params, X, Y, w) - return (params, (mW, mb), (vW, vb)), loss + def loss_flat(f): + return MultinomialLogisticRegression._loss(unravel(f), X, Y, w) + + grad_fn = jax.grad(loss_flat) + hess_fn = jax.hessian(loss_flat) + + def step(f, _): + g = grad_fn(f) + H = hess_fn(f) + f = f - jnp.linalg.solve(H + ridge * eye, g) + return f, loss_flat(f) - ts = jnp.arange(1, num_epochs + 1, dtype=float) - (params, _, _), loss_history = jax.lax.scan(step, (params, zeros, zeros), ts) - return params, loss_history + flat, loss_history = jax.lax.scan(step, flat0, xs=None, length=max_iter) + return unravel(flat), loss_history def fit(self, X, y, sample_weight=None, init_params=None): X = self._prep(X) @@ -77,18 +78,18 @@ def fit(self, X, y, sample_weight=None, init_params=None): else: w = jnp.asarray(sample_weight, dtype=float) if init_params is None: - params = (jnp.zeros((X.shape[1], K)), jnp.zeros(K)) + params = (jnp.zeros((X.shape[1], K - 1)), jnp.zeros(K - 1)) else: W, b = init_params W = jnp.asarray(W, dtype=float) b = jnp.asarray(b, dtype=float) - if W.shape != (X.shape[1], K): + if W.shape != (X.shape[1], K - 1): raise ValueError( - f"init_params W has shape {W.shape}, expected ({X.shape[1]}, {K})." + f"init_params W has shape {W.shape}, expected ({X.shape[1]}, {K - 1})." ) params = (W, b) self.params, self.loss_history_ = self._run( - params, X, Y, w, self.num_epochs, self.learning_rate + params, X, Y, w, self.max_iter, self.ridge ) return self diff --git a/tests/test_jax.py b/tests/test_jax.py index 47faf9b..98acfdd 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -99,7 +99,7 @@ def test_jax_multinomial_label_detection(): x = rng.normal(size=n) df = pd.DataFrame({"y": rng.integers(0, 3, size=n), "x": x}) - model = _JaxFit("y ~ x", df, num_epochs=300) + model = _JaxFit("y ~ x", df) assert list(model.classes_) == [0, 1, 2] probs = model.predict(df) From b855e590742a25a41ad39878c9391ba75b41dc8d Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Mon, 8 Jun 2026 23:17:55 +0200 Subject: [PATCH 14/14] update to install devel options --- .github/workflows/python-app.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 9f8c443..00d6d36 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -30,7 +30,7 @@ jobs: run: | uv venv --python ${{ matrix.python-version }} uv pip install flake8 pytest pytest-cov - uv pip install -e . + uv pip install -e ".[dev]" - name: Lint with flake8 run: |