diff --git a/NAMESPACE b/NAMESPACE index dd30007a..4a82e845 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -196,6 +196,7 @@ export(mashRandNullSample) export(matchRefPanel) export(mcpRssWeights) export(mcpWeights) +export(mergeCtwasBoundaryRegions) export(mergeMashData) export(mergeVariantInfo) export(metaAnalysisPerCell) diff --git a/R/causalInferencePipeline.R b/R/causalInferencePipeline.R index 63df8d93..ae34c843 100644 --- a/R/causalInferencePipeline.R +++ b/R/causalInferencePipeline.R @@ -49,9 +49,33 @@ #' When supplied, drives the MR computation and (when #' \code{twasWeights = NULL}) the TWAS-Z weights via the SuSiE-style #' coefficients on each entry's \code{topLoci}. -#' @param mrPipCutoff Numeric (length 1). PIP threshold for an entry's -#' \code{topLoci} variant to be used as an instrumental variable. -#' Used only when \code{mrMethod = "ivwPerVariant"}. Default \code{0.5}. +#' @param rsqCutoff Numeric (length 1). When \code{> 0}, performs CV weight +#' selection (ports the legacy \code{twas_pipeline} \code{pick_best_model} + +#' \code{update_twas_method}): per \code{(study, context, trait, gwasStudy)} +#' keep only the method whose \code{cvPerformance} \code{rsqOption} metric is +#' highest among methods that clear both \code{rsqCutoff} and the +#' \code{rsqPvalCutoff} gate AND that produced a finite TWAS Z (the NA/Inf +#' re-selection); groups where no method clears the cutoffs are dropped. A +#' group whose methods carry no usable \code{cvPerformance} (the SS-TWAS +#' path) keeps all methods. Needs the \code{twasWeights} \code{cvPerformance}, +#' so selection is a no-op on the fineMappingResult-only path. Default +#' \code{0} (no selection; score every method). +#' @param rsqPvalCutoff Numeric (length 1). CV-p-value gate for weight +#' selection (ports legacy \code{rsq_pval_cutoff}): a method is eligible only +#' when its \code{cvPerformance} \code{rsqPvalOption} metric is +#' \code{< rsqPvalCutoff}. Default \code{Inf} (no p-value gate). A finite +#' value activates selection even when \code{rsqCutoff = 0}. +#' @param rsqOption Character. Which \code{cvPerformance} metric is the +#' "r-squared" used for the cutoff and ranking (ports legacy +#' \code{rsq_option}); typically \code{"rsq"} or \code{"adj_rsq"}. +#' Default \code{"rsq"}. +#' @param rsqPvalOption Character vector of candidate \code{cvPerformance} +#' metric names for the p-value gate (ports legacy \code{rsq_pval_option}); +#' the first one present in a tuple's metrics is used. Default +#' \code{c("adj_rsq_pval", "pval")}. +#' @param mrPipCutoff Numeric (length 1). PIP threshold for a \code{topLoci} +#' variant to be used as an instrumental variable. Used only when +#' \code{mrMethod = "ivwPerVariant"}. Default \code{0.5}. #' @param mrMethod One of \code{"ivwPerVariant"} (default) or #' \code{"csAware"}. The IVW-per-variant method filters topLoci #' variants by \code{pip > mrPipCutoff} and IVW-pools Wald ratios @@ -63,6 +87,12 @@ #' @param mrCpipCutoff Numeric (length 1). Cumulative-PIP cutoff for #' retaining a credible set. Used only when #' \code{mrMethod = "csAware"}. Default \code{0.5}. +#' @param mrPvalCutoff Numeric (length 1). TWAS-p-value gate for running MR +#' (ports the legacy \code{twas_pipeline} \code{mr_pval_cutoff}): MR is +#' computed for a \code{(qtl tuple, gwas)} only when its \code{twasPval < +#' mrPvalCutoff}; otherwise the MR output columns are \code{NA}. Default +#' \code{1} (no gate; MR runs wherever a \code{fineMappingResult} entry +#' exists). #' @param combineMethods Optional character vector forwarded to #' \code{\link{combinePValues}} for cross-method combination per #' \code{(qtlStudy, context, trait, gwasStudy)} group. \code{NULL} @@ -73,9 +103,14 @@ causalInferencePipeline <- function(gwasSumStats, twasWeights = NULL, fineMappingResult = NULL, + rsqCutoff = 0, + rsqPvalCutoff = Inf, + rsqOption = "rsq", + rsqPvalOption = c("adj_rsq_pval", "pval"), mrPipCutoff = 0.5, mrMethod = c("ivwPerVariant", "csAware"), mrCpipCutoff = 0.5, + mrPvalCutoff = 1, combineMethods = NULL, ...) { mrMethod <- match.arg(mrMethod) @@ -119,6 +154,28 @@ causalInferencePipeline <- function(gwasSumStats, } qtlRows$useFmrForWeights <- is.null(twasWeights) + # --- Optional CV weight selection (legacy pick_best_model + ----------- + # update_twas_method): filter to eligible methods now, but defer the + # final best-method pick to AFTER the TWAS Z so the NA/Inf re-selection + # can see which methods actually produced a finite Z. + selectionActive <- !is.null(twasWeights) && + (rsqCutoff > 0 || is.finite(rsqPvalCutoff)) + rsqLookup <- NULL + if (selectionActive) { + metricTab <- .cipMethodMetrics(qtlRows, twasWeights, rsqOption, rsqPvalOption) + rsqLookup <- stats::setNames( + metricTab$rsq, + paste(metricTab$qtlStudy, metricTab$context, metricTab$trait, + metricTab$method, sep = "\r")) + qtlRows <- .cipFilterEligibleMethods(qtlRows, metricTab, + rsqCutoff, rsqPvalCutoff) + if (nrow(qtlRows) == 0L) { + stop("causalInferencePipeline: every QTL tuple was filtered out by ", + "rsqCutoff = ", rsqCutoff, " / rsqPvalCutoff = ", rsqPvalCutoff, + " (no method cleared the CV cutoffs).") + } + } + # --- Per-tuple loop: compute TWAS Z + (optional) MR -------------------- outRows <- list() for (qi in seq_len(nrow(qtlRows))) { @@ -156,7 +213,11 @@ causalInferencePipeline <- function(gwasSumStats, gwasDf = gdf, gwasLd = gwasLd) if (is.null(twasOut)) next - mrOut <- if (!is.null(fmrEntry)) { + # Gate MR on the TWAS p-value (legacy mr_pval_cutoff): only run MR where + # the TWAS association is significant. mrPvalCutoff >= 1 disables the gate. + mrGateOpen <- mrPvalCutoff >= 1 || + (!is.na(twasOut$pval) && twasOut$pval < mrPvalCutoff) + mrOut <- if (!is.null(fmrEntry) && mrGateOpen) { if (mrMethod == "csAware") { .cipComputeMrCsAware(fmrEntry = fmrEntry, gwasDf = gdf, cpipCutoff = mrCpipCutoff) @@ -195,7 +256,15 @@ causalInferencePipeline <- function(gwasSumStats, stop("causalInferencePipeline: no (qtl, gwas) tuples produced a result.") } - out <- .cipRowsToGranges(outRows) + resultDf <- .cipRowsToDf(outRows) + # Final best-method pick + NA/Inf re-selection (legacy update_twas_method): + # per (qtlStudy, context, trait, gwasStudy) keep the highest-rsqOption + # eligible method whose TWAS Z is finite, falling back to the top-rsq method + # when none is finite. SS-TWAS groups (no usable rsq) keep all methods. + if (selectionActive) { + resultDf <- .cipSelectBestMethod(resultDf, rsqLookup) + } + out <- .cipDfToGranges(resultDf) if (!is.null(combineMethods)) { out <- .cipCombineAcrossMethods(out, methods = combineMethods) @@ -238,6 +307,92 @@ causalInferencePipeline <- function(gwasSumStats, df } +# Resolve one CV metric (rsqOption / rsqPvalOption) for a single tuple from the +# TwasWeights cvPerformance, which the individual-level CV path stores as a list +# with a named $metrics vector (corr, rsq, adj_rsq, pval, RMSE, MAE); a bare +# metrics vector / data frame is tolerated too. `which` is a vector of candidate +# metric names; the first present is used. Returns NA when no usable metric. +# @noRd +.cipCvMetric <- function(twasWeights, study, context, trait, method, which) { + perf <- tryCatch( + getCvPerformance(twasWeights, study = study, context = context, + trait = trait, method = method), + error = function(e) NULL) + if (is.null(perf)) return(NA_real_) + metrics <- if (is.list(perf) && !is.null(perf[["metrics"]])) + perf[["metrics"]] else perf + nm <- intersect(which, names(metrics)) + if (length(nm) == 0L) return(NA_real_) + val <- suppressWarnings(as.numeric(metrics[[nm[[1L]]]])) + if (length(val) == 0L) NA_real_ else val[[1L]] +} + +# Tabulate the rsqOption (rsq) and rsqPvalOption (pval) CV metrics for every +# tuple in the work-list. Returns the identity columns plus `rsq`, `pval`. +# @noRd +.cipMethodMetrics <- function(qtlRows, twasWeights, rsqOption, rsqPvalOption) { + n <- nrow(qtlRows) + rsq <- vapply(seq_len(n), function(i) + .cipCvMetric(twasWeights, qtlRows$qtlStudy[[i]], qtlRows$context[[i]], + qtlRows$trait[[i]], qtlRows$method[[i]], which = rsqOption), + numeric(1)) + pval <- vapply(seq_len(n), function(i) + .cipCvMetric(twasWeights, qtlRows$qtlStudy[[i]], qtlRows$context[[i]], + qtlRows$trait[[i]], qtlRows$method[[i]], which = rsqPvalOption), + numeric(1)) + data.frame( + qtlStudy = qtlRows$qtlStudy, context = qtlRows$context, + trait = qtlRows$trait, method = qtlRows$method, + rsq = rsq, pval = pval, stringsAsFactors = FALSE) +} + +# Eligibility filter (legacy pick_best_model gate): per (study, context, trait) +# keep methods whose rsq >= rsqCutoff and (when the gate is finite) whose CV +# p-value < rsqPvalCutoff. A group whose methods carry no usable rsq (all NA) is +# the SS-TWAS path and keeps all its methods. Groups where no method clears the +# cutoffs contribute nothing. Returns the filtered work-list (same columns). +# @noRd +.cipFilterEligibleMethods <- function(qtlRows, metricTab, rsqCutoff, rsqPvalCutoff) { + grp <- paste(metricTab$qtlStudy, metricTab$context, metricTab$trait, sep = "\r") + keep <- logical(nrow(qtlRows)) + pvalGate <- is.finite(rsqPvalCutoff) + for (g in unique(grp)) { + idx <- which(grp == g) + rsq <- metricTab$rsq[idx] + if (all(is.na(rsq))) { keep[idx] <- TRUE; next } # SS-TWAS: keep all + elig <- !is.na(rsq) & rsq >= rsqCutoff + if (pvalGate) + elig <- elig & !is.na(metricTab$pval[idx]) & metricTab$pval[idx] < rsqPvalCutoff + keep[idx[elig]] <- TRUE + } + qtlRows[keep, , drop = FALSE] +} + +# Final best-method pick + NA/Inf re-selection (legacy update_twas_method): per +# (qtlStudy, context, trait, gwasStudy) rank the (already-eligible) methods by +# rsqLookup descending and keep the first whose twasZ is finite; if none is +# finite, keep the top-rsq method. A group with no usable rsq (SS-TWAS) keeps +# all its rows. `rsqLookup` is keyed by study\rcontext\rtrait\rmethod. +# @noRd +.cipSelectBestMethod <- function(df, rsqLookup) { + if (nrow(df) == 0L) return(df) + key <- paste(df$qtlStudy, df$context, df$trait, df$method, sep = "\r") + rsq <- unname(rsqLookup[key]) + grp <- paste(df$qtlStudy, df$context, df$trait, df$gwasStudy, sep = "\r") + keepRow <- logical(nrow(df)) + for (g in unique(grp)) { + idx <- which(grp == g) + r <- rsq[idx] + if (all(is.na(r))) { keepRow[idx] <- TRUE; next } # SS-TWAS: keep all + ord <- idx[order(r, decreasing = TRUE)] # NA sorts last + z <- suppressWarnings(as.numeric(df$twasZ[ord])) + fin <- which(is.finite(z)) + sel <- if (length(fin) > 0L) ord[[fin[[1L]]]] else ord[[1L]] + keepRow[sel] <- TRUE + } + df[keepRow, , drop = FALSE] +} + .cipFmrHasTuple <- function(fmr, study, context, trait, method) { length(.matchTupleRows(fmr, list(study = study, context = context, @@ -497,10 +652,14 @@ causalInferencePipeline <- function(gwasSumStats, 1 / sqrt(2 * n * maf * (1 - maf)) } -# Convert the accumulated list of row records to a GRanges with mcols. -.cipRowsToGranges <- function(rows) { - df <- do.call(rbind.data.frame, lapply(rows, as.data.frame, - stringsAsFactors = FALSE)) +# Convert the accumulated list of row records to a flat data.frame. +.cipRowsToDf <- function(rows) { + do.call(rbind.data.frame, lapply(rows, as.data.frame, + stringsAsFactors = FALSE)) +} + +# Convert the assembled result data.frame to a GRanges with mcols. +.cipDfToGranges <- function(df) { chr <- paste0("chr", sub("^chr", "", as.character(df$chrom), ignore.case = TRUE)) gr <- GenomicRanges::GRanges( diff --git a/R/ctwasPipeline.R b/R/ctwasPipeline.R index 487395a1..dd8bf142 100644 --- a/R/ctwasPipeline.R +++ b/R/ctwasPipeline.R @@ -239,7 +239,7 @@ assembleCtwasInputs <- function(gwasSumStats, twasWeights, ldFileByRegion[[rid]] <- ldKey zSnpPieces[[rid]] <- .ctwasBuildZSnp(gss) - regionInfoPieces[[rid]] <- .ctwasBuildSingleRegionInfo(rid, gwasLd) + regionInfoPieces[[rid]] <- .ctwasBuildSingleRegionInfo(rid, gss) snpMap[[rid]] <- .ctwasSnpInfoForGwasBlock(gss, ldPanel$snpInfo) } @@ -312,11 +312,15 @@ assembleCtwasInputs <- function(gwasSumStats, twasWeights, #' @param groupPriorVarStructure Pass-through. #' @param ncore Number of cores. #' @param fallbackToPrefit Logical (length 1). When \code{TRUE} (default -#' \code{FALSE}), if \code{ctwas::est_param}'s accurate EM diverges to -#' NaN and throws \code{"Estimated group_prior(_var)? contains NAs"}, -#' re-run only the prefit step via \code{ctwas:::fit_EM} and return -#' those (typically finite) priors as the param. Mirrors the legacy -#' ctwas_2 workaround on toy data where the accurate EM saturates. +#' \code{FALSE}), if \code{ctwas::est_param}'s accurate EM fails for ANY +#' reason on a degenerate input, re-run only the prefit step via +#' \code{ctwas:::fit_EM} and return those (typically finite) priors as the +#' param. The accurate-EM failure mode is version-dependent (ctwas <= 0.4.x: +#' \code{"contains NAs"}; ctwas >= 0.6.0: \code{"No regions selected!"} or a +#' NaN-loglik \code{"missing value where TRUE/FALSE needed"}), so the catch is +#' deliberately broad; a genuinely broken input still surfaces because the +#' prefit re-run will itself error. Mirrors the legacy ctwas_2 workaround on +#' toy data where the accurate EM cannot be estimated. #' @param ... Additional arguments forwarded to \code{ctwas::est_param} #' (e.g. \code{min_p_single_effect}, \code{min_group_size}). #' @return The \code{inputs} list augmented with \code{region_data}, @@ -376,8 +380,17 @@ estCtwasParam <- function(inputs, group_prior_var_structure = groupPriorVarStructure, ncore = as.integer(ncore)), extra = list(...)), error = function(e) { - if (fallbackToPrefit && grepl("contains NAs", conditionMessage(e))) { - message("estCtwasParam: accurate EM diverged (", + # The accurate EM fails on degenerate (e.g. single-gene) inputs in + # several version-dependent ways: ctwas <= 0.4.x throws "contains NAs"; + # ctwas >= 0.6.0 throws "No regions selected!" (zero regions clear the + # accurate pass) or "missing value where TRUE/FALSE needed" (NaN + # log-likelihood in the EM convergence test). Rather than enumerate + # brittle, version-specific messages, fall back on ANY accurate-EM error + # when fallbackToPrefit is set: re-run the prefit EM only, which scores + # every region and skips the p(single effect) selection gate. A genuinely + # broken input still surfaces, because the prefit re-run will itself error. + if (fallbackToPrefit) { + message("estCtwasParam: accurate EM unusable (", conditionMessage(e), "); falling back to prefit estimates.") .ctwasFitPrefitEm(regionData, niterPrefit = as.integer(niterPrefit), @@ -504,7 +517,116 @@ finemapCtwasRegions <- function(screenResult, susie_alpha_res = fmRes$susie_alpha_res, region_data = screenResult$region_data, boundary_genes = screenResult$boundary_genes, - screen_res = screenResult$screen_res) + screen_res = screenResult$screen_res, + # Carried forward so mergeCtwasBoundaryRegions() can re-finemap the merged + # boundary regions without re-deriving the assembled inputs. + region_info = screenResult$region_info, + z_snp = screenResult$z_snp, + weights = screenResult$weights, + snp_map = screenResult$snp_map, + LD_map = screenResult$LD_map, + LD_loader_fun = screenResult$LD_loader_fun, + snpinfo_loader_fun = screenResult$snpinfo_loader_fun) +} + +#' Merge boundary cTWAS regions and re-fine-map +#' +#' @description Optional step 4 of the cTWAS pipeline (default-off region +#' merging). A gene whose cis window straddles an LD-block boundary +#' (a \code{boundary_genes} member) is split across two regions in the +#' first-pass fine-mapping. This step selects the high-PIP boundary genes, +#' merges each one's adjacent regions into a single region, re-runs +#' fine-mapping on the merged regions, and splices the updated results back +#' into the \code{\link{finemapCtwasRegions}} output. Thin wrapper over +#' \code{ctwas::postprocess_region_merging()} (or +#' \code{ctwas::postprocess_region_merging_noLD()} when the inputs carry no +#' LD loaders). +#' +#' @param finemapResult A list returned by \code{\link{finemapCtwasRegions}}. +#' Must carry \code{finemap_res}, \code{susie_alpha_res}, +#' \code{region_data}, \code{region_info}, \code{z_snp}, \code{z_gene}, +#' \code{weights}, \code{snp_map}, \code{param}, and — on the LD path — +#' \code{LD_map} plus the \code{LD_loader_fun} / \code{snpinfo_loader_fun} +#' closures (all retained by \code{finemapCtwasRegions}). +#' @param pipThresh Numeric (length 1). PIP threshold for selecting which +#' boundary genes to merge (\code{select_boundary_genes} \code{pip_thresh}). +#' Default \code{0.5}. +#' @param filterCs Logical (length 1). Require the gene to be in a credible set +#' to be selected (\code{select_boundary_genes} \code{filter_cs}). Default +#' \code{FALSE}. +#' @param maxSNP Numeric (length 1). Per-merged-region SNP cap. Default +#' \code{Inf}. +#' @param L Integer. Max number of single effects for the merged-region +#' re-fine-mapping (LD path only). Default \code{5}. +#' @param ncore Number of cores. Default \code{1}. +#' @param ... Forwarded to the underlying ctwas postprocess function. +#' @return The \code{finemapResult} list with \code{finemap_res}, +#' \code{susie_alpha_res}, \code{region_data}, \code{region_info}, +#' \code{LD_map}, and \code{snp_map} replaced by the post-merge ("updated") +#' values, plus a \code{merge_res} element carrying the full ctwas postprocess +#' output. When no boundary gene clears \code{pipThresh}, ctwas returns the +#' inputs as the "updated" values, so the result is effectively unchanged. +#' @export +mergeCtwasBoundaryRegions <- function(finemapResult, + pipThresh = 0.5, + filterCs = FALSE, + maxSNP = Inf, + L = 5L, + ncore = 1L, + ...) { + if (!requireNamespace("ctwas", quietly = TRUE)) + stop("Package 'ctwas' is required for mergeCtwasBoundaryRegions.") + fmRes <- finemapResult$finemap_res + if (is.null(fmRes) || nrow(fmRes) == 0L) { + message("mergeCtwasBoundaryRegions: no first-pass finemap result; ", + "returning unchanged.") + return(finemapResult) + } + + hasLd <- !is.null(finemapResult$LD_loader_fun) + common <- list( + region_info = finemapResult$region_info, + region_data = finemapResult$region_data, + z_snp = finemapResult$z_snp, + z_gene = finemapResult$z_gene, + weights = finemapResult$weights, + snp_map = finemapResult$snp_map, + finemap_res = fmRes, + susie_alpha_res = finemapResult$susie_alpha_res, + group_prior = finemapResult$param$group_prior, + group_prior_var = finemapResult$param$group_prior_var, + pip_thresh = pipThresh, + filter_cs = filterCs, + maxSNP = maxSNP, + ncore = as.integer(ncore)) + + # ctwas's postprocess_*() forward `...` into finemap_regions, so the LD + # loader closures must ride in the explicit arg list (not through + # .ctwasInvoke, which would filter them to postprocess's own formals). + if (hasLd) { + fn <- ctwas::postprocess_region_merging + args <- c(common, list( + LD_map = finemapResult$LD_map, + L = as.integer(L), + LD_format = "custom", + LD_loader_fun = finemapResult$LD_loader_fun, + snpinfo_loader_fun = finemapResult$snpinfo_loader_fun)) + } else { + fn <- ctwas::postprocess_region_merging_noLD + args <- common + } + userExtra <- list(...) + userExtra <- userExtra[setdiff(names(userExtra), names(args))] + res <- do.call(fn, c(args, userExtra)) + + finemapResult$finemap_res <- res$updated_finemap_res + finemapResult$susie_alpha_res <- res$updated_susie_alpha_res + if (!is.null(res$updated_region_data)) finemapResult$region_data <- res$updated_region_data + if (!is.null(res$updated_region_info)) finemapResult$region_info <- res$updated_region_info + if (!is.null(res$updated_LD_map)) finemapResult$LD_map <- res$updated_LD_map + if (!is.null(res$updated_snp_map)) finemapResult$snp_map <- res$updated_snp_map + finemapResult$merge_res <- res + finemapResult } # Invoke a ctwas function with a fixed `args` list plus optional `extra` @@ -665,19 +787,32 @@ finemapCtwasRegions <- function(screenResult, # (min/max BP per chromosome). The sketch is assumed to cover exactly # one block. # @noRd -.ctwasBuildSingleRegionInfo <- function(regionId, gwasLd) { - snpInfo <- getSnpInfo(gwasLd) - chr <- unique(as.integer(sub("^chr", "", as.character(snpInfo$CHR), - ignore.case = TRUE))) +.ctwasBuildSingleRegionInfo <- function(regionId, gss) { + # Derive the block's [start, stop] from the GWAS variants actually in this + # block (the GwasSumStats entry GRanges) — NOT the LD sketch. When many + # blocks share one whole-chromosome LD payload (the common one-file-per-chr + # layout), getSnpInfo(ldSketch) spans the entire chromosome, so every region + # would collapse to the same whole-chromosome [start, stop] and every SNP + # would be assigned to every region (inflating SNP group_size N-fold and + # diluting the gene prior to ~0). + pos <- integer(0); chrs <- character(0) + for (i in seq_len(nrow(gss))) { + gr <- gss$entry[[i]] + pos <- c(pos, as.integer(GenomicRanges::start(gr))) + chrs <- c(chrs, as.character(GenomicRanges::seqnames(gr))) + } + chr <- unique(as.integer(sub("^chr", "", chrs, ignore.case = TRUE))) if (length(chr) != 1L) - stop("ctwasPipeline: gwasSumStats LD sketch spans multiple ", - "chromosomes (", paste(chr, collapse = ", "), - "). ctwasPipeline assumes a single LD block per call.") + stop("ctwasPipeline: GwasSumStats block '", regionId, "' spans multiple ", + "chromosomes (", paste(chr, collapse = ", "), ").") + if (length(pos) == 0L) + stop("ctwasPipeline: GwasSumStats block '", regionId, + "' has no variants to define region bounds.") data.frame( region_id = regionId, chrom = chr, - start = min(as.integer(snpInfo$BP)), - stop = max(as.integer(snpInfo$BP)), + start = min(pos), + stop = max(pos), stringsAsFactors = FALSE) } diff --git a/R/fineMappingPipeline.R b/R/fineMappingPipeline.R index cc828259..fb36c578 100644 --- a/R/fineMappingPipeline.R +++ b/R/fineMappingPipeline.R @@ -178,6 +178,17 @@ #' summary-statistics analog lives in \code{summaryStatsQc()}. \code{0} #' (default) disables the screen; a negative value uses the adaptive #' \code{3 / nVariants} threshold. +#' @param usePCA Logical (length 1). \code{QtlDataset} only. When +#' \code{TRUE} (default \code{FALSE}), each multi-trait context's +#' PCA-reduced phenotype is fine-mapped with univariate SuSiE on its +#' top principal components (ports the legacy \code{fsusie.R} +#' \code{susie_on_top_pc}). Each PC becomes a pseudo-trait row keyed +#' \code{trait = "topPC\{i\}"}, \code{method = "susie"}. Single-trait +#' contexts have no PCA and are skipped. +#' @param nPCs Integer (length 1). \code{QtlDataset} only. Caps the +#' number of top principal components fine-mapped per context when +#' \code{usePCA = TRUE} (default \code{10}). The effective count is +#' \code{min(nPCs, usable traits)}. #' @param jointSpecification Optional joint-fit specification (NULL by #' default). When NULL, the pipeline runs the implicit multi-context / #' multi-trait mvSuSiE / fSuSiE branches as before. When non-NULL, the @@ -697,6 +708,93 @@ setGeneric("fineMappingPipeline", } } +# Rebuild the mvSuSiE data-driven *reweighted* mixture prior + residual variance +# from a stored mr.mash fit -- the lean payload +# (list(dataDrivenPriorMatrices, w0, V)) that mrmashWeights(retainFit = TRUE) +# attaches and twasWeightsPipeline keeps on the mrmash TwasWeightsEntry. Shared +# by the fine-mapping mvsusie consumer and the twas mvsusie_weights consumer. +# +# Reproduces the deleted multivariate_pipeline.R reweighting bit-identically: +# rescaleCovW0(w0) collapses the expanded mr.mash weights onto the original +# data-driven covariance matrices ($U), filters to surviving components, and +# create_mixture_prior() wraps them, restricted to the fit's conditions +# (`conditionNames` = colnames(Y)). `V` becomes mvsusie's residual_variance. +# A NULL fit, NULL matrices, or no surviving component falls back to the +# canonical create_mixture_prior(R), matching the legacy `else` branch. +# Returns list(priorVariance, residualVariance) (residualVariance NULL only +# when no fit was supplied at all). +# @noRd +.buildMvsusieReweightedPrior <- function(fitParts, conditionNames, + weightsTol = 1e-10) { + R <- length(conditionNames) + canonical <- function(V) list( + priorVariance = mvsusieR::create_mixture_prior( + R = R, include_indices = conditionNames), + residualVariance = V) + if (is.null(fitParts)) return(canonical(NULL)) + ddpm <- fitParts$dataDrivenPriorMatrices + if (is.null(ddpm) || is.null(ddpm$U)) return(canonical(fitParts$V)) + w0Updated <- rescaleCovW0(fitParts$w0) + w0Updated <- w0Updated[names(w0Updated) %in% names(ddpm$U)] + if (length(w0Updated) == 0L) return(canonical(fitParts$V)) + mixture <- list(matrices = ddpm$U[names(w0Updated)], weights = w0Updated) + list( + priorVariance = mvsusieR::create_mixture_prior( + mixture_prior = mixture, weights_tol = weightsTol, + include_indices = conditionNames), + residualVariance = fitParts$V) +} + +# Locate the retained mr.mash fit payload {dataDrivenPriorMatrices, w0, V} for +# one (study, trait[, context]) inside a `TwasWeights` collection from a prior +# mr.mash twasWeightsPipeline run (the producer side of the mvSuSiE data-driven +# prior). The joint fit is attached to a single mrmash row of the group (the +# other rows carry fits = NULL), so scan the matching mrmash rows and return the +# first non-NULL payload. The fit may span more conditions than the mvsusie +# block fits -- `.buildMvsusieReweightedPrior(include_indices=)` subsets it. +# +# `context` is optional and disambiguates joint fits, which key differently: +# * per-context / cross-context mvsusie -> (study, trait=tid) [context = NULL] +# * cross-trait joint mvsusie -> (study, trait="joint", context=cx) +# Returns NULL when no TwasWeights is supplied or it carries no matching mr.mash +# fit (caller then falls back to the canonical prior). +# @noRd +.fmLookupMrmashFit <- function(twasWeights, study, trait, context = NULL) { + if (is.null(twasWeights)) return(NULL) + sel <- as.character(twasWeights$study) == study & + as.character(twasWeights$trait) == trait & + as.character(twasWeights$method) == "mrmash" + if (!is.null(context)) + sel <- sel & as.character(twasWeights$context) == context + for (i in which(sel)) { + f <- getFits(twasWeights$entry[[i]]) + if (!is.null(f)) return(f) + } + NULL +} + +# PCA-reduce a (samples x traits) phenotype matrix to its top `nPCs` principal +# component scores, for the `usePCA` top-PC susie path. Centers + scales +# (matching the legacy fsusie.R susie_on_top_pc), dropping incomplete rows and +# zero-variance traits first (prcomp requires complete, non-degenerate columns). +# Returns a (samples x k) score matrix, k = min(nPCs, usable traits), columns +# named topPC1..topPCk and rows keyed by sample; NULL when < 2 usable traits or +# samples (single-trait -> PCA undefined, so the caller skips). +# @noRd +.fmTopPcScores <- function(Y, nPCs) { + if (is.null(dim(Y)) || ncol(Y) < 2L) return(NULL) + Y <- Y[stats::complete.cases(Y), , drop = FALSE] + if (nrow(Y) < 2L) return(NULL) + Y <- Y[, apply(Y, 2L, stats::var) > 0, drop = FALSE] + if (ncol(Y) < 2L) return(NULL) + scores <- stats::prcomp(Y, center = TRUE, scale. = TRUE)$x + k <- min(as.integer(nPCs), ncol(scores)) + if (k < 1L) return(NULL) + scores <- scores[, seq_len(k), drop = FALSE] + colnames(scores) <- paste0("topPC", seq_len(k)) + scores +} + # Single-effect (SER) pre-screen, individual-level. Fits susie with L = 1 on a # residualized (X, y) block and reports whether any PIP clears `cutoff` -- i.e. # whether the block shows any potentially significant variant worth a full fit. @@ -1032,7 +1130,8 @@ setGeneric("fineMappingPipeline", # tokens are fit independently (no chained init) per fold, matching # twasWeightsCv's per-fold refit. Returns NULL on failure (caller skips it). # @noRd -.fmFoldWeights <- function(token, Xtr, Ytr, coverage, userArgs, pos) { +.fmFoldWeights <- function(token, Xtr, Ytr, coverage, userArgs, pos, + mvPrior = NULL) { asMat <- function(w) { if (is.matrix(w)) return(w) matrix(w, ncol = 1L, dimnames = list(names(w), NULL)) @@ -1050,11 +1149,18 @@ setGeneric("fineMappingPipeline", return(asMat(w)) } if (token == "mvsusie") { - pv <- mvsusieR::create_mixture_prior(R = ncol(Ytr)) + # Reuse the data-driven reweighted prior + residual covariance from the + # full-data mr.mash fit on every fold -- the prior is over conditions, which + # are identical across folds (only samples are held out). NULL mvPrior -> + # canonical prior (unchanged behavior). + baseArgs <- list(X = Xtr, Y = Ytr, coverage = coverage, + prior_variance = if (is.null(mvPrior)) + mvsusieR::create_mixture_prior(R = ncol(Ytr)) + else mvPrior$priorVariance) + if (!is.null(mvPrior) && !is.null(mvPrior$residualVariance)) + baseArgs$residual_variance <- mvPrior$residualVariance fit <- do.call(fitMvsusie, - .fmMergeUserArgs(list(X = Xtr, Y = Ytr, prior_variance = pv, - coverage = coverage), - "mvsusie", userArgs)) + .fmMergeUserArgs(baseArgs, "mvsusie", userArgs)) W <- as.matrix(mvsusieWeights(mvsusieFit = fit)) if (is.null(rownames(W))) rownames(W) <- colnames(Xtr) return(W) @@ -1076,7 +1182,7 @@ setGeneric("fineMappingPipeline", # @noRd .fmCrossValidate <- function(X, Y, tokens, methodArgs, fold, samplePartition = NULL, coverage = 0.95, - pos = NULL, verbose = 1) { + pos = NULL, verbose = 1, mvPrior = NULL) { if (length(tokens) == 0L) return(NULL) if (!is.matrix(Y)) { Y <- matrix(Y, ncol = 1L, @@ -1106,7 +1212,7 @@ setGeneric("fineMappingPipeline", XtrK <- Xtr[, keepCol, drop = FALSE] for (tk in tokens) { W <- tryCatch( - .fmFoldWeights(tk, XtrK, Ytr, coverage, methodArgs[[tk]], pos), + .fmFoldWeights(tk, XtrK, Ytr, coverage, methodArgs[[tk]], pos, mvPrior), error = function(e) { if (verbose >= 1) message(sprintf(" CV fold %s, method %s failed: %s", @@ -1186,7 +1292,11 @@ setMethod("fineMappingPipeline", "QtlDataset", cvFolds = 0, samplePartition = NULL, pipCutoffToSkip = 0, + usePCA = FALSE, + nPCs = 10L, seed = NULL, + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10, naAction = c("drop", "impute"), verbose = 1, trim = TRUE, @@ -1222,7 +1332,9 @@ setMethod("fineMappingPipeline", "QtlDataset", parsedJointSpec, data, intersect(tokens, c("mvsusie", "fsusie")), contexts, traitId, cisWindow, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs, xRegions = xRegions) + methodArgs = methodArgs, xRegions = xRegions, + twasWeights = twasWeights, + dataDrivenPriorWeightsCutoff = dataDrivenPriorWeightsCutoff) tokens <- setdiff(tokens, c("mvsusie", "fsusie")) methodArgs <- methodArgs[tokens] if (length(tokens) == 0L) { @@ -1372,6 +1484,54 @@ setMethod("fineMappingPipeline", "QtlDataset", } } + # ---- usePCA dispatch: PCA-reduce each multi-trait context's phenotype and + # fine-map the top PCs with univariate susie (ports the legacy fsusie.R + # susie_on_top_pc). Each PC is a pseudo-trait row (trait = topPC{i}, + # method = "susie"); a single-trait context has no PCA and is skipped. + if (isTRUE(usePCA)) { + for (ctx in useCtx) { + traits <- perCtxTraits[[ctx]] + if (length(traits) < 2L) next + Yctx <- .fmResidPheno(data, contexts = ctx, traitId = traits, + naAction = naAction) + scores <- .fmTopPcScores(Yctx, nPCs) + if (is.null(scores)) next + if (verbose >= 1) + message(sprintf( + "usePCA: fine-mapping %d top PC(s) of context='%s' (%d traits) ...", + ncol(scores), ctx, length(traits))) + for (pcName in colnames(scores)) { + cached <- .fmCacheLookup(fineMappingResult, study, ctx, pcName, "susie") + if (!is.null(cached)) { pushRow(study, ctx, pcName, "susie", cached); next } + pcY <- scores[, pcName] + blockEntries <- lapply(xRegions, function(rg) { + X <- if (is.null(rg)) { + .fmResidGeno(data, contexts = ctx, traitId = traits, + cisWindow = cisWindow, samples = rownames(scores)) + } else { + .fmResidGeno(data, contexts = ctx, region = rg, + samples = rownames(scores)) + } + common <- intersect(rownames(X), names(pcY)) + if (length(common) < 2L) return(list()) + Xb <- X[common, , drop = FALSE] + if (!.fmSerScreen(Xb, pcY[common], pipCutoffToSkip)) return(list()) + afVec <- .fmAfForX(data, Xb, traitId = traits, region = rg, + cisWindow = cisWindow) + .fmFitXBlock(Xb, pcY[common], "susie", FALSE, coverage, + secondaryCoverage, signalCutoff, minAbsCorr, + methodArgs, verbose, ctx, pcName, + cvFolds = cvFolds, samplePartition = samplePartition, + af = afVec) + }) + ents <- lapply(blockEntries, function(be) be[["susie"]]) + if (any(vapply(ents, is.null, logical(1)))) next + entry <- if (length(ents) == 1L) ents[[1L]] else .fmMergeEntries(ents) + pushRow(study, ctx, pcName, "susie", entry) + } + } + } + # ---- mvsusie dispatch: joint over selected (contexts, traits). if (length(mvTokens) > 0L) { if (!requireNamespace("mvsusieR", quietly = TRUE)) { @@ -1467,6 +1627,14 @@ setMethod("fineMappingPipeline", "QtlDataset", if (verbose >= 1) message(sprintf("Fitting mvsusie (multi-context) for trait='%s' ...", tid)) + # Data-driven mvSuSiE prior: if a prior mr.mash run was supplied via + # `twasWeights`, reuse its fitted mixture weights + residual covariance + # for this (study, trait) -> reweighted create_mixture_prior; else the + # lookup returns NULL and `.buildMvsusieReweightedPrior` falls back to + # the canonical prior (unchanged behavior). Keyed on (study, trait): + # the fit may span more contexts than survive the SER pre-screen, and + # `include_indices = colnames(Yc)` subsets it to the fitted contexts. + mvFitParts <- .fmLookupMrmashFit(twasWeights, study, tid) fitOneRegion <- function(rg) { X <- if (is.null(rg)) { .fmResidGeno(data, contexts = contextsHere, traitId = tid, @@ -1488,10 +1656,14 @@ setMethod("fineMappingPipeline", "QtlDataset", colnames(ym) <- ctx ym })) + mvPrior <- .buildMvsusieReweightedPrior( + mvFitParts, colnames(Yc), dataDrivenPriorWeightsCutoff) mvBaseArgs <- list( X = Xc, Y = Yc, - prior_variance = mvsusieR::create_mixture_prior(R = ncol(Yc)), + prior_variance = mvPrior$priorVariance, coverage = coverage) + if (!is.null(mvPrior$residualVariance)) + mvBaseArgs$residual_variance <- mvPrior$residualVariance fit <- do.call(fitMvsusie, .fmMergeUserArgs(mvBaseArgs, "mvsusie", methodArgs[["mvsusie"]])) @@ -1504,7 +1676,8 @@ setMethod("fineMappingPipeline", "QtlDataset", if (cvFolds > 1L) { cv <- .fmCrossValidate(Xc, Yc, "mvsusie", methodArgs, cvFolds, samplePartition = samplePartition, - coverage = coverage, verbose = verbose) + coverage = coverage, verbose = verbose, + mvPrior = mvPrior) entry <- .fmAttachCv(entry, .fmSliceCv(cv, "mvsusie")) } entry @@ -1591,6 +1764,10 @@ setMethod("fineMappingPipeline", "QtlDataset", Yc <- Y[common, , drop = FALSE] afVec <- .fmAfForX(data, Xc, traitId = traits, region = rg, cisWindow = cisWindow) + # Multi-trait mvsusie conditions are traits (one context), so there + # is no mr.mash-over-contexts fit to reweight from -- the data-driven + # prior (keyed on a single (study, trait)) does not apply here. Keep + # the canonical prior. mvBaseArgs <- list( X = Xc, Y = Yc, prior_variance = mvsusieR::create_mixture_prior(R = ncol(Yc)), @@ -1750,6 +1927,8 @@ setMethod("fineMappingPipeline", "MultiStudyQtlDataset", minAbsCorr = 0.8, medianAbsCorr = NULL, fineMappingResult = NULL, + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10, cvFolds = 0, samplePartition = NULL, pipCutoffToSkip = 0, @@ -1783,7 +1962,9 @@ setMethod("fineMappingPipeline", "MultiStudyQtlDataset", parsedJointSpec, data, intersect(tokens, c("mvsusie", "fsusie")), contexts, traitId, cisWindow, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs, xRegions = xRegions) + methodArgs = methodArgs, xRegions = xRegions, + twasWeights = twasWeights, + dataDrivenPriorWeightsCutoff = dataDrivenPriorWeightsCutoff) # Forward the still-pending (non-joint) tokens + their kwargs to the # per-QtlDataset recursion below, preserving the list shape so # methodArgs land on the right tokens. @@ -1897,6 +2078,8 @@ setMethod("fineMappingPipeline", "QtlSumStats", minAbsCorr = 0.8, medianAbsCorr = NULL, fineMappingResult = NULL, + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10, verbose = 1, trim = TRUE, ...) { @@ -1913,7 +2096,9 @@ setMethod("fineMappingPipeline", "QtlSumStats", parsedJointSpec, data, intersect(tokens, "mvsusie"), contexts, traitId, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs) + methodArgs = methodArgs, + twasWeights = twasWeights, + dataDrivenPriorWeightsCutoff = dataDrivenPriorWeightsCutoff) tokens <- setdiff(tokens, c("mvsusie", "fsusie")) methodArgs <- methodArgs[tokens] if (length(tokens) == 0L) { @@ -2091,10 +2276,19 @@ setMethod("fineMappingPipeline", "QtlSumStats", if (verbose >= 1) message(sprintf("Fitting mvsusie (RSS) for (study='%s', trait='%s', %d contexts) ...", st, tr, length(ctxNames))) + # Data-driven reweighted prior from a prior mr.mash (RSS) run, looked up + # on (study, trait); mvsusie_rss takes the same create_mixture_prior + + # residual_variance (K x K condition residual covariance) as fitMvsusie. + # NULL twasWeights / no fit -> canonical prior (unchanged behavior). + mvPrior <- .buildMvsusieReweightedPrior( + .fmLookupMrmashFit(twasWeights, st, tr), colnames(Z), + dataDrivenPriorWeightsCutoff) mvBaseArgs <- list( Z = Z, R = ldMat, N = as.numeric(stats::median(nVec)), - prior_variance = mvsusieR::create_mixture_prior(R = ncol(Z)), + prior_variance = mvPrior$priorVariance, coverage = coverage) + if (!is.null(mvPrior$residualVariance)) + mvBaseArgs$residual_variance <- mvPrior$residualVariance fit <- do.call(fitMvsusieRss, .fmMergeUserArgs(mvBaseArgs, "mvsusie", methodArgs[["mvsusie"]])) diff --git a/R/jointSpecification.R b/R/jointSpecification.R index 500b069e..6612a94f 100644 --- a/R/jointSpecification.R +++ b/R/jointSpecification.R @@ -826,7 +826,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { signalCutoff, minAbsCorr, verbose, methodArgs = list(), - region = NULL) { + region = NULL, + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10) { jointMethods <- intersect(methods, "mvsusie") if (length(jointMethods) == 0L) return(NULL) @@ -858,10 +860,17 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "jointCrossContext: fitting mvsusie for (study='%s', trait='%s') across contexts (%s) ...", study, tid, paste(xy$perTraitContexts, collapse = ", "))) + # Reweighted prior from a cross-context mr.mash joint twas run, keyed on + # (study, trait=tid, context="joint"); conditions are the contexts. + mvPrior <- .buildMvsusieReweightedPrior( + .fmLookupMrmashFit(twasWeights, study, tid, context = "joint"), + colnames(xy$Y), dataDrivenPriorWeightsCutoff) mvBaseArgs <- list( X = xy$X, Y = xy$Y, - prior_variance = mvsusieR::create_mixture_prior(R = ncol(xy$Y)), + prior_variance = mvPrior$priorVariance, coverage = coverage) + if (!is.null(mvPrior$residualVariance)) + mvBaseArgs$residual_variance <- mvPrior$residualVariance fit <- do.call(fitMvsusie, .fmMergeUserArgs(mvBaseArgs, "mvsusie", methodArgs[["mvsusie"]])) @@ -907,7 +916,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { signalCutoff, minAbsCorr, verbose, methodArgs = list(), - region = NULL) { + region = NULL, + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10) { jointMethods <- intersect(methods, c("mvsusie", "fsusie")) if (length(jointMethods) == 0L) return(NULL) @@ -934,10 +945,17 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { "jointCrossTrait: fitting %s for (study='%s', context='%s') across traits (%s) ...", mm, study, cx, paste(xy$traitsHere, collapse = ", "))) if (mm == "mvsusie") { + # Reweighted prior from a cross-trait mr.mash joint twas run, keyed on + # (study, trait="joint", context=cx); conditions are the traits. + mvPrior <- .buildMvsusieReweightedPrior( + .fmLookupMrmashFit(twasWeights, study, "joint", context = cx), + colnames(xy$Y), dataDrivenPriorWeightsCutoff) mvBaseArgs <- list( X = xy$X, Y = xy$Y, - prior_variance = mvsusieR::create_mixture_prior(R = ncol(xy$Y)), + prior_variance = mvPrior$priorVariance, coverage = coverage) + if (!is.null(mvPrior$residualVariance)) + mvBaseArgs$residual_variance <- mvPrior$residualVariance fit <- do.call(fitMvsusie, .fmMergeUserArgs(mvBaseArgs, "mvsusie", methodArgs[["mvsusie"]])) @@ -1000,7 +1018,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = list()) { + methodArgs = list(), + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10) { jointMethods <- intersect(methods, "mvsusie") if (length(jointMethods) == 0L) return(NULL) @@ -1044,10 +1064,17 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "jointCrossContext (QtlSumStats): fitting mvsusie_rss for (study='%s', trait='%s', %d contexts) ...", s, tid, length(ctxNames))) + # Reweighted prior from a cross-context mr.mash joint twas run, keyed on + # (study, trait=tid, context="joint"); conditions are the contexts. + mvPrior <- .buildMvsusieReweightedPrior( + .fmLookupMrmashFit(twasWeights, s, tid, context = "joint"), + colnames(jz$Z), dataDrivenPriorWeightsCutoff) mvBaseArgs <- list( Z = jz$Z, R = ldMat, N = as.numeric(stats::median(jz$nVec)), - prior_variance = mvsusieR::create_mixture_prior(R = ncol(jz$Z)), + prior_variance = mvPrior$priorVariance, coverage = coverage) + if (!is.null(mvPrior$residualVariance)) + mvBaseArgs$residual_variance <- mvPrior$residualVariance fit <- do.call(fitMvsusieRss, .fmMergeUserArgs(mvBaseArgs, "mvsusie", methodArgs[["mvsusie"]])) @@ -1090,7 +1117,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = list()) { + methodArgs = list(), + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10) { if ("fsusie" %in% methods) stop("jointCrossTrait (QtlSumStats): fsusie has no RSS variant; ", "fsusie cannot participate in sumstats-based joint fits.") @@ -1130,10 +1159,17 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "jointCrossTrait (QtlSumStats): fitting mvsusie_rss for (study='%s', context='%s', %d traits) ...", s, cx, length(trNames))) + # Reweighted prior from a cross-trait mr.mash joint twas run, keyed on + # (study, trait="joint", context=cx); conditions are the traits. + mvPrior <- .buildMvsusieReweightedPrior( + .fmLookupMrmashFit(twasWeights, s, "joint", context = cx), + colnames(jz$Z), dataDrivenPriorWeightsCutoff) mvBaseArgs <- list( Z = jz$Z, R = ldMat, N = as.numeric(stats::median(jz$nVec)), - prior_variance = mvsusieR::create_mixture_prior(R = ncol(jz$Z)), + prior_variance = mvPrior$priorVariance, coverage = coverage) + if (!is.null(mvPrior$residualVariance)) + mvBaseArgs$residual_variance <- mvPrior$residualVariance fit <- do.call(fitMvsusieRss, .fmMergeUserArgs(mvBaseArgs, "mvsusie", methodArgs[["mvsusie"]])) @@ -1177,7 +1213,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = list()) { + methodArgs = list(), + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10) { if ("fsusie" %in% methods) stop("jointCrossStudy: fsusie cannot participate (no RSS variant).") jointMethods <- intersect(methods, "mvsusie") @@ -1226,6 +1264,7 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { cx, tid, length(stNames))) mvBaseArgs <- list( Z = jz$Z, R = ldMat, N = as.numeric(stats::median(jz$nVec)), + # TODO(mvsusie-prior): cross-study lookup key undecided; canonical prior for now prior_variance = mvsusieR::create_mixture_prior(R = ncol(jz$Z)), coverage = coverage) fit <- do.call(fitMvsusieRss, @@ -1273,7 +1312,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { signalCutoff, minAbsCorr, verbose, methodArgs = list(), - region = NULL) { + region = NULL, + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10) { axes <- spec$axes if ("study" %in% axes) stop("composed jointSpecification (QtlDataset): axes including 'study' require sumstats input.") @@ -1298,10 +1339,18 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "composed joint (QtlDataset): fitting mvsusie for study='%s' over %d (context, trait) columns ...", study, ncol(xy$Y))) + # Reweighted prior from a composed mr.mash joint twas run, keyed on + # (study, trait="joint", context="joint"); conditions are the (context,trait) + # columns of the composed design. + mvPrior <- .buildMvsusieReweightedPrior( + .fmLookupMrmashFit(twasWeights, study, "joint", context = "joint"), + colnames(xy$Y), dataDrivenPriorWeightsCutoff) mvBaseArgs <- list( X = xy$X, Y = xy$Y, - prior_variance = mvsusieR::create_mixture_prior(R = ncol(xy$Y)), + prior_variance = mvPrior$priorVariance, coverage = coverage) + if (!is.null(mvPrior$residualVariance)) + mvBaseArgs$residual_variance <- mvPrior$residualVariance fit <- do.call(fitMvsusie, .fmMergeUserArgs(mvBaseArgs, "mvsusie", methodArgs[["mvsusie"]])) @@ -1338,7 +1387,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = list()) { + methodArgs = list(), + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10) { if ("fsusie" %in% methods) stop("composed jointSpecification (QtlSumStats): fsusie has no RSS variant.") jointMethods <- intersect(methods, "mvsusie") @@ -1380,10 +1431,18 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "composed joint (QtlSumStats): fitting mvsusie_rss for axes=(%s), %d columns ...", paste(axes, collapse = ", "), length(gIdx))) + # Reweighted prior from a composed mr.mash joint twas run, keyed on + # (study, trait="joint", context="joint"); conditions are the joint columns. + mvPrior <- .buildMvsusieReweightedPrior( + .fmLookupMrmashFit(twasWeights, studyCol[gIdx[[1L]]], "joint", + context = "joint"), + colnames(jz$Z), dataDrivenPriorWeightsCutoff) mvBaseArgs <- list( Z = jz$Z, R = ldMat, N = as.numeric(stats::median(jz$nVec)), - prior_variance = mvsusieR::create_mixture_prior(R = ncol(jz$Z)), + prior_variance = mvPrior$priorVariance, coverage = coverage) + if (!is.null(mvPrior$residualVariance)) + mvBaseArgs$residual_variance <- mvPrior$residualVariance fit <- do.call(fitMvsusieRss, .fmMergeUserArgs(mvBaseArgs, "mvsusie", methodArgs[["mvsusie"]])) @@ -1471,7 +1530,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { signalCutoff, minAbsCorr, verbose, methodArgs = list(), - xRegions = list(NULL)) { + xRegions = list(NULL), + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10) { # Run the joint dispatch once per region block, then merge per # (study, context, trait, method) across regions. A single block (cis or # jointRegions=TRUE concatenated) returns its result directly. @@ -1479,7 +1540,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { .fmDispatchJointSpecsQtlDatasetOneRegion( parsedJointSpec, data, methods, contexts, traitIds, cisWindow, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs, region = rg) + methodArgs = methodArgs, region = rg, + twasWeights = twasWeights, + dataDrivenPriorWeightsCutoff = dataDrivenPriorWeightsCutoff) }) perRegion <- Filter(Negate(is.null), perRegion) if (length(perRegion) == 0L) return(NULL) @@ -1494,16 +1557,22 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { signalCutoff, minAbsCorr, verbose, methodArgs = list(), - region = NULL) { + region = NULL, + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10) { + # Bundle the data-driven mvSuSiE prior pass-through args once; every leaf + # dispatcher accepts the same pair. + priorArgs <- list(twasWeights = twasWeights, + dataDrivenPriorWeightsCutoff = dataDrivenPriorWeightsCutoff) out <- NULL for (i in seq_along(parsedJointSpec)) { spec <- parsedJointSpec[[i]] axes <- spec$axes if (length(axes) > 1L) { - res <- .fmDispatchComposedQtlDataset( + res <- do.call(.fmDispatchComposedQtlDataset, c(list( spec, data, methods, contexts, traitIds, cisWindow, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs, region = region) + methodArgs = methodArgs, region = region), priorArgs)) if (!is.null(res)) out <- if (is.null(out)) res else .rbindFineMappingResult(out, res, ldSketch = NULL) @@ -1511,14 +1580,14 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { } axis <- axes[[1L]] res <- switch(axis, - context = .fmDispatchCrossContextQtlDataset( + context = do.call(.fmDispatchCrossContextQtlDataset, c(list( spec, data, methods, contexts, traitIds, cisWindow, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs, region = region), - trait = .fmDispatchCrossTraitQtlDataset( + methodArgs = methodArgs, region = region), priorArgs)), + trait = do.call(.fmDispatchCrossTraitQtlDataset, c(list( spec, data, methods, contexts, traitIds, cisWindow, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs, region = region), + methodArgs = methodArgs, region = region), priorArgs)), study = stop( "fineMappingPipeline(QtlDataset): jointSpecification with axes = 'study' requires sumstats input. ", "QtlDataset represents a single individual-level study; cross-study joints operate on the sumstats slot of MultiStudyQtlDataset or on QtlSumStats directly."), @@ -1538,16 +1607,22 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = list()) { + methodArgs = list(), + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10) { + # Bundle the data-driven mvSuSiE prior pass-through args once; every leaf + # dispatcher accepts the same pair. + priorArgs <- list(twasWeights = twasWeights, + dataDrivenPriorWeightsCutoff = dataDrivenPriorWeightsCutoff) out <- NULL for (i in seq_along(parsedJointSpec)) { spec <- parsedJointSpec[[i]] axes <- spec$axes if (length(axes) > 1L) { - res <- .fmDispatchComposedQtlSumStats( + res <- do.call(.fmDispatchComposedQtlSumStats, c(list( spec, data, methods, contexts, traitIds, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs) + methodArgs = methodArgs), priorArgs)) if (!is.null(res)) out <- if (is.null(out)) res else .rbindFineMappingResult(out, res, ldSketch = getLdSketch(data)) @@ -1555,18 +1630,18 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { } axis <- axes[[1L]] res <- switch(axis, - context = .fmDispatchCrossContextQtlSumStats( + context = do.call(.fmDispatchCrossContextQtlSumStats, c(list( spec, data, methods, contexts, traitIds, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs), - trait = .fmDispatchCrossTraitQtlSumStats( + methodArgs = methodArgs), priorArgs)), + trait = do.call(.fmDispatchCrossTraitQtlSumStats, c(list( spec, data, methods, contexts, traitIds, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs), - study = .fmDispatchCrossStudyQtlSumStats( + methodArgs = methodArgs), priorArgs)), + study = do.call(.fmDispatchCrossStudyQtlSumStats, c(list( spec, data, methods, contexts, traitIds, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs), + methodArgs = methodArgs), priorArgs)), stop(sprintf("Unsupported axis: %s", axis))) if (!is.null(res)) out <- if (is.null(out)) res @@ -1589,7 +1664,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { signalCutoff, minAbsCorr, verbose, methodArgs = list(), - xRegions = list(NULL)) { + xRegions = list(NULL), + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10) { out <- NULL embeddedLd <- NULL qtlDatasets <- getQtlDatasets(data) @@ -1612,7 +1689,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { qdRes <- .fmDispatchJointSpecsQtlDataset( nonStudyAxisSpecs, qd, methods, contexts, traitIds, cisWindow, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs, xRegions = xRegions) + methodArgs = methodArgs, xRegions = xRegions, + twasWeights = twasWeights, + dataDrivenPriorWeightsCutoff = dataDrivenPriorWeightsCutoff) if (!is.null(qdRes)) out <- if (is.null(out)) qdRes else .rbindFineMappingResult(out, qdRes, ldSketch = NULL) @@ -1623,7 +1702,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { ssRes <- .fmDispatchJointSpecsQtlSumStats( parsedJointSpec, sumStats, methods, contexts, traitIds, coverage, secondaryCoverage, signalCutoff, minAbsCorr, verbose, - methodArgs = methodArgs) + methodArgs = methodArgs, + twasWeights = twasWeights, + dataDrivenPriorWeightsCutoff = dataDrivenPriorWeightsCutoff) if (!is.null(ssRes)) { embeddedLd <- getLdSketch(ssRes) out <- if (is.null(out)) ssRes @@ -1647,7 +1728,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { .twasDispatchCrossContextQtlDataset <- function(spec, data, methods, contexts, traitIds, cisWindow, dataType, - verbose, region = NULL) { + verbose, region = NULL, + retainFit = TRUE, + retainFitDetail = "slim") { jointMethods <- intersect(methods, "mrmash") if (length(jointMethods) == 0L) return(NULL) @@ -1679,11 +1762,13 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "jointCrossContext (twas QtlDataset): fitting mr.mash for (study='%s', trait='%s') across contexts (%s) ...", study, tid, paste(xy$perTraitContexts, collapse = ", "))) - weights <- mrmashWeights(X = xy$X, Y = xy$Y) + weights <- mrmashWeights(X = xy$X, Y = xy$Y, + retainFit = retainFit, fitDetail = retainFitDetail) if (is.null(rownames(weights))) rownames(weights) <- colnames(xy$X) entry <- TwasWeightsEntry( variantIds = rownames(weights), weights = weights, + fits = attr(weights, "fit"), standardized = FALSE, dataType = dataType) rowStudy <- c(rowStudy, study) @@ -1713,7 +1798,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { .twasDispatchCrossTraitQtlDataset <- function(spec, data, methods, contexts, traitIds, cisWindow, dataType, - verbose, region = NULL) { + verbose, region = NULL, + retainFit = TRUE, + retainFitDetail = "slim") { jointMethods <- intersect(methods, "mrmash") if (length(jointMethods) == 0L) return(NULL) @@ -1738,11 +1825,13 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "jointCrossTrait (twas): fitting mr.mash for (study='%s', context='%s') across traits (%s) ...", study, cx, paste(xy$traitsHere, collapse = ", "))) - weights <- mrmashWeights(X = xy$X, Y = xy$Y) + weights <- mrmashWeights(X = xy$X, Y = xy$Y, + retainFit = retainFit, fitDetail = retainFitDetail) if (is.null(rownames(weights))) rownames(weights) <- colnames(xy$X) entry <- TwasWeightsEntry( variantIds = rownames(weights), weights = weights, + fits = attr(weights, "fit"), standardized = FALSE, dataType = dataType) rowStudy <- c(rowStudy, study) @@ -1770,7 +1859,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { # @noRd .twasDispatchCrossContextQtlSumStats <- function(spec, data, methods, contexts, traitIds, - dataType, verbose) { + dataType, verbose, + retainFit = TRUE, + retainFitDetail = "slim") { jointMethods <- intersect(methods, "mrmash") if (length(jointMethods) == 0L) return(NULL) @@ -1815,11 +1906,14 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "jointCrossContext (twas QtlSumStats): fitting mr.mash.rss for (study='%s', trait='%s', %d contexts) ...", s, tid, length(ctxNames))) - weights <- mrmashRssWeights(stat = stat, LD = ldMat) + weights <- mrmashRssWeights(stat = stat, LD = ldMat, + retainFit = retainFit, + fitDetail = retainFitDetail) if (is.null(rownames(weights))) rownames(weights) <- jz$variantIds entry <- TwasWeightsEntry( variantIds = rownames(weights), weights = weights, + fits = attr(weights, "fit"), standardized = TRUE, dataType = dataType) rowStudy <- c(rowStudy, s) @@ -1848,7 +1942,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { # @noRd .twasDispatchCrossTraitQtlSumStats <- function(spec, data, methods, contexts, traitIds, - dataType, verbose) { + dataType, verbose, + retainFit = TRUE, + retainFitDetail = "slim") { jointMethods <- intersect(methods, "mrmash") if (length(jointMethods) == 0L) return(NULL) @@ -1886,11 +1982,14 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "jointCrossTrait (twas QtlSumStats): fitting mr.mash.rss for (study='%s', context='%s', %d traits) ...", s, cx, length(trNames))) - weights <- mrmashRssWeights(stat = stat, LD = ldMat) + weights <- mrmashRssWeights(stat = stat, LD = ldMat, + retainFit = retainFit, + fitDetail = retainFitDetail) if (is.null(rownames(weights))) rownames(weights) <- jz$variantIds entry <- TwasWeightsEntry( variantIds = rownames(weights), weights = weights, + fits = attr(weights, "fit"), standardized = TRUE, dataType = dataType) rowStudy <- c(rowStudy, s) @@ -1919,7 +2018,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { # @noRd .twasDispatchCrossStudyQtlSumStats <- function(spec, data, methods, contexts, traitIds, - dataType, verbose) { + dataType, verbose, + retainFit = TRUE, + retainFitDetail = "slim") { jointMethods <- intersect(methods, "mrmash") if (length(jointMethods) == 0L) return(NULL) @@ -1965,11 +2066,14 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "jointCrossStudy (twas): fitting mr.mash.rss for (context='%s', trait='%s', %d studies) ...", cx, tid, length(stNames))) - weights <- mrmashRssWeights(stat = stat, LD = ldMat) + weights <- mrmashRssWeights(stat = stat, LD = ldMat, + retainFit = retainFit, + fitDetail = retainFitDetail) if (is.null(rownames(weights))) rownames(weights) <- jz$variantIds entry <- TwasWeightsEntry( variantIds = rownames(weights), weights = weights, + fits = attr(weights, "fit"), standardized = TRUE, dataType = dataType) rowStudy <- c(rowStudy, "joint") @@ -1997,7 +2101,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { # @noRd .twasDispatchComposedQtlSumStats <- function(spec, data, methods, contexts, traitIds, - dataType, verbose) { + dataType, verbose, + retainFit = TRUE, + retainFitDetail = "slim") { jointMethods <- intersect(methods, "mrmash") if (length(jointMethods) == 0L) return(NULL) @@ -2038,11 +2144,14 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "composed joint (twas QtlSumStats): fitting mr.mash.rss for axes=(%s), %d columns ...", paste(axes, collapse = ", "), length(gIdx))) - weights <- mrmashRssWeights(stat = stat, LD = ldMat) + weights <- mrmashRssWeights(stat = stat, LD = ldMat, + retainFit = retainFit, + fitDetail = retainFitDetail) if (is.null(rownames(weights))) rownames(weights) <- jz$variantIds entry <- TwasWeightsEntry( variantIds = rownames(weights), weights = weights, + fits = attr(weights, "fit"), standardized = TRUE, dataType = dataType) @@ -2087,7 +2196,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { .twasDispatchComposedQtlDataset <- function(spec, data, methods, contexts, traitIds, cisWindow, dataType, verbose, - region = NULL) { + region = NULL, + retainFit = TRUE, + retainFitDetail = "slim") { axes <- spec$axes if ("study" %in% axes) stop("composed jointSpecification (twas QtlDataset): axes including 'study' require sumstats input.") @@ -2111,11 +2222,13 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { message(sprintf( "composed joint (twas QtlDataset): fitting mr.mash for study='%s' over %d (context, trait) columns ...", study, ncol(xy$Y))) - weights <- mrmashWeights(X = xy$X, Y = xy$Y) + weights <- mrmashWeights(X = xy$X, Y = xy$Y, + retainFit = retainFit, fitDetail = retainFitDetail) if (is.null(rownames(weights))) rownames(weights) <- colnames(xy$X) entry <- TwasWeightsEntry( variantIds = rownames(weights), weights = weights, + fits = attr(weights, "fit"), standardized = FALSE, dataType = dataType) TwasWeights( @@ -2163,14 +2276,17 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { .twasDispatchJointSpecsQtlDataset <- function(parsedJointSpec, data, methods, contexts, traitIds, cisWindow, dataType, - verbose, xRegions = list(NULL)) { + verbose, xRegions = list(NULL), + retainFit = TRUE, + retainFitDetail = "slim") { # Run the joint dispatch once per region block, then merge per # (study, context, trait, method) across regions. A single block (cis or # jointRegions=TRUE concatenated) returns its result directly. perRegion <- lapply(xRegions, function(rg) { .twasDispatchJointSpecsQtlDatasetOneRegion( parsedJointSpec, data, methods, contexts, traitIds, cisWindow, dataType, - verbose, region = rg) + verbose, region = rg, + retainFit = retainFit, retainFitDetail = retainFitDetail) }) labs <- vapply(xRegions, .twasRegionLabel, character(1)) keep <- !vapply(perRegion, is.null, logical(1)) @@ -2183,7 +2299,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { .twasDispatchJointSpecsQtlDatasetOneRegion <- function(parsedJointSpec, data, methods, contexts, traitIds, cisWindow, dataType, - verbose, region = NULL) { + verbose, region = NULL, + retainFit = TRUE, + retainFitDetail = "slim") { out <- NULL for (i in seq_along(parsedJointSpec)) { spec <- parsedJointSpec[[i]] @@ -2191,7 +2309,8 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { if (length(axes) > 1L) { res <- .twasDispatchComposedQtlDataset( spec, data, methods, contexts, traitIds, cisWindow, dataType, verbose, - region = region) + region = region, + retainFit = retainFit, retainFitDetail = retainFitDetail) if (!is.null(res)) out <- if (is.null(out)) res else .rbindTwasWeights(out, res, ldSketch = NULL) @@ -2201,10 +2320,12 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { res <- switch(axis, context = .twasDispatchCrossContextQtlDataset( spec, data, methods, contexts, traitIds, cisWindow, dataType, verbose, - region = region), + region = region, + retainFit = retainFit, retainFitDetail = retainFitDetail), trait = .twasDispatchCrossTraitQtlDataset( spec, data, methods, contexts, traitIds, cisWindow, dataType, verbose, - region = region), + region = region, + retainFit = retainFit, retainFitDetail = retainFitDetail), study = stop( "twasWeightsPipeline(QtlDataset): jointSpecification with axes = 'study' requires sumstats input."), stop(sprintf("Unsupported axis: %s", axis))) @@ -2220,14 +2341,17 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { # @noRd .twasDispatchJointSpecsQtlSumStats <- function(parsedJointSpec, data, methods, contexts, traitIds, - dataType, verbose) { + dataType, verbose, + retainFit = TRUE, + retainFitDetail = "slim") { out <- NULL for (i in seq_along(parsedJointSpec)) { spec <- parsedJointSpec[[i]] axes <- spec$axes if (length(axes) > 1L) { res <- .twasDispatchComposedQtlSumStats( - spec, data, methods, contexts, traitIds, dataType, verbose) + spec, data, methods, contexts, traitIds, dataType, verbose, + retainFit = retainFit, retainFitDetail = retainFitDetail) if (!is.null(res)) out <- if (is.null(out)) res else .rbindTwasWeights(out, res, ldSketch = getLdSketch(data)) @@ -2236,11 +2360,14 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { axis <- axes[[1L]] res <- switch(axis, context = .twasDispatchCrossContextQtlSumStats( - spec, data, methods, contexts, traitIds, dataType, verbose), + spec, data, methods, contexts, traitIds, dataType, verbose, + retainFit = retainFit, retainFitDetail = retainFitDetail), trait = .twasDispatchCrossTraitQtlSumStats( - spec, data, methods, contexts, traitIds, dataType, verbose), + spec, data, methods, contexts, traitIds, dataType, verbose, + retainFit = retainFit, retainFitDetail = retainFitDetail), study = .twasDispatchCrossStudyQtlSumStats( - spec, data, methods, contexts, traitIds, dataType, verbose), + spec, data, methods, contexts, traitIds, dataType, verbose, + retainFit = retainFit, retainFitDetail = retainFitDetail), stop(sprintf("Unsupported axis: %s", axis))) if (!is.null(res)) out <- if (is.null(out)) res @@ -2255,7 +2382,9 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { .twasDispatchJointSpecsMultiStudy <- function(parsedJointSpec, data, methods, contexts, traitIds, cisWindow, dataType, verbose, - xRegions = list(NULL)) { + xRegions = list(NULL), + retainFit = TRUE, + retainFitDetail = "slim") { out <- NULL embeddedLd <- NULL qtlDatasets <- getQtlDatasets(data) @@ -2277,7 +2406,8 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { qd <- qtlDatasets[[qdName]] qdRes <- .twasDispatchJointSpecsQtlDataset( nonStudyAxisSpecs, qd, methods, contexts, traitIds, cisWindow, - dataType, verbose, xRegions = xRegions) + dataType, verbose, xRegions = xRegions, + retainFit = retainFit, retainFitDetail = retainFitDetail) if (!is.null(qdRes)) out <- if (is.null(out)) qdRes else .rbindTwasWeights(out, qdRes, ldSketch = NULL) @@ -2287,7 +2417,8 @@ validateMethodsVsJointSpec <- function(methodsParsed, jointSpecParsed) { if (!is.null(sumStats)) { ssRes <- .twasDispatchJointSpecsQtlSumStats( parsedJointSpec, sumStats, methods, contexts, traitIds, dataType, - verbose) + verbose, + retainFit = retainFit, retainFitDetail = retainFitDetail) if (!is.null(ssRes)) { embeddedLd <- getLdSketch(ssRes) out <- if (is.null(out)) ssRes diff --git a/R/qtlEnrichmentPipeline.R b/R/qtlEnrichmentPipeline.R index 928487b7..d2f99034 100644 --- a/R/qtlEnrichmentPipeline.R +++ b/R/qtlEnrichmentPipeline.R @@ -86,9 +86,52 @@ qtlEnrichmentPipeline <- function(gwasFineMappingResult, stop("qtlEnrichmentPipeline: no (gwasStudy, qtlStudy, qtlContext) ", "triples to compute (one of the inputs has zero rows).") + # Hoist the QTL-side work out of the gwasStudy loop. The variant-name + # alignment is independent of the GWAS study (all studies share one naming + # convention), so the original code re-ran the costly .matchRefPanel pass for + # every (gwasStudy, qtlTuple) pair. Instead: build each GWAS PIP vector once, + # derive the union variant-name panel, and align each QTL tuple's regions + # once against that union (memoised in `alignedByTuple`, reused across every + # gwasStudy). qtlEnrichment is then called with alignNames = FALSE, which + # skips the redundant alignment and only recomputes the cheap per-study + # "unmatched" set. + gwasPipByStudy <- lapply(gwasStudies, function(g) + .enrBuildGwasPipVector(gwasFineMappingResult, g)) + names(gwasPipByStudy) <- gwasStudies + unionGwasNames <- unique(unlist(lapply(gwasPipByStudy, names), + use.names = FALSE)) + + qtlRegionsByTuple <- lapply(seq_len(nrow(qtlTuples)), function(k) + .enrBuildQtlRegionsList(qtlFineMappingResult, + qtlTuples$qtlStudy[[k]], + qtlTuples$qtlContext[[k]])) + + # Align a tuple's regions to the union GWAS panel once and cache the result. + # Kept lazy (rather than a pre-loop lapply) so the alignment runs inside the + # per-tuple tryCatch below: a tuple whose names cannot be aligned is skipped + # with a warning instead of aborting the whole pipeline -- the behaviour + # before this optimization, where alignment lived inside qtlEnrichment. An + # empty union means no GWAS study has usable PIPs, so every study is skipped + # before this is ever reached (and the length guard keeps alignVariantNames + # from treating an empty reference as a convention mismatch). + alignedByTuple <- vector("list", nrow(qtlTuples)) + alignTuple <- function(k) { + if (!is.null(alignedByTuple[[k]])) return(alignedByTuple[[k]]) + aligned <- lapply(qtlRegionsByTuple[[k]], function(x) { + if (!is.null(names(x$pip)) && length(unionGwasNames) > 0L) { + names(x$pip) <- alignVariantNames(names(x$pip), + unionGwasNames)$alignedVariants + } + x + }) + alignedByTuple[[k]] <<- aligned + aligned + } + results <- list() - for (gStudy in gwasStudies) { - gwasPip <- .enrBuildGwasPipVector(gwasFineMappingResult, gStudy) + for (gi in seq_along(gwasStudies)) { + gStudy <- gwasStudies[[gi]] + gwasPip <- gwasPipByStudy[[gi]] if (length(gwasPip) == 0L) { warning(sprintf( "qtlEnrichmentPipeline: no usable PIPs for gwasStudy='%s'; skipping.", @@ -98,9 +141,7 @@ qtlEnrichmentPipeline <- function(gwasFineMappingResult, for (k in seq_len(nrow(qtlTuples))) { qStudy <- qtlTuples$qtlStudy[[k]] qContext <- qtlTuples$qtlContext[[k]] - qtlRegions <- .enrBuildQtlRegionsList(qtlFineMappingResult, - qStudy, qContext) - if (length(qtlRegions) == 0L) { + if (length(qtlRegionsByTuple[[k]]) == 0L) { warning(sprintf( "qtlEnrichmentPipeline: no usable QTL regions for (qtlStudy='%s', qtlContext='%s'); skipping.", qStudy, qContext)) @@ -109,12 +150,13 @@ qtlEnrichmentPipeline <- function(gwasFineMappingResult, enr <- tryCatch( qtlEnrichment( gwasPip = gwasPip, - susieQtlRegions = qtlRegions, + susieQtlRegions = alignTuple(k), numGwas = numGwas, piQtl = piQtl, lambda = lambda, impN = impN, numThreads = numThreads, + alignNames = FALSE, ...), error = function(e) { warning(sprintf( @@ -292,6 +334,12 @@ qtlEnrichmentPipeline <- function(gwasFineMappingResult, #' When it is set to 0, no shrinkage will be applied. A large value indicates strong shrinkage. The default value is set to 1.0. #' @param impN Rounds of multiple imputation to draw QTL from, default is 25. #' @param numThreads Number of Simultaneous running CPU threads for multiple imputation, default is 1. +#' @param alignNames Logical; when TRUE (default) QTL pip names are aligned to +#' the GWAS variant-naming convention via \code{alignVariantNames}. Set FALSE +#' when the caller has already aligned them (e.g. \code{qtlEnrichmentPipeline} +#' aligns each QTL tuple once against the union GWAS panel rather than +#' re-aligning per GWAS study); only the cheap per-study unmatched set is then +#' recomputed, skipping the costly \code{.matchRefPanel} pass. #' @return A list of enrichment parameter estimates #' #' @examples @@ -334,7 +382,8 @@ qtlEnrichment <- function(gwasPip, susieQtlRegions, lambda = 1.0, impN = 25, doubleShrinkage = FALSE, besselCorrection = TRUE, - numThreads = 1, verbose = TRUE) { + numThreads = 1, verbose = TRUE, + alignNames = TRUE) { if (is.null(numGwas)) { warning("numGwas is not provided. Estimating piGwas from the data. Note that this estimate may be biased if the input gwasPip does not contain genome-wide variants.") piGwas <- sum(gwasPip) / length(gwasPip) @@ -370,15 +419,31 @@ qtlEnrichment <- function(gwasPip, susieQtlRegions, stop("Variant names are missing in susieQtlRegions$pip. Please provide susieQtlRegions with named pip data.") } - # Align the names of susieQtlRegions$pip to gwasPip names and document unmatched variants - alignedSusieQtlRegions <- lapply(susieQtlRegions, function(x) { - alignmentResult <- alignVariantNames(names(x$pip), names(gwasPip)) - names(x$pip) <- alignmentResult$alignedVariants - if (length(alignmentResult$unmatchedIndices) > 0) { - x$unmatched_variants <- names(x$pip)[alignmentResult$unmatchedIndices] - } - x - }) + # Align the names of susieQtlRegions$pip to gwasPip names and document + # unmatched variants. With alignNames = FALSE the caller has already aligned + # the pip names to the GWAS naming convention (qtlEnrichmentPipeline aligns + # each QTL tuple once against the union GWAS panel), so the costly + # .matchRefPanel pass is skipped and only the per-study unmatched set is + # recomputed via a cheap set-membership test. + if (alignNames) { + alignedSusieQtlRegions <- lapply(susieQtlRegions, function(x) { + alignmentResult <- alignVariantNames(names(x$pip), names(gwasPip)) + names(x$pip) <- alignmentResult$alignedVariants + if (length(alignmentResult$unmatchedIndices) > 0) { + x$unmatched_variants <- names(x$pip)[alignmentResult$unmatchedIndices] + } + x + }) + } else { + gwasNameSet <- names(gwasPip) + alignedSusieQtlRegions <- lapply(susieQtlRegions, function(x) { + unmatchedIdx <- which(!(names(x$pip) %in% gwasNameSet)) + if (length(unmatchedIdx) > 0) { + x$unmatched_variants <- names(x$pip)[unmatchedIdx] + } + x + }) + } unmatchedVariants <- lapply(alignedSusieQtlRegions, function(x) x$unmatched_variants) # Update susieQtlRegions with the aligned variant names diff --git a/R/regularizedRegressionWrappers.R b/R/regularizedRegressionWrappers.R index 866cfe97..b4fa0136 100644 --- a/R/regularizedRegressionWrappers.R +++ b/R/regularizedRegressionWrappers.R @@ -448,13 +448,19 @@ susieAshRssWeights <- function(stat, LD, susieAshRssFit = NULL, retainFit = TRUE #' weights) the parts of the mr.mash fit that `fineMappingPipeline` needs to #' rebuild the mvSuSiE reweighted mixture prior + residual variance: the #' original data-driven prior matrices (`dataDrivenPriorMatrices`), the fitted -#' mixture weights (`w0`) and the residual covariance (`V`). The heavy -#' coefficient matrix (`mu1`) is intentionally not retained. Default FALSE. +#' mixture weights (`w0`) and the residual covariance (`V`). Default FALSE. +#' @param fitDetail How much of the fit to retain when `retainFit = TRUE`. +#' `"slim"` (default) keeps only the three reconstruction inputs above; the +#' heavy coefficient matrix (`mu1`) is already returned as the weights, so it +#' is not duplicated. `"full"` additionally retains the complete mr.mash fit +#' under `$fit` (consistent with how susie fits are kept), at the cost of a +#' larger payload. #' @param ... Additional arguments passed to `mrmashWrapper()` when fitting. #' @return Matrix of variant weights. #' @export mrmashWeights <- function(mrmashFit = NULL, X = NULL, Y = NULL, - retainFit = FALSE, ...) { + retainFit = FALSE, fitDetail = c("slim", "full"), + ...) { if (!requireNamespace("mr.mashr", quietly = TRUE)) { stop("Package 'mr.mashr' is required. Install with: devtools::install_github('stephenslab/mr.mashr')") } @@ -468,14 +474,18 @@ mrmashWeights <- function(mrmashFit = NULL, X = NULL, Y = NULL, } out <- mr.mashr::coef.mr.mash(mrmashFit)[-1, ] if (isTRUE(retainFit)) { - # Lean payload consumed by fineMappingPipeline to reproduce the legacy - # initializeMvsusiePrior reweighting (see the mvSuSiE-prior-from-mr.mash - # note). The original matrices are required for bit-identical results; - # rescaleCovW0(w0) collapses the expanded weights back onto them. - attr(out, "fit") <- list( + fitDetail <- match.arg(fitDetail) + # Reconstruction inputs for the mvSuSiE data-driven prior (see the + # mvSuSiE-prior-from-mr.mash note): the original matrices (required for + # bit-identical results -- rescaleCovW0(w0) collapses the expanded weights + # back onto them), w0, and V. mu1 is already returned as `out`, so "slim" + # does not duplicate it; "full" additionally keeps the whole fit. + fitList <- list( dataDrivenPriorMatrices = dotArgs$dataDrivenPriorMatrices, w0 = mrmashFit$w0, V = mrmashFit$V) + if (fitDetail == "full") fitList$fit <- mrmashFit + attr(out, "fit") <- fitList } out } @@ -709,8 +719,15 @@ fsusieWeights <- function(fsusieFit = NULL, X = NULL, Y = NULL, #' defaults to the identity matrix of size K. #' @param covY Optional response covariance matrix (K x K). When NULL, #' defaults to the identity matrix of size K. -#' @param retainFit If TRUE, attaches the fitted object as the -#' \code{"fit"} attribute on the returned weights. +#' @param retainFit If TRUE, attaches (as the \code{"fit"} attribute on the +#' returned weights) the inputs \code{fineMappingPipeline} needs to rebuild +#' the mvSuSiE reweighted prior: \code{dataDrivenPriorMatrices}, the fitted +#' \code{w0}, and the fitted \code{V}. Default FALSE. +#' @param fitDetail How much to retain when \code{retainFit = TRUE}. +#' \code{"slim"} (default) keeps only those reconstruction inputs (the +#' coefficients are already the returned weights); \code{"full"} additionally +#' keeps the complete \code{mr.mash.rss} fit under \code{$fit}. Mirrors +#' \code{\link{mrmashWeights}}. #' @param ... Additional arguments forwarded to #' \code{mr.mashr::mr.mash.rss}. #' @@ -721,7 +738,8 @@ mrmashRssWeights <- function(stat, LD, mrmashRssFit = NULL, dataDrivenPriorMatrices = NULL, canonicalPriorMatrices = TRUE, S0 = NULL, w0 = NULL, V = NULL, covY = NULL, - retainFit = FALSE, ...) { + retainFit = FALSE, fitDetail = c("slim", "full"), + ...) { if (!requireNamespace("mr.mashr", quietly = TRUE)) { stop("Package 'mr.mashr' is required. ", "Install with: devtools::install_github('stephenslab/mr.mash.alpha')") @@ -761,7 +779,18 @@ mrmashRssWeights <- function(stat, LD, mrmashRssFit = NULL, } # coef.mr.mash.rss returns nrow(Bhat) rows (no intercept). Do not strip. weights <- mr.mashr::coef.mr.mash.rss(mrmashRssFit) - if (retainFit) attr(weights, "fit") <- mrmashRssFit + if (retainFit) { + fitDetail <- match.arg(fitDetail) + # Mirror mrmashWeights(): the slim payload carries only the mvSuSiE-prior + # reconstruction inputs (the coefficients are already `weights`, so mu1 is + # not duplicated); "full" additionally keeps the whole mr.mash.rss fit. + fitList <- list( + dataDrivenPriorMatrices = dataDrivenPriorMatrices, + w0 = mrmashRssFit$w0, + V = mrmashRssFit$V) + if (fitDetail == "full") fitList$fit <- mrmashRssFit + attr(weights, "fit") <- fitList + } weights } diff --git a/R/twasWeights.R b/R/twasWeights.R index a04b45e4..3ba87a29 100644 --- a/R/twasWeights.R +++ b/R/twasWeights.R @@ -882,6 +882,7 @@ learnTwasWeights <- function(X, Y, weightMethods, numThreads = 1, fittedModels = NULL, retainFits = FALSE, + retainFitDetail = c("slim", "full"), standardized = FALSE, dataType = NULL, ldSketch = NULL, @@ -889,6 +890,7 @@ learnTwasWeights <- function(X, Y, weightMethods, if (!is.matrix(X) || (!is.matrix(Y) && !is.vector(Y))) { stop("X must be a matrix and Y must be a matrix or a vector.") } + retainFitDetail <- match.arg(retainFitDetail) if (is.vector(Y)) { Y <- matrix(Y, ncol = 1) @@ -940,6 +942,11 @@ learnTwasWeights <- function(X, Y, weightMethods, } else if ("retain_fit" %in% fnFormals) { args$retain_fit <- TRUE } + # Propagate the slim/full payload choice to producers that support it + # (mr.mash individual + RSS), unless the caller already set it per-method. + if ("fitDetail" %in% fnFormals && is.null(args$fitDetail)) { + args$fitDetail <- retainFitDetail + } } methodFit <- NULL diff --git a/R/twasWeightsPipeline.R b/R/twasWeightsPipeline.R index 0f712bc0..a34ced24 100644 --- a/R/twasWeightsPipeline.R +++ b/R/twasWeightsPipeline.R @@ -835,6 +835,8 @@ setMethod("twasWeightsPipeline", "QtlDataset", ensembleSolver = "quadprog", ensembleAlpha = 1, estimatePi = TRUE, + retainFit = TRUE, + retainFitDetail = c("slim", "full"), phenotypeCovariatesToResidualize = NULL, genotypeCovariatesToResidualize = NULL, residualizePhenotypeCovariates = TRUE, @@ -844,6 +846,7 @@ setMethod("twasWeightsPipeline", "QtlDataset", verbose = 1, ...) { naAction <- match.arg(naAction) + retainFitDetail <- match.arg(retainFitDetail) # `cisWindow` expands a trait's own coordinates; `region` is literal. # Supplying both signals a misunderstanding -> reject. if (!is.null(region) && !is.null(cisWindow)) { @@ -872,7 +875,8 @@ setMethod("twasWeightsPipeline", "QtlDataset", if (length(parsedJointSpec) > 0L) { jointResult <- .twasDispatchJointSpecsQtlDataset( parsedJointSpec, data, intersect(norm$tokens, "mrmash"), - contexts, traitId, cisWindow, dataType, verbose, xRegions = xRegions) + contexts, traitId, cisWindow, dataType, verbose, xRegions = xRegions, + retainFit = retainFit, retainFitDetail = retainFitDetail) drop <- intersect(norm$tokens, "mrmash") keep <- setdiff(norm$tokens, drop) if (length(keep) == 0L) { @@ -1108,7 +1112,10 @@ setMethod("twasWeightsPipeline", "QtlDataset", # Retain the mr.mash fit parts ({dataDrivenPriorMatrices, w0, V}) on # the entry's `fits` slot so fineMappingPipeline can rebuild the # mvSuSiE reweighted prior + residual variance from this shared fit. + # `retainFitDetail` selects the slim payload (default) or the full + # mr.mash fit. retainFits = TRUE, + retainFitDetail = retainFitDetail, fittedModels = jointFits, cvFolds = cvFolds, samplePartition = samplePartition, @@ -1165,9 +1172,12 @@ setMethod("twasWeightsPipeline", "QtlSumStats", jointSpecification = NULL, fineMappingResult = NULL, twasWeights = NULL, + retainFit = TRUE, + retainFitDetail = c("slim", "full"), dataType = NULL, verbose = 1L, ...) { + retainFitDetail <- match.arg(retainFitDetail) # summaryStatsQc() is mandatory before twasWeightsPipeline for SumStats # input; it also drops variants not present in the ldSketch, so by the # time we reach this method every entry's SNP set is a subset of the @@ -1200,7 +1210,8 @@ setMethod("twasWeightsPipeline", "QtlSumStats", if (length(parsedJointSpec) > 0L) { jointResult <- .twasDispatchJointSpecsQtlSumStats( parsedJointSpec, data, intersect(tokens, "mrmash"), - contexts, traitId, dataType, verbose) + contexts, traitId, dataType, verbose, + retainFit = retainFit, retainFitDetail = retainFitDetail) keep <- setdiff(tokens, "mrmash") if (length(keep) == 0L) { if (is.null(jointResult)) @@ -1374,6 +1385,15 @@ setMethod("twasWeightsPipeline", "QtlSumStats", else .twasMethodCapabilities[[tk]]$sumstatImpl userArgs <- methodArgs[[tk]] if (is.null(userArgs)) userArgs <- list() + # mr.mash (no fine-mapping adapter) is the producer of the mvSuSiE + # data-driven prior: retain its (slim by default) fit so a downstream + # mvsusie_rss fineMappingPipeline run can rebuild the reweighted prior. + # Mirrors the individual-level path, which hardcodes retainFits = TRUE. + # Respect an explicit caller override of either knob. + if (is.null(adapter) && tk == "mrmash") { + if (is.null(userArgs$retainFit)) userArgs$retainFit <- TRUE + if (is.null(userArgs$fitDetail)) userArgs$fitDetail <- retainFitDetail + } # mvsusie is fine-mapping; thread the pre-fit through. mr.mash is # not, so this branch only fires for mvsusie. if (!is.null(adapter)) { @@ -1463,6 +1483,8 @@ setMethod("twasWeightsPipeline", "MultiStudyQtlDataset", jointSpecification = NULL, fineMappingResult = NULL, twasWeights = NULL, + retainFit = TRUE, + retainFitDetail = c("slim", "full"), naAction = c("drop", "impute"), verbose = 1, phenotypeCovariatesToResidualize = NULL, @@ -1471,6 +1493,7 @@ setMethod("twasWeightsPipeline", "MultiStudyQtlDataset", residualizeGenotypeCovariates = TRUE, ...) { naAction <- match.arg(naAction) + retainFitDetail <- match.arg(retainFitDetail) if (!is.null(region) && !is.null(cisWindow)) { stop("twasWeightsPipeline(MultiStudyQtlDataset): specify either ", "`region` or `cisWindow`, not both.") @@ -1498,7 +1521,8 @@ setMethod("twasWeightsPipeline", "MultiStudyQtlDataset", names(methods)), "mrmash") jointResult <- .twasDispatchJointSpecsMultiStudy( parsedJointSpec, data, jointMethods, - contexts, traitId, cisWindow, NULL, verbose, xRegions = xRegions) + contexts, traitId, cisWindow, NULL, verbose, xRegions = xRegions, + retainFit = retainFit, retainFitDetail = retainFitDetail) # Strip mrmash from the methods passed to the per-component recursion. if (is.character(methods)) methods <- setdiff(methods, "mrmash") else if (is.list(methods)) { @@ -1636,7 +1660,9 @@ setMethod("twasWeightsPipeline", "ANY", dataType = NULL, ldSketch = NULL, retainFits = FALSE, + retainFitDetail = c("slim", "full"), verbose = 1) { + retainFitDetail <- match.arg(retainFitDetail) if (is.character(weightMethods)) { weightMethods <- .twasMethodLookup(weightMethods) } @@ -1679,7 +1705,7 @@ setMethod("twasWeightsPipeline", "ANY", learnArgs <- list( study = study, context = context, trait = trait, standardized = standardized, dataType = dataType, - ldSketch = ldSketch) + ldSketch = ldSketch, retainFitDetail = retainFitDetail) if (needsPiEstimation) { # Run mr.ash first to estimate sparsity diff --git a/man/causalInferencePipeline.Rd b/man/causalInferencePipeline.Rd index d6403454..b48bef0b 100644 --- a/man/causalInferencePipeline.Rd +++ b/man/causalInferencePipeline.Rd @@ -8,9 +8,14 @@ causalInferencePipeline( gwasSumStats, twasWeights = NULL, fineMappingResult = NULL, + rsqCutoff = 0, + rsqPvalCutoff = Inf, + rsqOption = "rsq", + rsqPvalOption = c("adj_rsq_pval", "pval"), mrPipCutoff = 0.5, mrMethod = c("ivwPerVariant", "csAware"), mrCpipCutoff = 0.5, + mrPvalCutoff = 1, combineMethods = NULL, ... ) @@ -28,9 +33,37 @@ When supplied, drives the MR computation and (when \code{twasWeights = NULL}) the TWAS-Z weights via the SuSiE-style coefficients on each entry's \code{topLoci}.} -\item{mrPipCutoff}{Numeric (length 1). PIP threshold for an entry's -\code{topLoci} variant to be used as an instrumental variable. -Used only when \code{mrMethod = "ivwPerVariant"}. Default \code{0.5}.} +\item{rsqCutoff}{Numeric (length 1). When \code{> 0}, performs CV weight +selection (ports the legacy \code{twas_pipeline} \code{pick_best_model} + +\code{update_twas_method}): per \code{(study, context, trait, gwasStudy)} +keep only the method whose \code{cvPerformance} \code{rsqOption} metric is +highest among methods that clear both \code{rsqCutoff} and the +\code{rsqPvalCutoff} gate AND that produced a finite TWAS Z (the NA/Inf +re-selection); groups where no method clears the cutoffs are dropped. A +group whose methods carry no usable \code{cvPerformance} (the SS-TWAS +path) keeps all methods. Needs the \code{twasWeights} \code{cvPerformance}, +so selection is a no-op on the fineMappingResult-only path. Default +\code{0} (no selection; score every method).} + +\item{rsqPvalCutoff}{Numeric (length 1). CV-p-value gate for weight +selection (ports legacy \code{rsq_pval_cutoff}): a method is eligible only +when its \code{cvPerformance} \code{rsqPvalOption} metric is +\code{< rsqPvalCutoff}. Default \code{Inf} (no p-value gate). A finite +value activates selection even when \code{rsqCutoff = 0}.} + +\item{rsqOption}{Character. Which \code{cvPerformance} metric is the +"r-squared" used for the cutoff and ranking (ports legacy +\code{rsq_option}); typically \code{"rsq"} or \code{"adj_rsq"}. +Default \code{"rsq"}.} + +\item{rsqPvalOption}{Character vector of candidate \code{cvPerformance} +metric names for the p-value gate (ports legacy \code{rsq_pval_option}); +the first one present in a tuple's metrics is used. Default +\code{c("adj_rsq_pval", "pval")}.} + +\item{mrPipCutoff}{Numeric (length 1). PIP threshold for a \code{topLoci} +variant to be used as an instrumental variable. Used only when +\code{mrMethod = "ivwPerVariant"}. Default \code{0.5}.} \item{mrMethod}{One of \code{"ivwPerVariant"} (default) or \code{"csAware"}. The IVW-per-variant method filters topLoci @@ -45,6 +78,13 @@ Cochran's Q + I-squared in the output columns \code{Q}, \code{I2}.} retaining a credible set. Used only when \code{mrMethod = "csAware"}. Default \code{0.5}.} +\item{mrPvalCutoff}{Numeric (length 1). TWAS-p-value gate for running MR +(ports the legacy \code{twas_pipeline} \code{mr_pval_cutoff}): MR is +computed for a \code{(qtl tuple, gwas)} only when its \code{twasPval < +mrPvalCutoff}; otherwise the MR output columns are \code{NA}. Default +\code{1} (no gate; MR runs wherever a \code{fineMappingResult} entry +exists).} + \item{combineMethods}{Optional character vector forwarded to \code{\link{combinePValues}} for cross-method combination per \code{(qtlStudy, context, trait, gwasStudy)} group. \code{NULL} diff --git a/man/estCtwasParam.Rd b/man/estCtwasParam.Rd index 3c1142f3..c0c1f284 100644 --- a/man/estCtwasParam.Rd +++ b/man/estCtwasParam.Rd @@ -27,11 +27,15 @@ estCtwasParam( \item{ncore}{Number of cores.} \item{fallbackToPrefit}{Logical (length 1). When \code{TRUE} (default -\code{FALSE}), if \code{ctwas::est_param}'s accurate EM diverges to -NaN and throws \code{"Estimated group_prior(_var)? contains NAs"}, -re-run only the prefit step via \code{ctwas:::fit_EM} and return -those (typically finite) priors as the param. Mirrors the legacy -ctwas_2 workaround on toy data where the accurate EM saturates.} +\code{FALSE}), if \code{ctwas::est_param}'s accurate EM fails for ANY +reason on a degenerate input, re-run only the prefit step via +\code{ctwas:::fit_EM} and return those (typically finite) priors as the +param. The accurate-EM failure mode is version-dependent (ctwas <= 0.4.x: +\code{"contains NAs"}; ctwas >= 0.6.0: \code{"No regions selected!"} or a +NaN-loglik \code{"missing value where TRUE/FALSE needed"}), so the catch is +deliberately broad; a genuinely broken input still surfaces because the +prefit re-run will itself error. Mirrors the legacy ctwas_2 workaround on +toy data where the accurate EM cannot be estimated.} \item{...}{Additional arguments forwarded to \code{ctwas::est_param} (e.g. \code{min_p_single_effect}, \code{min_group_size}).} diff --git a/man/fineMappingPipeline.Rd b/man/fineMappingPipeline.Rd index 0ed6530b..dc1f1a77 100644 --- a/man/fineMappingPipeline.Rd +++ b/man/fineMappingPipeline.Rd @@ -30,7 +30,11 @@ fineMappingPipeline(data, ...) cvFolds = 0, samplePartition = NULL, pipCutoffToSkip = 0, + usePCA = FALSE, + nPCs = 10L, seed = NULL, + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10, naAction = c("drop", "impute"), verbose = 1, trim = TRUE, @@ -57,6 +61,8 @@ fineMappingPipeline(data, ...) minAbsCorr = 0.8, medianAbsCorr = NULL, fineMappingResult = NULL, + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10, cvFolds = 0, samplePartition = NULL, pipCutoffToSkip = 0, @@ -84,6 +90,8 @@ fineMappingPipeline(data, ...) minAbsCorr = 0.8, medianAbsCorr = NULL, fineMappingResult = NULL, + twasWeights = NULL, + dataDrivenPriorWeightsCutoff = 1e-10, verbose = 1, trim = TRUE, ... @@ -211,6 +219,19 @@ summary-statistics analog lives in \code{summaryStatsQc()}. \code{0} (default) disables the screen; a negative value uses the adaptive \code{3 / nVariants} threshold.} +\item{usePCA}{Logical (length 1). \code{QtlDataset} only. When +\code{TRUE} (default \code{FALSE}), each multi-trait context's +PCA-reduced phenotype is fine-mapped with univariate SuSiE on its +top principal components (ports the legacy \code{fsusie.R} +\code{susie_on_top_pc}). Each PC becomes a pseudo-trait row keyed +\code{trait = "topPC\{i\}"}, \code{method = "susie"}. Single-trait +contexts have no PCA and are skipped.} + +\item{nPCs}{Integer (length 1). \code{QtlDataset} only. Caps the +number of top principal components fine-mapped per context when +\code{usePCA = TRUE} (default \code{10}). The effective count is +\code{min(nPCs, usable traits)}.} + \item{seed}{Optional integer. When non-NULL, \code{set.seed(seed)} is called once at the start of the call for reproducible fits. Default \code{NULL} (no seeding).} diff --git a/man/learnTwasWeights.Rd b/man/learnTwasWeights.Rd index 46360a14..8a5ce8f3 100644 --- a/man/learnTwasWeights.Rd +++ b/man/learnTwasWeights.Rd @@ -14,6 +14,7 @@ learnTwasWeights( numThreads = 1, fittedModels = NULL, retainFits = FALSE, + retainFitDetail = c("slim", "full"), standardized = FALSE, dataType = NULL, ldSketch = NULL, diff --git a/man/mergeCtwasBoundaryRegions.Rd b/man/mergeCtwasBoundaryRegions.Rd new file mode 100644 index 00000000..43f34baa --- /dev/null +++ b/man/mergeCtwasBoundaryRegions.Rd @@ -0,0 +1,62 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/ctwasPipeline.R +\name{mergeCtwasBoundaryRegions} +\alias{mergeCtwasBoundaryRegions} +\title{Merge boundary cTWAS regions and re-fine-map} +\usage{ +mergeCtwasBoundaryRegions( + finemapResult, + pipThresh = 0.5, + filterCs = FALSE, + maxSNP = Inf, + L = 5L, + ncore = 1L, + ... +) +} +\arguments{ +\item{finemapResult}{A list returned by \code{\link{finemapCtwasRegions}}. +Must carry \code{finemap_res}, \code{susie_alpha_res}, +\code{region_data}, \code{region_info}, \code{z_snp}, \code{z_gene}, +\code{weights}, \code{snp_map}, \code{param}, and — on the LD path — +\code{LD_map} plus the \code{LD_loader_fun} / \code{snpinfo_loader_fun} +closures (all retained by \code{finemapCtwasRegions}).} + +\item{pipThresh}{Numeric (length 1). PIP threshold for selecting which +boundary genes to merge (\code{select_boundary_genes} \code{pip_thresh}). +Default \code{0.5}.} + +\item{filterCs}{Logical (length 1). Require the gene to be in a credible set +to be selected (\code{select_boundary_genes} \code{filter_cs}). Default +\code{FALSE}.} + +\item{maxSNP}{Numeric (length 1). Per-merged-region SNP cap. Default +\code{Inf}.} + +\item{L}{Integer. Max number of single effects for the merged-region +re-fine-mapping (LD path only). Default \code{5}.} + +\item{ncore}{Number of cores. Default \code{1}.} + +\item{...}{Forwarded to the underlying ctwas postprocess function.} +} +\value{ +The \code{finemapResult} list with \code{finemap_res}, + \code{susie_alpha_res}, \code{region_data}, \code{region_info}, + \code{LD_map}, and \code{snp_map} replaced by the post-merge ("updated") + values, plus a \code{merge_res} element carrying the full ctwas postprocess + output. When no boundary gene clears \code{pipThresh}, ctwas returns the + inputs as the "updated" values, so the result is effectively unchanged. +} +\description{ +Optional step 4 of the cTWAS pipeline (default-off region + merging). A gene whose cis window straddles an LD-block boundary + (a \code{boundary_genes} member) is split across two regions in the + first-pass fine-mapping. This step selects the high-PIP boundary genes, + merges each one's adjacent regions into a single region, re-runs + fine-mapping on the merged regions, and splices the updated results back + into the \code{\link{finemapCtwasRegions}} output. Thin wrapper over + \code{ctwas::postprocess_region_merging()} (or + \code{ctwas::postprocess_region_merging_noLD()} when the inputs carry no + LD loaders). +} diff --git a/man/mrmashRssWeights.Rd b/man/mrmashRssWeights.Rd index 980d206d..da35b432 100644 --- a/man/mrmashRssWeights.Rd +++ b/man/mrmashRssWeights.Rd @@ -15,6 +15,7 @@ mrmashRssWeights( V = NULL, covY = NULL, retainFit = FALSE, + fitDetail = c("slim", "full"), ... ) } @@ -49,8 +50,16 @@ defaults to the identity matrix of size K.} \item{covY}{Optional response covariance matrix (K x K). When NULL, defaults to the identity matrix of size K.} -\item{retainFit}{If TRUE, attaches the fitted object as the -\code{"fit"} attribute on the returned weights.} +\item{retainFit}{If TRUE, attaches (as the \code{"fit"} attribute on the +returned weights) the inputs \code{fineMappingPipeline} needs to rebuild +the mvSuSiE reweighted prior: \code{dataDrivenPriorMatrices}, the fitted +\code{w0}, and the fitted \code{V}. Default FALSE.} + +\item{fitDetail}{How much to retain when \code{retainFit = TRUE}. +\code{"slim"} (default) keeps only those reconstruction inputs (the +coefficients are already the returned weights); \code{"full"} additionally +keeps the complete \code{mr.mash.rss} fit under \code{$fit}. Mirrors +\code{\link{mrmashWeights}}.} \item{...}{Additional arguments forwarded to \code{mr.mashr::mr.mash.rss}.} diff --git a/man/mrmashWeights.Rd b/man/mrmashWeights.Rd index 60714e9c..bd424d71 100644 --- a/man/mrmashWeights.Rd +++ b/man/mrmashWeights.Rd @@ -4,7 +4,14 @@ \alias{mrmashWeights} \title{Compute mr.mash TWAS weights} \usage{ -mrmashWeights(mrmashFit = NULL, X = NULL, Y = NULL, retainFit = FALSE, ...) +mrmashWeights( + mrmashFit = NULL, + X = NULL, + Y = NULL, + retainFit = FALSE, + fitDetail = c("slim", "full"), + ... +) } \arguments{ \item{mrmashFit}{Optional fitted mr.mash object.} @@ -17,8 +24,14 @@ mrmashWeights(mrmashFit = NULL, X = NULL, Y = NULL, retainFit = FALSE, ...) weights) the parts of the mr.mash fit that `fineMappingPipeline` needs to rebuild the mvSuSiE reweighted mixture prior + residual variance: the original data-driven prior matrices (`dataDrivenPriorMatrices`), the fitted -mixture weights (`w0`) and the residual covariance (`V`). The heavy -coefficient matrix (`mu1`) is intentionally not retained. Default FALSE.} +mixture weights (`w0`) and the residual covariance (`V`). Default FALSE.} + +\item{fitDetail}{How much of the fit to retain when `retainFit = TRUE`. +`"slim"` (default) keeps only the three reconstruction inputs above; the +heavy coefficient matrix (`mu1`) is already returned as the weights, so it +is not duplicated. `"full"` additionally retains the complete mr.mash fit +under `$fit` (consistent with how susie fits are kept), at the cost of a +larger payload.} \item{...}{Additional arguments passed to `mrmashWrapper()` when fitting.} } diff --git a/man/qtlEnrichment.Rd b/man/qtlEnrichment.Rd index f1b39997..98365e71 100644 --- a/man/qtlEnrichment.Rd +++ b/man/qtlEnrichment.Rd @@ -14,7 +14,8 @@ qtlEnrichment( doubleShrinkage = FALSE, besselCorrection = TRUE, numThreads = 1, - verbose = TRUE + verbose = TRUE, + alignNames = TRUE ) } \arguments{ @@ -33,6 +34,13 @@ When it is set to 0, no shrinkage will be applied. A large value indicates stron \item{impN}{Rounds of multiple imputation to draw QTL from, default is 25.} \item{numThreads}{Number of Simultaneous running CPU threads for multiple imputation, default is 1.} + +\item{alignNames}{Logical; when TRUE (default) QTL pip names are aligned to +the GWAS variant-naming convention via \code{alignVariantNames}. Set FALSE +when the caller has already aligned them (e.g. \code{qtlEnrichmentPipeline} +aligns each QTL tuple once against the union GWAS panel rather than +re-aligning per GWAS study); only the cheap per-study unmatched set is then +recomputed, skipping the costly \code{.matchRefPanel} pass.} } \value{ A list of enrichment parameter estimates diff --git a/man/twasWeightsPipeline.Rd b/man/twasWeightsPipeline.Rd index 9c44b0d6..4155ea34 100644 --- a/man/twasWeightsPipeline.Rd +++ b/man/twasWeightsPipeline.Rd @@ -33,6 +33,8 @@ twasWeightsPipeline(data, ...) ensembleSolver = "quadprog", ensembleAlpha = 1, estimatePi = TRUE, + retainFit = TRUE, + retainFitDetail = c("slim", "full"), phenotypeCovariatesToResidualize = NULL, genotypeCovariatesToResidualize = NULL, residualizePhenotypeCovariates = TRUE, @@ -51,6 +53,8 @@ twasWeightsPipeline(data, ...) jointSpecification = NULL, fineMappingResult = NULL, twasWeights = NULL, + retainFit = TRUE, + retainFitDetail = c("slim", "full"), dataType = NULL, verbose = 1L, ... @@ -67,6 +71,8 @@ twasWeightsPipeline(data, ...) jointSpecification = NULL, fineMappingResult = NULL, twasWeights = NULL, + retainFit = TRUE, + retainFitDetail = c("slim", "full"), naAction = c("drop", "impute"), verbose = 1, phenotypeCovariatesToResidualize = NULL, diff --git a/tests/testthat/test_causalInferencePipeline.R b/tests/testthat/test_causalInferencePipeline.R index 6c7b2b85..14a85a80 100644 --- a/tests/testthat/test_causalInferencePipeline.R +++ b/tests/testthat/test_causalInferencePipeline.R @@ -273,6 +273,125 @@ test_that(".cipZToSe: falls back to vector of 1 when maf/n are NA", { expect_equal(res, c(1, 1)) }) +test_that(".cipFilterEligibleMethods: rsq+pval gating, drop sub-cutoff groups, SS-TWAS keeps all", { + mkEntry <- function(rsq, pval = 0.01) TwasWeightsEntry( + variantIds = paste0("v", 1:3), weights = rep(0.1, 3), + cvPerformance = list(metrics = c(corr = 0.1, rsq = rsq, pval = pval))) + tw <- TwasWeights( + study = rep("S", 4), + context = rep("c1", 4), + trait = c("G", "G", "G", "G2"), + method = c("susie", "enet", "lasso", "susie"), + entry = list(mkEntry(0.20), mkEntry(0.05), mkEntry(0.50), mkEntry(0.01))) + qtlRows <- pecotmr:::.cipBuildQtlWorkList(tw, NULL) + mt <- pecotmr:::.cipMethodMetrics(qtlRows, tw, "rsq", c("adj_rsq_pval", "pval")) + # rsq gate only: G keeps susie(.20)+lasso(.50); enet(.05) out; G2(.01) dropped. + f1 <- pecotmr:::.cipFilterEligibleMethods(qtlRows, mt, rsqCutoff = 0.1, + rsqPvalCutoff = Inf) + expect_equal(sort(f1$method), c("lasso", "susie")) + expect_true(all(f1$trait == "G")) + # rsqCutoff above every method -> empty work-list. + expect_equal(nrow(pecotmr:::.cipFilterEligibleMethods(qtlRows, mt, 0.99, Inf)), 0L) + # pval gate: a high-rsq method with a bad CV p-value is excluded. + tw2 <- TwasWeights( + study = rep("S", 2), context = rep("c1", 2), trait = rep("G", 2), + method = c("susie", "lasso"), + entry = list(mkEntry(0.20, pval = 0.20), mkEntry(0.50, pval = 0.01))) + q2 <- pecotmr:::.cipBuildQtlWorkList(tw2, NULL) + m2 <- pecotmr:::.cipMethodMetrics(q2, tw2, "rsq", c("adj_rsq_pval", "pval")) + f2 <- pecotmr:::.cipFilterEligibleMethods(q2, m2, rsqCutoff = 0.1, + rsqPvalCutoff = 0.05) + expect_equal(f2$method, "lasso") + # SS-TWAS: no usable cvPerformance -> keep all methods in the group. + twss <- TwasWeights( + study = rep("S", 2), context = rep("c1", 2), trait = rep("G", 2), + method = c("susie", "lasso"), + entry = list( + TwasWeightsEntry(variantIds = paste0("v", 1:3), weights = rep(0.1, 3)), + TwasWeightsEntry(variantIds = paste0("v", 1:3), weights = rep(0.1, 3)))) + qss <- pecotmr:::.cipBuildQtlWorkList(twss, NULL) + mss <- pecotmr:::.cipMethodMetrics(qss, twss, "rsq", c("adj_rsq_pval", "pval")) + expect_equal( + nrow(pecotmr:::.cipFilterEligibleMethods(qss, mss, 0.1, Inf)), 2L) +}) + +test_that(".cipSelectBestMethod: max-rsq finite Z, NA/Inf re-selection, SS-TWAS keeps all", { + lk <- c("S\rc1\rG\rsusie" = 0.5, "S\rc1\rG\rlasso" = 0.2) + df <- data.frame( + qtlStudy = "S", context = "c1", trait = "G", gwasStudy = "X", + method = c("susie", "lasso"), twasZ = c(2.0, 1.0), + stringsAsFactors = FALSE) + expect_equal(pecotmr:::.cipSelectBestMethod(df, lk)$method, "susie") + # top-rsq method has NA Z -> fall back to next-best (lasso). + df2 <- df; df2$twasZ <- c(NA_real_, 1.0) + expect_equal(pecotmr:::.cipSelectBestMethod(df2, lk)$method, "lasso") + # none finite -> keep the top-rsq method anyway. + df3 <- df; df3$twasZ <- c(NA_real_, Inf) + expect_equal(pecotmr:::.cipSelectBestMethod(df3, lk)$method, "susie") + # SS-TWAS group (rsq lookup all NA) -> keep all rows. + lkNA <- c("S\rc1\rG\rsusie" = NA_real_, "S\rc1\rG\rlasso" = NA_real_) + expect_equal(nrow(pecotmr:::.cipSelectBestMethod(df, lkNA)), 2L) +}) + +test_that("causalInferencePipeline: rsqCutoff selects the max-rsq method per group", { + tw <- TwasWeights( + study = rep("Q1", 2), context = rep("c1", 2), trait = rep("t1", 2), + method = c("susie", "lasso"), + entry = list( + TwasWeightsEntry(variantIds = paste0("v", 1:5), + weights = c(0.1, 0.05, -0.2, 0.3, 0.0), + cvPerformance = list(metrics = c(rsq = 0.2, pval = 0.001))), + TwasWeightsEntry(variantIds = paste0("v", 1:5), + weights = c(0.2, 0.1, -0.1, 0.2, 0.1), + cvPerformance = list(metrics = c(rsq = 0.5, pval = 0.001)))), + ldSketch = .cip_makeHandle()) + local_mocked_bindings(extractBlockGenotypes = .cip_mockExtractor(), + .package = "pecotmr") + out <- causalInferencePipeline(gwasSumStats = .cip_makeGwasSumstats(), + twasWeights = tw, rsqCutoff = 0.1) + expect_equal(as.character(S4Vectors::mcols(out)$method), "lasso") +}) + +test_that("causalInferencePipeline: NA/Inf TWAS-Z triggers method re-selection", { + tw <- TwasWeights( + study = rep("Q1", 2), context = rep("c1", 2), trait = rep("t1", 2), + method = c("susie", "lasso"), + entry = list( + # top rsq but all-zero weights -> wᵀRw = 0 -> twasZ NaN + TwasWeightsEntry(variantIds = paste0("v", 1:5), weights = rep(0, 5), + cvPerformance = list(metrics = c(rsq = 0.9, pval = 0.001))), + TwasWeightsEntry(variantIds = paste0("v", 1:5), + weights = c(0.2, 0.1, -0.1, 0.2, 0.1), + cvPerformance = list(metrics = c(rsq = 0.5, pval = 0.001)))), + ldSketch = .cip_makeHandle()) + local_mocked_bindings(extractBlockGenotypes = .cip_mockExtractor(), + .package = "pecotmr") + out <- causalInferencePipeline(gwasSumStats = .cip_makeGwasSumstats(), + twasWeights = tw, rsqCutoff = 0.1) + expect_equal(as.character(S4Vectors::mcols(out)$method), "lasso") + expect_true(is.finite(S4Vectors::mcols(out)$twasZ)) +}) + +test_that("causalInferencePipeline: rsqPvalCutoff gates out high-CV-pval methods", { + tw <- TwasWeights( + study = rep("Q1", 2), context = rep("c1", 2), trait = rep("t1", 2), + method = c("susie", "lasso"), + entry = list( + TwasWeightsEntry(variantIds = paste0("v", 1:5), + weights = c(0.2, 0.1, -0.1, 0.2, 0.1), + cvPerformance = list(metrics = c(rsq = 0.9, pval = 0.5))), + TwasWeightsEntry(variantIds = paste0("v", 1:5), + weights = c(0.2, 0.1, -0.1, 0.2, 0.1), + cvPerformance = list(metrics = c(rsq = 0.5, pval = 0.001)))), + ldSketch = .cip_makeHandle()) + local_mocked_bindings(extractBlockGenotypes = .cip_mockExtractor(), + .package = "pecotmr") + out <- causalInferencePipeline(gwasSumStats = .cip_makeGwasSumstats(), + twasWeights = tw, rsqCutoff = 0.1, + rsqPvalCutoff = 0.05) + expect_equal(as.character(S4Vectors::mcols(out)$method), "lasso") +}) + context("twas: twasZ and harmonize deprecated wrappers") diff --git a/tests/testthat/test_ctwasPipeline.R b/tests/testthat/test_ctwasPipeline.R index 072e07cb..ba36f65a 100644 --- a/tests/testthat/test_ctwasPipeline.R +++ b/tests/testthat/test_ctwasPipeline.R @@ -274,19 +274,44 @@ test_that(".ctwasBuildZSnp: produces a flat data.frame keyed by SNP/study", { expect_setequal(unique(df$study), "G1") }) -test_that(".ctwasBuildSingleRegionInfo: pulls chrom + bp span from the ldSketch", { - ri <- pecotmr:::.ctwasBuildSingleRegionInfo("block1", .ctp_makeHandle()) +test_that(".ctwasBuildSingleRegionInfo: pulls chrom + bp span from the GWAS block entry", { + # Bounds come from the block's GWAS variants (the GwasSumStats entry), NOT + # the LD sketch — many blocks can share one whole-chromosome LD payload. + ri <- pecotmr:::.ctwasBuildSingleRegionInfo("block1", .ctp_makeGwasSumstats()) expect_equal(ri$region_id, "block1") expect_equal(ri$chrom, 1L) expect_equal(ri$start, 100L) expect_equal(ri$stop, 600L) }) -test_that(".ctwasBuildSingleRegionInfo: multi-chromosome sketch errors", { - h <- .ctp_makeHandle() - h@snpInfo$CHR[1:3] <- "2" +test_that(".ctwasBuildSingleRegionInfo: uses the block entry span, not the wider shared LD sketch", { + # Regression: many LD blocks can share one whole-chromosome LD payload, so + # the sketch span (here BP 100-600) is NOT the block's span. The entry here + # covers only 200-400; region bounds must follow the entry, otherwise every + # block collapses to the whole-chromosome span and every SNP is assigned to + # every region (inflating SNP group_size and crushing the gene PIP). + gr <- GenomicRanges::GRanges( + seqnames = "chr1", + ranges = IRanges::IRanges(start = c(200L, 300L, 400L), width = 1L)) + S4Vectors::mcols(gr) <- S4Vectors::DataFrame( + SNP = c("a", "b", "c"), A1 = "A", A2 = "G", Z = 0, N = 1000L) + gss <- GwasSumStats(study = "G1", entry = list(gr), genome = "hg19", + ldSketch = .ctp_makeHandle(), qcInfo = list(step1 = "ok")) + ri <- pecotmr:::.ctwasBuildSingleRegionInfo("blockX", gss) + expect_equal(ri$start, 200L) # entry min, not sketch min (100) + expect_equal(ri$stop, 400L) # entry max, not sketch max (600) +}) + +test_that(".ctwasBuildSingleRegionInfo: multi-chromosome block entry errors", { + gr <- GenomicRanges::GRanges( + seqnames = c("chr1", "chr1", "chr2"), + ranges = IRanges::IRanges(start = c(100L, 200L, 300L), width = 1L)) + S4Vectors::mcols(gr) <- S4Vectors::DataFrame( + SNP = c("a", "b", "c"), A1 = "A", A2 = "G", Z = 0, N = 1000L) + gss <- GwasSumStats(study = "G1", entry = list(gr), genome = "hg19", + ldSketch = .ctp_makeHandle(), qcInfo = list(step1 = "ok")) expect_error( - pecotmr:::.ctwasBuildSingleRegionInfo("block1", h), + pecotmr:::.ctwasBuildSingleRegionInfo("block1", gss), "spans multiple chromosomes" ) }) @@ -781,7 +806,9 @@ test_that("ctwasPipeline: dispatches assemble → est → screen → finemap and expect_setequal( names(out), c("z_gene", "param", "finemap_res", "susie_alpha_res", - "region_data", "boundary_genes", "screen_res")) + "region_data", "boundary_genes", "screen_res", + "region_info", "z_snp", "weights", "snp_map", + "LD_map", "LD_loader_fun", "snpinfo_loader_fun")) }) test_that("estCtwasParam: fallbackToPrefit recovers from accurate-EM NaN divergence", { @@ -856,7 +883,9 @@ test_that("estCtwasParam / screenCtwasRegions / finemapCtwasRegions can be calle expect_setequal( names(final), c("z_gene", "param", "finemap_res", "susie_alpha_res", - "region_data", "boundary_genes", "screen_res")) + "region_data", "boundary_genes", "screen_res", + "region_info", "z_snp", "weights", "snp_map", + "LD_map", "LD_loader_fun", "snpinfo_loader_fun")) }) # =========================================================================== @@ -902,7 +931,9 @@ test_that("ctwasPipeline: real-engine end-to-end on the bundled example panel", # ctwas_sumstats returns these 7 elements on success. expect_named(res, c("z_gene", "param", "finemap_res", "susie_alpha_res", - "region_data", "boundary_genes", "screen_res"), + "region_data", "boundary_genes", "screen_res", + "region_info", "z_snp", "weights", "snp_map", + "LD_map", "LD_loader_fun", "snpinfo_loader_fun"), ignore.order = TRUE) # The gene we passed in came through. expect_true(any(grepl("study1\\|brain\\|ENSG_example\\|susie", res$z_gene$id))) @@ -1020,3 +1051,71 @@ test_that(".ctwasBuildWeights: twasWeightCutoff drops low-magnitude variants", { expect_equal(wl[[1L]]$n_wgt, 3L) expect_setequal(rownames(wl[[1L]]$wgt), vids[c(2L, 4L, 5L)]) }) + +# =========================================================================== +# mergeCtwasBoundaryRegions (step 4: boundary-gene region merging) +# =========================================================================== + +test_that("mergeCtwasBoundaryRegions: no first-pass finemap_res returns unchanged", { + fmr <- list(finemap_res = NULL, region_data = "rd") + expect_identical(mergeCtwasBoundaryRegions(fmr), fmr) + fmr0 <- list(finemap_res = data.frame()[0, ], region_data = "rd") + expect_identical(mergeCtwasBoundaryRegions(fmr0), fmr0) +}) + +test_that("mergeCtwasBoundaryRegions: LD path forwards carried state + splices updated_*", { + captured <- NULL + local_mocked_bindings( + postprocess_region_merging = function(...) { + captured <<- list(...) + list(updated_finemap_res = data.frame(id = "g", susie_pip = 0.9), + updated_susie_alpha_res = "ua_new", + updated_region_data = "rd_new", + updated_region_info = "ri_new", + updated_LD_map = "ld_new", + updated_snp_map = "sm_new", + selected_boundary_genes = data.frame(id = "g")) + }, + .package = "ctwas") + fmr <- list( + finemap_res = data.frame(id = "g", type = "gene", susie_pip = 0.6), + susie_alpha_res = "ua0", region_data = "rd", region_info = "ri", + z_snp = "zs", z_gene = data.frame(id = "g"), weights = "w", snp_map = "sm", + LD_map = "ld", LD_loader_fun = function() NULL, + snpinfo_loader_fun = function() NULL, + param = list(group_prior = 0.1, group_prior_var = 5)) + out <- mergeCtwasBoundaryRegions(fmr, pipThresh = 0.5, maxSNP = 100) + # dispatched to the LD path carrying the loaders + first-pass state + expect_true(all(c("LD_map", "LD_loader_fun", "snpinfo_loader_fun") %in% names(captured))) + expect_equal(captured$pip_thresh, 0.5) + expect_equal(captured$maxSNP, 100) + expect_identical(captured$region_data, "rd") + expect_equal(captured$group_prior, 0.1) + # updated_* spliced back into the result + expect_equal(out$finemap_res$susie_pip, 0.9) + expect_identical(out$region_data, "rd_new") + expect_identical(out$region_info, "ri_new") + expect_identical(out$LD_map, "ld_new") + expect_identical(out$snp_map, "sm_new") + expect_identical(out$susie_alpha_res, "ua_new") + expect_identical(out$merge_res$selected_boundary_genes$id, "g") +}) + +test_that("mergeCtwasBoundaryRegions: no-LD path used when LD loaders are absent", { + called <- NULL + local_mocked_bindings( + postprocess_region_merging_noLD = function(...) { + called <<- "noLD" + list(updated_finemap_res = data.frame(id = "g"), + updated_susie_alpha_res = NULL) + }, + .package = "ctwas") + fmr <- list( + finemap_res = data.frame(id = "g", type = "gene", susie_pip = 0.6), + susie_alpha_res = NULL, region_data = "rd", region_info = "ri", + z_snp = "zs", z_gene = data.frame(id = "g"), weights = "w", snp_map = "sm", + LD_map = NULL, LD_loader_fun = NULL, snpinfo_loader_fun = NULL, + param = list(group_prior = 0.1, group_prior_var = 5)) + out <- mergeCtwasBoundaryRegions(fmr) + expect_equal(called, "noLD") +}) diff --git a/tests/testthat/test_fineMappingPipeline.R b/tests/testthat/test_fineMappingPipeline.R index dbf2a9e7..dfce0a2d 100644 --- a/tests/testthat/test_fineMappingPipeline.R +++ b/tests/testthat/test_fineMappingPipeline.R @@ -547,6 +547,115 @@ test_that(".fmSerScreen: disables on 0, skips no-signal, keeps signal + adaptive expect_true(fn(X, yNull, NA)) # malformed cutoff -> advisory keep }) +test_that(".fmTopPcScores: clean matrix -> samples x min(nPCs, traits) topPC scores", { + set.seed(7) + n <- 30L + Y <- matrix(rnorm(n * 3L), nrow = n, ncol = 3L, + dimnames = list(paste0("s", seq_len(n)), c("ta", "tb", "tc"))) + fn <- function(...) pecotmr:::.fmTopPcScores(...) + # (a) 3 traits, nPCs >= traits -> 3 columns named topPC1..topPC3, rows = samples. + sc <- fn(Y, 10L) + expect_true(is.matrix(sc)) + expect_equal(ncol(sc), 3L) + expect_equal(colnames(sc), c("topPC1", "topPC2", "topPC3")) + expect_equal(nrow(sc), n) + expect_equal(rownames(sc), rownames(Y)) + # (b) nPCs caps the number of returned columns. + sc2 <- fn(Y, 2L) + expect_equal(ncol(sc2), 2L) + expect_equal(colnames(sc2), c("topPC1", "topPC2")) + # (c) single-column Y -> NULL (PCA undefined for a single trait). + expect_null(fn(Y[, 1L, drop = FALSE], 10L)) + # (d) a zero-variance trait is dropped; k reflects only the usable traits. + Yzv <- cbind(Y, td = rep(1, n)) + scz <- fn(Yzv, 10L) + expect_equal(ncol(scz), 3L) # td dropped -> still 3 usable traits + expect_equal(colnames(scz), c("topPC1", "topPC2", "topPC3")) + # (e) rows with any NA are dropped before PCA. + Yna <- Y + Yna[c(1L, 2L), 1L] <- NA + scn <- fn(Yna, 10L) + expect_equal(nrow(scn), n - 2L) + expect_false(any(c("s1", "s2") %in% rownames(scn))) +}) + +test_that("fineMappingPipeline(QtlDataset, usePCA): top-PC susie rows keyed topPC{i}", { + qd <- .fmp_makeQtlDataset(contexts = "brain", + traits = c("ENSG_A", "ENSG_B", "ENSG_C")) + local_mocked_bindings( + extractBlockGenotypes = .fmp_mockExtractor(), + .fmFitSusieIndiv = .fmp_mockFitIndiv(), + .fmPostprocessOne = .fmp_mockPostprocess(), + .package = "pecotmr") + res <- suppressMessages( + fineMappingPipeline(qd, methods = "susie", + cisWindow = 1000L, addSusieInf = FALSE, + usePCA = TRUE, nPCs = 2L)) + expect_s4_class(res, "QtlFineMappingResult") + expect_setequal(getMethodNames(res), "susie") + pcRows <- as.character(res$trait) %in% c("topPC1", "topPC2") + # 3 per-trait univariate susie rows + 2 top-PC rows = 5. + expect_equal(sum(pcRows), 2L) + expect_setequal(as.character(res$trait)[pcRows], c("topPC1", "topPC2")) + expect_setequal(as.character(res$method)[pcRows], "susie") +}) + +test_that(".buildMvsusieReweightedPrior: canonical fallback when no usable fit", { + bp <- function(...) pecotmr:::.buildMvsusieReweightedPrior(...) + # No fit at all -> canonical prior, residualVariance NULL. + p1 <- bp(NULL, c("c1", "c2")) + expect_false(is.null(p1$priorVariance)) + expect_null(p1$residualVariance) + # Fit with no data-driven matrices -> canonical prior, but V carried through. + p2 <- bp(list(dataDrivenPriorMatrices = NULL, V = diag(2)), c("c1", "c2")) + expect_equal(p2$residualVariance, diag(2)) +}) + +test_that(".buildMvsusieReweightedPrior: reweights matrices by rescaleCovW0(w0)", { + ddpm <- list(U = list(compA = diag(2), compB = diag(2) * 2), + w = c(compA = 0.5, compB = 0.5)) + fit <- list(dataDrivenPriorMatrices = ddpm, + w0 = c(compA_grid1 = 0.3, compB_grid1 = 0.7), + V = diag(2) * 3) + captured <- NULL + # rescaleCovW0 collapses expanded w0 onto the original matrix names; mock it + # so the test asserts the wiring, not rescaleCovW0's internals. + local_mocked_bindings( + rescaleCovW0 = function(w0) c(compA = 0.4, compB = 0.6), + .package = "pecotmr") + local_mocked_bindings( + create_mixture_prior = function(...) { captured <<- list(...); "PRIOR" }, + .package = "mvsusieR") + res <- pecotmr:::.buildMvsusieReweightedPrior(fit, c("c1", "c2"), + weightsTol = 1e-8) + expect_identical(res$priorVariance, "PRIOR") + expect_equal(res$residualVariance, diag(2) * 3) + expect_equal(captured$mixture_prior$weights, c(compA = 0.4, compB = 0.6)) + expect_equal(names(captured$mixture_prior$matrices), c("compA", "compB")) + expect_equal(captured$include_indices, c("c1", "c2")) + expect_equal(captured$weights_tol, 1e-8) +}) + +test_that(".fmLookupMrmashFit: finds the mr.mash fit by (study, trait)", { + mkEntry <- function(fits) TwasWeightsEntry( + variantIds = c("v1", "v2"), weights = c(0.1, 0.2), fits = fits) + payload <- list(dataDrivenPriorMatrices = list(U = list(a = diag(2))), + w0 = c(a = 1), V = diag(2)) + # The joint fit lives on the first mrmash row of the (study, trait) group; + # the other context row carries fits = NULL. A non-mrmash row is ignored. + tw <- TwasWeights( + study = c("S", "S", "S"), + context = c("c1", "c2", "c1"), + trait = c("G", "G", "G"), + method = c("mrmash", "mrmash", "enet"), + entry = list(mkEntry(payload), mkEntry(NULL), mkEntry(payload))) + lk <- function(...) pecotmr:::.fmLookupMrmashFit(...) + expect_identical(lk(tw, "S", "G"), payload) # first non-NULL mrmash row + expect_null(lk(tw, "S", "OTHER")) # no such trait + expect_null(lk(tw, "OTHER", "G")) # no such study + expect_null(lk(NULL, "S", "G")) # no TwasWeights supplied +}) + test_that("fineMappingPipeline(QtlDataset): pipCutoffToSkip skips no-signal univariate traits", { qd <- .fmp_makeQtlDataset(contexts = "brain", traits = c("ENSG_A", "ENSG_B")) # Stateful screen: reject the first block (ENSG_A), keep the rest (ENSG_B). diff --git a/tests/testthat/test_regularizedRegressionWrappers.R b/tests/testthat/test_regularizedRegressionWrappers.R index 438c4d72..2aed3b06 100644 --- a/tests/testthat/test_regularizedRegressionWrappers.R +++ b/tests/testthat/test_regularizedRegressionWrappers.R @@ -64,6 +64,30 @@ test_that("susieRssWeights retains fit when retainFit = TRUE", { expect_false(is.null(attr(w, "fit"))) }) +test_that("mrmashWeights fitDetail: slim default omits the full fit, full keeps it", { + skip_if_not_installed("mr.mashr") + fakeFit <- list(w0 = c(a_1 = 0.5, a_2 = 0.5), V = diag(2)) + ddpm <- list(U = list(a = diag(2))) + # Mock coef extraction so we exercise only the retain payload logic, not a + # real mr.mash fit. coef.mr.mash(fit)[-1, ] -> drop the intercept row. + local_mocked_bindings( + coef.mr.mash = function(object, ...) rbind(c(0, 0), c(0.1, 0.2)), + .package = "mr.mashr") + fitSlim <- attr(mrmashWeights(mrmashFit = fakeFit, retainFit = TRUE, + dataDrivenPriorMatrices = ddpm), "fit") + expect_setequal(names(fitSlim), c("dataDrivenPriorMatrices", "w0", "V")) + expect_null(fitSlim$fit) # slim: no full fit + expect_identical(fitSlim$dataDrivenPriorMatrices, ddpm) + expect_identical(fitSlim$w0, fakeFit$w0) + + fitFull <- attr(mrmashWeights(mrmashFit = fakeFit, retainFit = TRUE, + fitDetail = "full", + dataDrivenPriorMatrices = ddpm), "fit") + expect_true("fit" %in% names(fitFull)) # full: the whole fit retained + expect_identical(fitFull$fit, fakeFit) + expect_identical(fitFull$w0, fakeFit$w0) # slim fields still present +}) + test_that("susieInfRssWeights works", { skip_if_not_installed("susieR") set.seed(42)