diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 9f8c443..00d6d36 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -30,7 +30,7 @@ jobs: run: | uv venv --python ${{ matrix.python-version }} uv pip install flake8 pytest pytest-cov - uv pip install -e . + uv pip install -e ".[dev]" - name: Lint with flake8 run: | diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 3f12a91..ae0fec9 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -10,113 +10,63 @@ class SEQopts: Parameter builder for ``pySEQTarget.SEQuential`` analysis :param bootstrap_nboot: Number of bootstraps to perform - :type bootstrap_nboot: int :param bootstrap_sample: Subsampling proportion of ID-Trials gathered for each bootstrapping iteration - :type bootstrap_sample: float :param bootstrap_CI: If bootstrapped, confidence interval level - :type bootstrap_CI: float :param bootstrap_CI_method: If bootstrapped, confidence interval method ['SE' or 'percentile'] - :type bootstrap_CI_method: str :param cense_colname: Column name for censoring effect (LTFU, etc.) - :type cense_colname: str :param cense_denominator: Override to specify denominator patsy formula for censoring models; "1" or "" indicate intercept only model - :type cense_denominator: Optional[str] or None :param cense_numerator: Override to specify numerator patsy formula for censoring models - :type cense_numerator: Optional[str] or None :param cense_eligible_colname: Column name to identify which rows are eligible for censoring model fitting - :type cense_eligible_colname: Optional[str] or None :param compevent_colname: Column name specifying a competing event to the outcome - :type compevent_colname: str :param covariates: Override to specify the outcome patsy formula for outcome model fitting - :type covariates: Optional[str] or None :param denominator: Override to specify the outcome patsy formula for denominator model fitting - :type denominator: Optional[str] or None :param excused: Boolean to allow excused conditions when method is censoring - :type excused: bool :param excused_colnames: Column names (at the same length of treatment_level) specifying excused conditions, default ``[]`` - :type excused_colnames: List[str] :param expand_only: If True, ``SEQuential.expand()`` returns the expanded dataset and skips weighting, modelling, and survival steps - :type expand_only: bool + :param glm_package: Backend for fitting logistic (outcome/competing-event) models ["statsmodels", "glum", or "jax"], default "statsmodels". :param followup_class: Boolean to force followup values to be treated as classes - :type followup_class: bool :param followup_include: Boolean to force regular followup values into model covariates - :type followup_include: bool :param followup_spline: Boolean to force followup values to be fit to cubic spline - :type followup_spline: bool :param followup_spline_df: Degrees of freedom for the followup cubic spline, default ``4`` - :type followup_spline_df: int :param followup_max: Maximum allowed followup in analysis - :type followup_max: int or None :param followup_min: Minimum allowed followup in analysis - :type followup_min: int :param hazard_estimate: Boolean to create hazard estimates - :type hazard_estimate: bool :param indicator_baseline: How to indicate baseline columns in models - :type indicator_baseline: str :param indicator_squared: How to indicate squared columns in models - :type indicator_squared: str :param km_curves: Boolean to create survival, risk, and incidence (if applicable) estimates - :type km_curves: bool :param ncores: Number of cores to use if running in parallel, default ``max(1, cpu_count() - 1)`` - :type ncores: int :param numerator: Override to specify the outcome patsy formula for numerator models; "1" or "" indicate intercept only model - :type numerator: str :param offload: Boolean to offload intermediate model data to disk - :type offload: bool :param offload_dir: Directory to offload intermediate model data - :type offload_dir: str :param parallel: Boolean to run model fitting in parallel - :type parallel: bool :param plot_colors: List of colors for KM plots, if applicable, default ``["#F8766D", "#00BFC4", "#555555"]`` - :type plot_colors: List[str] :param plot_labels: List of length treat_level to specify treatment labeling, default ``[]`` - :type plot_labels: List[str] :param plot_title: Plot title - :type plot_title: str :param plot_type: Type of plot to show ["risk", "survival" or "incidence" if compevent is specified] - :type plot_type: str :param risk_times: Followup times at which to report risk difference and risk ratio when ``km_curves = True``. Each requested time is snapped to the latest available followup at or before it, and the maximum followup is always included. Defaults to ``None`` (report at the maximum followup only). - :type risk_times: Optional[List[float]] or None :param seed: RNG seed - :type seed: int :param selection_first_trial: Boolean to only use first trial for analysis (similar to non-expanded) - :type selection_first_trial: bool :param selection_sample: Subsampling proportion of ID-trials which did not initiate a treatment - :type selection_sample: float :param selection_random: Boolean to randomly downsample ID-trials which did not initiate a treatment - :type selection_random: bool :param subgroup_colname: Column name for subgroups to share the same weighting but different outcome model fits - :type subgroup_colname: str :param treatment_level: List of eligible treatment levels within treatment_col, default ``[0, 1]`` - :type treatment_level: List[int] :param trial_include: Boolean to force trial values into model covariates - :type trial_include: bool :param visit_colname: Column name specifying visit number - :type visit_colname: str :param weight_eligible_colnames: List of column names of length treatment_level to identify which rows are eligible for weight fitting, default ``[]`` - :type weight_eligible_colnames: List[str] :param weight_fit_method: The fitting method to be used ["newton", "bfgs", "lbfgs", "nm"], default "newton" - :type weight_fit_method: str :param weight_min: Minimum weight - :type weight_min: float :param weight_max: Maximum weight - :type weight_max: float or None :param weight_lag_condition: Boolean to fit weights based on their treatment lag - :type weight_lag_condition: bool :param weight_p99: Boolean to force weight min and max to be 1st and 99th percentile respectively - :type weight_p99: bool :param weight_preexpansion: Boolean to fit weights on preexpanded data - :type weight_preexpansion: bool :param verbose: Boolean to print dataset size summaries and bootstrap information - :type verbose: bool :param weighted: Boolean to weight analysis - :type weighted: bool """ bootstrap_nboot: int = 0 @@ -134,7 +84,7 @@ class SEQopts: excused: bool = False excused_colnames: List[str] = field(default_factory=lambda: []) expand_only: bool = False - glm_package: Literal["statsmodels", "glum"] = "statsmodels" + glm_package: Literal["statsmodels", "glum", "jax"] = "statsmodels" followup_class: bool = False followup_include: bool = True followup_max: int = None @@ -233,8 +183,8 @@ def _validate_choices(self): ) if self.bootstrap_CI_method not in ["se", "percentile"]: raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'") - if self.glm_package not in ["statsmodels", "glum"]: - raise ValueError("glm_package must be 'statsmodels' or 'glum'") + if self.glm_package not in ["statsmodels", "glum", "jax"]: + raise ValueError("glm_package must be 'statsmodels', 'glum', or 'jax'") if self.cox_package not in ["lifelines", "scikit-survival"]: raise ValueError("cox_package must be 'lifelines' or 'scikit-survival'") diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index c46cb61..190e60d 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -29,23 +29,14 @@ class SEQuential: Primary class initializer for SEQuentially nested target trial emulation :param data: Data for analysis - :type data: pl.DataFrame :param id_col: Column name for unique patient IDs - :type id_col: str :param time_col: Column name for observational time points - :type time_col: str :param eligible_col: Column name for analytical eligibility - :type eligible_col: str :param treatment_col: Column name specifying treatment per time_col - :type treatment_col: str :param outcome_col: Column name specifying outcome per time_col - :type outcome_col: str :param time_varying_cols: Time-varying column names as covariates (BMI, Age, etc.) - :type time_varying_cols: Optional[List[str]] or None :param fixed_cols: Fixed column names as covariates (Sex, YOB, etc.) - :type fixed_cols: Optional[List[str]] or None :param method: Method for analysis ['ITT', 'dose-response', or 'censoring'] - :type method: str :param parameters: Parameters to augment analysis, specified with ``pySEQTarget.SEQopts`` :type parameters: Optional[SEQopts] or None """ diff --git a/pySEQTarget/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py index f63e1e3..98ccdb7 100644 --- a/pySEQTarget/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -124,24 +124,29 @@ def _outcome_fit( ) full_formula = f"{outcome} ~ {formula}" - - if getattr(self, "glm_package", "statsmodels") == "glum": - from ..helpers._glum_fit import _fit_glum - - return _fit_glum( - full_formula, - df_pd, - var_weights=df_pd[weight_col] if weighted else None, - ) - + var_weights = df_pd[weight_col] if weighted else None + + match getattr(self, "glm_package", "statsmodels"): + case "glum": + from ..helpers._glum_fit import _fit_glum + return _fit_glum(full_formula, df_pd, var_weights=var_weights) + + case "jax": + from ..helpers._jax_fit import _fit_jax + return _fit_jax( + full_formula, + df_pd, + var_weights=var_weights, + start_params=start_params, + ) + # default glm_kwargs = { "formula": full_formula, "data": df_pd, "family": sm.families.Binomial(), } - - if weighted: - glm_kwargs["var_weights"] = df_pd[weight_col] + if var_weights is not None: + glm_kwargs["var_weights"] = var_weights model = smf.glm(**glm_kwargs) diff --git a/pySEQTarget/helpers/_glum_fit.py b/pySEQTarget/helpers/_glum_fit.py index f8894b1..b375194 100644 --- a/pySEQTarget/helpers/_glum_fit.py +++ b/pySEQTarget/helpers/_glum_fit.py @@ -12,7 +12,7 @@ class _GlumFit: the codebase (and users) expect: .params (Series), .model.exog_names, .model.data.design_info, .predict(df) / .predict(X_numpy, transform=False), - .bse, .summary(), .summary2(). + .bse, .summary(). Standard errors are derived lazily from the stored design matrix using the GLM asymptotic covariance (X' W X)^-1, which matches statsmodels for the @@ -81,7 +81,7 @@ def _coef_table(self): index=list(self.params.index), ) - def summary2(self): + def summary(self): from statsmodels.iolib.summary2 import Summary info = pd.DataFrame( @@ -101,9 +101,6 @@ def summary2(self): smry.add_df(self._coef_table()) return smry - # statsmodels exposes both; the codebase/practical use either, so alias them. - summary = summary2 - def _fit_glum(formula, data, var_weights=None): """Fit a binomial GLM with glum and return a _GlumFit wrapper.""" diff --git a/pySEQTarget/helpers/_jax_fit.py b/pySEQTarget/helpers/_jax_fit.py new file mode 100644 index 0000000..f1dfe55 --- /dev/null +++ b/pySEQTarget/helpers/_jax_fit.py @@ -0,0 +1,177 @@ +import types + +import numpy as np +import pandas as pd +import patsy +import polars as pl + +from ._jax_logistic import MultinomialLogisticRegression + + +class _JaxFit: + def __init__( + self, + formula, + df, + var_weights=None, + max_iter=25, + start_params=None, + ): + df_pd = df.to_pandas() if isinstance(df, pl.DataFrame) else df + y_mat, X_mat = patsy.dmatrices(formula, df_pd, return_type="dataframe") + + design_info = X_mat.design_info + feature_names = list(X_mat.columns) + self._design_info = design_info + self._feature_names = feature_names + self._X_design = X_mat.values + X_arr = X_mat.drop(columns=["Intercept"], errors="ignore").values + + y_raw = y_mat.values.ravel() + self.classes_ = np.unique(y_raw) + self._n_classes = len(self.classes_) + y_idx = np.searchsorted(self.classes_, y_raw) + + self._sample_weight = None + if var_weights is not None: + self._sample_weight = np.asarray(var_weights, dtype=float) + + jax_model = MultinomialLogisticRegression( + max_iter=max_iter, + n_classes=self._n_classes, + ) + + jax_model.mean_ = X_arr.mean(axis=0) + jax_model.std_ = X_arr.std(axis=0) + jax_model.eps # guard constants + self._jax = jax_model + + init_params = self._warm_init(start_params, X_arr.shape[1]) + jax_model.fit( + X_arr, y_idx, sample_weight=self._sample_weight, init_params=init_params + ) + + # statsmodels 'like' exposure + self.model = types.SimpleNamespace( + exog_names=feature_names, + data=types.SimpleNamespace(design_info=design_info), + ) + self.exog_names = feature_names + self.params = self._build_params() + + def _coef_components(self): + W, b = self._jax.params + coef = np.asarray(W) + intercept = np.asarray(b) + mean = np.asarray(self._jax.mean_) + std = np.asarray(self._jax.std_) + intercept = intercept - (coef * (mean / std)[:, None]).sum(axis=0) + coef = coef / std[:, None] + + return intercept, coef + + def _build_params(self): + intercept, coef = self._coef_components() + if self._n_classes == 2: + return pd.Series( + np.concatenate([intercept, coef[:, 0]]), index=self._feature_names + ) + data = np.vstack([intercept[None, :], coef]) + return pd.DataFrame( + data, + index=self._feature_names, + columns=[f"class_{c}" for c in self.classes_[1:]], + ) + + def _warm_init(self, start_params, n_features): + if start_params is None or self._n_classes != 2: + return None + sp_values, sp_names = start_params + if list(sp_names) != self._feature_names: + return None + sp_values = np.asarray(sp_values, dtype=float) + intercept = float(sp_values[0]) + coef = sp_values[1:] + if coef.shape[0] != n_features: + return None + + W = np.zeros((n_features, 1)) + b = np.zeros(1) + mean = np.asarray(self._jax.mean_) + std = np.asarray(self._jax.std_) + W[:, 0] = coef * std + b[0] = intercept + float(np.sum(coef * mean)) + + return (W, b) + + def predict(self, data, transform=True): + if transform: + data_pd = data.to_pandas() if isinstance(data, pl.DataFrame) else data + X = patsy.build_design_matrices( + [self._design_info], data_pd, return_type="dataframe" + )[0] + X_arr = X.drop(columns=["Intercept"], errors="ignore").values + else: + X_arr = np.asarray(data)[:, 1:] + probs = np.asarray(self._jax.predict(X_arr), dtype=np.float64) + return probs[:, 1] if self._n_classes == 2 else probs + + def cov_params(self): + if self._n_classes != 2: + raise NotImplementedError( + "Standard errors are only implemented for binary jax fits." + ) + X = self._X_design + mu = np.asarray(self._jax.predict(X[:, 1:]))[:, 1] + w = mu * (1.0 - mu) + if self._sample_weight is not None: + w = w * self._sample_weight + return np.linalg.pinv(X.T @ (w[:, None] * X)) + + @property + def bse(self): + return pd.Series(np.sqrt(np.diag(self.cov_params())), index=self.params.index) + + def _coef_table(self): + from scipy import stats + + coef = self.params.values + se = self.bse.values + with np.errstate(divide="ignore", invalid="ignore"): + z = coef / se + pvals = 2.0 * stats.norm.sf(np.abs(z)) + crit = stats.norm.ppf(0.975) + return pd.DataFrame( + { + "Coef.": coef, + "Std.Err.": se, + "z": z, + "P>|z|": pvals, + "[0.025": coef - crit * se, + "0.975]": coef + crit * se, + }, + index=list(self.params.index), + ) + + def summary(self): + from statsmodels.iolib.summary2 import Summary + + info = pd.DataFrame( + { + " ": [ + "GLM (jax backend)", + "Binomial", + "logit", + str(self._X_design.shape[0]), + ] + }, + index=["Model:", "Family:", "Link:", "No. Observations:"], + ) + smry = Summary() + smry.add_title("Generalized Linear Model Regression Results") + smry.add_df(info, header=False) + smry.add_df(self._coef_table()) + return smry + + +def _fit_jax(formula, data, var_weights=None, start_params=None): + return _JaxFit(formula, data, var_weights=var_weights, start_params=start_params) diff --git a/pySEQTarget/helpers/_jax_logistic.py b/pySEQTarget/helpers/_jax_logistic.py new file mode 100644 index 0000000..ce8ebb6 --- /dev/null +++ b/pySEQTarget/helpers/_jax_logistic.py @@ -0,0 +1,99 @@ +from functools import partial + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + + +class MultinomialLogisticRegression: + """ + Multinomial logistic regression in JAX. + """ + + def __init__(self, max_iter=25, ridge=1e-8, eps=1e-7, n_classes=None): + self.max_iter = max_iter + self.ridge = ridge # tiny Hessian jitter for numerical stability + self.eps = eps + self.n_classes = n_classes + self.params = None + self.loss_history_ = None + self.n_classes_ = None + self.mean_ = None + self.std_ = None + + def _prep(self, X): + X = jnp.asarray(X) + if self.mean_ is not None: + X = (X - self.mean_) / self.std_ + return X + + @staticmethod + def _logits(params, X): + W, b = params + z = jnp.dot(X, W) + b + return jnp.concatenate([jnp.zeros((z.shape[0], 1)), z], axis=1) + + @staticmethod + def _predict(params, X): + return jax.nn.softmax(MultinomialLogisticRegression._logits(params, X), axis=-1) + + @staticmethod + @jax.jit + def _loss(params, X, Y, w): + logp = jax.nn.log_softmax( + MultinomialLogisticRegression._logits(params, X), axis=-1 + ) + ce = -jnp.sum(Y * logp, axis=-1) + return jnp.sum(w * ce) / jnp.sum(w) + + @staticmethod + @partial(jax.jit, static_argnums=(4,)) + def _run(params, X, Y, w, max_iter, ridge): + flat0, unravel = ravel_pytree(params) + eye = jnp.eye(flat0.shape[0]) + + def loss_flat(f): + return MultinomialLogisticRegression._loss(unravel(f), X, Y, w) + + grad_fn = jax.grad(loss_flat) + hess_fn = jax.hessian(loss_flat) + + def step(f, _): + g = grad_fn(f) + H = hess_fn(f) + f = f - jnp.linalg.solve(H + ridge * eye, g) + return f, loss_flat(f) + + flat, loss_history = jax.lax.scan(step, flat0, xs=None, length=max_iter) + return unravel(flat), loss_history + + def fit(self, X, y, sample_weight=None, init_params=None): + X = self._prep(X) + y = jnp.asarray(y) + K = self.n_classes if self.n_classes is not None else int(y.max()) + 1 + self.n_classes_ = K + Y = jax.nn.one_hot(y, K) + if sample_weight is None: + w = jnp.ones(X.shape[0]) + else: + w = jnp.asarray(sample_weight, dtype=float) + if init_params is None: + params = (jnp.zeros((X.shape[1], K - 1)), jnp.zeros(K - 1)) + else: + W, b = init_params + W = jnp.asarray(W, dtype=float) + b = jnp.asarray(b, dtype=float) + if W.shape != (X.shape[1], K - 1): + raise ValueError( + f"init_params W has shape {W.shape}, expected ({X.shape[1]}, {K - 1})." + ) + params = (W, b) + self.params, self.loss_history_ = self._run( + params, X, Y, w, self.max_iter, self.ridge + ) + return self + + def predict(self, X): + if self.params is None: + raise RuntimeError("Model is not fitted yet; call `fit` first.") + return self._predict(self.params, self._prep(X)) diff --git a/pyproject.toml b/pyproject.toml index adfcb10..41c78d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,14 @@ output = [ "weasyprint", "tabulate" ] +gpu = [ + "jax[cuda12]; sys_platform == 'linux'", + "jax; sys_platform != 'linux'", +] +dev = [ + "pySEQTarget[output,gpu]", + "pytest", +] [project.urls] Homepage = "https://github.com/CausalInference/pySEQTarget" diff --git a/tests/test_glum.py b/tests/test_glum.py index 48a0b48..76ba434 100644 --- a/tests/test_glum.py +++ b/tests/test_glum.py @@ -92,7 +92,7 @@ def test_glum_summary_is_printable_and_consistent(): smry = model.summary() assert str(smry) # renders without error - coef_col = model.summary2().tables[1]["Coef."].to_list() + coef_col = model.summary().tables[1]["Coef."].to_list() assert coef_col == approx(list(model.params), rel=1e-9, abs=1e-9) diff --git a/tests/test_jax.py b/tests/test_jax.py new file mode 100644 index 0000000..98acfdd --- /dev/null +++ b/tests/test_jax.py @@ -0,0 +1,118 @@ +import numpy as np +import pandas as pd +import pytest +from pytest import approx + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data +from pySEQTarget.helpers._jax_fit import _JaxFit + + +def _fit(method, glm_package, dataset="SEQdata", **opts): + data = load_data(dataset) + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method=method, + parameters=SEQopts(glm_package=glm_package, **opts), + ) + s.expand() + s.fit() + return s + + +def _outcome_coefs(s): + return list(s.outcome_model[0]["outcome"].params) + + +def test_jax_matches_statsmodels_ITT(): + sm = _outcome_coefs(_fit("ITT", "statsmodels")) + jx = _outcome_coefs(_fit("ITT", "jax")) + assert jx == approx(sm, rel=1e-2, abs=2e-3) + + +def test_jax_matches_statsmodels_censoring_preexpansion(): + opts = dict(weighted=True, weight_preexpansion=True) + sm = _outcome_coefs(_fit("censoring", "statsmodels", **opts)) + jx = _outcome_coefs(_fit("censoring", "jax", **opts)) + assert jx == approx(sm, rel=1e-2, abs=2e-3) + + +def test_jax_matches_statsmodels_censoring_postexpansion(): + opts = dict(weighted=True, weight_preexpansion=False) + sm = _outcome_coefs(_fit("censoring", "statsmodels", **opts)) + jx = _outcome_coefs(_fit("censoring", "jax", **opts)) + assert jx == approx(sm, rel=1e-2, abs=2e-3) + + +def test_jax_standard_errors_match_statsmodels(): + sm_model = _fit("ITT", "statsmodels").outcome_model[0]["outcome"] + jx_model = _fit("ITT", "jax").outcome_model[0]["outcome"] + assert list(jx_model.bse) == approx(list(sm_model.bse), rel=1e-2, abs=1e-3) + + +def test_jax_bootstrap_survival_matches_statsmodels(): + common = dict(bootstrap_nboot=2, seed=1636, km_curves=True) + + def risk_diff(pkg): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(glm_package=pkg, **common), + ) + s.expand() + s.bootstrap() + s.fit() + s.survival() + rd = s.risk_estimates["risk_difference"] + assert rd["RD 95% LCI"].null_count() == 0 + return rd["Risk Difference"].to_list() + + assert risk_diff("jax") == approx(risk_diff("statsmodels"), rel=1e-2, abs=2e-3) + + +def _binary_frame(seed=1, n=2000): + rng = np.random.default_rng(seed) + x1, x2 = rng.normal(size=n), rng.normal(size=n) + eta = -1.0 + 0.8 * x1 - 0.5 * x2 + y = rng.binomial(1, 1 / (1 + np.exp(-eta))) + return pd.DataFrame({"y": y, "x1": x1, "x2": x2}) + + +def test_jax_multinomial_label_detection(): + rng = np.random.default_rng(0) + n = 900 + x = rng.normal(size=n) + df = pd.DataFrame({"y": rng.integers(0, 3, size=n), "x": x}) + + model = _JaxFit("y ~ x", df) + assert list(model.classes_) == [0, 1, 2] + + probs = model.predict(df) + assert probs.shape == (n, 3) + assert probs.sum(axis=1) == approx(np.ones(n), abs=1e-5) + + +def test_jax_warm_start_reaches_same_optimum(): + df = _binary_frame() + cold = _JaxFit("y ~ x1 + x2", df) + warm = _JaxFit( + "y ~ x1 + x2", + df, + start_params=(cold.params.values, list(cold.model.exog_names)), + ) + assert list(warm.params) == approx(list(cold.params), rel=1e-3, abs=1e-3)