diff --git a/tutorials/plugandplay.py b/tutorials/plugandplay.py index 02db781..b0186d7 100644 --- a/tutorials/plugandplay.py +++ b/tutorials/plugandplay.py @@ -27,6 +27,7 @@ import numpy as np import pylops from pylops.config import set_ndarray_multiplication +from pylops.utils.metrics import snr import pyproximal @@ -72,7 +73,7 @@ ############################################################################### # At this point we create a denoiser instance using the BM3D algorithm and use -# as Plug-and-Play Prior to the PG and ADMM algorithms +# as Plug-and-Play Prior to the ADMM, PG and HQS algorithms def callback(x, xtrue, errhist): @@ -84,14 +85,31 @@ def callback(x, xtrue, errhist): tau = 1.0 / L sigma = 0.05 -l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True) - # BM3D denoiser denoiser = lambda x, tau: bm3d.bm3d( np.real(x), sigma_psd=sigma * tau, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING ) +# ADMM-PnP +l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True) + +errhistadmm = [] +xpnpadmm = pyproximal.optimization.pnp.PlugAndPlay( + l2, + denoiser, + x.shape, + solver=pyproximal.optimization.primal.ADMM, + tau=tau, + x0=np.zeros(x.size), + niter=40, + show=True, + callback=lambda xx: callback(xx, x.ravel(), errhistadmm), +)[0] +xpnpadmm = np.real(xpnpadmm.reshape(x.shape)) + # PG-Pnp +l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True) + errhistpg = [] xpnppg = pyproximal.optimization.pnp.PlugAndPlay( l2, @@ -107,39 +125,111 @@ def callback(x, xtrue, errhist): ) xpnppg = np.real(xpnppg.reshape(x.shape)) -# ADMM-PnP -errhistadmm = [] -xpnpadmm = pyproximal.optimization.pnp.PlugAndPlay( +# HQS-PnP +l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True) + +tau_hqs = 1.0 / L * 0.99 ** (np.arange(40)) +errhisthqs = [] +xpnphqs = pyproximal.optimization.pnp.PlugAndPlay( l2, denoiser, x.shape, + solver=pyproximal.optimization.primal.HQS, + tau=tau_hqs, + x0=np.zeros(x.size), + niter=40, + show=True, + callback=lambda xx: callback(xx, x.ravel(), errhisthqs), +)[0] +xpnphqs = np.real(xpnphqs.reshape(x.shape)) + +fig, axs = plt.subplots(1, 4, sharey=True, figsize=(15, 5)) +axs[0].imshow(x, vmin=0, vmax=1, cmap="gray") +axs[0].set_title("Model") +axs[0].axis("tight") +axs[1].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray") +axs[1].set_title(f"ADMM-PnP (SNR={snr(x, xpnpadmm):.2f} dB)") +axs[1].axis("tight") +axs[2].imshow(xpnppg, vmin=0, vmax=1, cmap="gray") +axs[2].set_title(f"PG-PnP (SNR={snr(x, xpnppg):.2f} dB)") +axs[2].axis("tight") +axs[3].imshow(xpnphqs, vmin=0, vmax=1, cmap="gray") +axs[3].set_title(f"HQS-PnP (SNR={snr(x, xpnphqs):.2f} dB)") +axs[3].axis("tight") +plt.tight_layout() + +############################################################################### +# Finally, the attentive reader may have noticed that in the HQS server a +# continuation strategy was used for the `tau` parameter; whilst this is +# strictly needed for HQS to converge, there is a consensus in the literature +# that also other solvers should benefit from adopting the same strategy +# when used with a PnP prior. This can be in fact interpreted as reducing +# the strength of the denoiser as iterations progress and the estimate comes +# closer to the true solution. +# +# While our :func:`pyproximal.optimization.primal.ADMM` solver does currently +# not offer relaxation out-of-the-box, this can be achieved pretty easily +# by creating an auxiliary `Denoiser` class with a `decay` parameter as +# shown below. + + +class Denoiser: + def __init__(self, sigma, decay): + self.sigma = sigma + self.decay = decay + self.iiter = 0 + + def denoise(self, x, tau): + xden = bm3d.bm3d( + np.real(x), + sigma_psd=self.decay[self.iiter] * self.sigma * tau, + stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING, + ) + self.iiter += 1 + return xden + + +# ADMM-PnP with relaxation +denoiser = Denoiser(sigma, decay=0.99 ** (np.arange(40))) +l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True) + +errhistadmm1 = [] +xpnpadmm1 = pyproximal.optimization.pnp.PlugAndPlay( + l2, + denoiser.denoise, + x.shape, solver=pyproximal.optimization.primal.ADMM, tau=tau, x0=np.zeros(x.size), niter=40, show=True, - callback=lambda xx: callback(xx, x.ravel(), errhistadmm), + callback=lambda xx: callback(xx, x.ravel(), errhistadmm1), )[0] -xpnpadmm = np.real(xpnpadmm.reshape(x.shape)) +xpnpadmm1 = np.real(xpnpadmm1.reshape(x.shape)) -fig, axs = plt.subplots(1, 3, figsize=(14, 5)) +fig, axs = plt.subplots(1, 3, sharey=True, figsize=(15, 5)) axs[0].imshow(x, vmin=0, vmax=1, cmap="gray") axs[0].set_title("Model") axs[0].axis("tight") -axs[1].imshow(xpnppg, vmin=0, vmax=1, cmap="gray") -axs[1].set_title("PG-PnP Inversion") +axs[1].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray") +axs[1].set_title(f"ADMM-PnP (SNR={snr(x, xpnpadmm):.2f} dB)") axs[1].axis("tight") -axs[2].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray") -axs[2].set_title("ADMM-PnP Inversion") +axs[2].imshow(xpnpadmm1, vmin=0, vmax=1, cmap="gray") +axs[2].set_title(f"ADMM-PnP with rel. (SNR={snr(x, xpnpadmm1):.2f} dB)") axs[2].axis("tight") plt.tight_layout() ############################################################################### -# Finally, let's compare the error convergence of the two variations of PnP +# Let's finally compare the error convergence of the four variations of PnP plt.figure(figsize=(12, 3)) -plt.plot(errhistpg, "k", lw=2, label="PG") -plt.plot(errhistadmm, "r", lw=2, label="ADMM") +plt.semilogy(errhistadmm, "k", lw=2, label="ADMM") +plt.semilogy(errhistpg, "r", lw=2, label="PG") +plt.semilogy(errhisthqs, "b", lw=2, label="HQS") +plt.semilogy(errhistadmm1, "--b", lw=2, label="ADMM with rel.") plt.title("Error norm") plt.legend() plt.tight_layout() + +############################################################################### +# This final results clearly shows the importance of relaxation also for ADMM.