diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 00d6d36..cdf2d43 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -12,16 +12,18 @@ permissions: jobs: test: - runs-on: macos-26 + runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: + os: [macos-26, ubuntu-latest] python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v6 - name: Install uv - uses: astral-sh/setup-uv@v8.1.0 + uses: astral-sh/setup-uv@v8.2.0 - name: Set up Python ${{ matrix.python-version }} run: uv python install ${{ matrix.python-version }} @@ -42,7 +44,7 @@ jobs: uv run pytest tests/ -v --cov=pySEQTarget --cov-report=xml - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v6 + uses: codecov/codecov-action@v7 with: token: ${{ secrets.CODECOV_TOKEN }} slug: CausalInference/pySEQTarget diff --git a/docs/_static/custom.css b/docs/_static/custom.css index 23a7d8b..394c984 100644 --- a/docs/_static/custom.css +++ b/docs/_static/custom.css @@ -6,3 +6,12 @@ div#top_nav nav { padding: 0.7rem 1rem; } + +/* The navbar height is mostly content-driven (h1 + search box), so the padding + override alone leaves it tall. Shrink the title so the row collapses to the + intended height. */ +div#top_nav nav h1 { + margin: 0; + font-size: 1rem; + line-height: 1.2; +} diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 190e60d..111e92b 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -23,6 +23,10 @@ _fit_visit, _offload_weights, _weight_bind, _weight_predict, _weight_setup, _weight_stats) +# Default seed used when the user supplies none, so an unseeded run is +# deterministic across processes (matching SEQTaRget's capture of .Random.seed). +_DEFAULT_SEED = 0 + class SEQuential: """ @@ -72,9 +76,18 @@ def __init__( for name, value in asdict(parameters).items(): setattr(self, name, value) - self._rng = ( - np.random.RandomState(self.seed) if self.seed is not None else np.random - ) + # Mirror SEQTaRget (R): always pin a concrete seed so the Monte-Carlo + # hazard simulation is reseeded before each run and is reproducible. + # R captures .Random.seed when none is given, which is fixed in a fresh + # process, so an unseeded R run is deterministic across runs. We match + # that with a fixed default seed rather than falling back to the global, + # never-reseeded np.random — which let hazard estimates change silently + # between otherwise identical runs. + if self.seed is None: + self.seed = _DEFAULT_SEED + if self.verbose: + print(f"No seed supplied; using default seed {self.seed}") + self._rng = np.random.RandomState(self.seed) self._offloader = Offloader(enabled=self.offload, dir=self.offload_dir) @@ -154,10 +167,6 @@ def expand(self): self.cense_denominator, ] ).union(kept), - ).with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)) - - self.data = self.data.with_columns( - pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col) ) if self.verbose: @@ -241,25 +250,30 @@ def fit(self) -> None: boot_idx = self._current_boot_idx if self.weighted: - WDT = _weight_setup(self) + WDT_pl = _weight_setup(self) if not self.weight_preexpansion and not self.excused: - WDT = WDT.filter(pl.col("followup") > 0) - - WDT = WDT.to_pandas() + WDT_pl = WDT_pl.filter(pl.col("followup") > 0) + + # The weight-fit helpers (_fit_LTFU etc.) use pandas-style indexing + # and pass pandas frames to glum/statsmodels, so we convert once. + # The fits don't mutate WDT_pd - they store models on `self` - so + # we keep the original polars frame for the downstream steps + # rather than paying a pl.from_pandas() round-trip per replicate. + WDT_pd = WDT_pl.to_pandas() for col in self.fixed_cols: - if col in WDT.columns: - WDT[col] = WDT[col].astype("category") + if col in WDT_pd.columns: + WDT_pd[col] = WDT_pd[col].astype("category") - _fit_LTFU(self, WDT) - _fit_visit(self, WDT) - _fit_numerator(self, WDT) - _fit_denominator(self, WDT) + _fit_LTFU(self, WDT_pd) + _fit_visit(self, WDT_pd) + _fit_numerator(self, WDT_pd) + _fit_denominator(self, WDT_pd) if self.offload: _offload_weights(self, boot_idx) - WDT = pl.from_pandas(WDT) - WDT = _weight_predict(self, WDT) + del WDT_pd + WDT = _weight_predict(self, WDT_pl) _weight_bind(self, WDT) self.weight_stats = _weight_stats(self) diff --git a/pySEQTarget/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py index b4e7981..9c13656 100644 --- a/pySEQTarget/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -130,7 +130,22 @@ def _outcome_fit( case "glum": from ..helpers._glum_fit import _fit_glum - return _fit_glum(full_formula, df_pd, var_weights=var_weights) + # Per-instance cache of patsy design infos, keyed by formula. The + # main fit populates it; bootstrap replicates skip the formula + # parse and reuse the cached column structure. Cleared on the + # main-fit pass so a second fit() call with a different formula + # doesn't hit stale entries. + if getattr(self, "_current_boot_idx", None) is None: + self._patsy_design_cache = {} + cache = self.__dict__.setdefault("_patsy_design_cache", {}) + + return _fit_glum( + full_formula, + df_pd, + var_weights=var_weights, + start_params=start_params, + design_cache=cache, + ) case "jax": from ..helpers._jax_fit import _fit_jax diff --git a/pySEQTarget/error/_param_checker.py b/pySEQTarget/error/_param_checker.py index fcacad4..14b0b94 100644 --- a/pySEQTarget/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -59,6 +59,24 @@ def _param_checker(self): "For weighted ITT analyses, cense_colname or visit_colname must be provided." ) + if ( + self.weighted + and self.method != "ITT" + and self.numerator is not None + and self.denominator is not None + and self.numerator == self.denominator + ): + warnings.warn( + f"Numerator and denominator weight models use identical " + f"covariates ('{self.numerator}'); the stabilized weights " + "will all equal 1 (i.e., no weighting). The denominator " + "should typically include the time-varying confounders " + "that the numerator omits — check for a typo in either or " + "both of 'numerator' and 'denominator'.", + UserWarning, + stacklevel=2, + ) + if self.excused: _, self.excused_colnames = _pad(self.treatment_level, self.excused_colnames) _, self.weight_eligible_colnames = _pad( diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index a8dd844..733be5e 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -18,18 +18,34 @@ def _prepare_boot_data(self, data, boot_id): {self.id_col: list(id_counts.keys()), "count": list(id_counts.values())} ) + # Build a per-row unique ID for the resampled frame. With an integer-typed + # ID column do this in integer arithmetic (orig_id * id_mult + replicate) + # rather than the old string-concat form ("{orig_id}_{replicate}"), since + # string IDs make every downstream join/groupby ~3-5x slower. id_mult is + # the max count seen this iteration plus one, so the (orig_id, replicate) + # pair maps to a unique int64 with room. Falls back to string concat for + # non-integer ID columns (e.g. user-supplied string IDs). + id_is_int = data.schema[self.id_col].is_integer() + if id_is_int: + id_mult = (max(id_counts.values()) if id_counts else 1) + 1 + # _weight_bind recovers orig_id with `id // id_mult` to join the + # bootstrap-resampled self.DT back to the un-resampled WDT, so the + # multiplier has to be discoverable downstream. + self._boot_id_mult = id_mult + new_id = ( + pl.col(self.id_col).cast(pl.Int64) * id_mult + pl.col("replicate") + ).alias(self.id_col) + else: + new_id = ( + pl.col(self.id_col).cast(pl.Utf8) + "_" + pl.col("replicate").cast(pl.Utf8) + ).alias(self.id_col) + bootstrapped = ( data.lazy() .join(counts.lazy(), on=self.id_col, how="inner") .with_columns(pl.int_ranges(0, pl.col("count")).alias("replicate")) .explode("replicate") - .with_columns( - ( - pl.col(self.id_col).cast(pl.Utf8) - + "_" - + pl.col("replicate").cast(pl.Utf8) - ).alias(self.id_col) - ) + .with_columns(new_id) .drop("count", "replicate") .collect() ) diff --git a/pySEQTarget/helpers/_glum_fit.py b/pySEQTarget/helpers/_glum_fit.py index b375194..b9f2a73 100644 --- a/pySEQTarget/helpers/_glum_fit.py +++ b/pySEQTarget/helpers/_glum_fit.py @@ -5,6 +5,26 @@ import patsy from glum import GeneralizedLinearRegressor +from ._fix_categories import _fix_categories_for_predict + + +def _align_categories(design_info, data): + """ + Re-align ``data``'s categorical columns to the level set and ORDER frozen + in ``design_info``. Wraps ``_fix_categories_for_predict`` (which expects a + model-like object) so the cached design info can be re-applied to a + bootstrap resample whose categoricals materialised in a different order. + """ + + class _Stub: + class model: + class data: + pass + + stub = _Stub() + stub.model.data.design_info = design_info + return _fix_categories_for_predict(stub, data) + class _GlumFit: """ @@ -102,16 +122,61 @@ def summary(self): return smry -def _fit_glum(formula, data, var_weights=None): - """Fit a binomial GLM with glum and return a _GlumFit wrapper.""" - y_mat, X_mat = patsy.dmatrices(formula, data, return_type="dataframe") +def _fit_glum(formula, data, var_weights=None, start_params=None, design_cache=None): + """Fit a binomial GLM with glum and return a _GlumFit wrapper. + + ``start_params`` is the cached ``(values, names)`` tuple from a previous fit, + used as a warm-start in the bootstrap loop. It is only honoured when the + design matrix columns line up exactly with ``names`` - a bootstrap resample + can drop a categorical level and shift the column structure, in which case + the cached coefs are meaningless and using them as init would derail the + coordinate-descent solver. + + ``design_cache`` is an optional ``dict`` keyed by ``formula``. On a hit, the + formula parse and patsy model.frame construction are skipped and the cached + ``(y_design_info, X_design_info)`` are re-applied to ``data`` via + ``patsy.build_design_matrices``. On a miss, ``patsy.dmatrices`` parses the + formula and the result is stored. Caching freezes the categorical encoding + to the main fit's column structure, which also makes the warm-start + guarantee trivially satisfied for every replicate. + """ + if design_cache is not None and formula in design_cache: + y_dinfo, x_dinfo = design_cache[formula] + try: + y_mat, X_mat = patsy.build_design_matrices( + [y_dinfo, x_dinfo], data, return_type="dataframe" + ) + except patsy.PatsyError as e: + if "mismatching levels" not in str(e): + raise + # A bootstrap resample can realise the same categorical levels in a + # different ORDER than the cached design_info froze. Re-align the + # categories to the cached structure and retry, so the cached column + # layout (and the warm-start that relies on it) stays valid. + data = _align_categories(x_dinfo, data.copy()) + y_mat, X_mat = patsy.build_design_matrices( + [y_dinfo, x_dinfo], data, return_type="dataframe" + ) + else: + y_mat, X_mat = patsy.dmatrices(formula, data, return_type="dataframe") + if design_cache is not None: + design_cache[formula] = (y_mat.design_info, X_mat.design_info) + y_arr = y_mat.values.ravel() design_info = X_mat.design_info feature_names = list(X_mat.columns) # "Intercept" first, then predictors X_design = X_mat.values # includes intercept column (for covariance) X_arr = X_mat.drop(columns=["Intercept"]).values - glm = GeneralizedLinearRegressor(family="binomial", fit_intercept=True) + init = None + if start_params is not None: + sp_values, sp_names = start_params + if list(sp_names) == feature_names: + init = np.asarray(sp_values, dtype=float) + + glm = GeneralizedLinearRegressor( + family="binomial", fit_intercept=True, start_params=init + ) sample_weight = None fit_kwargs = {} diff --git a/pySEQTarget/weighting/_weight_bind.py b/pySEQTarget/weighting/_weight_bind.py index 307c426..d159af5 100644 --- a/pySEQTarget/weighting/_weight_bind.py +++ b/pySEQTarget/weighting/_weight_bind.py @@ -6,9 +6,21 @@ def _weight_bind(self, WDT): join = "inner" on = [self.id_col, "period"] WDT = WDT.rename({self.time_col: "period"}) - self.DT = self.DT.with_columns( - pl.col(self.id_col).str.replace(r"_\d+$", "").alias(self.id_col) - ) + # On a bootstrap pass _prepare_boot_data transformed id_col so that + # each replicate has a unique value -- integer math (orig_id * id_mult + # + replicate) for int IDs, "{orig_id}_{replicate}" for string IDs. + # Recover the original ID here so the join to WDT (which still carries + # un-resampled originals) lines up. No-op on the main fit pass. + is_boot = getattr(self, "_current_boot_idx", None) is not None + if is_boot: + if self.DT.schema[self.id_col].is_integer(): + self.DT = self.DT.with_columns( + (pl.col(self.id_col) // self._boot_id_mult).alias(self.id_col) + ) + else: + self.DT = self.DT.with_columns( + pl.col(self.id_col).str.replace(r"_\d+$", "").alias(self.id_col) + ) else: join = "left" on = [self.id_col, "trial", "followup"] diff --git a/pyproject.toml b/pyproject.toml index 41c78d3..046f356 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.13.6" +version = "0.13.7" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} diff --git a/tests/test_bootstrap_ids.py b/tests/test_bootstrap_ids.py new file mode 100644 index 0000000..a84d5cd --- /dev/null +++ b/tests/test_bootstrap_ids.py @@ -0,0 +1,77 @@ +"""Behavioural tests for the integer-ID bootstrap path.""" + +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data +from pySEQTarget.helpers._bootstrap import _prepare_boot_data + + +def _build(**opts): + opts.setdefault("seed", 42) + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(**opts), + ) + s.expand() + return s + + +def test_expand_preserves_int_id_dtype(): + # SEQdata's ID is Int64. expand() must not coerce it to Utf8 - string IDs + # make every downstream join/groupby ~3-5x slower than the int path. + s = _build() + assert s.DT.schema["ID"] == pl.Int64 + assert s.data.schema["ID"] == pl.Int64 + + +def test_bootstrap_id_uses_integer_arithmetic_for_int_ids(): + # Resampled IDs are built as orig_id * id_mult + replicate. Each + # (orig_id, replicate) pair has to map to a unique new ID and the original + # ID must be recoverable via integer division. + s = _build(bootstrap_nboot=3) + s.bootstrap() + + boot = _prepare_boot_data(s, s.DT, boot_id=0) + assert boot.schema["ID"] == pl.Int64 + id_mult = s._boot_id_mult + assert id_mult >= 2 + + orig_ids = set(s.DT["ID"].to_list()) + # Recovered IDs are all from the original set + recovered = boot["ID"] // id_mult + assert set(recovered.to_list()) <= orig_ids + # The replicate component is bounded by id_mult, so the (orig, rep) pair is + # uniquely encoded + rep = boot["ID"] - recovered * id_mult + assert rep.min() >= 0 + assert rep.max() < id_mult + + +def test_bootstrap_id_falls_back_to_string_concat_for_non_int_ids(): + # User-supplied non-integer IDs still work via the original "{id}_{rep}" + # string-concat path. Build a String-keyed DT manually since pySEQTarget no + # longer casts to Utf8. + s = _build(bootstrap_nboot=3) + s.bootstrap() + # Coerce id_col to Utf8 in both the DT and the boot_samples Counter keys so + # the join lines up + s.DT = s.DT.with_columns(pl.col("ID").cast(pl.Utf8)) + from collections import Counter + + s._boot_samples = [ + Counter({str(k): v for k, v in c.items()}) for c in s._boot_samples + ] + + boot = _prepare_boot_data(s, s.DT, boot_id=0) + assert boot.schema["ID"] == pl.Utf8 + # IDs follow the "{orig}_{rep}" pattern + assert all("_" in v for v in boot["ID"].unique().to_list()) diff --git a/tests/test_followup_class.py b/tests/test_followup_class.py new file mode 100644 index 0000000..c59182d --- /dev/null +++ b/tests/test_followup_class.py @@ -0,0 +1,36 @@ +import re + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def test_followup_class_encodes_followup_as_a_factor(): + # followup_class=True makes the outcome model treat follow-up as categorical + # (cast to category in _cast_categories), so it gains one patsy dummy + # 'followup[T.]' per non-reference follow-up level and loses the linear + # followup / followup_sq pair. It is exclusive with followup_include, so that + # is switched off here. + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(followup_class=True, followup_include=False), + ) + s.expand() + s.fit() + names = list(s.outcome_model[0]["outcome"].params.index) + + # Categorical, not continuous: no linear follow-up terms + assert "followup" not in names + assert "followup_sq" not in names + + # One dummy per non-reference follow-up level + dummies = [n for n in names if re.fullmatch(r"followup\[T\.\d+\]", n)] + assert len(dummies) > 2 + assert len(dummies) == s.DT["followup"].n_unique() - 1 diff --git a/tests/test_followup_min_max.py b/tests/test_followup_min_max.py new file mode 100644 index 0000000..172a211 --- /dev/null +++ b/tests/test_followup_min_max.py @@ -0,0 +1,35 @@ +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _expand(**opts): + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(expand_only=True, **opts), + ) + s.expand() + return s.DT + + +def test_followup_min_max_restrict_the_expanded_followup_range(): + # Expansion filters rows to followup in [followup_min, followup_max]. With + # expand_only=True the DT is returned without any fit step, so the clamp is + # directly visible. + full = _expand() + lim = _expand(followup_min=3, followup_max=10) + + # Unrestricted expansion genuinely extends past the requested window + assert full["followup"].min() < 3 + assert full["followup"].max() > 10 + # Restricted expansion is clamped to exactly [3, 10] and has fewer rows + assert lim["followup"].min() == 3 + assert lim["followup"].max() == 10 + assert lim.height < full.height diff --git a/tests/test_followup_trial_include.py b/tests/test_followup_trial_include.py new file mode 100644 index 0000000..3a631db --- /dev/null +++ b/tests/test_followup_trial_include.py @@ -0,0 +1,44 @@ +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _coef_names(**opts): + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(**opts), + ) + s.expand() + s.fit() + return set(s.outcome_model[0]["outcome"].params.index) + + +def test_followup_include_and_trial_include_add_or_drop_outcome_terms(): + # These flags control whether the follow-up and trial terms (and their + # squares) enter the outcome-model formula, so the effect is visible in the + # fitted coefficient names. + both = _coef_names() + no_fup = _coef_names(followup_include=False) + no_trial = _coef_names(trial_include=False) + + fup_terms = {"followup", "followup_sq"} + trial_terms = {"trial", "trial_sq"} + + # Baseline: all four terms present + assert fup_terms <= both + assert trial_terms <= both + + # followup_include=False drops the follow-up terms but keeps the trial terms + assert fup_terms.isdisjoint(no_fup) + assert trial_terms <= no_fup + + # trial_include=False drops the trial terms but keeps the follow-up terms + assert trial_terms.isdisjoint(no_trial) + assert fup_terms <= no_trial diff --git a/tests/test_glum.py b/tests/test_glum.py index 76ba434..cb3b3c5 100644 --- a/tests/test_glum.py +++ b/tests/test_glum.py @@ -141,3 +141,246 @@ def risk_diff(pkg): return rd["Risk Difference"].to_list() assert risk_diff("glum") == approx(risk_diff("statsmodels"), rel=1e-2, abs=2e-3) + + +def _run_bootstrap_glum_outcome_coefs(monkeypatch, disable_warm_start): + """Run a small bootstrap fit with glum and return each replicate's outcome coefs.""" + import pySEQTarget.helpers._glum_fit as glm_mod + + if disable_warm_start: + original = glm_mod._fit_glum + + def patched(*args, **kwargs): + kwargs["start_params"] = None + return original(*args, **kwargs) + + monkeypatch.setattr(glm_mod, "_fit_glum", patched) + + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(glm_package="glum", bootstrap_nboot=3, seed=42), + ) + s.expand() + s.bootstrap() + s.fit() + return [list(m["outcome"].params.values) for m in s.outcome_model] + + +def test_glum_warm_start_is_wired_through_bootstrap(monkeypatch): + # The main fit gets start_params=None (cold); every bootstrap replicate gets + # the cached (values, names) tuple. Capture start_params on each call to + # _fit_glum and assert the wiring. + import pySEQTarget.helpers._glum_fit as glm_mod + + seen = [] + original = glm_mod._fit_glum + + def spy(*args, **kwargs): + seen.append(kwargs.get("start_params")) + return original(*args, **kwargs) + + monkeypatch.setattr(glm_mod, "_fit_glum", spy) + + _run_bootstrap_glum_outcome_coefs(monkeypatch, disable_warm_start=False) + + assert len(seen) >= 4 # main + 3 bootstrap replicates + assert seen[0] is None # main fit: cold start + for sp in seen[1:]: + assert sp is not None + values, names = sp + assert len(values) == len(names) + assert names[0] == "Intercept" + + +def test_glum_warm_start_matches_cold_start_outcome_coefs(monkeypatch): + # Warm-start is a pure convergence optimisation: it must reach the same + # optimum as a cold start. Compare per-replicate outcome coefficients with + # warm-start enabled (default) vs forcibly disabled. + warm = _run_bootstrap_glum_outcome_coefs(monkeypatch, disable_warm_start=False) + cold = _run_bootstrap_glum_outcome_coefs(monkeypatch, disable_warm_start=True) + + # Both runs reach an optimum within glum's coordinate-descent tolerance, but + # the warm-start path stops at a point a few iterations earlier and so + # differs from the cold-start path within that tolerance. A genuine wiring + # regression would shift coefficients by several percent or more. + assert len(warm) == len(cold) + for w, c in zip(warm, cold): + assert w == approx(c, rel=1e-2, abs=1e-4) + + +def test_glum_design_cache_avoids_reparsing_on_bootstrap(monkeypatch): + # The main fit calls patsy.dmatrices to build the design info; every + # bootstrap replicate should hit the cache and reuse it via + # patsy.build_design_matrices instead of re-parsing the formula. + import patsy + + import pySEQTarget.helpers._glum_fit as glm_mod + + real_dmatrices = patsy.dmatrices + real_build = patsy.build_design_matrices + parsed_formulas = [] + build_calls = [0] + + def spy_dmatrices(formula, *a, **kw): + parsed_formulas.append(formula) + return real_dmatrices(formula, *a, **kw) + + def spy_build(*a, **kw): + build_calls[0] += 1 + return real_build(*a, **kw) + + monkeypatch.setattr(glm_mod.patsy, "dmatrices", spy_dmatrices) + monkeypatch.setattr(glm_mod.patsy, "build_design_matrices", spy_build) + + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(glm_package="glum", bootstrap_nboot=3, seed=42), + ) + s.expand() + s.bootstrap() + s.fit() + + # ITT has no weight models, so the only formula reaching _fit_glum is the + # outcome formula. It must be parsed exactly once (the main fit); the three + # bootstrap replicates must reuse the cached dinfo via build_design_matrices. + outcome_formula = f"{s.outcome_col} ~ {s.covariates}" + assert parsed_formulas.count(outcome_formula) == 1 + assert build_calls[0] >= 3 + # The cache survives onto self for inspection / future replicates. + assert outcome_formula in s._patsy_design_cache + + +def _run_bootstrap_outcome_coefs_with_cache_disabled(monkeypatch): + """Run a small bootstrap fit with the design cache forcibly disabled.""" + import pySEQTarget.helpers._glum_fit as glm_mod + + original = glm_mod._fit_glum + + def patched(*args, **kwargs): + kwargs["design_cache"] = None + return original(*args, **kwargs) + + monkeypatch.setattr(glm_mod, "_fit_glum", patched) + + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(glm_package="glum", bootstrap_nboot=3, seed=42), + ) + s.expand() + s.bootstrap() + s.fit() + return [list(m["outcome"].params.values) for m in s.outcome_model] + + +def test_glum_design_cache_matches_no_cache_outcome_coefs(monkeypatch): + # The design cache freezes the categorical column structure to the main + # fit's columns. On SEQdata both arms appear in every bootstrap resample, so + # a no-cache run encodes the same columns from scratch each time and must + # converge to the same coefficients (within glum's tolerance). + cached = _run_bootstrap_glum_outcome_coefs(monkeypatch, disable_warm_start=False) + no_cache = _run_bootstrap_outcome_coefs_with_cache_disabled(monkeypatch) + + assert len(cached) == len(no_cache) + for c, nc in zip(cached, no_cache): + assert c == approx(nc, rel=1e-2, abs=1e-4) + + +def test_glum_design_cache_handles_categorical_level_reordering(): + # A bootstrap resample can realise the same categorical levels in a + # different ORDER than the cached design_info froze (e.g. polars->pandas + # appearance order on the full data vs sorted order on a resample). The + # cached build_design_matrices path must re-align the categories instead of + # raising "mismatching levels". Regression for the short-course render. + import numpy as np + import pandas as pd + + from pySEQTarget.helpers._glum_fit import _fit_glum + + rng = np.random.default_rng(0) + n = 2000 + levels = ["16-29", "30-39", "40-49", "50+"] + formula = "y ~ age_grp + x" + cache = {} + + # Main fit: categories in a NON-sorted order; this freezes the cache. + main = pd.DataFrame( + { + "age_grp": pd.Categorical( + rng.choice(levels, n), categories=["30-39", "16-29", "40-49", "50+"] + ), + "x": rng.standard_normal(n), + "y": (rng.random(n) < 0.4).astype(int), + } + ) + m = _fit_glum(formula, main, design_cache=cache) + + # Bootstrap: same levels, sorted order — the crash trigger. + boot = pd.DataFrame( + { + "age_grp": pd.Categorical( + rng.choice(levels, n), categories=["16-29", "30-39", "40-49", "50+"] + ), + "x": rng.standard_normal(n), + "y": (rng.random(n) < 0.4).astype(int), + } + ) + mb = _fit_glum(formula, boot, design_cache=cache) + + assert np.all(np.isfinite(mb.params.values)) + # The cached column structure is preserved (categories re-aligned, not reparsed). + assert list(mb.params.index) == list(m.params.index) + + +def test_glum_warm_start_dropped_when_design_columns_mismatch(): + # The defensive guard in _fit_glum: a (values, names) tuple whose names + # don't line up with the patsy design matrix must be ignored, falling back + # to the cold-start init and producing the same fit as start_params=None. + import numpy as np + import pandas as pd + + from pySEQTarget.helpers._glum_fit import _fit_glum + + rng = np.random.default_rng(0) + n = 1000 + df = pd.DataFrame( + { + "x1": rng.standard_normal(n), + "x2": rng.standard_normal(n), + "y": (rng.random(n) < 0.4).astype(int), + } + ) + + ref = _fit_glum("y ~ x1 + x2", df) + bogus = (np.zeros(5), ["Intercept", "wrong", "names", "here", "extra"]) + bogus_fit = _fit_glum("y ~ x1 + x2", df, start_params=bogus) + + assert list(bogus_fit.params.values) == approx( + list(ref.params.values), rel=1e-8, abs=1e-12 + ) diff --git a/tests/test_jax.py b/tests/test_jax.py index 98acfdd..7461437 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -3,6 +3,10 @@ import pytest from pytest import approx +# jax is an optional dependency (the ``gpu`` extra) and is not installed on +# every platform — skip the whole module rather than erroring at collection. +pytest.importorskip("jax") + from pySEQTarget import SEQopts, SEQuential from pySEQTarget.data import load_data from pySEQTarget.helpers._jax_fit import _JaxFit diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 4f0969e..c60a42f 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -23,6 +23,57 @@ def _make_seq(seed, **extra_opts): ) +def test_unseeded_run_assigns_stable_concrete_seed(): + # With no seed, the hazard simulation must not fall back to the global, + # never-reseeded np.random. Instead a concrete seed is drawn once, recorded + # on self.seed, and held fixed for the life of the object so the seed is the + # same before and after a run (and can be reported to reproduce it). + s = _make_seq(seed=None, hazard_estimate=True) + + before = s.seed + assert before is not None + assert isinstance(before, int) + assert 0 <= before < 2**32 + assert isinstance(s._rng, np.random.RandomState) + + s.expand() + s.fit() + s.hazard() + + assert s.seed == before + + +def test_two_unseeded_runs_are_deterministic(): + # With no seed supplied, runs use a fixed default seed (mirroring R), so two + # otherwise identical unseeded runs produce the same hazard ratio. + results = [] + for _ in range(2): + s = _make_seq(seed=None, hazard_estimate=True) + s.expand() + s.fit() + s.hazard() + results.append(s.hazard_ratio["Hazard ratio"][0]) + + assert results[0] == results[1] + + +def test_unseeded_captured_seed_reproduces_hazard(): + # The seed recorded on an unseeded run is the one actually used, so feeding + # it back as an explicit seed reproduces the hazard ratio exactly. + s1 = _make_seq(seed=None, hazard_estimate=True) + captured = s1.seed + s1.expand() + s1.fit() + s1.hazard() + + s2 = _make_seq(seed=captured, hazard_estimate=True) + s2.expand() + s2.fit() + s2.hazard() + + assert s1.hazard_ratio["Hazard ratio"][0] == s2.hazard_ratio["Hazard ratio"][0] + + def test_hazard_reproducible_with_seed(): results = [] for _ in range(2): diff --git a/tests/test_selection_random.py b/tests/test_selection_random.py new file mode 100644 index 0000000..6dbe2b8 --- /dev/null +++ b/tests/test_selection_random.py @@ -0,0 +1,56 @@ +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _build(**opts): + opts.setdefault("seed", 1) + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(**opts), + ) + s.expand() + return s + + +def _arm_trial_starts(dt): + """Trial-starts (followup == 0) per baseline-treatment arm.""" + counts = ( + dt.filter(pl.col("followup") == 0) + .group_by("tx_init_bas") + .len() + .sort("tx_init_bas") + ) + return dict(zip(counts["tx_init_bas"].to_list(), counts["len"].to_list())) + + +def test_selection_random_keeps_all_treated_and_subsamples_controls(): + # With selection_random=True, treated trial-starts (tx_init_bas == 1) are + # all retained, while control trial-starts (tx_init_bas == 0) are + # subsampled to int(selection_sample * N_controls). + prob = 0.5 + + base = _build() + sel = _build(selection_random=True, selection_sample=prob) + + base_c = _arm_trial_starts(base.DT) + sel_c = _arm_trial_starts(sel.DT) + + assert sel_c[1] == base_c[1] + assert sel_c[0] < base_c[0] + assert sel_c[0] == int(prob * base_c[0]) + + +def test_selection_random_is_reproducible_with_fixed_seed(): + a = _build(selection_random=True, selection_sample=0.5, seed=7) + b = _build(selection_random=True, selection_sample=0.5, seed=7) + assert a.DT.equals(b.DT) diff --git a/tests/test_warnings.py b/tests/test_warnings.py new file mode 100644 index 0000000..6fd7ef5 --- /dev/null +++ b/tests/test_warnings.py @@ -0,0 +1,64 @@ +import warnings + +import pytest + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _build(**opts): + return SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method=opts.pop("method", "censoring"), + parameters=SEQopts(**opts), + ) + + +def test_warns_when_numerator_and_denominator_are_identical(): + # Identical num/denom -> stabilized weights all 1 -> usually a typo. + formula = "sex" + with pytest.warns(UserWarning, match="identical"): + _build(weighted=True, numerator=formula, denominator=formula) + + +def test_no_warning_when_numerator_and_denominator_differ(): + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + _build(weighted=True, numerator="sex", denominator="sex+N+L+P") + + +def test_no_warning_under_ITT_even_if_identical(): + # ITT doesn't fit treatment-weight models, so the warning is gated on method. + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + # weighted ITT requires LTFU/visit; use LTFU dataset + SEQuential( + load_data("SEQdata_LTFU"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts( + weighted=True, + cense_colname="LTFU", + numerator="sex", + denominator="sex", + ), + ) + + +def test_no_warning_when_not_weighted(): + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + _build(weighted=False, numerator="sex", denominator="sex") diff --git a/tests/test_weight_block_no_roundtrip.py b/tests/test_weight_block_no_roundtrip.py new file mode 100644 index 0000000..75e877b --- /dev/null +++ b/tests/test_weight_block_no_roundtrip.py @@ -0,0 +1,47 @@ +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def test_weight_predict_receives_polars_frame(monkeypatch): + # Guard against re-introducing the pl.from_pandas() round-trip in the + # weighted fit block: _weight_predict must receive the original polars + # frame (not one that was just rebuilt from pandas), since the weight-fit + # helpers store models on `self` and don't mutate WDT. + import importlib + + seq_mod = importlib.import_module("pySEQTarget.SEQuential") + + original = seq_mod._weight_predict + seen_types = [] + + def spy(self, WDT): + seen_types.append(type(WDT)) + return original(self, WDT) + + monkeypatch.setattr(seq_mod, "_weight_predict", spy) + + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_preexpansion=True, + bootstrap_nboot=2, + seed=42, + ), + ) + s.expand() + s.bootstrap() + s.fit() + + assert len(seen_types) >= 1 + assert all(t is pl.DataFrame for t in seen_types) diff --git a/tests/test_weight_eligible_colnames.py b/tests/test_weight_eligible_colnames.py new file mode 100644 index 0000000..cb05a48 --- /dev/null +++ b/tests/test_weight_eligible_colnames.py @@ -0,0 +1,41 @@ +import polars as pl + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _nobs_per_arm(weight_eligible_colnames): + data = load_data("SEQdata") + # Balanced 0/1 eligibility indicator carried through expansion. + median_n = data["N"].median() + data = data.with_columns((pl.col("N") > median_n).cast(pl.Int32).alias("welig")) + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, + weight_eligible_colnames=weight_eligible_colnames, + seed=1, + ), + ) + s.expand() + s.fit() + return [int(m.nobs) for m in s.denominator_model] + + +def test_weight_eligible_colnames_restricts_weight_models_to_eligible_rows(): + # Each arm's weight model is fit only on rows where its + # weight_eligible_colnames indicator == 1 (_get_subset_for_level). With a + # roughly half-on indicator the per-arm denominator nobs drops below the + # unfiltered baseline. + base = _nobs_per_arm(weight_eligible_colnames=[]) + elig = _nobs_per_arm(weight_eligible_colnames=["welig", "welig"]) + + assert all(e < b for e, b in zip(elig, base)) diff --git a/tests/test_weight_lag_condition.py b/tests/test_weight_lag_condition.py new file mode 100644 index 0000000..e225ed9 --- /dev/null +++ b/tests/test_weight_lag_condition.py @@ -0,0 +1,36 @@ +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _nobs_per_arm(weight_lag_condition): + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts( + weighted=True, weight_lag_condition=weight_lag_condition, seed=1 + ), + ) + s.expand() + s.fit() + return [int(m.nobs) for m in s.denominator_model] + + +def test_weight_lag_condition_conditions_each_arm_on_its_treatment_lag_stratum(): + # weight_lag_condition=True (default): each arm's weight model is fit only on + # rows where tx_lag matches that arm (per-arm row counts differ but partition + # the full data). =False: both arms fit on the full data (equal counts). + on = _nobs_per_arm(weight_lag_condition=True) + off = _nobs_per_arm(weight_lag_condition=False) + + # FALSE: both arms see the full data -> equal observation counts + assert off[0] == off[1] + # TRUE: arms fit on disjoint treatment-lag strata that partition that full data + assert on[0] != on[1] + assert on[0] + on[1] == off[0] diff --git a/tests/test_weight_p99.py b/tests/test_weight_p99.py new file mode 100644 index 0000000..be856f2 --- /dev/null +++ b/tests/test_weight_p99.py @@ -0,0 +1,44 @@ +import numpy as np + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _fit(**opts): + opts.setdefault("seed", 1) + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts(weighted=True, **opts), + ) + s.expand() + s.fit() + return s + + +def _coefs(s): + return s.outcome_model[0]["outcome"].params.values + + +def test_weight_p99_truncates_at_p01_p99_percentile_weights(): + # weight_p99=True overrides weight_min/weight_max with the p01/p99 of the + # (untruncated) weights -- these are reported in weight_stats. So it must be + # equivalent to passing those percentile values as explicit bounds, and must + # differ from an untruncated weighted fit. + p99 = _fit(weight_p99=True) + ws = p99.weight_stats + p01_val = float(ws["weight_p01"][0]) + p99_val = float(ws["weight_p99"][0]) + + explicit = _fit(weight_min=p01_val, weight_max=p99_val) + untruncated = _fit() + + assert np.allclose(_coefs(p99), _coefs(explicit), atol=1e-8) + assert not np.allclose(_coefs(p99), _coefs(untruncated), atol=1e-6) diff --git a/tests/test_weight_truncation.py b/tests/test_weight_truncation.py new file mode 100644 index 0000000..5c3f6d8 --- /dev/null +++ b/tests/test_weight_truncation.py @@ -0,0 +1,41 @@ +import numpy as np + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +def _fit(**opts): + opts.setdefault("seed", 1) + s = SEQuential( + load_data("SEQdata"), + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="censoring", + parameters=SEQopts(weighted=True, **opts), + ) + s.expand() + s.fit() + return s.outcome_model[0]["outcome"].params.values + + +def test_weight_min_max_truncate_the_weights_used_in_the_outcome_fit(): + # Truncation is applied to the weight vector passed to the outcome GLM (in + # _outcome_fit.py via pl.col(weight_col).clip(weight_min, weight_max)). It + # doesn't change self.DT or weight_stats, so we check it via the fitted + # coefficients. SEQdata weights span ~0.5-2, so a band entirely above that + # range collapses every weight to the lower bound. A GLM is invariant to a + # uniform scaling of its weights, so two all-constant clamps must give + # identical coefficients, while a genuinely varying-weight fit must differ. + varying = _fit() + clamp3 = _fit(weight_min=3, weight_max=4) + clamp10 = _fit(weight_min=10, weight_max=11) + + # Both clamps collapse weights to a constant => identical fit (scale-invariant) + assert np.allclose(clamp3, clamp10, atol=1e-6) + # Clamping away the real weight variation changes the fit + assert not np.allclose(clamp3, varying, atol=1e-6)