From e9bb5f22db6b7f05e19955334a402cd0b4b7ea2a Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Tue, 23 Aug 2022 04:01:25 +0100 Subject: [PATCH 01/14] setup.py fix for stable diffusion --- setup.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 setup.py diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..35dca52 --- /dev/null +++ b/setup.py @@ -0,0 +1,24 @@ +from setuptools import setup, find_packages + +setup( + name='k-diffusion', + version='0.0.1', + description='Karras et al. (2022) diffusion models for PyTorch', + packages=find_packages(), + install_requires=[ + 'accelerate', + 'clean-fid', + 'einops', + 'jsonmerge', + 'kornia', + 'Pillow', + 'resize-right', + 'scikit-image', + 'scipy', + 'torch', + 'torchdiffeq', + 'torchvision', + 'tqdm', + 'wandb', + ], +) \ No newline at end of file From 7dd460995c589e6c5b348b61c1d45607b638731a Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Tue, 23 Aug 2022 04:18:02 +0100 Subject: [PATCH 02/14] Stable diffusion fix --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 54c5a95..e04de9d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ torchdiffeq torchvision tqdm wandb -git+https://github.com/openai/CLIP +git+https://github.com/openai/CLIP#egg=clip From 833b0916712ade1ee25848a630a246522a452a01 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 27 Aug 2022 14:42:08 +0100 Subject: [PATCH 03/14] implement torch.frac() on-GPU. https://github.com/pytorch/pytorch/issues/77764\#issuecomment-1229193859 --- k_diffusion/external.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k_diffusion/external.py b/k_diffusion/external.py index 71c5b94..44b037d 100644 --- a/k_diffusion/external.py +++ b/k_diffusion/external.py @@ -68,7 +68,7 @@ def sigma_to_t(self, sigma, quantize=None): def t_to_sigma(self, t): t = t.float() - low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t-(t*t.sgn()).floor()*t.sgn() if t.device.type == 'mps' else t.frac() return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx] From af0641b11e6b1c3d478692801893699425aa428b Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 27 Aug 2022 21:10:25 +0100 Subject: [PATCH 04/14] deliberate fallback-to-CPU --- k_diffusion/external.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/k_diffusion/external.py b/k_diffusion/external.py index 44b037d..acccd73 100644 --- a/k_diffusion/external.py +++ b/k_diffusion/external.py @@ -59,7 +59,21 @@ def sigma_to_t(self, sigma, quantize=None): dists = torch.abs(sigma - self.sigmas[:, None]) if quantize: return torch.argmin(dists, dim=0).view(sigma.shape) - low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] + topk_indices = torch.topk(dists, dim=0, k=2, largest=False).indices + topk_indices_device=topk_indices.device + + # TODO: revert this once MPS supports aten::sort.values_stable. + # we're transferring the topk indices to CPU, sorting them, then transferring the result (sort_values) back to GPU. + # it's fine to sort on-CPU, because it's a wee little 2x2 matrix. + # PYTORCH_ENABLE_MPS_FALLBACK=1 would do the same thing. but I want us to be able to run without that. + # so that we find out any time a fallback is required, and can review whether it's consequential. + must_sort_on_cpu = topk_indices_device.type == 'mps' + topk_indices = topk_indices.cpu() if must_sort_on_cpu else topk_indices + + sort_values = torch.sort(topk_indices, dim=0).values + sort_values = sort_values.to(topk_indices_device) if must_sort_on_cpu else sort_values + + low_idx, high_idx = sort_values low, high = self.sigmas[low_idx], self.sigmas[high_idx] w = (low - sigma) / (low - high) w = w.clamp(0, 1) From 35dc3b7b17684530623da16ad594688bf512b241 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sat, 27 Aug 2022 21:05:19 +0100 Subject: [PATCH 05/14] aten::sgn.out isn't implemented on MPS, so we'll have to use the simpler, faster solution which only works for positive numbers. our inputs are generated by a linspace between two positive numbers, so we should be fine. --- k_diffusion/external.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/k_diffusion/external.py b/k_diffusion/external.py index acccd73..63cc624 100644 --- a/k_diffusion/external.py +++ b/k_diffusion/external.py @@ -82,7 +82,8 @@ def sigma_to_t(self, sigma, quantize=None): def t_to_sigma(self, t): t = t.float() - low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t-(t*t.sgn()).floor()*t.sgn() if t.device.type == 'mps' else t.frac() + t_floor = t.floor() + low_idx, high_idx, w = t_floor.long(), t.ceil().long(), t-t_floor if t.device.type == 'mps' else t.frac() return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx] From 556ba993ebd9a09f219c34b65e58c2a5c750ece0 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sun, 28 Aug 2022 19:01:16 +0100 Subject: [PATCH 06/14] fix sample_heun on MPS --- k_diffusion/sampling.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index b5a1c39..1a300ef 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -34,9 +34,13 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): return append_zero(sigmas) -def to_d(x, sigma, denoised): +def to_d(x, sigma, denoised, clone_please=False): """Converts a denoiser output to a Karras ODE derivative.""" - return (x - denoised) / utils.append_dims(sigma, x.ndim) + coeff = utils.append_dims(sigma, x.ndim) + # for some reason, cloning coeff fixes a problem where values were returned as ±inf + # there's probably a better place to do the cloning than here, but this fixes sample_heun on MPS + coeff = coeff.detach().clone() if coeff.device.type == 'mps' and clone_please else coeff + return (x - denoised) / coeff def get_ancestral_step(sigma_from, sigma_to): @@ -109,7 +113,7 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, # Heun's method x_2 = x + d * dt denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) - d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2, clone_please=True) d_prime = (d + d_2) / 2 x = x + d_prime * dt return x From 8ebcd0098e67823ceee6499c34bb29e20704c1f3 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Sun, 28 Aug 2022 23:59:12 +0100 Subject: [PATCH 07/14] for Karras samplers: add time step discretization (as described in the Elucidating paper arXiv:2206.00364 section C.3.4 "practical challenge" 3) --- k_diffusion/sampling.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 1a300ef..56c96a1 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -1,9 +1,12 @@ import math +import numpy as np +from numpy.typing import _ArrayLikeFloat_co from scipy import integrate import torch from torchdiffeq import odeint from tqdm.auto import trange, tqdm +from typing import Optional from . import utils @@ -52,7 +55,7 @@ def get_ancestral_step(sigma_from, sigma_to): @torch.no_grad() -def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., quanta: Optional[_ArrayLikeFloat_co]=None): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -60,6 +63,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = sigma_hat if quanta is None else quanta[torch.argmin((quanta-sigma_hat).abs(), dim=0)] if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -91,7 +95,7 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis @torch.no_grad() -def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., quanta: Optional[_ArrayLikeFloat_co]=None): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -99,6 +103,7 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = sigma_hat if quanta is None else quanta[torch.argmin((quanta-sigma_hat).abs(), dim=0)] if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -120,7 +125,7 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): +def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., quanta: Optional[_ArrayLikeFloat_co]=None): """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -128,6 +133,7 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) + sigma_hat = sigma_hat if quanta is None else quanta[torch.argmin((quanta-sigma_hat).abs(), dim=0)] if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) From badc90ec06596368e031ce056fa4edfae8af5dda Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 29 Aug 2022 00:54:08 +0100 Subject: [PATCH 08/14] fix k_lms sampling on MPS backend --- k_diffusion/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 56c96a1..8928ee0 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -194,7 +194,7 @@ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, o ds = [] for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) - d = to_d(x, sigmas[i], denoised) + d = to_d(x, sigmas[i], denoised, clone_please=True) ds.append(d) if len(ds) > order: ds.pop(0) From dcd5fa6e0e89b7c0d6488dd37a454e0b669186ac Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Mon, 29 Aug 2022 14:44:54 +0100 Subject: [PATCH 09/14] when a model (e.g. stable diffusion) only supports discrete sigmas (arXiv:2206.00364 C.3.4), we don't want to generate an out-of-range sigma. make this optional. --- k_diffusion/sampling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 8928ee0..c3cc2db 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -15,13 +15,14 @@ def append_zero(x): return torch.cat([x, x.new_zeros([1])]) -def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'): +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu', concat_zero=True): """Constructs the noise schedule of Karras et al. (2022).""" ramp = torch.linspace(0, 1, n) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho - return append_zero(sigmas).to(device) + sigmas = sigmas.to(device) + return append_zero(sigmas) if concat_zero else sigmas def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): From 455d0af822f1f718ba190727cf3e72f25324382b Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Tue, 30 Aug 2022 22:13:00 +0100 Subject: [PATCH 10/14] generalize quantizer --- k_diffusion/sampling.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index c3cc2db..06abd1a 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -1,15 +1,25 @@ import math import numpy as np -from numpy.typing import _ArrayLikeFloat_co +from functools import partial from scipy import integrate import torch +from torch import Tensor from torchdiffeq import odeint from tqdm.auto import trange, tqdm -from typing import Optional +from typing import Optional, Callable, TypeAlias, Union from . import utils +TensorOperator: TypeAlias = Callable[[Tensor], Tensor] + +def _quantize(quanta: Tensor, candidate: Union[int, float, Tensor]) -> Tensor: + """Rounds `candidate` to the nearest element in `quanta`""" + return quanta[torch.argmin((quanta-candidate).abs(), dim=0)] + +def make_quantizer(quanta: Tensor) -> TensorOperator: + """Returns an monotype operator which accepts a single-element 1-dimensional Tensor, and rounds its element to the nearest element in `quanta`""" + return partial(_quantize, quanta) def append_zero(x): return torch.cat([x, x.new_zeros([1])]) @@ -56,7 +66,7 @@ def get_ancestral_step(sigma_from, sigma_to): @torch.no_grad() -def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., quanta: Optional[_ArrayLikeFloat_co]=None): +def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -64,7 +74,7 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) - sigma_hat = sigma_hat if quanta is None else quanta[torch.argmin((quanta-sigma_hat).abs(), dim=0)] + sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -96,7 +106,7 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis @torch.no_grad() -def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., quanta: Optional[_ArrayLikeFloat_co]=None): +def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None): """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -104,7 +114,7 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) - sigma_hat = sigma_hat if quanta is None else quanta[torch.argmin((quanta-sigma_hat).abs(), dim=0)] + sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) @@ -126,7 +136,7 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, @torch.no_grad() -def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., quanta: Optional[_ArrayLikeFloat_co]=None): +def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., decorate_sigma_hat: Optional[TensorOperator] = None): """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" extra_args = {} if extra_args is None else extra_args s_in = x.new_ones([x.shape[0]]) @@ -134,7 +144,7 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. eps = torch.randn_like(x) * s_noise sigma_hat = sigmas[i] * (gamma + 1) - sigma_hat = sigma_hat if quanta is None else quanta[torch.argmin((quanta-sigma_hat).abs(), dim=0)] + sigma_hat = decorate_sigma_hat(sigma_hat) if callable(decorate_sigma_hat) else sigma_hat if gamma > 0: x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 denoised = model(x, sigma_hat * s_in, **extra_args) From 7d1611288e1c3d686344a58c6837cfc87792be07 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Tue, 30 Aug 2022 22:16:26 +0100 Subject: [PATCH 11/14] move utility part to utils --- k_diffusion/sampling.py | 8 ++------ k_diffusion/utils.py | 7 ++++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 06abd1a..3c0a1ff 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -7,19 +7,15 @@ from torch import Tensor from torchdiffeq import odeint from tqdm.auto import trange, tqdm -from typing import Optional, Callable, TypeAlias, Union +from typing import Optional, Callable, TypeAlias from . import utils TensorOperator: TypeAlias = Callable[[Tensor], Tensor] -def _quantize(quanta: Tensor, candidate: Union[int, float, Tensor]) -> Tensor: - """Rounds `candidate` to the nearest element in `quanta`""" - return quanta[torch.argmin((quanta-candidate).abs(), dim=0)] - def make_quantizer(quanta: Tensor) -> TensorOperator: """Returns an monotype operator which accepts a single-element 1-dimensional Tensor, and rounds its element to the nearest element in `quanta`""" - return partial(_quantize, quanta) + return partial(utils.quantize, quanta) def append_zero(x): return torch.cat([x, x.new_zeros([1])]) diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 8d700c2..9cb8b07 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -7,8 +7,9 @@ import warnings import torch -from torch import optim +from torch import optim, Tensor from torchvision.transforms import functional as TF +from typing import Union def from_pil_image(x): @@ -249,3 +250,7 @@ def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.floa min_value = math.log(min_value) max_value = math.log(max_value) return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp() + +def quantize(quanta: Tensor, candidate: Union[int, float, Tensor]) -> Tensor: + """Rounds `candidate` to the nearest element in `quanta`""" + return quanta[torch.argmin((quanta-candidate).abs(), dim=0)] From 72ce8f1bdd92449ce85680f5f1a7289ee9b6dbc3 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Tue, 30 Aug 2022 23:24:40 +0100 Subject: [PATCH 12/14] ramp further if you're not getting a zero for free --- k_diffusion/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 3c0a1ff..4943ec5 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -23,7 +23,7 @@ def append_zero(x): def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu', concat_zero=True): """Constructs the noise schedule of Karras et al. (2022).""" - ramp = torch.linspace(0, 1, n) + ramp = torch.linspace(0, 1, n if concat_zero else n+1) min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho From 6e5c8a77edc62e75414ad850cb0a6f7ddceea0d4 Mon Sep 17 00:00:00 2001 From: Alex Birch Date: Wed, 31 Aug 2022 12:22:27 +0100 Subject: [PATCH 13/14] fix sample_euler_ancestral for MPS --- k_diffusion/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 4943ec5..3df824a 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -93,7 +93,7 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - d = to_d(x, sigmas[i], denoised) + d = to_d(x, sigmas[i], denoised, clone_please=True) # Euler method dt = sigma_down - sigmas[i] x = x + d * dt From bd00ffefb6e7212806e1653fc2a60a35618e918d Mon Sep 17 00:00:00 2001 From: Chris Hua Date: Thu, 1 Sep 2022 16:50:30 -0700 Subject: [PATCH 14/14] chore: backport TypeAlias for py3.9 compat --- k_diffusion/sampling.py | 6 +++++- requirements.txt | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 3df824a..49df669 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -7,7 +7,11 @@ from torch import Tensor from torchdiffeq import odeint from tqdm.auto import trange, tqdm -from typing import Optional, Callable, TypeAlias +from typing import Optional, Callable +try: + from typing import TypeAlias +except ImportError: + from typing_extensions import TypeAlias from . import utils diff --git a/requirements.txt b/requirements.txt index e04de9d..bcdb6e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ torchvision tqdm wandb git+https://github.com/openai/CLIP#egg=clip +typing-extensions \ No newline at end of file