From 983d6a1566ec9205a130dd7a8ea7fdce89ecf557 Mon Sep 17 00:00:00 2001 From: PR Author Date: Fri, 19 Jun 2026 18:40:47 +0800 Subject: [PATCH 1/2] Support bounded feedback loops in the DAG execution engine Allow sampler nodes' internal iteration variables (e.g. step_index) to flow back upstream through ComfyMathExpression nodes to control per-step parameters (cfg, s_noise, eta, r) without triggering a dependency cycle error. Architecture: Two-level cycle handling - Static validation: _is_bounded_feedback_cycle() allows cycles where any node declares BOUNDED_FEEDBACK - Graph building: _is_feedback_output() skips strong links for declared feedback sockets, records them in feedback_links Multi-hop chain walking via _build_feedback_fns() resolves expression->CFGGuider/Sampler chains with simple_eval + MATH_FUNCTIONS, composing per-step fn(step, total_steps) callables. Sampler functions now re-read s_noise/eta/r each iteration via _init_dynamic_options() / _refresh_dynamic_params() / _apply_dynamic_s_noise(). KSAMPLER.sample() conditionally injects mutable extra_options ref. Safety: _dynamic_sampler_options popped at function top before model() calls. One-line opt-in: BOUNDED_FEEDBACK = {'step_index'} on any node. --- comfy/k_diffusion/sampling.py | 2133 +++++++++++++++++++++----- comfy/samplers.py | 6 + comfy_execution/graph.py | 44 + comfy_extras/nodes_custom_sampler.py | 31 +- execution.py | 320 +++- 5 files changed, 2185 insertions(+), 349 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 11db46d94c9b..afdc963f9b1e 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,26 +1,25 @@ import math from functools import partial -from scipy import integrate import torch -from torch import nn import torchsde +from scipy import integrate +from torch import nn from tqdm.auto import tqdm -from . import utils -from . import deis -from . import sa_solver +import comfy.memory_management import comfy.model_patcher import comfy.model_sampling - -import comfy.memory_management from comfy.utils import model_trange as trange +from . import deis, sa_solver, utils + + def append_zero(x): return torch.cat([x, x.new_zeros([1])]) -def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"): """Constructs the noise schedule of Karras et al. (2022).""" ramp = torch.linspace(0, 1, n, device=device) min_inv_rho = sigma_min ** (1 / rho) @@ -29,49 +28,57 @@ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): return append_zero(sigmas).to(device) -def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): +def get_sigmas_exponential(n, sigma_min, sigma_max, device="cpu"): """Constructs an exponential noise schedule.""" - sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + sigmas = torch.linspace( + math.log(sigma_max), math.log(sigma_min), n, device=device + ).exp() return append_zero(sigmas) -def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'): +def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1.0, device="cpu"): """Constructs an polynomial in log sigma noise schedule.""" ramp = torch.linspace(1, 0, n, device=device) ** rho - sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)) + sigmas = torch.exp( + ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min) + ) return append_zero(sigmas) -def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): +def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device="cpu"): """Constructs a continuous VP noise schedule.""" t = torch.linspace(1, eps_s, n, device=device) - sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t)) + sigmas = torch.sqrt(torch.special.expm1(beta_d * t**2 / 2 + beta_min * t)) return append_zero(sigmas) -def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'): - """Constructs the noise schedule proposed by Tiankai et al. (2024). """ - epsilon = 1e-5 # avoid log(0) +def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0.0, beta=0.5, device="cpu"): + """Constructs the noise schedule proposed by Tiankai et al. (2024).""" + epsilon = 1e-5 # avoid log(0) x = torch.linspace(0, 1, n, device=device) clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max) - lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon) + lmb = mu - beta * torch.sign(0.5 - x) * torch.log( + 1 - 2 * torch.abs(0.5 - x) + epsilon + ) sigmas = clamp(torch.exp(lmb)) return sigmas - def to_d(x, sigma, denoised): """Converts a denoiser output to a Karras ODE derivative.""" return (x - denoised) / utils.append_dims(sigma, x.ndim) -def get_ancestral_step(sigma_from, sigma_to, eta=1.): +def get_ancestral_step(sigma_from, sigma_to, eta=1.0): """Calculates the noise level (sigma_down) to step down to and the amount of noise to add (sigma_up) when doing an ancestral sampling step.""" if not eta: - return sigma_to, 0. - sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) - sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_to, 0.0 + sigma_up = min( + sigma_to, + eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, + ) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 return sigma_down, sigma_up @@ -85,7 +92,9 @@ def default_noise_sampler(x, seed=None): else: generator = None - return lambda sigma, sigma_next: torch.randn(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) + return lambda sigma, sigma_next: torch.randn( + x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator + ) class BatchedBrownianTree: @@ -94,22 +103,26 @@ class BatchedBrownianTree: def __init__(self, x, t0, t1, seed=None, **kwargs): self.cpu_tree = kwargs.pop("cpu", True) t0, t1, self.sign = self.sort(t0, t1) - w0 = kwargs.pop('w0', None) + w0 = kwargs.pop("w0", None) if w0 is None: w0 = torch.zeros_like(x) self.batched = False if seed is None: - seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),) + seed = (torch.randint(0, 2**63 - 1, ()).item(),) elif isinstance(seed, (tuple, list)): if len(seed) != x.shape[0]: - raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.") + raise ValueError( + "Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size." + ) self.batched = True w0 = w0[0] else: seed = (seed,) if self.cpu_tree: t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu() - self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed) + self.trees = tuple( + torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed + ) @staticmethod def sort(a, b): @@ -120,7 +133,9 @@ def __call__(self, t0, t1): device, dtype = t0.device, t0.dtype if self.cpu_tree: t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float() - w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign) + w = torch.stack([tree(t0, t1) for tree in self.trees]).to( + device=device, dtype=dtype + ) * (self.sign * sign) return w if self.batched else w[0] @@ -139,13 +154,21 @@ class BrownianTreeNoiseSampler: internal timestep. """ - def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): + def __init__( + self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False + ): self.transform = transform - t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + t0, t1 = ( + self.transform(torch.as_tensor(sigma_min)), + self.transform(torch.as_tensor(sigma_max)), + ) self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) def __call__(self, sigma, sigma_next): - t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + t0, t1 = ( + self.transform(torch.as_tensor(sigma)), + self.transform(torch.as_tensor(sigma_next)), + ) return self.tree(t0, t1) / (t1 - t0).abs().sqrt() @@ -186,14 +209,68 @@ def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor: return (torch.expm1(h) - h) / h +def _apply_dynamic_s_noise(dynamic_opts, model_sampling, current_s_noise): + """Re-read s_noise from the mutable extra_options dict if available. + + Bounded-feedback callbacks update ``extra_options["s_noise"]`` at each step. + Call this at the top of each sampler loop iteration with *dynamic_opts* + set to the KSAMPLER.extra_options reference (popped from extra_args so it + never reaches the model). + """ + if dynamic_opts is None: + return current_s_noise + new_val = dynamic_opts.get("s_noise") + if new_val is None: + return current_s_noise + noise_scale = getattr(model_sampling, "noise_scale", 1.0) + return new_val * noise_scale + + +def _init_dynamic_options(extra_args): + """Pop and return the mutable extra_options dict for per-step re-reading, + or None if no bounded-feedback is active on this sampler.""" + if extra_args is None: + return None + return extra_args.pop("_dynamic_sampler_options", None) + + +def _refresh_dynamic_params(dynamic_opts, model_sampling, s_noise, eta): + """Re-read s_noise and eta from mutable dynamic_opts if available. + Returns (s_noise, eta) tuple with updated values. + """ + if dynamic_opts is None: + return s_noise, eta + ns = getattr(model_sampling, "noise_scale", 1.0) + if "s_noise" in dynamic_opts: + s_noise = dynamic_opts["s_noise"] * ns + if "eta" in dynamic_opts: + eta = dynamic_opts["eta"] + return s_noise, eta + + @torch.no_grad() -def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_euler( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) sigma_hat = sigmas[i] * (gamma + 1) else: gamma = 0 @@ -201,11 +278,19 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, if gamma > 0: eps = torch.randn_like(x) * s_noise - x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) dt = sigmas[i + 1] - sigma_hat # Euler method x = x + d * dt @@ -213,19 +298,47 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST): - return sample_euler_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) +def sample_euler_ancestral( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): + if isinstance( + model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST + ): + return sample_euler_ancestral_RF( + model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler + ) """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigma_down == 0: x = denoised @@ -233,22 +346,52 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis d = to_d(x, sigmas[i], denoised) # Euler method dt = sigma_down - sigmas[i] - x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + x = ( + x + + d * dt + + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + ) return x + @torch.no_grad() -def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1., noise_sampler=None): +def sample_euler_ancestral_RF( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """Ancestral sampling with Euler method steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) denoised = model(x, sigmas[i] * s_in, **extra_args) # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: x = denoised @@ -257,22 +400,42 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, sigma_down = sigmas[i + 1] * downstep_ratio alpha_ip1 = 1 - sigmas[i + 1] alpha_down = 1 - sigma_down - renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5 + renoise_coeff = ( + sigmas[i + 1] ** 2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2 + ) ** 0.5 # Euler method sigma_down_i_ratio = sigma_down / sigmas[i] x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised if eta > 0: - x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + x = (alpha_ip1 / alpha_down) * x + noise_sampler( + sigmas[i], sigmas[i + 1] + ) * s_noise * renoise_coeff return x + @torch.no_grad() -def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_heun( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) sigma_hat = sigmas[i] * (gamma + 1) else: gamma = 0 @@ -281,11 +444,19 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: eps = torch.randn_like(x) * s_noise - x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == 0: # Euler method @@ -301,13 +472,28 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_dpm_2( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +): """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): if s_churn > 0: - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) sigma_hat = sigmas[i] * (gamma + 1) else: gamma = 0 @@ -315,11 +501,19 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, if gamma > 0: eps = torch.randn_like(x) * s_noise - x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Euler method dt = sigmas[i + 1] - sigma_hat @@ -337,20 +531,48 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST): - return sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) +def sample_dpm_2_ancestral( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): + if isinstance( + model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST + ): + return sample_dpm_2_ancestral_RF( + model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler + ) """Ancestral sampling with DPM-Solver second-order steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d = to_d(x, sigmas[i], denoised) if sigma_down == 0: # Euler method @@ -368,24 +590,52 @@ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, dis x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x + @torch.no_grad() -def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpm_2_ancestral_RF( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """Ancestral sampling with DPM-Solver second-order steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) denoised = model(x, sigmas[i] * s_in, **extra_args) - downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta - sigma_down = sigmas[i+1] * downstep_ratio - alpha_ip1 = 1 - sigmas[i+1] + downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta + sigma_down = sigmas[i + 1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i + 1] alpha_down = 1 - sigma_down - renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 + renoise_coeff = ( + sigmas[i + 1] ** 2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2 + ) ** 0.5 if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d = to_d(x, sigmas[i], denoised) if sigma_down == 0: # Euler method @@ -400,19 +650,24 @@ def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) d_2 = to_d(x_2, sigma_mid, denoised_2) x = x + d_2 * dt_2 - x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + x = (alpha_ip1 / alpha_down) * x + noise_sampler( + sigmas[i], sigmas[i + 1] + ) * s_noise * renoise_coeff return x + def linear_multistep_coeff(order, t, i, j): if order - 1 > i: - raise ValueError(f'Order {order} too high for step {i}') + raise ValueError(f"Order {order} too high for step {i}") + def fn(tau): - prod = 1. + prod = 1.0 for k in range(order): if j == k: continue prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) return prod + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] @@ -429,20 +684,34 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o if len(ds) > order: ds.pop(0) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Denoising step x = denoised else: cur_order = min(i + 1, order) - coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + coeffs = [ + linear_multistep_coeff(cur_order, sigmas_cpu, i, j) + for j in range(cur_order) + ] x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) return x class PIDStepSizeController: """A PID controller for ODE adaptive step size control.""" - def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): + + def __init__( + self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8 + ): self.h = h self.b1 = (pcoeff + icoeff + dcoeff) / order self.b2 = -(pcoeff + 2 * dcoeff) / order @@ -459,7 +728,9 @@ def propose_step(self, error): if not self.errs: self.errs = [inv_error, inv_error, inv_error] self.errs[0] = inv_error - factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 + factor = ( + self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 + ) factor = self.limiter(factor) accept = factor >= self.accept_safety if accept: @@ -489,7 +760,9 @@ def eps(self, eps_cache, key, x, t, *args, **kwargs): if key in eps_cache: return eps_cache[key], eps_cache sigma = self.sigma(t) * x.new_ones([x.shape[0]]) - eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) + eps = ( + x - self.model(x, sigma, *args, **self.extra_args, **kwargs) + ) / self.sigma(t) if self.eps_callback is not None: self.eps_callback() return eps, {key: eps, **eps_cache} @@ -497,37 +770,58 @@ def eps(self, eps_cache, key, x, t, *args, **kwargs): def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) x_1 = x - self.sigma(t_next) * h.expm1() * eps return x_1, eps_cache def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) s1 = t + r1 * h u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps - eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) - x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) + eps_r1, eps_cache = self.eps(eps_cache, "eps_r1", u1, s1) + x_2 = ( + x + - self.sigma(t_next) * h.expm1() * eps + - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) + ) return x_2, eps_cache def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): eps_cache = {} if eps_cache is None else eps_cache h = t_next - t - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) s1 = t + r1 * h s2 = t + r2 * h u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps - eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) - u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) - eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) - x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) + eps_r1, eps_cache = self.eps(eps_cache, "eps_r1", u1, s1) + u2 = ( + x + - self.sigma(s2) * (r2 * h).expm1() * eps + - self.sigma(s2) + * (r2 / r1) + * ((r2 * h).expm1() / (r2 * h) - 1) + * (eps_r1 - eps) + ) + eps_r2, eps_cache = self.eps(eps_cache, "eps_r2", u2, s2) + x_3 = ( + x + - self.sigma(t_next) * h.expm1() * eps + - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) + ) return x_3, eps_cache - def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None): - noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler + def dpm_solver_fast( + self, x, t_start, t_end, nfe, eta=0.0, s_noise=1.0, noise_sampler=None + ): + noise_sampler = ( + default_noise_sampler(x, seed=self.extra_args.get("seed", None)) + if noise_sampler is None + else noise_sampler + ) if not t_end > t_start and eta: - raise ValueError('eta must be 0 for reverse sampling') + raise ValueError("eta must be 0 for reverse sampling") m = math.floor(nfe / 3) + 1 ts = torch.linspace(t_start, t_end, m + 1, device=x.device) @@ -545,59 +839,99 @@ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_samp t_next_ = torch.minimum(t_end, self.t(sd)) su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 else: - t_next_, su = t_next, 0. + t_next_, su = t_next, 0.0 - eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + eps, eps_cache = self.eps(eps_cache, "eps", x, t) denoised = x - self.sigma(t) * eps if self.info_callback is not None: - self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) + self.info_callback( + {"x": x, "i": i, "t": ts[i], "t_up": t, "denoised": denoised} + ) if orders[i] == 1: - x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) + x, eps_cache = self.dpm_solver_1_step( + x, t, t_next_, eps_cache=eps_cache + ) elif orders[i] == 2: - x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) + x, eps_cache = self.dpm_solver_2_step( + x, t, t_next_, eps_cache=eps_cache + ) else: - x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) + x, eps_cache = self.dpm_solver_3_step( + x, t, t_next_, eps_cache=eps_cache + ) x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next)) return x - def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None): - noise_sampler = default_noise_sampler(x, seed=self.extra_args.get("seed", None)) if noise_sampler is None else noise_sampler + def dpm_solver_adaptive( + self, + x, + t_start, + t_end, + order=3, + rtol=0.05, + atol=0.0078, + h_init=0.05, + pcoeff=0.0, + icoeff=1.0, + dcoeff=0.0, + accept_safety=0.81, + eta=0.0, + s_noise=1.0, + noise_sampler=None, + ): + noise_sampler = ( + default_noise_sampler(x, seed=self.extra_args.get("seed", None)) + if noise_sampler is None + else noise_sampler + ) if order not in {2, 3}: - raise ValueError('order should be 2 or 3') + raise ValueError("order should be 2 or 3") forward = t_end > t_start if not forward and eta: - raise ValueError('eta must be 0 for reverse sampling') + raise ValueError("eta must be 0 for reverse sampling") h_init = abs(h_init) * (1 if forward else -1) atol = torch.tensor(atol) rtol = torch.tensor(rtol) s = t_start x_prev = x accept = True - pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) - info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} + pid = PIDStepSizeController( + h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety + ) + info = {"steps": 0, "nfe": 0, "n_accept": 0, "n_reject": 0} while s < t_end - 1e-5 if forward else s > t_end + 1e-5: eps_cache = {} - t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) + t = ( + torch.minimum(t_end, s + pid.h) + if forward + else torch.maximum(t_end, s + pid.h) + ) if eta: sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) t_ = torch.minimum(t_end, self.t(sd)) su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 else: - t_, su = t, 0. + t_, su = t, 0.0 - eps, eps_cache = self.eps(eps_cache, 'eps', x, s) + eps, eps_cache = self.eps(eps_cache, "eps", x, s) denoised = x - self.sigma(s) * eps if order == 2: x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) - x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_2_step( + x, s, t_, eps_cache=eps_cache + ) else: - x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) - x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) + x_low, eps_cache = self.dpm_solver_2_step( + x, s, t_, r1=1 / 3, eps_cache=eps_cache + ) + x_high, eps_cache = self.dpm_solver_3_step( + x, s, t_, eps_cache=eps_cache + ) delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 accept = pid.propose_step(error) @@ -605,63 +939,173 @@ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078 x_prev = x_low x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t)) s = t - info['n_accept'] += 1 + info["n_accept"] += 1 else: - info['n_reject'] += 1 - info['nfe'] += order - info['steps'] += 1 + info["n_reject"] += 1 + info["nfe"] += order + info["steps"] += 1 if self.info_callback is not None: - self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info}) + self.info_callback( + { + "x": x, + "i": info["steps"] - 1, + "t": s, + "t_up": s, + "denoised": denoised, + "error": error, + "h": pid.h, + **info, + } + ) return x, info @torch.no_grad() -def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None): +def sample_dpm_fast( + model, + x, + sigma_min, + sigma_max, + n, + extra_args=None, + callback=None, + disable=None, + eta=0.0, + s_noise=1.0, + noise_sampler=None, +): """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: - raise ValueError('sigma_min and sigma_max must not be 0') + raise ValueError("sigma_min and sigma_max must not be 0") + extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) with tqdm(total=n, disable=disable) as pbar: dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: - dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) - return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler) + dpm_solver.info_callback = lambda info: callback( + { + "sigma": dpm_solver.sigma(info["t"]), + "sigma_hat": dpm_solver.sigma(info["t_up"]), + **info, + } + ) + return dpm_solver.dpm_solver_fast( + x, + dpm_solver.t(torch.tensor(sigma_max)), + dpm_solver.t(torch.tensor(sigma_min)), + n, + eta, + s_noise, + noise_sampler, + ) @torch.no_grad() -def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False): +def sample_dpm_adaptive( + model, + x, + sigma_min, + sigma_max, + extra_args=None, + callback=None, + disable=None, + order=3, + rtol=0.05, + atol=0.0078, + h_init=0.05, + pcoeff=0.0, + icoeff=1.0, + dcoeff=0.0, + accept_safety=0.81, + eta=0.0, + s_noise=1.0, + noise_sampler=None, + return_info=False, +): """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" if sigma_min <= 0 or sigma_max <= 0: - raise ValueError('sigma_min and sigma_max must not be 0') + raise ValueError("sigma_min and sigma_max must not be 0") + extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) with tqdm(disable=disable) as pbar: dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: - dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info}) - x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler) + dpm_solver.info_callback = lambda info: callback( + { + "sigma": dpm_solver.sigma(info["t"]), + "sigma_hat": dpm_solver.sigma(info["t_up"]), + **info, + } + ) + x, info = dpm_solver.dpm_solver_adaptive( + x, + dpm_solver.t(torch.tensor(sigma_max)), + dpm_solver.t(torch.tensor(sigma_min)), + order, + rtol, + atol, + h_init, + pcoeff, + icoeff, + dcoeff, + accept_safety, + eta, + s_noise, + noise_sampler, + ) if return_info: return x, info return x @torch.no_grad() -def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST): - return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler) +def sample_dpmpp_2s_ancestral( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): + if isinstance( + model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST + ): + return sample_dpmpp_2s_ancestral_RF( + model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler + ) """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigma_down == 0: # Euler method d = to_d(x, sigmas[i], denoised) @@ -683,28 +1127,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, @torch.no_grad() -def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpmpp_2s_ancestral_RF( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") + s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1 - lambda_fn = lambda sigma: ((1-sigma)/sigma).log() + lambda_fn = lambda sigma: ((1 - sigma) / sigma).log() # logged_x = x.unsqueeze(0) for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) denoised = model(x, sigmas[i] * s_in, **extra_args) - downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta - sigma_down = sigmas[i+1] * downstep_ratio - alpha_ip1 = 1 - sigmas[i+1] + downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta + sigma_down = sigmas[i + 1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i + 1] alpha_down = 1 - sigma_down - renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 + renoise_coeff = ( + sigmas[i + 1] ** 2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2 + ) ** 0.5 # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Euler method d = to_d(x, sigmas[i], denoised) @@ -729,33 +1200,64 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non # print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff) # Noise addition if sigmas[i + 1] > 0 and eta > 0: - x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + x = (alpha_ip1 / alpha_down) * x + noise_sampler( + sigmas[i], sigmas[i + 1] + ) * s_noise * renoise_coeff # logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0) return x @torch.no_grad() -def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): +def sample_dpmpp_sde( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + r=1 / 2, +): """DPM-Solver++ (stochastic).""" if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() seed = extra_args.get("seed", None) - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler + noise_sampler = ( + BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) + if noise_sampler is None + else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) + if _dynamic_opts is not None and "r" in _dynamic_opts: + r = _dynamic_opts["r"] denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Denoising step x = denoised @@ -773,12 +1275,18 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N alpha_t = sigmas[i + 1] * lambda_t.exp() # Step 1 - sd, su = get_ancestral_step(lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta) + sd, su = get_ancestral_step( + lambda_s.neg().exp(), lambda_s_1.neg().exp(), eta + ) lambda_s_1_ = sd.log().neg() h_ = lambda_s_1_ - lambda_s - x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * (-h_).expm1() * denoised + x_2 = (alpha_s_1 / alpha_s) * (-h_).exp() * x - alpha_s_1 * ( + -h_ + ).expm1() * denoised if eta > 0 and s_noise > 0: - x_2 = x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su + x_2 = ( + x_2 + alpha_s_1 * noise_sampler(sigmas[i], sigma_s_1) * s_noise * su + ) denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 @@ -786,7 +1294,9 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N lambda_t_ = sd.log().neg() h_ = lambda_t_ - lambda_s denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * (-h_).expm1() * denoised_d + x = (alpha_t / alpha_s) * (-h_).exp() * x - alpha_t * ( + -h_ + ).expm1() * denoised_d if eta > 0 and s_noise > 0: x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * su return x @@ -804,7 +1314,15 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) h = t_next - t if old_denoised is None or sigmas[i + 1] == 0: @@ -819,21 +1337,37 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No @torch.no_grad() -def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): +def sample_dpmpp_2m_sde( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + solver_type="midpoint", +): """DPM-Solver++(2M) SDE.""" if len(sigmas) <= 1: return x - if solver_type not in {'heun', 'midpoint'}: - raise ValueError('solver_type must be \'heun\' or \'midpoint\'') + if solver_type not in {"heun", "midpoint"}: + raise ValueError("solver_type must be 'heun' or 'midpoint'") extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler + noise_sampler = ( + BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) + if noise_sampler is None + else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) @@ -842,9 +1376,20 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl h, h_last = None, None for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Denoising step x = denoised @@ -856,17 +1401,30 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl alpha_t = sigmas[i + 1] * lambda_t.exp() - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised + x = ( + sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + + alpha_t * (-h_eta).expm1().neg() * denoised + ) if old_denoised is not None: r = h_last / h - if solver_type == 'heun': - x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * (1 / r) * (denoised - old_denoised) - elif solver_type == 'midpoint': - x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * (denoised - old_denoised) + if solver_type == "heun": + x = x + alpha_t * ((-h_eta).expm1().neg() / (-h_eta) + 1) * ( + 1 / r + ) * (denoised - old_denoised) + elif solver_type == "midpoint": + x = x + 0.5 * alpha_t * (-h_eta).expm1().neg() * (1 / r) * ( + denoised - old_denoised + ) if eta > 0 and s_noise > 0: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise + x = ( + x + + noise_sampler(sigmas[i], sigmas[i + 1]) + * sigmas[i + 1] + * (-2 * h * eta).expm1().neg().sqrt() + * s_noise + ) old_denoised = denoised h_last = h @@ -874,24 +1432,61 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl @torch.no_grad() -def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): - return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) +def sample_dpmpp_2m_sde_heun( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + solver_type="heun", +): + return sample_dpmpp_2m_sde( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + solver_type=solver_type, + ) @torch.no_grad() -def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpmpp_3m_sde( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """DPM-Solver++(3M) SDE.""" if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler + noise_sampler = ( + BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) + if noise_sampler is None + else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) @@ -900,9 +1495,20 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl h, h_1, h_2 = None, None, None for i in trange(len(sigmas) - 1, disable=disable): + s_noise, eta = _refresh_dynamic_params( + _dynamic_opts, model_sampling, s_noise, eta + ) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Denoising step x = denoised @@ -913,7 +1519,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl alpha_t = sigmas[i + 1] * lambda_t.exp() - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + alpha_t * (-h_eta).expm1().neg() * denoised + x = ( + sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + + alpha_t * (-h_eta).expm1().neg() * denoised + ) if h_2 is not None: # DPM-Solver++(3M) SDE @@ -934,7 +1543,13 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl x = x + (alpha_t * phi_2) * d if eta > 0 and s_noise > 0: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise + x = ( + x + + noise_sampler(sigmas[i], sigmas[i + 1]) + * sigmas[i + 1] + * (-2 * h * eta).expm1().neg().sqrt() + * s_noise + ) denoised_1, denoised_2 = denoised, denoised_1 h_1, h_2 = h, h_1 @@ -942,94 +1557,263 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl @torch.no_grad() -def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpmpp_3m_sde_gpu( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler - return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler) + noise_sampler = ( + BrownianTreeNoiseSampler( + x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False + ) + if noise_sampler is None + else noise_sampler + ) + return sample_dpmpp_3m_sde( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + ) @torch.no_grad() -def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'): +def sample_dpmpp_2m_sde_heun_gpu( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + solver_type="heun", +): if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler - return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + noise_sampler = ( + BrownianTreeNoiseSampler( + x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False + ) + if noise_sampler is None + else noise_sampler + ) + return sample_dpmpp_2m_sde_heun( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + solver_type=solver_type, + ) @torch.no_grad() -def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'): +def sample_dpmpp_2m_sde_gpu( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + solver_type="midpoint", +): if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler - return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type) + noise_sampler = ( + BrownianTreeNoiseSampler( + x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False + ) + if noise_sampler is None + else noise_sampler + ) + return sample_dpmpp_2m_sde( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + solver_type=solver_type, + ) @torch.no_grad() -def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2): +def sample_dpmpp_sde_gpu( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + r=1 / 2, +): if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler - return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r) + noise_sampler = ( + BrownianTreeNoiseSampler( + x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False + ) + if noise_sampler is None + else noise_sampler + ) + return sample_dpmpp_sde( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + r=r, + ) def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler): alpha_cumprod = 1 / ((sigma * sigma) + 1) alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1) - alpha = (alpha_cumprod / alpha_cumprod_prev) + alpha = alpha_cumprod / alpha_cumprod_prev mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt()) if sigma_prev > 0: - mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) + mu += ( + (1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod) + ).sqrt() * noise_sampler(sigma, sigma_prev) return mu -def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): + +def generic_step_sampler( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + noise_sampler=None, + step_function=None, +): extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) + x = step_function( + x / torch.sqrt(1.0 + sigmas[i] ** 2.0), + sigmas[i], + sigmas[i + 1], + (x - denoised) / sigmas[i], + noise_sampler, + ) if sigmas[i + 1] != 0: x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0) return x @torch.no_grad() -def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): - return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) +def sample_ddpm( + model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None +): + return generic_step_sampler( + model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step + ) + @torch.no_grad() -def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, s_noise=1.0, s_noise_end=None, noise_clip_std=0.0): +def sample_lcm( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + noise_sampler=None, + s_noise=1.0, + s_noise_end=None, + noise_clip_std=0.0, +): # s_noise / s_noise_end: per-step noise multiplier, linearly interpolated across steps # noise_clip_std: clamp injected noise to +/- N stddevs (0 disables). extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) n_steps = max(1, len(sigmas) - 1) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") s_start = float(s_noise) s_end = s_start if s_noise_end is None else float(s_noise_end) for i in trange(n_steps, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) x = denoised if sigmas[i + 1] > 0: @@ -1046,39 +1830,63 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n @torch.no_grad() -def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_heunpp2( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float("inf"), + s_noise=1.0, +): # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) s_in = x.new_ones([x.shape[0]]) s_end = sigmas[-1] for i in trange(len(sigmas) - 1, disable=disable): - gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + gamma = ( + min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigmas[i] <= s_tmax + else 0.0 + ) eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) if gamma > 0: - x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) d = to_d(x, sigma_hat, denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigma_hat, + "denoised": denoised, + } + ) dt = sigmas[i + 1] - sigma_hat if sigmas[i + 1] == s_end: # Euler method x = x + d * dt elif sigmas[i + 2] == s_end: - # Heun's method x_2 = x + d * dt denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) d_2 = to_d(x_2, sigmas[i + 1], denoised_2) w = 2 * sigmas[0] - w2 = sigmas[i+1]/w + w2 = sigmas[i + 1] / w w1 = 1 - w2 d_prime = d * w1 + d_2 * w2 - x = x + d_prime * dt else: @@ -1102,9 +1910,11 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non return x -#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py -#under Apache 2 license -def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4): +# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py +# under Apache 2 license +def sample_ipndm( + model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4 +): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -1119,25 +1929,48 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, denoised = model(x_cur, t_cur * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d_cur = (x_cur - denoised) / t_cur - order = min(max_order, i+1) - if t_next == 0: # Denoising step + order = min(max_order, i + 1) + if t_next == 0: # Denoising step x_next = denoised - elif order == 1: # First Euler step. + elif order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur - elif order == 2: # Use one history point. + elif order == 2: # Use one history point. x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2 - elif order == 3: # Use two history points. - x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12 - elif order == 4: # Use three history points. - x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24 + elif order == 3: # Use two history points. + x_next = ( + x_cur + + (t_next - t_cur) + * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) + / 12 + ) + elif order == 4: # Use three history points. + x_next = ( + x_cur + + (t_next - t_cur) + * ( + 55 * d_cur + - 59 * buffer_model[-1] + + 37 * buffer_model[-2] + - 9 * buffer_model[-3] + ) + / 24 + ) if len(buffer_model) == max_order - 1: for k in range(max_order - 2): - buffer_model[k] = buffer_model[k+1] + buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur else: buffer_model.append(d_cur) @@ -1145,9 +1978,11 @@ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, return x_next -#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py -#under Apache 2 license -def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4): +# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py +# under Apache 2 license +def sample_ipndm_v( + model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4 +): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -1163,47 +1998,106 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non denoised = model(x_cur, t_cur * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d_cur = (x_cur - denoised) / t_cur - order = min(max_order, i+1) - if t_next == 0: # Denoising step + order = min(max_order, i + 1) + if t_next == 0: # Denoising step x_next = denoised - elif order == 1: # First Euler step. + elif order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur - elif order == 2: # Use one history point. - h_n = (t_next - t_cur) - h_n_1 = (t_cur - t_steps[i-1]) + elif order == 2: # Use one history point. + h_n = t_next - t_cur + h_n_1 = t_cur - t_steps[i - 1] coeff1 = (2 + (h_n / h_n_1)) / 2 coeff2 = -(h_n / h_n_1) / 2 - x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1]) - elif order == 3: # Use two history points. - h_n = (t_next - t_cur) - h_n_1 = (t_cur - t_steps[i-1]) - h_n_2 = (t_steps[i-1] - t_steps[i-2]) - temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2 + x_next = x_cur + (t_next - t_cur) * ( + coeff1 * d_cur + coeff2 * buffer_model[-1] + ) + elif order == 3: # Use two history points. + h_n = t_next - t_cur + h_n_1 = t_cur - t_steps[i - 1] + h_n_2 = t_steps[i - 1] - t_steps[i - 2] + temp = ( + 1 + - h_n + / (3 * (h_n + h_n_1)) + * (h_n * (h_n + h_n_1)) + / (h_n_1 * (h_n_1 + h_n_2)) + ) / 2 coeff1 = (2 + (h_n / h_n_1)) / 2 + temp coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp coeff3 = temp * h_n_1 / h_n_2 - x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2]) - elif order == 4: # Use three history points. - h_n = (t_next - t_cur) - h_n_1 = (t_cur - t_steps[i-1]) - h_n_2 = (t_steps[i-1] - t_steps[i-2]) - h_n_3 = (t_steps[i-2] - t_steps[i-3]) - temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2 - temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \ - * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3)) + x_next = x_cur + (t_next - t_cur) * ( + coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + ) + elif order == 4: # Use three history points. + h_n = t_next - t_cur + h_n_1 = t_cur - t_steps[i - 1] + h_n_2 = t_steps[i - 1] - t_steps[i - 2] + h_n_3 = t_steps[i - 2] - t_steps[i - 3] + temp1 = ( + 1 + - h_n + / (3 * (h_n + h_n_1)) + * (h_n * (h_n + h_n_1)) + / (h_n_1 * (h_n_1 + h_n_2)) + ) / 2 + temp2 = ( + ( + (1 - h_n / (3 * (h_n + h_n_1))) / 2 + + (1 - h_n / (2 * (h_n + h_n_1))) + * h_n + / (6 * (h_n + h_n_1 + h_n_2)) + ) + * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) + / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3)) + ) coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2 - coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2 - coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2 - coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2 - x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3]) + coeff2 = ( + -(h_n / h_n_1) / 2 + - (1 + h_n_1 / h_n_2) * temp1 + - ( + 1 + + (h_n_1 / h_n_2) + + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) + ) + * temp2 + ) + coeff3 = ( + temp1 * h_n_1 / h_n_2 + + ( + (h_n_1 / h_n_2) + + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) + * (1 + h_n_2 / h_n_3) + ) + * temp2 + ) + coeff4 = ( + -temp2 + * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) + * h_n_1 + / h_n_2 + ) + x_next = x_cur + (t_next - t_cur) * ( + coeff1 * d_cur + + coeff2 * buffer_model[-1] + + coeff3 * buffer_model[-2] + + coeff4 * buffer_model[-3] + ) if len(buffer_model) == max_order - 1: for k in range(max_order - 2): - buffer_model[k] = buffer_model[k+1] + buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur.detach() else: buffer_model.append(d_cur.detach()) @@ -1211,10 +2105,19 @@ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=Non return x_next -#From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py -#under Apache 2 license +# From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py +# under Apache 2 license @torch.no_grad() -def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'): +def sample_deis( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + max_order=3, + deis_mode="tab", +): extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -1232,29 +2135,48 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, denoised = model(x_cur, t_cur * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) d_cur = (x_cur - denoised) / t_cur - order = min(max_order, i+1) + order = min(max_order, i + 1) if t_next <= 0: order = 1 - if order == 1: # First Euler step. + if order == 1: # First Euler step. x_next = x_cur + (t_next - t_cur) * d_cur - elif order == 2: # Use one history point. + elif order == 2: # Use one history point. coeff_cur, coeff_prev1 = coeff_list[i] x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] - elif order == 3: # Use two history points. + elif order == 3: # Use two history points. coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i] - x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] - elif order == 4: # Use three history points. + x_next = ( + x_cur + + coeff_cur * d_cur + + coeff_prev1 * buffer_model[-1] + + coeff_prev2 * buffer_model[-2] + ) + elif order == 4: # Use three history points. coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i] - x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3] + x_next = ( + x_cur + + coeff_cur * d_cur + + coeff_prev1 * buffer_model[-1] + + coeff_prev2 * buffer_model[-2] + + coeff_prev3 * buffer_model[-3] + ) if len(buffer_model) == max_order - 1: for k in range(max_order - 2): - buffer_model[k] = buffer_model[k+1] + buffer_model[k] = buffer_model[k + 1] buffer_model[-1] = d_cur.detach() else: buffer_model.append(d_cur.detach()) @@ -1263,11 +2185,24 @@ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_euler_ancestral_cfg_pp( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """Ancestral sampling with Euler method steps (CFG++).""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) @@ -1281,63 +2216,129 @@ def post_cfg_function(args): return args["denoised"] model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args["model_options"] = ( + comfy.model_patcher.set_model_options_post_cfg_function( + model_options, post_cfg_function, disable_cfg1_optimization=True + ) + ) s_in = x.new_ones([x.shape[0]]) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: # Denoising step x = denoised else: alpha_s = sigmas[i] * lambda_fn(sigmas[i]).exp() alpha_t = sigmas[i + 1] * lambda_fn(sigmas[i + 1]).exp() - d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise + d = to_d(x, sigmas[i], alpha_s * uncond_denoised) # to noise # DDIM stochastic sampling - sigma_down, sigma_up = get_ancestral_step(sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta) + sigma_down, sigma_up = get_ancestral_step( + sigmas[i] / alpha_s, sigmas[i + 1] / alpha_t, eta=eta + ) sigma_down = alpha_t * sigma_down # Euler method x = alpha_t * denoised + sigma_down * d if eta > 0 and s_noise > 0: - x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + x = ( + x + + alpha_t + * noise_sampler(sigmas[i], sigmas[i + 1]) + * s_noise + * sigma_up + ) return x @torch.no_grad() def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): """Euler method steps (CFG++).""" - return sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None) + return sample_euler_ancestral_cfg_pp( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=0.0, + s_noise=0.0, + noise_sampler=None, + ) @torch.no_grad() -def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): +def sample_dpmpp_2s_ancestral_cfg_pp( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + s_noise = s_noise * getattr( + model.inner_model.model_patcher.get_model_object("model_sampling"), + "noise_scale", + 1.0, + ) temp = [0] + def post_cfg_function(args): temp[0] = args["uncond_denoised"] return args["denoised"] model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args["model_options"] = ( + comfy.model_patcher.set_model_options_post_cfg_function( + model_options, post_cfg_function, disable_cfg1_optimization=True + ) + ) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigma_down == 0: # Euler method d = to_d(x, sigmas[i], temp[0]) @@ -1349,16 +2350,23 @@ def post_cfg_function(args): r = 1 / 2 h = t_next - t s = t + r * h - x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised + x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - ( + -h * r + ).expm1() * denoised denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) - x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2 + x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - ( + -h + ).expm1() * denoised_2 # Noise addition if sigmas[i + 1] > 0: x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x + @torch.no_grad() -def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): +def sample_dpmpp_2m_cfg_pp( + model, x, sigmas, extra_args=None, callback=None, disable=None +): """DPM-Solver++(2M).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -1366,18 +2374,31 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis old_uncond_denoised = None uncond_denoised = None + def post_cfg_function(args): nonlocal uncond_denoised uncond_denoised = args["uncond_denoised"] return args["denoised"] model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args["model_options"] = ( + comfy.model_patcher.set_model_options_post_cfg_function( + model_options, post_cfg_function, disable_cfg1_optimization=True + ) + ) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) h = t_next - t if old_uncond_denoised is None or sigmas[i + 1] == 0: @@ -1385,17 +2406,38 @@ def post_cfg_function(args): else: h_last = t - t_fn(sigmas[i - 1]) r = h_last / h - denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * (1 / (2 * r)) * (denoised - old_uncond_denoised) + denoised_mix = -torch.exp(-h) * uncond_denoised - torch.expm1(-h) * ( + 1 / (2 * r) + ) * (denoised - old_uncond_denoised) x = denoised + denoised_mix + torch.exp(-h) * x old_uncond_denoised = uncond_denoised return x + @torch.no_grad() -def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False): +def res_multistep( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_noise=1.0, + noise_sampler=None, + eta=1.0, + cfg_pp=False, +): extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + s_noise = s_noise * getattr( + model.inner_model.model_patcher.get_model_object("model_sampling"), + "noise_scale", + 1.0, + ) s_in = x.new_ones([x.shape[0]]) sigma_fn = lambda t: t.neg().exp() t_fn = lambda sigma: sigma.log().neg() @@ -1405,6 +2447,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None old_sigma_down = None old_denoised = None uncond_denoised = None + def post_cfg_function(args): nonlocal uncond_denoised uncond_denoised = args["uncond_denoised"] @@ -1412,13 +2455,28 @@ def post_cfg_function(args): if cfg_pp: model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args["model_options"] = ( + comfy.model_patcher.set_model_options_post_cfg_function( + model_options, post_cfg_function, disable_cfg1_optimization=True + ) + ) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: - callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigma_down == 0 or old_denoised is None: # Euler method if cfg_pp: @@ -1430,7 +2488,12 @@ def post_cfg_function(args): x = x + d * dt else: # Second order multistep method in https://arxiv.org/pdf/2308.02157 - t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1]) + t, t_old, t_next, t_prev = ( + t_fn(sigmas[i]), + t_fn(old_sigma_down), + t_fn(sigma_down), + t_fn(sigmas[i - 1]), + ) h = t_next - t c2 = (t_prev - t_old) / h @@ -1455,31 +2518,128 @@ def post_cfg_function(args): old_sigma_down = sigma_down return x + @torch.no_grad() -def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None): - return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False) +def sample_res_multistep( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_noise=1.0, + noise_sampler=None, +): + return res_multistep( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + s_noise=s_noise, + noise_sampler=noise_sampler, + eta=0.0, + cfg_pp=False, + ) + @torch.no_grad() -def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None): - return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True) +def sample_res_multistep_cfg_pp( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_noise=1.0, + noise_sampler=None, +): + return res_multistep( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + s_noise=s_noise, + noise_sampler=noise_sampler, + eta=0.0, + cfg_pp=True, + ) + @torch.no_grad() -def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False) +def sample_res_multistep_ancestral( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): + return res_multistep( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + s_noise=s_noise, + noise_sampler=noise_sampler, + eta=eta, + cfg_pp=False, + ) + @torch.no_grad() -def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None): - return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True) +def sample_res_multistep_ancestral_cfg_pp( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, +): + return res_multistep( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + s_noise=s_noise, + noise_sampler=noise_sampler, + eta=eta, + cfg_pp=True, + ) @torch.no_grad() -def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False): +def sample_gradient_estimation( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + ge_gamma=2.0, + cfg_pp=False, +): """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK""" extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) s_in = x.new_ones([x.shape[0]]) old_d = None uncond_denoised = None + def post_cfg_function(args): nonlocal uncond_denoised uncond_denoised = args["uncond_denoised"] @@ -1487,7 +2647,11 @@ def post_cfg_function(args): if cfg_pp: model_options = extra_args.get("model_options", {}).copy() - extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + extra_args["model_options"] = ( + comfy.model_patcher.set_model_options_post_cfg_function( + model_options, post_cfg_function, disable_cfg1_optimization=True + ) + ) for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) @@ -1496,7 +2660,15 @@ def post_cfg_function(args): else: d = to_d(x, sigmas[i], denoised) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) dt = sigmas[i + 1] - sigmas[i] if sigmas[i + 1] == 0: # Denoising step @@ -1517,27 +2689,61 @@ def post_cfg_function(args): @torch.no_grad() -def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.): - return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True) +def sample_gradient_estimation_cfg_pp( + model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.0 +): + return sample_gradient_estimation( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + ge_gamma=ge_gamma, + cfg_pp=True, + ) @torch.no_grad() -def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1.0, noise_sampler=None, noise_scaler=None, max_stage=3): +def sample_er_sde( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_noise=1.0, + noise_sampler=None, + noise_scaler=None, + max_stage=3, +): """Extended Reverse-Time SDE solver (VP ER-SDE-Solver-3). arXiv: https://arxiv.org/abs/2309.06169. Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py. """ extra_args = {} if extra_args is None else extra_args + # Pop bounded-feedback dynamic options so they never reach the model call. + # We keep a reference to the same mutable dict — the callback updates it + # in-place and we re-read at each loop iteration. + _dynamic_opts = extra_args.pop("_dynamic_sampler_options", None) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler - s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0) + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) + s_noise = s_noise * getattr( + model.inner_model.model_patcher.get_model_object("model_sampling"), + "noise_scale", + 1.0, + ) s_in = x.new_ones([x.shape[0]]) def default_er_sde_noise_scaler(x): - return x * ((x ** 0.3).exp() + 10.0) + return x * ((x**0.3).exp() + 10.0) noise_scaler = default_er_sde_noise_scaler if noise_scaler is None else noise_scaler num_integration_points = 200.0 - point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device) + point_indice = torch.arange( + 0, num_integration_points, dtype=torch.float32, device=x.device + ) model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) @@ -1548,9 +2754,18 @@ def default_er_sde_noise_scaler(x): old_denoised_d = None for i in trange(len(sigmas) - 1, disable=disable): + s_noise = _apply_dynamic_s_noise(_dynamic_opts, model_sampling, s_noise) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) stage_used = min(max_stage, i + 1) if sigmas[i + 1] == 0: x = denoised @@ -1572,24 +2787,50 @@ def default_er_sde_noise_scaler(x): # Stage 2 s = torch.sum(1 / scaled_pos) * lambda_step_size - denoised_d = (denoised - old_denoised) / (er_lambda_s - er_lambdas[i - 1]) + denoised_d = (denoised - old_denoised) / ( + er_lambda_s - er_lambdas[i - 1] + ) x = x + alpha_t * (dt + s * noise_scaler(er_lambda_t)) * denoised_d if stage_used >= 3: # Stage 3 - s_u = torch.sum((lambda_pos - er_lambda_s) / scaled_pos) * lambda_step_size - denoised_u = (denoised_d - old_denoised_d) / ((er_lambda_s - er_lambdas[i - 2]) / 2) - x = x + alpha_t * ((dt ** 2) / 2 + s_u * noise_scaler(er_lambda_t)) * denoised_u + s_u = ( + torch.sum((lambda_pos - er_lambda_s) / scaled_pos) + * lambda_step_size + ) + denoised_u = (denoised_d - old_denoised_d) / ( + (er_lambda_s - er_lambdas[i - 2]) / 2 + ) + x = ( + x + + alpha_t + * ((dt**2) / 2 + s_u * noise_scaler(er_lambda_t)) + * denoised_u + ) old_denoised_d = denoised_d if s_noise > 0: - x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (er_lambda_t ** 2 - er_lambda_s ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0) + x = x + alpha_t * noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * ( + er_lambda_t**2 - er_lambda_s**2 * r**2 + ).sqrt().nan_to_num(nan=0.0) old_denoised = denoised return x @torch.no_grad() -def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"): +def sample_seeds_2( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + r=0.5, + solver_type="phi_1", +): """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) """ @@ -1597,11 +2838,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non raise ValueError("solver_type must be 'phi_1' or 'phi_2'") extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) inject_noise = eta > 0 and s_noise > 0 sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) @@ -1611,9 +2855,21 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non fac = 1 / (2 * r) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) + r = _dynamic_opts.get("r", r) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: x = denoised @@ -1629,51 +2885,116 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non alpha_t = sigmas[i + 1] * lambda_t.exp() # Step 1 - x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised + x_2 = ( + sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x + - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised + ) if inject_noise: - sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1) + sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler( + sigmas[i], sigma_s_1 + ) x_2 = x_2 + sde_noise * sigma_s_1 * s_noise denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 if solver_type == "phi_1": denoised_d = torch.lerp(denoised, denoised_2, fac) - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d + x = ( + sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x + - alpha_t * ei_h_phi_1(-h_eta) * denoised_d + ) elif solver_type == "phi_2": b2 = ei_h_phi_2(-h_eta) / r b1 = ei_h_phi_1(-h_eta) - b2 - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ( + b1 * denoised + b2 * denoised_2 + ) if inject_noise: segment_factor = (r - 1) * h * eta sde_noise = sde_noise * segment_factor.exp() - sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1]) + sde_noise = sde_noise + segment_factor.mul( + 2 + ).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1]) x = x + sde_noise * sigmas[i + 1] * s_noise return x + @torch.no_grad() -def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"): +def sample_exp_heun_2_x0( + model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2" +): """Deterministic exponential Heun second order method in data prediction (x0) and logSNR time.""" - return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type) + return sample_seeds_2( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=0.0, + s_noise=0.0, + noise_sampler=None, + r=1.0, + solver_type=solver_type, + ) @torch.no_grad() -def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"): +def sample_exp_heun_2_x0_sde( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + solver_type="phi_2", +): """Stochastic exponential Heun second order method in data prediction (x0) and logSNR time.""" - return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type) + return sample_seeds_2( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + eta=eta, + s_noise=s_noise, + noise_sampler=noise_sampler, + r=1.0, + solver_type=solver_type, + ) @torch.no_grad() -def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): +def sample_seeds_3( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + eta=1.0, + s_noise=1.0, + noise_sampler=None, + r_1=1.0 / 3, + r_2=2.0 / 3, +): """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) """ extra_args = {} if extra_args is None else extra_args + _dynamic_opts = _init_dynamic_options(extra_args) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) - model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') + model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0) inject_noise = eta > 0 and s_noise > 0 sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling) @@ -1681,9 +3002,20 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) for i in trange(len(sigmas) - 1, disable=disable): + if _dynamic_opts is not None: + s_noise = _dynamic_opts.get("s_noise", s_noise) + eta = _dynamic_opts.get("eta", eta) denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + callback( + { + "x": x, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: x = denoised @@ -1701,43 +3033,76 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non alpha_t = sigmas[i + 1] * lambda_t.exp() # Step 1 - x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised + x_2 = ( + sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x + - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised + ) if inject_noise: - sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1) + sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler( + sigmas[i], sigma_s_1 + ) x_2 = x_2 + sde_noise * sigma_s_1 * s_noise denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) # Step 2 a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta) a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2 - x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2) + x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * ( + a3_1 * denoised + a3_2 * denoised_2 + ) if inject_noise: segment_factor = (r_1 - r_2) * h * eta sde_noise = sde_noise * segment_factor.exp() - sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2) + sde_noise = sde_noise + segment_factor.mul( + 2 + ).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2) x_3 = x_3 + sde_noise * sigma_s_2 * s_noise denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) # Step 3 b3 = ei_h_phi_2(-h_eta) / r_2 b1 = ei_h_phi_1(-h_eta) - b3 - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ( + b1 * denoised + b3 * denoised_3 + ) if inject_noise: segment_factor = (r_2 - 1) * h * eta sde_noise = sde_noise * segment_factor.exp() - sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1]) + sde_noise = sde_noise + segment_factor.mul( + 2 + ).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1]) x = x + sde_noise * sigmas[i + 1] * s_noise return x @torch.no_grad() -def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, use_pece=False, simple_order_2=False): +def sample_sa_solver( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=False, + tau_func=None, + s_noise=1.0, + noise_sampler=None, + predictor_order=3, + corrector_order=4, + use_pece=False, + simple_order_2=False, +): """Stochastic Adams Solver with predictor-corrector method (NeurIPS 2023).""" if len(sigmas) <= 1: return x extra_args = {} if extra_args is None else extra_args + # Pop bounded-feedback dynamic options so they never reach the model call. + # We keep a reference to the same mutable dict — the callback updates it + # in-place and we re-read at each loop iteration. + _dynamic_opts = extra_args.pop("_dynamic_sampler_options", None) seed = extra_args.get("seed", None) - noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + noise_sampler = ( + default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler + ) s_in = x.new_ones([x.shape[0]]) model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling") @@ -1763,10 +3128,21 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F lower_order_to_end = sigmas[-1].item() == 0 for i in trange(len(sigmas) - 1, disable=disable): + # Re-read dynamic s_noise updated per-step by bounded-feedback. + s_noise = _apply_dynamic_s_noise(_dynamic_opts, model_sampling, s_noise) + # Evaluation denoised = model(x_pred, sigmas[i] * s_in, **extra_args) if callback is not None: - callback({"x": x_pred, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised}) + callback( + { + "x": x_pred, + "i": i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) pred_list.append(denoised) pred_list = pred_list[-max_used_order:] @@ -1785,7 +3161,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F # Update by the predicted state x = x_pred else: - curr_lambdas = lambdas[i - corrector_order_used + 1:i + 1] + curr_lambdas = lambdas[i - corrector_order_used + 1 : i + 1] b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs( sigmas[i], curr_lambdas, @@ -1795,9 +3171,11 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F simple_order_2, is_corrector_step=True, ) - pred_mat = torch.stack(pred_list[-corrector_order_used:], dim=1) # (B, K, ...) + pred_mat = torch.stack( + pred_list[-corrector_order_used:], dim=1 + ) # (B, K, ...) corr_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...) - x = sigmas[i] / sigmas[i - 1] * (-(tau_t ** 2) * h).exp() * x + corr_res + x = sigmas[i] / sigmas[i - 1] * (-(tau_t**2) * h).exp() * x + corr_res if tau_t > 0 and s_noise > 0: # The noise from the previous predictor step @@ -1814,7 +3192,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F x_pred = denoised else: tau_t = tau_func(sigmas[i + 1]) - curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1] + curr_lambdas = lambdas[i - predictor_order_used + 1 : i + 1] b_coeffs = sa_solver.compute_stochastic_adams_b_coeffs( sigmas[i + 1], curr_lambdas, @@ -1824,26 +3202,67 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F simple_order_2, is_corrector_step=False, ) - pred_mat = torch.stack(pred_list[-predictor_order_used:], dim=1) # (B, K, ...) + pred_mat = torch.stack( + pred_list[-predictor_order_used:], dim=1 + ) # (B, K, ...) pred_res = torch.tensordot(pred_mat, b_coeffs, dims=([1], [0])) # (B, ...) h = lambdas[i + 1] - lambdas[i] - x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t ** 2) * h).exp() * x + pred_res + x_pred = sigmas[i + 1] / sigmas[i] * (-(tau_t**2) * h).exp() * x + pred_res if tau_t > 0 and s_noise > 0: - noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise + noise = ( + noise_sampler(sigmas[i], sigmas[i + 1]) + * sigmas[i + 1] + * (-2 * tau_t**2 * h).expm1().neg().sqrt() + * s_noise + ) x_pred = x_pred + noise return x_pred @torch.no_grad() -def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): +def sample_sa_solver_pece( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=False, + tau_func=None, + s_noise=1.0, + noise_sampler=None, + predictor_order=3, + corrector_order=4, + simple_order_2=False, +): """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023).""" - return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2) + return sample_sa_solver( + model, + x, + sigmas, + extra_args=extra_args, + callback=callback, + disable=disable, + tau_func=tau_func, + s_noise=s_noise, + noise_sampler=noise_sampler, + predictor_order=predictor_order, + corrector_order=corrector_order, + use_pece=True, + simple_order_2=simple_order_2, + ) @torch.no_grad() -def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None, - num_frame_per_block=1): +def sample_ar_video( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + num_frame_per_block=1, +): """ Autoregressive video sampler: block-by-block denoising with KV cache and flow-match re-noising for Causal Forcing / Self-Forcing models. @@ -1867,7 +3286,10 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No inner_model = model.inner_model.inner_model causal_model = inner_model.diffusion_model - if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")): + if not ( + hasattr(causal_model, "init_kv_caches") + and hasattr(causal_model, "init_crossattn_caches") + ): raise TypeError( "ar_video sampler requires a Causal-WAN compatible model whose diffusion_model " "exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint " @@ -1877,12 +3299,14 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No seed = extra_args.get("seed", 0) bs, c, lat_t, lat_h, lat_w = x.shape - frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division - num_blocks = -(-lat_t // num_frame_per_block) # ceiling division + frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division + num_blocks = -(-lat_t // num_frame_per_block) # ceiling division device = x.device model_dtype = inner_model.get_dtype() - kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype) + kv_caches = causal_model.init_kv_caches( + bs, lat_t * frame_seq_len, device, model_dtype + ) crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype) output = torch.zeros_like(x) @@ -1890,13 +3314,21 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No current_start_frame = 0 # I2V: seed KV cache with the initial image latent before the denoising loop - initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None) + initial_latent = transformer_options.get("ar_config", {}).get( + "initial_latent", None + ) if initial_latent is not None: - initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype) + initial_latent = inner_model.process_latent_in(initial_latent).to( + device=device, dtype=model_dtype + ) n_init = initial_latent.shape[2] output[:, :, :n_init] = initial_latent - ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches} + ar_state = { + "start_frame": 0, + "kv_caches": kv_caches, + "crossattn_caches": crossattn_caches, + } transformer_options["ar_state"] = ar_state zero_sigma = sigmas.new_zeros([1]) _ = model(initial_latent, zero_sigma * s_in, **extra_args) @@ -1927,8 +3359,15 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No if callback is not None: scaled_i = step_count * num_sigma_steps // total_real_steps - callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], - "sigma_hat": sigmas[i], "denoised": denoised}) + callback( + { + "x": noisy_input, + "i": scaled_i, + "sigma": sigmas[i], + "sigma_hat": sigmas[i], + "denoised": denoised, + } + ) if sigmas[i + 1] == 0: noisy_input = denoised @@ -1936,7 +3375,9 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No sigma_next = sigmas[i + 1] torch.manual_seed(seed + block_idx * 1000 + i) fresh_noise = torch.randn_like(denoised) - noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise + noisy_input = ( + 1.0 - sigma_next + ) * denoised + sigma_next * fresh_noise for cache in kv_caches: cache["end"] -= bf * frame_seq_len diff --git a/comfy/samplers.py b/comfy/samplers.py index 25c5a855fd04..48a700cf5b9d 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -996,6 +996,12 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N if callback is not None: k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + # Expose mutable extra_options so sampler functions can re-read + # updated values at each step (e.g. s_noise varied by feedback). + # Only inject when the sampler has per-step feedback param functions, + # otherwise _dynamic_sampler_options would leak to the model call. + if hasattr(self, '_feedback_param_fns') and self._feedback_param_fns: + extra_args["_dynamic_sampler_options"] = self.extra_options samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) samples = model_wrap.inner_model.model_sampling.inverse_noise_scaling(sigmas[-1], samples) return samples diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 479ee8a53b87..c593197c5dae 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -111,6 +111,32 @@ def __init__(self, dynprompt): self.blocking = {} # Which nodes are blocked by this node self.externalBlocks = 0 self.unblockedEvent = asyncio.Event() + # Tracks bounded-feedback edges that were intentionally excluded from + # strong (blocking) links. Maps to_node_id -> list of (from_node_id, + # from_socket) so the execution layer can inject initial values for the + # iteration output that closes the cycle. + self.feedback_links = {} + + def _is_feedback_output(self, from_node_id, from_socket): + """Return True when *from_socket* of *from_node_id* is a declared + bounded-iteration output (``BOUNDED_FEEDBACK``).""" + try: + class_type = self.dynprompt.get_node(from_node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS.get(class_type) + except (NodeNotFoundError, KeyError): + return False + if class_def is None: + return False + bounded = getattr(class_def, 'BOUNDED_FEEDBACK', None) + if not bounded: + return False + # Map socket index to name via RETURN_NAMES, falling back to the raw index. + return_names = getattr(class_def, 'RETURN_NAMES', None) + idx = int(from_socket) + if return_names is not None and 0 <= idx < len(return_names): + return return_names[idx] in bounded + # If the socket is already a string (uncommon), check directly. + return str(from_socket) in bounded def get_input_info(self, unique_id, input_name): class_type = self.dynprompt.get_node(unique_id)["class_type"] @@ -163,6 +189,24 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): links.append((from_node_id, from_socket, unique_id)) for link in links: + from_node_id, from_socket, to_node_id = link + if self._is_feedback_output(from_node_id, from_socket): + # This edge carries an iteration variable (e.g. step_index) + # back upstream to close a bounded feedback cycle. Don't + # create a strong (blocking) link — that would deadlock the + # topological dissolve. Instead record it so the execution + # layer can seed the iteration output with an initial value. + if to_node_id not in self.feedback_links: + self.feedback_links[to_node_id] = [] + self.feedback_links[to_node_id].append((from_node_id, from_socket)) + # Still ensure the source node is in the graph. + self.add_node(from_node_id) + # Create a cache link so the downstream node can read the + # placeholder value injected into the output cache by the + # execution bootstrap (only available on ExecutionList). + if hasattr(self, 'cache_link'): + self.cache_link(from_node_id, to_node_id) + continue self.add_strong_link(*link) def add_external_block(self, node_id): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index c9d7e06fc3a0..8cb758f8794c 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -1011,6 +1011,10 @@ def execute(cls, noise_seed) -> io.NodeOutput: class SamplerCustomAdvanced(io.ComfyNode): + # Declare which outputs are bounded iteration variables that may feed back + # through the graph to control upstream parameters (e.g. step_index -> cfg). + BOUNDED_FEEDBACK = {"step_index"} + @classmethod def define_schema(cls): return io.Schema( @@ -1026,6 +1030,7 @@ def define_schema(cls): outputs=[ io.Latent.Output(display_name="output"), io.Latent.Output(display_name="denoised_output"), + io.Int.Output(display_name="step_index"), ] ) @@ -1041,8 +1046,30 @@ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput: if "noise_mask" in latent: noise_mask = latent["noise_mask"] + total_steps = sigmas.shape[-1] - 1 x0_output = {} - callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) + callback = latent_preview.prepare_callback(guider.model_patcher, total_steps, x0_output) + + # ---- bounded-feedback per-step updates ---- + # The execution engine may have injected per-step update functions + # onto the guider and/or sampler objects. Wrap the callback to + # apply them before the *next* sampling step. The k-diffusion + # callback fires *after* the model call for step i, so we pass + # i+1 so that step N uses parameters computed with a=N. + cfg_fn = getattr(guider, '_feedback_cfg_fn', None) + param_fns = getattr(sampler, '_feedback_param_fns', None) + _has_feedback = cfg_fn is not None or param_fns + if _has_feedback: + _orig_callback = callback + def _feedback_callback(step, x0, x, total_steps): + if cfg_fn is not None: + guider.cfg = cfg_fn(step + 1, total_steps) + if param_fns is not None: + for key, fn in param_fns.items(): + sampler.extra_options[key] = fn(step + 1, total_steps) + _orig_callback(step, x0, x, total_steps) + callback = _feedback_callback + # ---------------------------------------------------- disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) @@ -1061,7 +1088,7 @@ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput: out_denoised["samples"] = x0_out else: out_denoised = out - return io.NodeOutput(out, out_denoised) + return io.NodeOutput(out, out_denoised, total_steps) sample = execute diff --git a/execution.py b/execution.py index 9e16e451d8a7..c265cbb499f2 100644 --- a/execution.py +++ b/execution.py @@ -110,6 +110,21 @@ class CacheType(Enum): RAM_PRESSURE = 3 +# Initial values for bounded-feedback iteration outputs keyed by ComfyUI type +# string. When the DAG contains a feedback loop (e.g. step_index → … → cfg +# → guider → sampler) the execution engine seeds the iteration output with +# the default listed here so the downstream chain can evaluate before the +# iteration-producing node runs. +_FEEDBACK_DEFAULTS = { + "INT": 0, + "FLOAT": 0.0, + "BOOLEAN": False, + "STRING": "", + "NUMBER": 0, + "PRIMITIVE": 0, +} + + class CacheSet: def __init__(self, cache_type=None, cache_args={}): if cache_type == CacheType.NONE: @@ -176,12 +191,28 @@ def mark_missing(): continue # This might be a lazily-evaluated input cached = execution_list.get_cache(input_unique_id, unique_id) if cached is None or cached.outputs is None: - mark_missing() + # If this is a bounded-feedback link whose source hasn't + # executed yet, supply the type-appropriate initial value + # (e.g. step_index=0) so the feedback chain can evaluate + # before the iteration-producing node runs. + if _is_feedback_link(execution_list, unique_id, input_unique_id, output_index): + default_val = _get_feedback_default(dynprompt, input_unique_id, output_index) + obj = default_val + if isinstance(obj, (int, float, bool, str)): + obj = (obj,) + input_data_all[x] = obj + else: + mark_missing() continue if output_index >= len(cached.outputs): mark_missing() continue obj = cached.outputs[output_index] + # Wrap atomic types (int, float, bool, str) in a tuple so + # _async_map_node_over_list can call len() on every input. + # The slice_dict helper then unwraps: (val,)[0] == val. + if isinstance(obj, (int, float, bool, str)): + obj = (obj,) input_data_all[x] = obj elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS): input_data_all[x] = [input_data] @@ -658,6 +689,209 @@ async def await_completion(): return (ExecutionResult.SUCCESS, None, None) + +def _is_feedback_link(execution_list, to_node_id, from_node_id, from_socket): + """Return True when *to_node_id* receives *from_node_id*:*from_socket* + through a bounded-feedback edge (recorded during graph construction).""" + edges = execution_list.feedback_links.get(to_node_id, []) + return (from_node_id, from_socket) in edges + + +def _get_feedback_default(dynprompt, from_node_id, from_socket): + """Return the type-appropriate initial value for a feedback iteration + output (e.g. 0 for INT, 0.0 for FLOAT).""" + try: + class_type = dynprompt.get_node(from_node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return_types = class_def.RETURN_TYPES + except Exception: + return 0 + if from_socket < len(return_types): + return _FEEDBACK_DEFAULTS.get(return_types[from_socket], 0) + return 0 + + +def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id, + cfg_injections, sampler_injections): + """Try to build per-step update functions from a feedback edge. + + Walks forward from the feedback-receiving node through intermediate + ComfyMathExpression nodes to find targets that need per-step callables. + Handles two target types: + + * **CFGGuider** — populates *cfg_injections* keyed by guider node id + with a ``cfg_fn(step, total_steps)`` callable. + * **Sampler-producing nodes** (any node whose class_type starts with + "Sampler" except the iteration node itself) — populates + *sampler_injections* keyed by (sampler_node_id, param_name) with a + ``param_fn(step, total_steps)`` callable. + + Supports multi-hop chains like:: + + iteration_node ──(step_index)──→ MathExpr_A ──→ MathExpr_B ──→ CFGGuider + ├─→ SamplerXXX + └─→ ... + """ + try: + prompt = dynamic_prompt.original_prompt + except Exception: + return + + from simpleeval import simple_eval + from comfy_extras.nodes_math import MATH_FUNCTIONS + + # ---- helpers ---- + def _find_consumers(source_id): + consumers = [] + for nid, n in prompt.items(): + for iname, ival in n.get("inputs", {}).items(): + if isinstance(ival, list) and len(ival) == 2 \ + and ival[0] == source_id and ival[1] == 0: + consumers.append((nid, n.get("class_type"), iname)) + return consumers + + def _is_sampler_target(class_type): + # Sampler-producing nodes whose parameters can be updated per-step + # via KSAMPLER.extra_options. + return (class_type is not None + and "Sampler" in class_type + and class_type != "SamplerCustomAdvanced") + + def _resolve_input_value(source_node_id, source_socket): + """Try to resolve a non-feedback linked input to a static value. + + First checks the source node's ``inputs`` dict (API format) for a + direct scalar value at the socket. Falls back to ``widgets_values`` + positional mapping (workflow-file format). Returns the resolved + value, or None if unresolvable. + """ + try: + snode = prompt.get(str(source_node_id)) + if snode is None: + return None + class_type = snode.get("class_type", "") + inputs = snode.get("inputs", {}) + + # API format: inputs are named — find the name that maps to + # *source_socket* via the class's INPUT_TYPES ordering. + cls = nodes.NODE_CLASS_MAPPINGS.get(class_type) + if cls is not None: + try: + input_types = cls.INPUT_TYPES() + except Exception: + input_types = {} + required = input_types.get("required", {}) + req_names = list(required.keys()) + if source_socket < len(req_names): + name = req_names[source_socket] + val = inputs.get(name) + if val is not None and not isinstance(val, list): + return val + + # Fallback: widgets_values positional mapping (workflow-file format) + wv = snode.get("widgets_values", []) + if wv: + if class_type in ("PrimitiveInt", "PrimitiveFloat", "PrimitiveBool"): + if source_socket == 0 and len(wv) > 0: + return wv[0] + if cls is not None and source_socket < len(req_names) and source_socket < len(wv): + return wv[source_socket] + return None + except Exception: + return None + + def _collect_extra_names(node_id, feedback_from_node, feedback_from_socket, + feedback_var_name): + """Collect non-feedback linked inputs from a MathExpression node + and resolve them to values. Returns dict of name→value.""" + extra = {} + try: + snode = prompt.get(str(node_id)) + if snode is None: + return extra + for inp_name, inp_val in snode.get("inputs", {}).items(): + if not isinstance(inp_val, list) or len(inp_val) != 2: + continue + src_id, src_socket = inp_val[0], inp_val[1] + # Skip the feedback-linked input — that's the iteration variable + if (src_id == str(feedback_from_node) + and int(src_socket) == int(feedback_from_socket)): + continue + # This is an additional linked input — try to resolve it + val = _resolve_input_value(src_id, src_socket) + if val is not None: + var_name = inp_name.rsplit(".", 1)[-1] + extra[var_name] = val + except Exception: + pass + return extra + + # Each chain element is now (expression, feedback_var, extra_names_dict) + # ---- depth-first search ---- + def _dfs(start_id, from_node, from_socket, chain): + """Walk the MathExpr chain looking for any target node that needs + per-step updates. Returns a list of (target_type, target_id, + input_name, full_chain) tuples, where target_type is 'guider' + or 'sampler'.""" + try: + node = dynamic_prompt.get_node(start_id) + except Exception: + return [] + if node.get("class_type") != "ComfyMathExpression": + return [] + + expression = node.get("inputs", {}).get("expression", "") + if not expression or not expression.strip(): + return [] + + var_name = None + for input_name, input_val in node.get("inputs", {}).items(): + if isinstance(input_val, list) and len(input_val) == 2 \ + and input_val[0] == from_node and input_val[1] == from_socket: + var_name = input_name.rsplit(".", 1)[-1] + break + if var_name is None: + return [] + + # Collect additional (non-feedback) input values for this node + extra_names = _collect_extra_names(start_id, from_node, from_socket, + var_name) + + new_chain = chain + [(expression, var_name, extra_names)] + results = [] + + for cid, ctype, ciname in _find_consumers(start_id): + if ctype == "CFGGuider": + results.append(("guider", cid, None, new_chain)) + elif _is_sampler_target(ctype): + results.append(("sampler", cid, ciname, new_chain)) + elif ctype == "ComfyMathExpression": + results.extend(_dfs(cid, start_id, 0, new_chain)) + return results + + # ---- compose functions from discovered chains ---- + for target_type, target_id, param_name, chain in \ + _dfs(to_node_id, from_node_id, from_socket, []): + if not chain: + continue + + def _make_fn(_chain): + def _fn(step, total_steps): + val = step + for expr_str, var, extra_names in _chain: + ctx = dict(extra_names) if extra_names else {} + ctx[var] = val + val = float(simple_eval(expr_str, names=ctx, functions=MATH_FUNCTIONS)) + return val + return _fn + + if target_type == "guider": + cfg_injections[target_id] = _make_fn(chain) + elif target_type == "sampler" and param_name: + sampler_injections[target_id] = sampler_injections.get(target_id, {}) + sampler_injections[target_id][param_name] = _make_fn(chain) + + class PromptExecutor: def __init__(self, server, cache_type=False, cache_args=None): self.cache_args = cache_args @@ -774,6 +1008,26 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= for node_id in list(execute_outputs): execution_list.add_node(node_id) + # ---- bounded-feedback bootstrap --------------------------------- + # Build per-step update functions for feedback chains that + # pass through ComfyMathExpression → CFGGuider / SamplerXXX. + # These are injected into the guider / sampler after the + # target node executes so the sampler can vary parameters + # (cfg, s_noise, ...) with step_index. + _feedback_cfg_injections = {} # guider_node_id → cfg_fn + _feedback_sampler_injections = {} # sampler_node_id → {param: fn} + for to_node_id, edges in execution_list.feedback_links.items(): + for from_node_id, from_socket in edges: + try: + _build_feedback_fns( + dynamic_prompt, from_node_id, from_socket, + to_node_id, _feedback_cfg_injections, + _feedback_sampler_injections, + ) + except Exception: + pass # non-critical – feedback just wonʼt vary per step + # ----------------------------------------------------------------- + while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() if error is not None: @@ -789,6 +1043,29 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= elif result == ExecutionResult.PENDING: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: + # ---- bounded-feedback injection ---- + # If this node just produced a guider or sampler + # that is part of a feedback cycle, inject per-step + # update function(s). + if node_id in _feedback_cfg_injections: + try: + output = self.caches.outputs.get_local(node_id) + if output is not None and output.outputs is not None \ + and len(output.outputs) > 0 and len(output.outputs[0]) > 0: + guider = output.outputs[0][0] + guider._feedback_cfg_fn = _feedback_cfg_injections[node_id] + except Exception: + pass + if node_id in _feedback_sampler_injections: + try: + output = self.caches.outputs.get_local(node_id) + if output is not None and output.outputs is not None \ + and len(output.outputs) > 0 and len(output.outputs[0]) > 0: + sampler_obj = output.outputs[0][0] + sampler_obj._feedback_param_fns = _feedback_sampler_injections[node_id] + except Exception: + pass + # --------------------------------------- execution_list.complete_node_execution() if self.cache_type == CacheType.RAM_PRESSURE: @@ -831,6 +1108,34 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= self._notify_prompt_lifecycle("end", prompt_id) +def _is_bounded_feedback_cycle(prompt, visiting, unique_id): + """Check whether a detected dependency cycle is a *bounded* feedback loop. + + A cycle is bounded when at least one node in it declares ``BOUNDED_FEEDBACK``, + i.e. the node has a finite internal iteration whose step / index variable + feeds back upstream to control its own parameters (e.g. a sampler's + ``step_index`` flowing through a math expression to set ``cfg``). + + Because the iteration is bounded (N steps, then terminates) this isn't an + infinite cycle — the DAG can safely allow it and the execution engine will + break the feedback edge by seeding the iteration output with an initial value. + """ + cycle_nodes = visiting[visiting.index(unique_id):] + [unique_id] + for node_id in cycle_nodes: + if node_id not in prompt: + continue + class_type = prompt[node_id].get('class_type') + if class_type is None: + continue + obj_class = nodes.NODE_CLASS_MAPPINGS.get(class_type) + if obj_class is None: + continue + bounded = getattr(obj_class, 'BOUNDED_FEEDBACK', None) + if bounded: + return True + return False + + async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): if visiting is None: visiting = [] @@ -842,6 +1147,19 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): if unique_id in visiting: cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id] cycle_nodes = list(dict.fromkeys(cycle_path_nodes)) + + # A bounded feedback cycle is one where at least one node in the cycle + # declares BOUNDED_FEEDBACK — meaning its internal iteration is finite + # and its iteration output(s) can safely flow back upstream without + # causing an infinite loop (e.g. a sampler's step_index controlling cfg). + if _is_bounded_feedback_cycle(prompt, visiting, unique_id): + # Mark the repeated node as valid and continue the traversal on + # other branches. The execution layer handles the feedback edge + # by breaking it and seeding the iteration output with an initial + # value (e.g. step_index = 0). + validated[unique_id] = (True, [], unique_id) + return validated[unique_id] + cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes) for node_id in cycle_nodes: validated[node_id] = (False, [{ From 4fece5b6b59e6263e91a933a5cd764709610b93c Mon Sep 17 00:00:00 2001 From: PR Author Date: Fri, 19 Jun 2026 19:43:41 +0800 Subject: [PATCH 2/2] Add docstrings to feedback helpers; parameterize _find_consumers socket - _FEEDBACK_DEFAULTS: add module-level docstring comment - _find_consumers: add source_socket parameter (default 0), add docstring - _is_sampler_target: add docstring --- execution.py | 783 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 568 insertions(+), 215 deletions(-) diff --git a/execution.py b/execution.py index c265cbb499f2..69e21cc4bba0 100644 --- a/execution.py +++ b/execution.py @@ -1,33 +1,42 @@ +import asyncio import copy import heapq import inspect import logging -import psutil import sys import threading import time import traceback from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union -import asyncio +import comfy_aimdo.model_vbar +import psutil import torch -from comfy.cli_args import args import comfy.memory_management import comfy.model_management import comfy.model_prefetch -import comfy_aimdo.model_vbar - -from latent_preview import set_preview_method import nodes +from comfy.cli_args import args +from comfy_api.internal import ( + _ComfyNodeInternal, + _NodeOutputInternal, + first_real_override, + is_class, + make_locked_method_func, +) +from comfy_api.latest import _io, io +from comfy_execution.asset_enrichment import enrich_output_with_assets +from comfy_execution.cache_provider import _get_cache_providers, _has_cache_providers +from comfy_execution.cache_provider import _logger as _cache_logger from comfy_execution.caching import ( BasicCache, CacheKeySetID, CacheKeySetInputSignature, - NullCache, HierarchicalCache, LRUCache, + NullCache, RAMPressureCache, ) from comfy_execution.graph import ( @@ -37,13 +46,15 @@ get_input_info, ) from comfy_execution.graph_utils import GraphBuilder, is_link -from comfy_execution.validation import validate_node_input -from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler +from comfy_execution.progress import ( + WebUIProgressHandler, + add_progress_handler, + get_progress_state, + reset_progress_state, +) from comfy_execution.utils import CurrentNodeContext -from comfy_execution.asset_enrichment import enrich_output_with_assets -from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func -from comfy_api.latest import io, _io -from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger +from comfy_execution.validation import validate_node_input +from latent_preview import set_preview_method class ExecutionResult(Enum): @@ -51,11 +62,15 @@ class ExecutionResult(Enum): FAILURE = 1 PENDING = 2 + class DuplicateNodeError(Exception): pass + class IsChangedCache: - def __init__(self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache): + def __init__( + self, prompt_id: str, dynprompt: DynamicPrompt, outputs_cache: BasicCache + ): self.prompt_id = prompt_id self.dynprompt = dynprompt self.outputs_cache = outputs_cache @@ -70,7 +85,10 @@ async def get(self, node_id): class_def = nodes.NODE_CLASS_MAPPINGS[class_type] has_is_changed = False is_changed_name = None - if issubclass(class_def, _ComfyNodeInternal) and first_real_override(class_def, "fingerprint_inputs") is not None: + if ( + issubclass(class_def, _ComfyNodeInternal) + and first_real_override(class_def, "fingerprint_inputs") is not None + ): has_is_changed = True is_changed_name = "fingerprint_inputs" elif hasattr(class_def, "IS_CHANGED"): @@ -85,11 +103,22 @@ async def get(self, node_id): return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, v3_data = get_input_data( + node["inputs"], class_def, node_id, None + ) try: - is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data) + is_changed = await _async_map_node_over_list( + self.prompt_id, + node_id, + class_def, + input_data_all, + is_changed_name, + v3_data=v3_data, + ) is_changed = await resolve_map_node_over_list_results(is_changed) - node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed] + node["is_changed"] = [ + None if isinstance(x, ExecutionBlocker) else x for x in is_changed + ] except Exception as e: logging.warning("WARNING: {}".format(e)) node["is_changed"] = float("NaN") @@ -145,15 +174,21 @@ def __init__(self, cache_type=None, cache_args={}): # Performs like the old cache -- dump data ASAP def init_classic_cache(self): - self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True) + self.outputs = HierarchicalCache( + CacheKeySetInputSignature, enable_providers=True + ) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True) + self.outputs = LRUCache( + CacheKeySetInputSignature, max_size=cache_size, enable_providers=True + ) self.objects = HierarchicalCache(CacheKeySetID) def init_ram_cache(self, min_headroom): - self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True) + self.outputs = RAMPressureCache( + CacheKeySetInputSignature, enable_providers=True + ) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): @@ -166,37 +201,51 @@ def recursive_debug_dump(self): } return result + SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") -def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): + +def get_input_data( + inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={} +): is_v3 = issubclass(class_def, _ComfyNodeInternal) v3_data: io.V3Data = {} hidden_inputs_v3 = {} valid_inputs = class_def.INPUT_TYPES() if is_v3: - valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs(valid_inputs, inputs) + valid_inputs, hidden, v3_data = _io.get_finalized_class_inputs( + valid_inputs, inputs + ) input_data_all = {} missing_keys = {} for x in inputs: input_data = inputs[x] _, input_category, input_info = get_input_info(class_def, x, valid_inputs) + def mark_missing(): missing_keys[x] = True input_data_all[x] = (None,) - if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): + + if is_link(input_data) and ( + not input_info or not input_info.get("rawLink", False) + ): input_unique_id = input_data[0] output_index = input_data[1] if execution_list is None: mark_missing() - continue # This might be a lazily-evaluated input + continue # This might be a lazily-evaluated input cached = execution_list.get_cache(input_unique_id, unique_id) if cached is None or cached.outputs is None: # If this is a bounded-feedback link whose source hasn't # executed yet, supply the type-appropriate initial value # (e.g. step_index=0) so the feedback chain can evaluate # before the iteration-producing node runs. - if _is_feedback_link(execution_list, unique_id, input_unique_id, output_index): - default_val = _get_feedback_default(dynprompt, input_unique_id, output_index) + if _is_feedback_link( + execution_list, unique_id, input_unique_id, output_index + ): + default_val = _get_feedback_default( + dynprompt, input_unique_id, output_index + ) obj = default_val if isinstance(obj, (int, float, bool, str)): obj = (obj,) @@ -220,29 +269,41 @@ def mark_missing(): if is_v3: if hidden is not None: if io.Hidden.prompt.name in hidden: - hidden_inputs_v3[io.Hidden.prompt] = dynprompt.get_original_prompt() if dynprompt is not None else {} + hidden_inputs_v3[io.Hidden.prompt] = ( + dynprompt.get_original_prompt() if dynprompt is not None else {} + ) if io.Hidden.dynprompt.name in hidden: hidden_inputs_v3[io.Hidden.dynprompt] = dynprompt if io.Hidden.extra_pnginfo.name in hidden: - hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get('extra_pnginfo', None) + hidden_inputs_v3[io.Hidden.extra_pnginfo] = extra_data.get( + "extra_pnginfo", None + ) if io.Hidden.unique_id.name in hidden: hidden_inputs_v3[io.Hidden.unique_id] = unique_id if io.Hidden.auth_token_comfy_org.name in hidden: - hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) + hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get( + "auth_token_comfy_org", None + ) if io.Hidden.api_key_comfy_org.name in hidden: - hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) + hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get( + "api_key_comfy_org", None + ) if io.Hidden.comfy_usage_source.name in hidden: - hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get("comfy_usage_source", None) + hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get( + "comfy_usage_source", None + ) else: if "hidden" in valid_inputs: h = valid_inputs["hidden"] for x in h: if h[x] == "PROMPT": - input_data_all[x] = [dynprompt.get_original_prompt() if dynprompt is not None else {}] + input_data_all[x] = [ + dynprompt.get_original_prompt() if dynprompt is not None else {} + ] if h[x] == "DYNPROMPT": input_data_all[x] = [dynprompt] if h[x] == "EXTRA_PNGINFO": - input_data_all[x] = [extra_data.get('extra_pnginfo', None)] + input_data_all[x] = [extra_data.get("extra_pnginfo", None)] if h[x] == "UNIQUE_ID": input_data_all[x] = [unique_id] if h[x] == "AUTH_TOKEN_COMFY_ORG": @@ -254,7 +315,9 @@ def mark_missing(): v3_data["hidden_inputs"] = hidden_inputs_v3 return input_data_all, missing_keys, v3_data -map_node_over_list = None #Don't hook this please + +map_node_over_list = None # Don't hook this please + async def resolve_map_node_over_list_results(results): remaining = [x for x in results if isinstance(x, asyncio.Task) and not x.done()] @@ -268,7 +331,18 @@ async def resolve_map_node_over_list_results(results): raise exc return [x.result() if isinstance(x, asyncio.Task) else x for x in results] -async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None): + +async def _async_map_node_over_list( + prompt_id, + unique_id, + obj, + input_data_all, + func, + allow_interrupt=False, + execution_block_cb=None, + pre_execute_cb=None, + v3_data=None, +): # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -282,6 +356,7 @@ def slice_dict(d, i): return {k: v[i if len(v) > i else -1] for k, v in d.items()} results = [] + async def process_inputs(inputs, index=None, input_is_list=False): if allow_interrupt: nodes.before_node_execution() @@ -299,7 +374,9 @@ async def process_inputs(inputs, index=None, input_is_list=False): if pre_execute_cb is not None and index is not None: pre_execute_cb(index) # V3 - if isinstance(obj, _ComfyNodeInternal) or (is_class(obj) and issubclass(obj, _ComfyNodeInternal)): + if isinstance(obj, _ComfyNodeInternal) or ( + is_class(obj) and issubclass(obj, _ComfyNodeInternal) + ): # if is just a class, then assign no state, just create clone if is_class(obj): type_obj = obj @@ -318,10 +395,14 @@ async def process_inputs(inputs, index=None, input_is_list=False): else: f = getattr(obj, func) if inspect.iscoroutinefunction(f): + async def async_wrapper(f, prompt_id, unique_id, list_index, args): with CurrentNodeContext(prompt_id, unique_id, list_index): return await f(**args) - task = asyncio.create_task(async_wrapper(f, prompt_id, unique_id, index, args=inputs)) + + task = asyncio.create_task( + async_wrapper(f, prompt_id, unique_id, index, args=inputs) + ) # Give the task a chance to execute without yielding await asyncio.sleep(0) if task.done(): @@ -368,14 +449,36 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None): - return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) - has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) + +async def get_output_data( + prompt_id, + unique_id, + obj, + input_data_all, + execution_block_cb=None, + pre_execute_cb=None, + v3_data=None, +): + return_values = await _async_map_node_over_list( + prompt_id, + unique_id, + obj, + input_data_all, + obj.FUNCTION, + allow_interrupt=True, + execution_block_cb=execution_block_cb, + pre_execute_cb=pre_execute_cb, + v3_data=v3_data, + ) + has_pending_task = any( + isinstance(r, asyncio.Task) and not r.done() for r in return_values + ) if has_pending_task: return return_values, {}, False, has_pending_task output, ui, has_subgraph = get_output_from_returns(return_values, obj) return output, ui, has_subgraph, False + def get_output_from_returns(return_values, obj): results = [] uis = [] @@ -384,17 +487,17 @@ def get_output_from_returns(return_values, obj): for i in range(len(return_values)): r = return_values[i] if isinstance(r, dict): - if 'ui' in r: - uis.append(r['ui']) - if 'expand' in r: + if "ui" in r: + uis.append(r["ui"]) + if "expand" in r: # Perform an expansion, but do not append results has_subgraph = True - new_graph = r['expand'] + new_graph = r["expand"] result = r.get("result", None) if isinstance(result, ExecutionBlocker): result = tuple([result] * len(obj.RETURN_TYPES)) subgraph_results.append((new_graph, result)) - elif 'result' in r: + elif "result" in r: result = r.get("result", None) if isinstance(result, ExecutionBlocker): result = tuple([result] * len(obj.RETURN_TYPES)) @@ -412,12 +515,16 @@ def get_output_from_returns(return_values, obj): new_graph = r.expand result = r.result if r.block_execution is not None: - result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + result = tuple( + [ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES) + ) subgraph_results.append((new_graph, result)) elif r.result is not None: result = r.result if r.block_execution is not None: - result = tuple([ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES)) + result = tuple( + [ExecutionBlocker(r.block_execution)] * len(obj.RETURN_TYPES) + ) results.append(result) subgraph_results.append((None, result)) else: @@ -441,6 +548,7 @@ def get_output_from_returns(return_values, obj): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui, has_subgraph + def format_value(x): if x is None: return None @@ -449,31 +557,56 @@ def format_value(x): else: return str(x) + def _is_intermediate_output(dynprompt, node_id): class_type = dynprompt.get_node(node_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False) + return getattr(class_def, "HAS_INTERMEDIATE_OUTPUT", False) def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs): if server.client_id is None: return cached_ui = cached.ui or {} - server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id) + server.send_sync( + "executed", + { + "node": node_id, + "display_node": display_node_id, + "output": cached_ui.get("output", None), + "prompt_id": prompt_id, + }, + server.client_id, + ) if cached.ui is not None: ui_outputs[node_id] = cached.ui -async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): + +async def execute( + server, + dynprompt, + caches, + current_item, + extra_data, + executed, + prompt_id, + execution_list, + pending_subgraph_results, + pending_async_nodes, + ui_outputs, +): unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) parent_node_id = dynprompt.get_parent_node_id(unique_id) - inputs = dynprompt.get_node(unique_id)['inputs'] - class_type = dynprompt.get_node(unique_id)['class_type'] + inputs = dynprompt.get_node(unique_id)["inputs"] + class_type = dynprompt.get_node(unique_id)["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] cached = await caches.outputs.get(unique_id) if cached is not None: - _send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs) + _send_cached_ui( + server, unique_id, display_node_id, cached, prompt_id, ui_outputs + ) get_progress_state().finish_progress(unique_id) execution_list.cache_update(unique_id, cached) return (ExecutionResult.SUCCESS, None, None) @@ -493,7 +626,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, else: results.append(r) del pending_async_nodes[unique_id] - output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def) + output_data, output_ui, has_subgraph = get_output_from_returns( + results, class_def + ) elif unique_id in pending_subgraph_results: cached_results = pending_subgraph_results[unique_id] resolved_outputs = [] @@ -505,7 +640,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_cached = execution_list.get_cache(source_node, unique_id) + node_cached = execution_list.get_cache( + source_node, unique_id + ) for o in node_cached.outputs[source_output]: resolved_output.append(o) @@ -518,10 +655,20 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) + input_data_all, missing_keys, v3_data = get_input_data( + inputs, class_def, unique_id, execution_list, dynprompt, extra_data + ) if server.client_id is not None: server.last_node_id = display_node_id - server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) + server.send_sync( + "executing", + { + "node": unique_id, + "display_node": display_node_id, + "prompt_id": prompt_id, + }, + server.client_id, + ) obj = await caches.objects.get(unique_id) if obj is None: @@ -529,19 +676,38 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, await caches.objects.set(unique_id, obj) if issubclass(class_def, _ComfyNodeInternal): - lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None + lazy_status_present = ( + first_real_override(class_def, "check_lazy_status") is not None + ) else: - lazy_status_present = getattr(obj, "check_lazy_status", None) is not None + lazy_status_present = ( + getattr(obj, "check_lazy_status", None) is not None + ) if lazy_status_present: # for check_lazy_status, the returned data should include the original key of the input v3_data_lazy = v3_data.copy() v3_data_lazy["create_dynamic_tuple"] = True - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data_lazy) - required_inputs = await resolve_map_node_over_list_results(required_inputs) - required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) - required_inputs = [x for x in required_inputs if isinstance(x,str) and ( - x not in input_data_all or x in missing_keys - )] + required_inputs = await _async_map_node_over_list( + prompt_id, + unique_id, + obj, + input_data_all, + "check_lazy_status", + allow_interrupt=True, + v3_data=v3_data_lazy, + ) + required_inputs = await resolve_map_node_over_list_results( + required_inputs + ) + required_inputs = set( + sum([r for r in required_inputs if isinstance(r, list)], []) + ) + required_inputs = [ + x + for x in required_inputs + if isinstance(x, str) + and (x not in input_data_all or x in missing_keys) + ] if len(required_inputs) > 0: for i in required_inputs: execution_list.make_input_strong_link(unique_id, i) @@ -554,7 +720,6 @@ def execution_block_cb(block): "node_id": unique_id, "node_type": class_type, "executed": list(executed), - "exception_message": f"Execution Blocked: {block.message}", "exception_type": "ExecutionBlocked", "traceback": [], @@ -565,12 +730,26 @@ def execution_block_cb(block): return ExecutionBlocker(None) else: return block + def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) try: - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + ( + output_data, + output_ui, + has_subgraph, + has_pending_tasks, + ) = await get_output_data( + prompt_id, + unique_id, + obj, + input_data_all, + execution_block_cb=execution_block_cb, + pre_execute_cb=pre_execute_cb, + v3_data=v3_data, + ) finally: if comfy.memory_management.aimdo_enabled: if args.verbose == "DEBUG": @@ -582,10 +761,12 @@ def pre_execute_cb(call_index): if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) + async def await_completion(): tasks = [x for x in output_data if isinstance(x, asyncio.Task)] await asyncio.gather(*tasks, return_exceptions=True) unblock() + asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: @@ -600,10 +781,19 @@ async def await_completion(): "parent_node": parent_node_id, "real_node_id": real_node_id, }, - "output": output_ui + "output": output_ui, } if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.send_sync( + "executed", + { + "node": unique_id, + "display_node": display_node_id, + "output": output_ui, + "prompt_id": prompt_id, + }, + server.client_id, + ) if has_subgraph: cached_outputs = [] new_node_ids = [] @@ -617,15 +807,23 @@ async def await_completion(): for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) - dynprompt.add_ephemeral_node(node_id, node_info, unique_id, display_id) + dynprompt.add_ephemeral_node( + node_id, node_info, unique_id, display_id + ) # Figure out if the newly created node is an output node class_type = node_info["class_type"] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + if ( + hasattr(class_def, "OUTPUT_NODE") + and class_def.OUTPUT_NODE == True + ): new_output_ids.append(node_id) for i in range(len(node_outputs)): if is_link(node_outputs[i]): - from_node_id, from_socket = node_outputs[i][0], node_outputs[i][1] + from_node_id, from_socket = ( + node_outputs[i][0], + node_outputs[i][1], + ) new_output_links.append((from_node_id, from_socket)) cached_outputs.append((True, node_outputs)) new_node_ids = set(new_node_ids) @@ -668,18 +866,26 @@ async def await_completion(): if comfy.model_management.is_oom(ex): tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." - logging.info("Memory summary:\n{}".format(comfy.model_management.debug_memory_summary())) + logging.info( + "Memory summary:\n{}".format( + comfy.model_management.debug_memory_summary() + ) + ) logging.error("Got an OOM, unloading all loaded models.") comfy.model_management.unload_all_models() - elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type: - tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected." + elif ( + isinstance(ex, RuntimeError) + and ("mat1 and mat2 shapes" in str(ex)) + and "Sampler" in class_type + ): + tips = '\n\nTIPS: If you have any "Load CLIP" or "*CLIP Loader" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected.' error_details = { "node_id": real_node_id, "exception_message": "{}\n{}".format(ex, tips), "exception_type": exception_type, "traceback": traceback.format_tb(tb), - "current_inputs": input_data_formatted + "current_inputs": input_data_formatted, } return (ExecutionResult.FAILURE, error_details, ex) @@ -711,8 +917,14 @@ def _get_feedback_default(dynprompt, from_node_id, from_socket): return 0 -def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id, - cfg_injections, sampler_injections): +def _build_feedback_fns( + dynamic_prompt, + from_node_id, + from_socket, + to_node_id, + cfg_injections, + sampler_injections, +): """Try to build per-step update functions from a feedback edge. Walks forward from the feedback-receiving node through intermediate @@ -738,24 +950,35 @@ def _build_feedback_fns(dynamic_prompt, from_node_id, from_socket, to_node_id, return from simpleeval import simple_eval + from comfy_extras.nodes_math import MATH_FUNCTIONS # ---- helpers ---- - def _find_consumers(source_id): + def _find_consumers(source_id, source_socket=0): + """Return all nodes consuming output *source_socket* of *source_id*. + + Each result is ``(consumer_id, class_type, input_name)``. + """ consumers = [] for nid, n in prompt.items(): for iname, ival in n.get("inputs", {}).items(): - if isinstance(ival, list) and len(ival) == 2 \ - and ival[0] == source_id and ival[1] == 0: + if ( + isinstance(ival, list) + and len(ival) == 2 + and ival[0] == source_id + and ival[1] == source_socket + ): consumers.append((nid, n.get("class_type"), iname)) return consumers def _is_sampler_target(class_type): - # Sampler-producing nodes whose parameters can be updated per-step - # via KSAMPLER.extra_options. - return (class_type is not None - and "Sampler" in class_type - and class_type != "SamplerCustomAdvanced") + """Return True if *class_type* is a sampler that accepts per-step + parameter updates via :class:`KSAMPLER.extra_options`.""" + return ( + class_type is not None + and "Sampler" in class_type + and class_type != "SamplerCustomAdvanced" + ) def _resolve_input_value(source_node_id, source_socket): """Try to resolve a non-feedback linked input to a static value. @@ -794,14 +1017,19 @@ def _resolve_input_value(source_node_id, source_socket): if class_type in ("PrimitiveInt", "PrimitiveFloat", "PrimitiveBool"): if source_socket == 0 and len(wv) > 0: return wv[0] - if cls is not None and source_socket < len(req_names) and source_socket < len(wv): + if ( + cls is not None + and source_socket < len(req_names) + and source_socket < len(wv) + ): return wv[source_socket] return None except Exception: return None - def _collect_extra_names(node_id, feedback_from_node, feedback_from_socket, - feedback_var_name): + def _collect_extra_names( + node_id, feedback_from_node, feedback_from_socket, feedback_var_name + ): """Collect non-feedback linked inputs from a MathExpression node and resolve them to values. Returns dict of name→value.""" extra = {} @@ -814,8 +1042,9 @@ def _collect_extra_names(node_id, feedback_from_node, feedback_from_socket, continue src_id, src_socket = inp_val[0], inp_val[1] # Skip the feedback-linked input — that's the iteration variable - if (src_id == str(feedback_from_node) - and int(src_socket) == int(feedback_from_socket)): + if src_id == str(feedback_from_node) and int(src_socket) == int( + feedback_from_socket + ): continue # This is an additional linked input — try to resolve it val = _resolve_input_value(src_id, src_socket) @@ -846,16 +1075,19 @@ def _dfs(start_id, from_node, from_socket, chain): var_name = None for input_name, input_val in node.get("inputs", {}).items(): - if isinstance(input_val, list) and len(input_val) == 2 \ - and input_val[0] == from_node and input_val[1] == from_socket: + if ( + isinstance(input_val, list) + and len(input_val) == 2 + and input_val[0] == from_node + and input_val[1] == from_socket + ): var_name = input_name.rsplit(".", 1)[-1] break if var_name is None: return [] # Collect additional (non-feedback) input values for this node - extra_names = _collect_extra_names(start_id, from_node, from_socket, - var_name) + extra_names = _collect_extra_names(start_id, from_node, from_socket, var_name) new_chain = chain + [(expression, var_name, extra_names)] results = [] @@ -870,8 +1102,9 @@ def _dfs(start_id, from_node, from_socket, chain): return results # ---- compose functions from discovered chains ---- - for target_type, target_id, param_name, chain in \ - _dfs(to_node_id, from_node_id, from_socket, []): + for target_type, target_id, param_name, chain in _dfs( + to_node_id, from_node_id, from_socket, [] + ): if not chain: continue @@ -881,8 +1114,11 @@ def _fn(step, total_steps): for expr_str, var, extra_names in _chain: ctx = dict(extra_names) if extra_names else {} ctx[var] = val - val = float(simple_eval(expr_str, names=ctx, functions=MATH_FUNCTIONS)) + val = float( + simple_eval(expr_str, names=ctx, functions=MATH_FUNCTIONS) + ) return val + return _fn if target_type == "guider": @@ -913,7 +1149,9 @@ def add_message(self, event, data: dict, broadcast: bool): if self.server.client_id is not None or broadcast: self.server.send_sync(event, data, self.server.client_id) - def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): + def handle_execution_error( + self, prompt_id, prompt, current_outputs, executed, error, ex + ): node_id = error["node_id"] class_type = prompt[node_id]["class_type"] @@ -952,7 +1190,9 @@ def _notify_prompt_lifecycle(self, event: str, prompt_id: str): elif event == "end": provider.on_prompt_end(prompt_id) except Exception as e: - _cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}") + _cache_logger.warning( + f"Cache provider {provider.__class__.__name__} error on {event}: {e}" + ) def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) @@ -968,22 +1208,32 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= self.server.client_id = None self.status_messages = [] - self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) + self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False) self._notify_prompt_lifecycle("start", prompt_id) - ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) - ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024 ** 3)) - ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None - comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom) + ram_headroom = int(self.cache_args["ram"] * (1024**3)) + ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024**3)) + ram_release_callback = ( + self.caches.outputs.ram_release + if self.cache_type == CacheType.RAM_PRESSURE + else None + ) + comfy.memory_management.set_ram_cache_release_state( + ram_release_callback, ram_headroom + ) try: with torch.inference_mode(): dynamic_prompt = DynamicPrompt(prompt) reset_progress_state(prompt_id, dynamic_prompt) add_progress_handler(WebUIProgressHandler(self.server)) - is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs) + is_changed_cache = IsChangedCache( + prompt_id, dynamic_prompt, self.caches.outputs + ) for cache in self.caches.all: - await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) + await cache.set_prompt( + dynamic_prompt, prompt.keys(), is_changed_cache + ) cache.clean_unused() node_ids = list(prompt.keys()) @@ -991,16 +1241,19 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= *(self.caches.outputs.get(node_id) for node_id in node_ids) ) cached_nodes = [ - node_id for node_id, result in zip(node_ids, cache_results) + node_id + for node_id, result in zip(node_ids, cache_results) if result is not None ] comfy.model_management.cleanup_models_gc() - self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, - broadcast=False) + self.add_message( + "execution_cached", + {"nodes": cached_nodes, "prompt_id": prompt_id}, + broadcast=False, + ) pending_subgraph_results = {} - pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results ui_node_outputs = {} executed = set() execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) @@ -1014,35 +1267,66 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= # These are injected into the guider / sampler after the # target node executes so the sampler can vary parameters # (cfg, s_noise, ...) with step_index. - _feedback_cfg_injections = {} # guider_node_id → cfg_fn - _feedback_sampler_injections = {} # sampler_node_id → {param: fn} + _feedback_cfg_injections = {} # guider_node_id → cfg_fn + _feedback_sampler_injections = {} # sampler_node_id → {param: fn} for to_node_id, edges in execution_list.feedback_links.items(): for from_node_id, from_socket in edges: try: _build_feedback_fns( - dynamic_prompt, from_node_id, from_socket, - to_node_id, _feedback_cfg_injections, + dynamic_prompt, + from_node_id, + from_socket, + to_node_id, + _feedback_cfg_injections, _feedback_sampler_injections, ) except Exception: - pass # non-critical – feedback just wonʼt vary per step + pass # non-critical – feedback just wonʼt vary per step # ----------------------------------------------------------------- while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() if error is not None: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + self.handle_execution_error( + prompt_id, + dynamic_prompt.original_prompt, + current_outputs, + executed, + error, + ex, + ) break - assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) + assert node_id is not None, ( + "Node ID should not be None at this point" + ) + result, error, ex = await execute( + self.server, + dynamic_prompt, + self.caches, + node_id, + extra_data, + executed, + prompt_id, + execution_list, + pending_subgraph_results, + pending_async_nodes, + ui_node_outputs, + ) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: - self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) + self.handle_execution_error( + prompt_id, + dynamic_prompt.original_prompt, + current_outputs, + executed, + error, + ex, + ) break elif result == ExecutionResult.PENDING: execution_list.unstage_node_execution() - else: # result == ExecutionResult.SUCCESS: + else: # result == ExecutionResult.SUCCESS: # ---- bounded-feedback injection ---- # If this node just produced a guider or sampler # that is part of a feedback cycle, inject per-step @@ -1050,19 +1334,31 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= if node_id in _feedback_cfg_injections: try: output = self.caches.outputs.get_local(node_id) - if output is not None and output.outputs is not None \ - and len(output.outputs) > 0 and len(output.outputs[0]) > 0: + if ( + output is not None + and output.outputs is not None + and len(output.outputs) > 0 + and len(output.outputs[0]) > 0 + ): guider = output.outputs[0][0] - guider._feedback_cfg_fn = _feedback_cfg_injections[node_id] + guider._feedback_cfg_fn = _feedback_cfg_injections[ + node_id + ] except Exception: pass if node_id in _feedback_sampler_injections: try: output = self.caches.outputs.get_local(node_id) - if output is not None and output.outputs is not None \ - and len(output.outputs) > 0 and len(output.outputs[0]) > 0: + if ( + output is not None + and output.outputs is not None + and len(output.outputs) > 0 + and len(output.outputs[0]) > 0 + ): sampler_obj = output.outputs[0][0] - sampler_obj._feedback_param_fns = _feedback_sampler_injections[node_id] + sampler_obj._feedback_param_fns = ( + _feedback_sampler_injections[node_id] + ) except Exception: pass # --------------------------------------- @@ -1071,9 +1367,11 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= if self.cache_type == CacheType.RAM_PRESSURE: ram_release_callback(ram_inactive_headroom) ram_shortfall = ram_headroom - psutil.virtual_memory().available - freed = comfy.model_management.free_pins(ram_shortfall + 512 * (1024 ** 2)) + freed = comfy.model_management.free_pins( + ram_shortfall + 512 * (1024**2) + ) if freed < ram_shortfall: - if freed > 64 * (1024 ** 2): + if freed > 64 * (1024**2): # AIMDO MEM_DECOMMIT can outrun psutil.available catching up. time.sleep(0.05) ram_release_callback(ram_headroom, free_active=True) @@ -1087,9 +1385,20 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= continue cached = await self.caches.outputs.get(node_id) if cached is not None: - display_node_id = dynamic_prompt.get_display_node_id(node_id) - _send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs) - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + display_node_id = dynamic_prompt.get_display_node_id( + node_id + ) + _send_cached_ui( + self.server, + node_id, + display_node_id, + cached, + prompt_id, + ui_node_outputs, + ) + self.add_message( + "execution_success", {"prompt_id": prompt_id}, broadcast=False + ) ui_outputs = {} meta_outputs = {} @@ -1120,17 +1429,17 @@ def _is_bounded_feedback_cycle(prompt, visiting, unique_id): infinite cycle — the DAG can safely allow it and the execution engine will break the feedback edge by seeding the iteration output with an initial value. """ - cycle_nodes = visiting[visiting.index(unique_id):] + [unique_id] + cycle_nodes = visiting[visiting.index(unique_id) :] + [unique_id] for node_id in cycle_nodes: if node_id not in prompt: continue - class_type = prompt[node_id].get('class_type') + class_type = prompt[node_id].get("class_type") if class_type is None: continue obj_class = nodes.NODE_CLASS_MAPPINGS.get(class_type) if obj_class is None: continue - bounded = getattr(obj_class, 'BOUNDED_FEEDBACK', None) + bounded = getattr(obj_class, "BOUNDED_FEEDBACK", None) if bounded: return True return False @@ -1145,7 +1454,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): return validated[unique_id] if unique_id in visiting: - cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id] + cycle_path_nodes = visiting[visiting.index(unique_id) :] + [unique_id] cycle_nodes = list(dict.fromkeys(cycle_path_nodes)) # A bounded feedback cycle is one where at least one node in the cycle @@ -1160,21 +1469,30 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): validated[unique_id] = (True, [], unique_id) return validated[unique_id] - cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes) + cycle_path = " -> ".join( + f"{node_id} ({prompt[node_id]['class_type']})" + for node_id in cycle_path_nodes + ) for node_id in cycle_nodes: - validated[node_id] = (False, [{ - "type": "dependency_cycle", - "message": "Dependency cycle detected", - "details": cycle_path, - "extra_info": { - "node_id": node_id, - "cycle_nodes": cycle_nodes, - } - }], node_id) + validated[node_id] = ( + False, + [ + { + "type": "dependency_cycle", + "message": "Dependency cycle detected", + "details": cycle_path, + "extra_info": { + "node_id": node_id, + "cycle_nodes": cycle_nodes, + }, + } + ], + node_id, + ) return validated[unique_id] - inputs = prompt[unique_id]['inputs'] - class_type = prompt[unique_id]['class_type'] + inputs = prompt[unique_id]["inputs"] + class_type = prompt[unique_id]["class_type"] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] errors = [] @@ -1199,10 +1517,14 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): validate_has_kwargs = argspec.varkw is not None received_types = {} - valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) + valid_inputs = set(class_inputs.get("required", {})).union( + set(class_inputs.get("optional", {})) + ) for x in valid_inputs: - input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) + input_type, input_category, extra_info = get_input_info( + obj_class, x, class_inputs + ) assert extra_info is not None if x not in inputs: if input_category == "required": @@ -1211,9 +1533,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "type": "required_input_missing", "message": "Required input is missing", "details": details, - "extra_info": { - "input_name": x - } + "extra_info": {"input_name": x}, } errors.append(error) continue @@ -1229,18 +1549,21 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "extra_info": { "input_name": x, "input_config": info, - "received_value": val - } + "received_value": val, + }, } errors.append(error) continue o_id = val[0] - o_class_type = prompt[o_id]['class_type'] + o_class_type = prompt[o_id]["class_type"] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type): + if ( + "input_types" not in validate_function_inputs + and not validate_node_input(received_type, input_type) + ): details = f"{x}, received_type({received_type}) mismatch input_type({input_type})" error = { "type": "return_type_mismatch", @@ -1250,15 +1573,17 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "input_name": x, "input_config": info, "received_type": received_type, - "linked_node": val - } + "linked_node": val, + }, } errors.append(error) continue try: visiting.append(unique_id) try: - r = await validate_inputs(prompt_id, prompt, o_id, validated, visiting) + r = await validate_inputs( + prompt_id, prompt, o_id, validated, visiting + ) finally: visiting.pop() if r[0] is False: @@ -1269,19 +1594,21 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): typ, _, tb = sys.exc_info() valid = False exception_type = full_type_name(typ) - reasons = [{ - "type": "exception_during_inner_validation", - "message": "Exception when validating inner node", - "details": str(ex), - "extra_info": { - "input_name": x, - "input_config": info, - "exception_message": str(ex), - "exception_type": exception_type, - "traceback": traceback.format_tb(tb), - "linked_node": val + reasons = [ + { + "type": "exception_during_inner_validation", + "message": "Exception when validating inner node", + "details": str(ex), + "extra_info": { + "input_name": x, + "input_config": info, + "exception_message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "linked_node": val, + }, } - }] + ] validated[o_id] = (False, reasons, o_id) continue else: @@ -1316,8 +1643,8 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "input_name": x, "input_config": info, "received_value": val, - "exception_message": str(ex) - } + "exception_message": str(ex), + }, } errors.append(error) continue @@ -1326,26 +1653,30 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): if "min" in extra_info and val < extra_info["min"]: error = { "type": "value_smaller_than_min", - "message": "Value {} smaller than min of {}".format(val, extra_info["min"]), + "message": "Value {} smaller than min of {}".format( + val, extra_info["min"] + ), "details": f"{x}", "extra_info": { "input_name": x, "input_config": info, "received_value": val, - } + }, } errors.append(error) continue if "max" in extra_info and val > extra_info["max"]: error = { "type": "value_bigger_than_max", - "message": "Value {} bigger than max of {}".format(val, extra_info["max"]), + "message": "Value {} bigger than max of {}".format( + val, extra_info["max"] + ), "details": f"{x}", "extra_info": { "input_name": x, "input_config": info, "received_value": val, - } + }, } errors.append(error) continue @@ -1380,7 +1711,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "input_name": x, "input_config": input_config, "received_value": val, - } + }, } errors.append(error) continue @@ -1391,10 +1722,17 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: input_filtered[x] = input_data_all[x] - if 'input_types' in validate_function_inputs: - input_filtered['input_types'] = [received_types] - - ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data) + if "input_types" in validate_function_inputs: + input_filtered["input_types"] = [received_types] + + ret = await _async_map_node_over_list( + prompt_id, + unique_id, + obj_class, + input_filtered, + validate_function_name, + v3_data=v3_data, + ) ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: for i, r in enumerate(ret): @@ -1409,7 +1747,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): "details": details, "extra_info": { "input_name": x, - } + }, } errors.append(error) continue @@ -1425,18 +1763,22 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): validated[unique_id] = ret return ret + def full_type_name(klass): module = klass.__module__ - if module == 'builtins': + if module == "builtins": return klass.__qualname__ - return module + '.' + klass.__qualname__ + return module + "." + klass.__qualname__ + -async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]): +async def validate_prompt( + prompt_id, prompt, partial_execution_list: Union[list[str], None] +): outputs = set() for x in prompt: - if 'class_type' not in prompt[x]: + if "class_type" not in prompt[x]: node_data = prompt[x] - node_title = node_data.get('_meta', {}).get('title') + node_title = node_data.get("_meta", {}).get("title") error = { "type": "missing_node_type", "message": f"Node '{node_title or f'ID #{x}'}' has no class_type. The workflow may be corrupted or a custom node is missing.", @@ -1444,16 +1786,16 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ "extra_info": { "node_id": x, "class_type": None, - "node_title": node_title - } + "node_title": node_title, + }, } return (False, error, [], {}) - class_type = prompt[x]['class_type'] + class_type = prompt[x]["class_type"] class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) if class_ is None: node_data = prompt[x] - node_title = node_data.get('_meta', {}).get('title', class_type) + node_title = node_data.get("_meta", {}).get("title", class_type) error = { "type": "missing_node_type", "message": f"Node '{node_title}' not found. The custom node may not be installed.", @@ -1461,12 +1803,12 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ "extra_info": { "node_id": x, "class_type": class_type, - "node_title": node_title - } + "node_title": node_title, + }, } return (False, error, [], {}) - if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True: + if hasattr(class_, "OUTPUT_NODE") and class_.OUTPUT_NODE is True: if partial_execution_list is None or x in partial_execution_list: outputs.add(x) @@ -1475,7 +1817,7 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ "type": "prompt_no_outputs", "message": "Prompt has no outputs", "details": "", - "extra_info": {} + "extra_info": {}, } return (False, error, [], {}) @@ -1494,15 +1836,17 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ typ, _, tb = sys.exc_info() valid = False exception_type = full_type_name(typ) - reasons = [{ - "type": "exception_during_validation", - "message": "Exception when validating node", - "details": str(ex), - "extra_info": { - "exception_type": exception_type, - "traceback": traceback.format_tb(tb) + reasons = [ + { + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + }, } - }] + ] validated[o] = (False, reasons, o) if valid is True: @@ -1522,15 +1866,17 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ # So don't return those nodes as having errors in the response. if valid is not True and len(reasons) > 0: if node_id not in node_errors: - class_type = prompt[node_id]['class_type'] + class_type = prompt[node_id]["class_type"] node_errors[node_id] = { "errors": reasons, "dependent_outputs": [], - "class_type": class_type + "class_type": class_type, } logging.error(f"* {class_type} {node_id}:") for reason in reasons: - logging.error(f" - {reason['message']}: {reason['details']}") + logging.error( + f" - {reason['message']}: {reason['details']}" + ) node_errors[node_id]["dependent_outputs"].append(o) logging.error("Output will be ignored") @@ -1545,15 +1891,17 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ "type": "prompt_outputs_failed_validation", "message": "Prompt outputs failed validation", "details": errors_list, - "extra_info": {} + "extra_info": {}, } return (False, error, list(good_outputs), node_errors) return (True, None, list(good_outputs), node_errors) + MAXIMUM_HISTORY_SIZE = 10000 + class PromptQueue: def __init__(self, server): self.server = server @@ -1585,12 +1933,17 @@ def get(self, timeout=None): return (item, i) class ExecutionStatus(NamedTuple): - status_str: Literal['success', 'error'] + status_str: Literal["success", "error"] completed: bool messages: List[str] - def task_done(self, item_id, history_result, - status: Optional['PromptQueue.ExecutionStatus'], process_item=None): + def task_done( + self, + item_id, + history_result, + status: Optional["PromptQueue.ExecutionStatus"], + process_item=None, + ): with self.mutex: prompt = self.currently_running.pop(item_id) if len(self.history) > MAXIMUM_HISTORY_SIZE: @@ -1606,7 +1959,7 @@ def task_done(self, item_id, history_result, self.history[prompt[1]] = { "prompt": prompt, "outputs": {}, - 'status': status_dict, + "status": status_dict, } self.history[prompt[1]].update(history_result) self.server.queue_updated()