Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
827ba75
Bump astral-sh/setup-uv to 8.2.0
remlapmot Jun 4, 2026
1126eac
Warn when numerator and denominator weight models use identical covar…
remlapmot Jun 4, 2026
7ea10ab
Bump version
remlapmot Jun 4, 2026
b44c399
Add behavioural tests for selection_random
remlapmot Jun 4, 2026
bbef15e
Add behavioural test for weight_min/weight_max truncation
remlapmot Jun 4, 2026
f3c8271
Add behavioural test for weight_p99 truncation
remlapmot Jun 4, 2026
94a52b6
Add behavioural test for followup_include and trial_include
remlapmot Jun 4, 2026
b491a1d
Add behavioural test for followup_class
remlapmot Jun 4, 2026
6fbcfa8
Add behavioural test for weight_lag_condition
remlapmot Jun 4, 2026
5cb3987
Add behavioural test for followup_min/followup_max
remlapmot Jun 4, 2026
a723f34
Add behavioural test for weight_eligible_colnames
remlapmot Jun 4, 2026
620fd6f
Auto-format code
github-actions[bot] Jun 4, 2026
f0d912d
Second attempt at fixing the RTD navbar height on deployed site
remlapmot Jun 4, 2026
324cf96
Wire glum warm-start through bootstrap outcome fits
remlapmot Jun 6, 2026
527c714
Cache the patsy design_info across bootstrap outcome fits
remlapmot Jun 6, 2026
b963cd9
Auto-format code
github-actions[bot] Jun 6, 2026
acd80cb
Use integer IDs through the bootstrap path
remlapmot Jun 6, 2026
4544bd9
Skip the pl.from_pandas round-trip in the weighted fit block
remlapmot Jun 6, 2026
a2d9788
Auto-format code
github-actions[bot] Jun 6, 2026
b72d04b
Re-align reordered categoricals on the glum design-cache bootstrap path
remlapmot Jun 6, 2026
7d6a465
Use a fixed default seed when none is supplied
remlapmot Jun 8, 2026
bf4914b
Auto-format code
github-actions[bot] Jun 8, 2026
8c5bf4d
Skip jax tests if jax not installed
remlapmot Jun 9, 2026
bbfbbb9
Additionally run CI tests on Linux
remlapmot Jun 9, 2026
a50f9be
Move numerator==denominator weight warning out of __init__ into _para…
remlapmot Jun 9, 2026
1a1a0f5
Bump codecov/codecov-action to v7
remlapmot Jun 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@ permissions:

jobs:
test:
runs-on: macos-26
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-26, ubuntu-latest]
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]

steps:
- uses: actions/checkout@v6

- name: Install uv
uses: astral-sh/setup-uv@v8.1.0
uses: astral-sh/setup-uv@v8.2.0

- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
Expand All @@ -42,7 +44,7 @@ jobs:
uv run pytest tests/ -v --cov=pySEQTarget --cov-report=xml

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v6
uses: codecov/codecov-action@v7
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: CausalInference/pySEQTarget
9 changes: 9 additions & 0 deletions docs/_static/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,12 @@
div#top_nav nav {
padding: 0.7rem 1rem;
}

/* The navbar height is mostly content-driven (h1 + search box), so the padding
override alone leaves it tall. Shrink the title so the row collapses to the
intended height. */
div#top_nav nav h1 {
margin: 0;
font-size: 1rem;
line-height: 1.2;
}
52 changes: 33 additions & 19 deletions pySEQTarget/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
_fit_visit, _offload_weights, _weight_bind,
_weight_predict, _weight_setup, _weight_stats)

# Default seed used when the user supplies none, so an unseeded run is
# deterministic across processes (matching SEQTaRget's capture of .Random.seed).
_DEFAULT_SEED = 0


class SEQuential:
"""
Expand Down Expand Up @@ -72,9 +76,18 @@ def __init__(
for name, value in asdict(parameters).items():
setattr(self, name, value)

self._rng = (
np.random.RandomState(self.seed) if self.seed is not None else np.random
)
# Mirror SEQTaRget (R): always pin a concrete seed so the Monte-Carlo
# hazard simulation is reseeded before each run and is reproducible.
# R captures .Random.seed when none is given, which is fixed in a fresh
# process, so an unseeded R run is deterministic across runs. We match
# that with a fixed default seed rather than falling back to the global,
# never-reseeded np.random — which let hazard estimates change silently
# between otherwise identical runs.
if self.seed is None:
self.seed = _DEFAULT_SEED
if self.verbose:
print(f"No seed supplied; using default seed {self.seed}")
self._rng = np.random.RandomState(self.seed)

self._offloader = Offloader(enabled=self.offload, dir=self.offload_dir)

Expand Down Expand Up @@ -154,10 +167,6 @@ def expand(self):
self.cense_denominator,
]
).union(kept),
).with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col))

self.data = self.data.with_columns(
pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)
)

if self.verbose:
Expand Down Expand Up @@ -241,25 +250,30 @@ def fit(self) -> None:
boot_idx = self._current_boot_idx

if self.weighted:
WDT = _weight_setup(self)
WDT_pl = _weight_setup(self)
if not self.weight_preexpansion and not self.excused:
WDT = WDT.filter(pl.col("followup") > 0)

WDT = WDT.to_pandas()
WDT_pl = WDT_pl.filter(pl.col("followup") > 0)

# The weight-fit helpers (_fit_LTFU etc.) use pandas-style indexing
# and pass pandas frames to glum/statsmodels, so we convert once.
# The fits don't mutate WDT_pd - they store models on `self` - so
# we keep the original polars frame for the downstream steps
# rather than paying a pl.from_pandas() round-trip per replicate.
WDT_pd = WDT_pl.to_pandas()
for col in self.fixed_cols:
if col in WDT.columns:
WDT[col] = WDT[col].astype("category")
if col in WDT_pd.columns:
WDT_pd[col] = WDT_pd[col].astype("category")

_fit_LTFU(self, WDT)
_fit_visit(self, WDT)
_fit_numerator(self, WDT)
_fit_denominator(self, WDT)
_fit_LTFU(self, WDT_pd)
_fit_visit(self, WDT_pd)
_fit_numerator(self, WDT_pd)
_fit_denominator(self, WDT_pd)

if self.offload:
_offload_weights(self, boot_idx)

WDT = pl.from_pandas(WDT)
WDT = _weight_predict(self, WDT)
del WDT_pd
WDT = _weight_predict(self, WDT_pl)
_weight_bind(self, WDT)
self.weight_stats = _weight_stats(self)

Expand Down
17 changes: 16 additions & 1 deletion pySEQTarget/analysis/_outcome_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,22 @@ def _outcome_fit(
case "glum":
from ..helpers._glum_fit import _fit_glum

return _fit_glum(full_formula, df_pd, var_weights=var_weights)
# Per-instance cache of patsy design infos, keyed by formula. The
# main fit populates it; bootstrap replicates skip the formula
# parse and reuse the cached column structure. Cleared on the
# main-fit pass so a second fit() call with a different formula
# doesn't hit stale entries.
if getattr(self, "_current_boot_idx", None) is None:
self._patsy_design_cache = {}
cache = self.__dict__.setdefault("_patsy_design_cache", {})

return _fit_glum(
full_formula,
df_pd,
var_weights=var_weights,
start_params=start_params,
design_cache=cache,
)

case "jax":
from ..helpers._jax_fit import _fit_jax
Expand Down
18 changes: 18 additions & 0 deletions pySEQTarget/error/_param_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,24 @@ def _param_checker(self):
"For weighted ITT analyses, cense_colname or visit_colname must be provided."
)

if (
self.weighted
and self.method != "ITT"
and self.numerator is not None
and self.denominator is not None
and self.numerator == self.denominator
):
warnings.warn(
f"Numerator and denominator weight models use identical "
f"covariates ('{self.numerator}'); the stabilized weights "
"will all equal 1 (i.e., no weighting). The denominator "
"should typically include the time-varying confounders "
"that the numerator omits — check for a typo in either or "
"both of 'numerator' and 'denominator'.",
UserWarning,
stacklevel=2,
)

if self.excused:
_, self.excused_colnames = _pad(self.treatment_level, self.excused_colnames)
_, self.weight_eligible_colnames = _pad(
Expand Down
30 changes: 23 additions & 7 deletions pySEQTarget/helpers/_bootstrap.py
Comment thread
remlapmot marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,34 @@ def _prepare_boot_data(self, data, boot_id):
{self.id_col: list(id_counts.keys()), "count": list(id_counts.values())}
)

# Build a per-row unique ID for the resampled frame. With an integer-typed
# ID column do this in integer arithmetic (orig_id * id_mult + replicate)
# rather than the old string-concat form ("{orig_id}_{replicate}"), since
# string IDs make every downstream join/groupby ~3-5x slower. id_mult is
# the max count seen this iteration plus one, so the (orig_id, replicate)
# pair maps to a unique int64 with room. Falls back to string concat for
# non-integer ID columns (e.g. user-supplied string IDs).
id_is_int = data.schema[self.id_col].is_integer()
if id_is_int:
id_mult = (max(id_counts.values()) if id_counts else 1) + 1
# _weight_bind recovers orig_id with `id // id_mult` to join the
# bootstrap-resampled self.DT back to the un-resampled WDT, so the
# multiplier has to be discoverable downstream.
self._boot_id_mult = id_mult
new_id = (
pl.col(self.id_col).cast(pl.Int64) * id_mult + pl.col("replicate")
).alias(self.id_col)
else:
new_id = (
pl.col(self.id_col).cast(pl.Utf8) + "_" + pl.col("replicate").cast(pl.Utf8)
).alias(self.id_col)

bootstrapped = (
data.lazy()
.join(counts.lazy(), on=self.id_col, how="inner")
.with_columns(pl.int_ranges(0, pl.col("count")).alias("replicate"))
.explode("replicate")
.with_columns(
(
pl.col(self.id_col).cast(pl.Utf8)
+ "_"
+ pl.col("replicate").cast(pl.Utf8)
).alias(self.id_col)
)
.with_columns(new_id)
.drop("count", "replicate")
.collect()
)
Expand Down
73 changes: 69 additions & 4 deletions pySEQTarget/helpers/_glum_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
import patsy
from glum import GeneralizedLinearRegressor

from ._fix_categories import _fix_categories_for_predict


def _align_categories(design_info, data):
"""
Re-align ``data``'s categorical columns to the level set and ORDER frozen
in ``design_info``. Wraps ``_fix_categories_for_predict`` (which expects a
model-like object) so the cached design info can be re-applied to a
bootstrap resample whose categoricals materialised in a different order.
"""

class _Stub:
class model:
class data:
pass

stub = _Stub()
stub.model.data.design_info = design_info
return _fix_categories_for_predict(stub, data)


class _GlumFit:
"""
Expand Down Expand Up @@ -102,16 +122,61 @@ def summary(self):
return smry


def _fit_glum(formula, data, var_weights=None):
"""Fit a binomial GLM with glum and return a _GlumFit wrapper."""
y_mat, X_mat = patsy.dmatrices(formula, data, return_type="dataframe")
def _fit_glum(formula, data, var_weights=None, start_params=None, design_cache=None):
"""Fit a binomial GLM with glum and return a _GlumFit wrapper.

``start_params`` is the cached ``(values, names)`` tuple from a previous fit,
used as a warm-start in the bootstrap loop. It is only honoured when the
design matrix columns line up exactly with ``names`` - a bootstrap resample
can drop a categorical level and shift the column structure, in which case
the cached coefs are meaningless and using them as init would derail the
coordinate-descent solver.

``design_cache`` is an optional ``dict`` keyed by ``formula``. On a hit, the
formula parse and patsy model.frame construction are skipped and the cached
``(y_design_info, X_design_info)`` are re-applied to ``data`` via
``patsy.build_design_matrices``. On a miss, ``patsy.dmatrices`` parses the
formula and the result is stored. Caching freezes the categorical encoding
to the main fit's column structure, which also makes the warm-start
guarantee trivially satisfied for every replicate.
"""
if design_cache is not None and formula in design_cache:
y_dinfo, x_dinfo = design_cache[formula]
try:
y_mat, X_mat = patsy.build_design_matrices(
[y_dinfo, x_dinfo], data, return_type="dataframe"
)
except patsy.PatsyError as e:
if "mismatching levels" not in str(e):
raise
# A bootstrap resample can realise the same categorical levels in a
# different ORDER than the cached design_info froze. Re-align the
# categories to the cached structure and retry, so the cached column
# layout (and the warm-start that relies on it) stays valid.
data = _align_categories(x_dinfo, data.copy())
y_mat, X_mat = patsy.build_design_matrices(
[y_dinfo, x_dinfo], data, return_type="dataframe"
)
else:
y_mat, X_mat = patsy.dmatrices(formula, data, return_type="dataframe")
if design_cache is not None:
design_cache[formula] = (y_mat.design_info, X_mat.design_info)

y_arr = y_mat.values.ravel()
design_info = X_mat.design_info
feature_names = list(X_mat.columns) # "Intercept" first, then predictors
X_design = X_mat.values # includes intercept column (for covariance)
X_arr = X_mat.drop(columns=["Intercept"]).values

glm = GeneralizedLinearRegressor(family="binomial", fit_intercept=True)
init = None
if start_params is not None:
sp_values, sp_names = start_params
if list(sp_names) == feature_names:
init = np.asarray(sp_values, dtype=float)

glm = GeneralizedLinearRegressor(
family="binomial", fit_intercept=True, start_params=init
)

sample_weight = None
fit_kwargs = {}
Expand Down
18 changes: 15 additions & 3 deletions pySEQTarget/weighting/_weight_bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,21 @@ def _weight_bind(self, WDT):
join = "inner"
on = [self.id_col, "period"]
WDT = WDT.rename({self.time_col: "period"})
self.DT = self.DT.with_columns(
pl.col(self.id_col).str.replace(r"_\d+$", "").alias(self.id_col)
)
# On a bootstrap pass _prepare_boot_data transformed id_col so that
# each replicate has a unique value -- integer math (orig_id * id_mult
# + replicate) for int IDs, "{orig_id}_{replicate}" for string IDs.
# Recover the original ID here so the join to WDT (which still carries
# un-resampled originals) lines up. No-op on the main fit pass.
is_boot = getattr(self, "_current_boot_idx", None) is not None
if is_boot:
if self.DT.schema[self.id_col].is_integer():
self.DT = self.DT.with_columns(
(pl.col(self.id_col) // self._boot_id_mult).alias(self.id_col)
)
else:
self.DT = self.DT.with_columns(
pl.col(self.id_col).str.replace(r"_\d+$", "").alias(self.id_col)
)
else:
join = "left"
on = [self.id_col, "trial", "followup"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.13.6"
version = "0.13.7"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand Down
Loading