Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 106 additions & 16 deletions tutorials/plugandplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
import pylops
from pylops.config import set_ndarray_multiplication
from pylops.utils.metrics import snr

import pyproximal

Expand Down Expand Up @@ -72,7 +73,7 @@

###############################################################################
# At this point we create a denoiser instance using the BM3D algorithm and use
# as Plug-and-Play Prior to the PG and ADMM algorithms
# as Plug-and-Play Prior to the ADMM, PG and HQS algorithms


def callback(x, xtrue, errhist):
Expand All @@ -84,14 +85,31 @@ def callback(x, xtrue, errhist):
tau = 1.0 / L
sigma = 0.05

l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)

# BM3D denoiser
denoiser = lambda x, tau: bm3d.bm3d(
np.real(x), sigma_psd=sigma * tau, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING
)

# ADMM-PnP
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)

errhistadmm = []
xpnpadmm = pyproximal.optimization.pnp.PlugAndPlay(
l2,
denoiser,
x.shape,
solver=pyproximal.optimization.primal.ADMM,
tau=tau,
x0=np.zeros(x.size),
niter=40,
show=True,
callback=lambda xx: callback(xx, x.ravel(), errhistadmm),
)[0]
xpnpadmm = np.real(xpnpadmm.reshape(x.shape))

# PG-Pnp
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)

errhistpg = []
xpnppg = pyproximal.optimization.pnp.PlugAndPlay(
l2,
Expand All @@ -107,39 +125,111 @@ def callback(x, xtrue, errhist):
)
xpnppg = np.real(xpnppg.reshape(x.shape))

# ADMM-PnP
errhistadmm = []
xpnpadmm = pyproximal.optimization.pnp.PlugAndPlay(
# HQS-PnP
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)

tau_hqs = 1.0 / L * 0.99 ** (np.arange(40))
errhisthqs = []
xpnphqs = pyproximal.optimization.pnp.PlugAndPlay(
l2,
denoiser,
x.shape,
solver=pyproximal.optimization.primal.HQS,
tau=tau_hqs,
x0=np.zeros(x.size),
niter=40,
show=True,
callback=lambda xx: callback(xx, x.ravel(), errhisthqs),
)[0]
xpnphqs = np.real(xpnphqs.reshape(x.shape))

fig, axs = plt.subplots(1, 4, sharey=True, figsize=(15, 5))
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
axs[0].set_title("Model")
axs[0].axis("tight")
axs[1].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray")
axs[1].set_title(f"ADMM-PnP (SNR={snr(x, xpnpadmm):.2f} dB)")
axs[1].axis("tight")
axs[2].imshow(xpnppg, vmin=0, vmax=1, cmap="gray")
axs[2].set_title(f"PG-PnP (SNR={snr(x, xpnppg):.2f} dB)")
axs[2].axis("tight")
axs[3].imshow(xpnphqs, vmin=0, vmax=1, cmap="gray")
axs[3].set_title(f"HQS-PnP (SNR={snr(x, xpnphqs):.2f} dB)")
axs[3].axis("tight")
plt.tight_layout()

###############################################################################
# Finally, the attentive reader may have noticed that in the HQS server a
# continuation strategy was used for the `tau` parameter; whilst this is
# strictly needed for HQS to converge, there is a consensus in the literature
# that also other solvers should benefit from adopting the same strategy
# when used with a PnP prior. This can be in fact interpreted as reducing
# the strength of the denoiser as iterations progress and the estimate comes
# closer to the true solution.
#
# While our :func:`pyproximal.optimization.primal.ADMM` solver does currently
# not offer relaxation out-of-the-box, this can be achieved pretty easily
# by creating an auxiliary `Denoiser` class with a `decay` parameter as
# shown below.


class Denoiser:
def __init__(self, sigma, decay):
self.sigma = sigma
self.decay = decay
self.iiter = 0

def denoise(self, x, tau):
xden = bm3d.bm3d(
np.real(x),
sigma_psd=self.decay[self.iiter] * self.sigma * tau,
stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING,
)
self.iiter += 1
return xden


# ADMM-PnP with relaxation
denoiser = Denoiser(sigma, decay=0.99 ** (np.arange(40)))
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)

errhistadmm1 = []
xpnpadmm1 = pyproximal.optimization.pnp.PlugAndPlay(
l2,
denoiser.denoise,
x.shape,
solver=pyproximal.optimization.primal.ADMM,
tau=tau,
x0=np.zeros(x.size),
niter=40,
show=True,
callback=lambda xx: callback(xx, x.ravel(), errhistadmm),
callback=lambda xx: callback(xx, x.ravel(), errhistadmm1),
)[0]
xpnpadmm = np.real(xpnpadmm.reshape(x.shape))
xpnpadmm1 = np.real(xpnpadmm1.reshape(x.shape))

fig, axs = plt.subplots(1, 3, figsize=(14, 5))
fig, axs = plt.subplots(1, 3, sharey=True, figsize=(15, 5))
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
axs[0].set_title("Model")
axs[0].axis("tight")
axs[1].imshow(xpnppg, vmin=0, vmax=1, cmap="gray")
axs[1].set_title("PG-PnP Inversion")
axs[1].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray")
axs[1].set_title(f"ADMM-PnP (SNR={snr(x, xpnpadmm):.2f} dB)")
axs[1].axis("tight")
axs[2].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray")
axs[2].set_title("ADMM-PnP Inversion")
axs[2].imshow(xpnpadmm1, vmin=0, vmax=1, cmap="gray")
axs[2].set_title(f"ADMM-PnP with rel. (SNR={snr(x, xpnpadmm1):.2f} dB)")
axs[2].axis("tight")
plt.tight_layout()

###############################################################################
# Finally, let's compare the error convergence of the two variations of PnP
# Let's finally compare the error convergence of the four variations of PnP

plt.figure(figsize=(12, 3))
plt.plot(errhistpg, "k", lw=2, label="PG")
plt.plot(errhistadmm, "r", lw=2, label="ADMM")
plt.semilogy(errhistadmm, "k", lw=2, label="ADMM")
plt.semilogy(errhistpg, "r", lw=2, label="PG")
plt.semilogy(errhisthqs, "b", lw=2, label="HQS")
plt.semilogy(errhistadmm1, "--b", lw=2, label="ADMM with rel.")
plt.title("Error norm")
plt.legend()
plt.tight_layout()

###############################################################################
# This final results clearly shows the importance of relaxation also for ADMM.
Loading