diff --git a/.gitignore b/.gitignore index 6b0eac38..0579fc6f 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ spyrit/drunet/ !spyrit/images/tuto/*.png docs/source/html docs/source/_autosummary +docs/source/_templates docs/source/_static docs/source/api docs/source/gallery diff --git a/docs/source/fig/drunet.png b/docs/source/fig/drunet.png new file mode 100644 index 00000000..903521a4 Binary files /dev/null and b/docs/source/fig/drunet.png differ diff --git a/spyrit/external/drunet.py b/spyrit/external/drunet.py index 7a69b1f3..3612f60e 100644 --- a/spyrit/external/drunet.py +++ b/spyrit/external/drunet.py @@ -116,6 +116,63 @@ def forward(self, x0): return x +class DRUNet(UNetRes): + def __init__( + self, + noise_level=5, + n_channels=1, + nc=[64, 128, 256, 512], + nb=4, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ): + super(DRUNet, self).__init__( + n_channels + 1, n_channels, nc, nb, act_mode, downsample_mode, upsample_mode + ) + self.register_buffer("noise_level", torch.FloatTensor([noise_level / 255.0])) + + def forward(self, x): + # Image domain denoising + x = self.concat_noise_map(x) + + # Pass input images through the network + x = super(DRUNet, self).forward(x) + return x + + def concat_noise_map(self, x): + r"""Concatenation of noise level map to reconstructed images + + Args: + :attr:`x`: reconstructed images from the reconstruction layer + + Shape: + :attr:`x`: reconstructed images with shape :math:`(BC,1,H,W)` + + :attr:`output`: reconstructed images with concatenated noise level map with shape :math:`(BC,2,H,W)` + """ + + b, c, h, w = x.shape + x = 0.5 * (x + 1) + x = torch.cat((x, self.noise_level.expand(b, 1, h, w)), dim=1) + return x + + def set_noise_level(self, noise_level): + r"""Reset noise level value + + Args: + :attr:`noise_level`: noise level value in the range [0, 255] + + Shape: + :attr:`noise_level`: float value noise level :math:`(1)` + + :attr:`output`: noise level tensor with shape :math:`(1)` + """ + self.noise_level = torch.FloatTensor([noise_level / 255.0]).to( + self.noise_level.device + ) + + # ---------------------------------------------- # Functions taken from basicblock.py # https://github.com/cszn/DPIR/tree/master/models diff --git a/tutorial/tuto_06_dcnet_split_measurements.py b/tutorial/tuto_06_dcnet_split_measurements.py index 16f8dc06..a4e874d2 100644 --- a/tutorial/tuto_06_dcnet_split_measurements.py +++ b/tutorial/tuto_06_dcnet_split_measurements.py @@ -1,83 +1,81 @@ #!/usr/bin/env python3 r""" -06. DCNet solution for split measurements +========================================= +06. Denoised Completion Network (DCNet) ========================================= .. _tuto_dcnet_split_measurements: -This tutorial shows how to perform image reconstruction using DCNet (denoised -completion network) with -and without a trainable image denoiser. In the previous tutorial -:ref:`Acquisition - split measurements ` -we showed how to handle split measurements for a Hadamard operator -and how to perform a pseudo-inverse reconstruction with PinvNet. +This tutorial shows how to perform image reconstruction using the denoised completion network (DCNet) with a trainable image denoiser. In the next tutorial, we will plug a denoiser into a DCNet, which requires no training. -.. image:: ../fig/tuto6.png +.. figure:: ../fig/tuto6.png :width: 600 :align: center :alt: Reconstruction and neural network denoising architecture sketch using split measurements -These tutorials load image samples from `/images/`. """ +###################################################################### +# .. note:: +# +# As in the previous tutorials, we consider a split Hadamard operator and measurements corrupted by Poisson noise (see :ref:`Tutorial 5 `). + # %% # Load a batch of images -# ----------------------------------------------------------------------------- +# ========================================= -############################################################################### -# Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized -# using the :func:`transform_gray_norm` function. - -import os - -import torch -import torchvision -import numpy as np -import matplotlib.pyplot as plt - -from spyrit.misc.disp import imagesc -from spyrit.misc.statistics import transform_gray_norm +###################################################################### +# Update search path # sphinx_gallery_thumbnail_path = 'fig/tuto6.png' +import os -h = 64 # image size hxh -i = 1 # Image index (modify to change the image) spyritPath = os.getcwd() imgs_path = os.path.join(spyritPath, "images/") +###################################################################### +# Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized and resized using the :func:`transform_gray_norm` function. +from spyrit.misc.statistics import transform_gray_norm -# Create a transform for natural images to normalized grayscale image tensors +h = 64 # image is resized to h x h transform = transform_gray_norm(img_size=h) -# Create dataset and loader (expects class folder 'images/test/') +###################################################################### +# Create a data loader from some dataset (images must be in the folder `images/test/`) +import torch +import torchvision + dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) x, _ = next(iter(dataloader)) print(f"Shape of input images: {x.shape}") -# Select image +###################################################################### +# Select the `i`-th image in the batch +i = 1 # Image index (modify to change the image) x = x[i : i + 1, :, :, :] x = x.detach().clone() b, c, h, w = x.shape -# plot +###################################################################### +# Plot the selected image +from spyrit.misc.disp import imagesc + x_plot = x.view(-1, h, h).cpu().numpy() imagesc(x_plot[0, :, :], r"$x$ in [-1, 1]") # %% # Forward operators for split measurements -# ----------------------------------------------------------------------------- +# ========================================= -############################################################################### -# We consider noisy split measurements for a Hadamard operator and a -# "variance subsampling" strategy that preserves the coefficients with the largest variance, -# obtained from a previously estimated covariance matrix (for more details, -# refer to :ref:`Acquisition - split measurements `). +###################################################################### +# We consider noisy measurements obtained from a split Hadamard operator, and a subsampling strategy that retaines the coefficients with the largest variance (for more details, refer to :ref:`Tutorial 5 `). -############################################################################### -# First, we download the covariance matrix and load it. +###################################################################### +# First, we download the covariance matrix from our warehouse. import girder_client +import numpy as np # api Rest url of the warehouse url = "https://pilot-warehouse.creatis.insa-lyon.fr/api/v1" @@ -109,11 +107,8 @@ Cov = np.eye(h * h) print(f"Cov matrix {cov_name} not found! Set to the identity") -############################################################################### -# We define the measurement, noise and preprocessing operators and then -# simulate a noiseless measurement vector :math:`y`. As in the previous tutorial, -# we simulate an accelerated acquisition by subsampling the measurement matrix -# by retaining only the first :math:`M` rows of a Hadamard matrix :math:`\textrm{Perm} H`. +###################################################################### +# We define the measurement, noise and preprocessing operators and then simulate a measurement vector corrupted by Poisson noise. As in the previous tutorials, we simulate an accelerated acquisition by subsampling the measurement matrix by retaining only the first rows of a Hadamard matrix that is permuted looking at the diagonal of the covariance matrix. from spyrit.core.meas import HadamSplit from spyrit.core.noise import Poisson @@ -125,11 +120,9 @@ M = 64 * 64 // 4 # Number of measurements (here, 1/4 of the pixels) alpha = 100.0 # number of photons -# Ordering matrix -Ord = Cov2Var(Cov) - # Measurement and noise operators -meas_op = HadamSplit(M, h, torch.from_numpy(Ord)) +Ord = Cov2Var(Cov) +meas_op = HadamSplit(M, h, Ord) noise_op = Poisson(meas_op, alpha) prep_op = SplitPoisson(alpha, meas_op) @@ -146,88 +139,99 @@ imagesc(m_plot, r"Measurements $m$") # %% -# PinvNet network -# ----------------------------------------------------------------------------- +# Pseudo inverse solution +# ========================================= + +###################################################################### +# We compute the pseudo inverse solution using :class:`spyrit.core.recon.PinvNet` class as in the previous tutorial. -############################################################################### -# We reconstruct with the pseudo inverse using :class:`spyrit.core.recon.PinvNet` class -# as in the previous tutorial. For this, we define the neural network and then perform the reconstruction. +# Instantiate a PinvNet (with no denoising by default) from spyrit.core.recon import PinvNet -from spyrit.misc.disp import add_colorbar, noaxis -# Reconstruction with for Core module (linear net) pinvnet = PinvNet(noise_op, prep_op) -# use GPU, if available +# Use GPU, if available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - -# Pseudo-inverse net pinvnet = pinvnet.to(device) +y = y.to(device) # Reconstruction with torch.no_grad(): - z_invnet = pinvnet.reconstruct(y.to(device)) # reconstruct from raw measurements + z_invnet = pinvnet.reconstruct(y) # %% -# DCNet network -# ----------------------------------------------------------------------------- +# Denoised completion network (DCNet) +# ========================================= + +###################################################################### +# .. image:: ../fig/dcnet.png +# :width: 400 +# :align: center +# :alt: Sketch of the DCNet architecture -############################################################################### -# We can improve PinvNet results by using the *denoised* completion network DCNet with the -# :class:`spyrit.core.recon.DCNet` class. It has four sequential steps: +###################################################################### +# The DCNet is based on four sequential steps: # -# i) denoising of the acquired measurements, +# i) Denoising in the measurement domain. # -# ii) estimation of the missing measurements from the denoised ones, +# ii) Estimation of the missing measurements from the denoised ones. # -# iii) mapping them to the image domain, and +# iii) Image-domain mapping. # -# iv) denoising in the image-domain. +# iv) (Learned) Denoising in the image domain. # -# Only the last step involves learnable parameters. +# Typically, only the last step involves learnable parameters. -############################################################################### -# .. image:: ../fig/dcnet.png -# :width: 400 -# :align: center -# :alt: Sketch of the DCNet architecture -############################################################################### -# For the denoiser, we compare the default unit matrix (no denoising) with the UNet denoiser -# with the :class:`spyrit.core.nnet.Unet` class. For the latter, we load the pretrained model -# weights. +# %% +# Denoised completion +# ========================================= + +###################################################################### +# The first three steps implement denoised completion, which corresponds to Tikhonov regularization. Considering linear measurements :math:`y = Hx`, where :math:`H` is the measurement matrix and :math:`x` is the unknown image, it estimates :math:`x` from :math:`y` by minimizing +# +# .. math:: +# \| y - Hx \|^2_{\Sigma^{-1}_\alpha} + \|x\|^2_{\Sigma^{-1}}, +# +# where :math:`\Sigma` is a covariance prior and :math:`\Sigma_\alpha` is the noise covariance. Denoised completation can be performed using the :class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag` class (see documentation for more details). -############################################################################### -# Without *learnable image-domain* denoising +###################################################################### +# In practice, it is more convenient to use the :class:`spyrit.core.recon.DCNet` class, which relies on a forward operator, a preprocessing operator, and a covariance prior. from spyrit.core.recon import DCNet -from spyrit.core.nnet import Unet -from torch import nn -# Reconstruction with for DCNet (linear net) -dcnet = DCNet(noise_op, prep_op, torch.from_numpy(Cov), denoi=nn.Identity()) +dcnet = DCNet(noise_op, prep_op, torch.from_numpy(Cov)) + +# Use GPU, if available dcnet = dcnet.to(device) +y = y.to(device) -# Reconstruction with torch.no_grad(): - z_dcnet = dcnet.reconstruct(y.to(device)) # reconstruct from raw measurements + z_dcnet = dcnet.reconstruct(y) -############################################################################### -# With a UNet denoising layer, we define the denoising network and -# then load the pretrained weights. +###################################################################### +# .. note:: +# In this tutorial, the covariance matrix used to define subsampling is also used as prior knowledge during reconstruction. -from spyrit.core.train import load_net -import matplotlib.pyplot as plt -from spyrit.misc.disp import add_colorbar, noaxis -# Define UNet denoiser -denoi = Unet() +# %% +# (Learned) Denoising in the image domain +# ========================================= + +###################################################################### +# To implement denoising in the image domain, we provide a :class:`spyrit.core.nnet.Unet` denoiser to a :class:`spyrit.core.recon.DCNet`. -# Define DCNet (with UNet denoising) +from spyrit.core.nnet import Unet + +denoi = Unet() dcnet_unet = DCNet(noise_op, prep_op, torch.from_numpy(Cov), denoi) -dcnet_unet = dcnet_unet.to(device) +dcnet_unet = dcnet_unet.to(device) # Use GPU, if available + +######################################################################## +# We load pretrained weights for the UNet + +from spyrit.core.train import load_net -# Load previously trained model # Download weights url_unet = "https://drive.google.com/file/d/15PRRZj5OxKpn1iJw78lGwUUBtTbFco1l/view?usp=drive_link" model_path = "./model" @@ -254,52 +258,54 @@ load_net(model_unet_path, dcnet_unet, device, False) # print(f"Model {model_unet_path} loaded.") - -# Reconstruction +###################################################################### +# We reconstruct the image with torch.no_grad(): - z_dcnet_unet = dcnet_unet.reconstruct( - y.to(device) - ) # reconstruct from raw measurements + z_dcnet_unet = dcnet_unet.reconstruct(y) + +# %% +# Results +# ========================================= -############################################################################### -# We plot all results +import matplotlib.pyplot as plt +from spyrit.misc.disp import add_colorbar, noaxis -# plot reconstruction side by side x_plot = x.view(-1, h, h).cpu().numpy() x_plot2 = z_invnet.view(-1, h, h).cpu().numpy() x_plot3 = z_dcnet.view(-1, h, h).cpu().numpy() x_plot4 = z_dcnet_unet.view(-1, h, h).cpu().numpy() f, axs = plt.subplots(2, 2, figsize=(10, 10)) + +# Plot the ground-truth image im1 = axs[0, 0].imshow(x_plot[0, :, :], cmap="gray") axs[0, 0].set_title("Ground-truth image", fontsize=16) noaxis(axs[0, 0]) add_colorbar(im1, "bottom") +# Plot the pseudo inverse solution im2 = axs[0, 1].imshow(x_plot2[0, :, :], cmap="gray") -axs[0, 1].set_title("PinvNet", fontsize=16) +axs[0, 1].set_title("Pseudo inverse", fontsize=16) noaxis(axs[0, 1]) add_colorbar(im2, "bottom") +# Plot the solution obtained from denoised completion im3 = axs[1, 0].imshow(x_plot3[0, :, :], cmap="gray") -axs[1, 0].set_title(f"DCNet (without denoising)", fontsize=16) +axs[1, 0].set_title(f"Denoised completion", fontsize=16) noaxis(axs[1, 0]) add_colorbar(im3, "bottom") +# Plot the solution obtained from denoised completion with UNet denoising im4 = axs[1, 1].imshow(x_plot4[0, :, :], cmap="gray") -axs[1, 1].set_title(f"DCNet (UNet denoising)", fontsize=16) +axs[1, 1].set_title(f"Denoised completion with UNet denoising", fontsize=16) noaxis(axs[1, 1]) add_colorbar(im4, "bottom") plt.show() -############################################################################### -# Comparing results, PinvNet provides pixelized reconstruction, DCNet with no denoising -# leads to a smoother reconstruction, as expected by a Tikonov regularization, and -# DCNet with UNet denoising provides the best reconstruction. +###################################################################### +# .. note:: +# While the pseudo inverse reconstrcution is pixelized, the solution obtained by denoised completion is smoother. DCNet with UNet denoising in the image domain provides the best reconstruction. -############################################################################### +###################################################################### # .. note:: -# -# In this tutorial, we have used DCNet with a UNet denoising layer for split measurements. -# We refer to `spyrit-examples tutorials `_ -# for a comparison of different solutions for split measurements (pinvNet, DCNet and DRUNet). +# We refer to `spyrit-examples tutorials `_ for a comparison of different solutions (pinvNet, DCNet and DRUNet) that can be run in colab. diff --git a/tutorial/tuto_07_drunet_split_measurements.py b/tutorial/tuto_07_drunet_split_measurements.py new file mode 100644 index 00000000..744b30cf --- /dev/null +++ b/tutorial/tuto_07_drunet_split_measurements.py @@ -0,0 +1,336 @@ +r""" +====================================================================== +07. DCNet with plug-and-play DRUNet denoising +====================================================================== +.. _tuto_dcdrunet_split_measurements: + +This tutorial shows how to perform image reconstruction using a DCNet (data completion network) that includes a `DRUNet denoiser `_. DRUNet is a pretrained plug-and-play denoising network that has been pretrained for a wide range of noise levels. DRUNet admits the noise level as an input. Contratry to the DCNet described in :ref:`Tutorial 6 `, it requires no training. +""" + +###################################################################### +# .. figure:: ../fig/drunet.png +# :width: 600 +# :align: center +# :alt: DCNet with DRUNet denoising in the image domain + +###################################################################### +# .. note:: +# +# As in the previous tutorials, we consider a split Hadamard operator and measurements corrupted by Poisson noise (see :ref:`Tutorial 5 `). + +import numpy as np +import os +from spyrit.misc.disp import imagesc +import matplotlib.pyplot as plt + + +# %% +# Load a batch of images +# ==================================================================== + +###################################################################### +# Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized using the :func:`transform_gray_norm` function. + +# sphinx_gallery_thumbnail_path = 'fig/drunet.png' + +from spyrit.misc.statistics import transform_gray_norm +import torchvision +import torch + +h = 64 # image size hxh +i = 1 # Image index (modify to change the image) +spyritPath = os.getcwd() +imgs_path = os.path.join(spyritPath, "images") + + +# Create a transform for natural images to normalized grayscale image tensors +transform = transform_gray_norm(img_size=h) + +# Create dataset and loader (expects class folder 'images/test/') +dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) +dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) + +x, _ = next(iter(dataloader)) +print(f"Shape of input images: {x.shape}") + +# Select image +x = x[i : i + 1, :, :, :] +x = x.detach().clone() +b, c, h, w = x.shape + +# plot +x_plot = x.view(-1, h, h).cpu().numpy() +imagesc(x_plot[0, :, :], r"$x$ in [-1, 1]") + +# %% +# Operators for split measurements +# ==================================================================== + +###################################################################### +# We consider noisy measurements obtained from a split Hadamard operator, and a subsampling strategy that retaines the coefficients with the largest variance (for more details, refer to :ref:`Tutorial 5 `). + +###################################################################### +# First, we download the covariance matrix from our warehouse. + +import girder_client + +# api Rest url of the warehouse +url = "https://pilot-warehouse.creatis.insa-lyon.fr/api/v1" + +# Generate the warehouse client +gc = girder_client.GirderClient(apiUrl=url) + +# Download the covariance matrix and mean image +data_folder = "./stat/" +dataId_list = [ + "63935b624d15dd536f0484a5", # for reconstruction (imageNet, 64) + "63935a224d15dd536f048496", # for reconstruction (imageNet, 64) +] +cov_name = "./stat/Cov_64x64.npy" + +try: + for dataId in dataId_list: + myfile = gc.getFile(dataId) + gc.downloadFile(dataId, data_folder + myfile["name"]) + + print(f"Created {data_folder}") + + Cov = np.load(cov_name) + print(f"Cov matrix {cov_name} loaded") +except: + Cov = np.eye(h * h) + print(f"Cov matrix {cov_name} not found! Set to the identity") + +###################################################################### +# We define the measurement, noise and preprocessing operators and then simulate a measurement vector corrupted by Poisson noise. As in the previous tutorials, we simulate an accelerated acquisition by subsampling the measurement matrix by retaining only the first rows of a Hadamard matrix that is permuted looking at the diagonal of the covariance matrix. + + +from spyrit.core.meas import HadamSplit +from spyrit.core.noise import Poisson +from spyrit.misc.sampling import meas2img2 +from spyrit.misc.statistics import Cov2Var +from spyrit.core.prep import SplitPoisson + +# Measurement parameters +M = 64 * 64 // 4 # Number of measurements (here, 1/4 of the pixels) +alpha = 100.0 # number of photons + +# Measurement and noise operators +Ord = Cov2Var(Cov) +meas_op = HadamSplit(M, h, Ord) +noise_op = Poisson(meas_op, alpha) +prep_op = SplitPoisson(alpha, meas_op) + +# Vectorize image +x = x.view(b * c, h * w) +print(f"Shape of vectorized image: {x.shape}") + +# Measurements +y = noise_op(x) # a noisy measurement vector +m = prep_op(y) # preprocessed measurement vector + +m_plot = m.detach().numpy() +m_plot = meas2img2(m_plot.T, Ord) +imagesc(m_plot, r"Measurements $m$") + +# %% +# DRUNet denoising +# ==================================================================== + +###################################################################### +# DRUNet is defined by the :class:`spyrit.external.drunet.DRUNet` class. This class inherits from the original :class:`spyrit.external.drunet.UNetRes` class introduced in [ZhLZ21]_, with some modifications to handle different noise levels. + +############################################################################### +# We instantiate the DRUNet by providing the noise level, which is expected to be in [0, 255], and the number of channels. The larger the noise level, the higher the denoising. + +from spyrit.external.drunet import DRUNet + +noise_level = 7 +denoi_drunet = DRUNet(noise_level=noise_level, n_channels=1) + +# Use GPU, if available +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +denoi_drunet = denoi_drunet.to(device) + +############################################################################### +# We download the pretrained weights of the DRUNet and load them. + +try: + import gdown + + # Download pretrained weights + model_drunet_path = "./model" + url_drunet = "https://drive.google.com/file/d/1fhnIDJAbh7IRSZ9tgk4JPtfGra4O1ghk/view?usp=drive_link" + + if os.path.exists(model_drunet_path) is False: + os.mkdir(model_drunet_path) + print(f"Created {model_drunet_path}") + + model_drunet_path = os.path.join(model_drunet_path, "drunet_gray.pth") + gdown.download(url_drunet, model_drunet_path, quiet=False, fuzzy=True) + + # Load pretrained weights + denoi_drunet.load_state_dict(torch.load(model_drunet_path), strict=False) + print(f"Model {denoi_drunet} loaded.") +except: + print(f"Model {model_drunet_path} not found!") + +# %% +# Pluggind the DRUnet in a DCNet +# ==================================================================== + +###################################################################### +# We define the DCNet network by providing the forward operator, preprocessing operator, covariance prior and denoising prior. The DCNet class :class:`spyrit.core.recon.DCNet` is discussed in :ref:`Tutorial 06 `. + +from spyrit.core.recon import DCNet + +dcnet_drunet = DCNet(noise_op, prep_op, torch.from_numpy(Cov), denoi=denoi_drunet) +dcnet_drunet = dcnet_drunet.to(device) # Use GPU, if available + +###################################################################### +# Then, we reconstruct the image from the noisy measurements. + +with torch.no_grad(): + z_dcnet_drunet = dcnet_drunet.reconstruct(y.to(device)) + +# %% +# Tunning of the denoising +# ==================================================================== + +###################################################################### +# We reconstruct the images for another two different noise levels of DRUnet + +noise_level_2 = 1 +noise_level_3 = 20 + +with torch.no_grad(): + + denoi_drunet.set_noise_level(noise_level_2) + z_dcnet_drunet_2 = dcnet_drunet.reconstruct(y.to(device)) + + denoi_drunet.set_noise_level(noise_level_3) + z_dcnet_drunet_3 = dcnet_drunet.reconstruct(y.to(device)) + +###################################################################### +# Plot all reconstructions +from spyrit.misc.disp import add_colorbar, noaxis + +x_plot = z_dcnet_drunet.view(-1, h, h).cpu().numpy() +x_plot2 = z_dcnet_drunet_2.view(-1, h, h).cpu().numpy() +x_plot3 = z_dcnet_drunet_3.view(-1, h, h).cpu().numpy() + +f, axs = plt.subplots(1, 3, figsize=(10, 5)) +im1 = axs[0].imshow(x_plot2[0, :, :], cmap="gray") +axs[0].set_title(f"DRUNet\n (n map={noise_level_2})", fontsize=16) +noaxis(axs[0]) +add_colorbar(im1, "bottom") + +im2 = axs[1].imshow(x_plot[0, :, :], cmap="gray") +axs[1].set_title(f"DRUNet\n (n map={noise_level})", fontsize=16) +noaxis(axs[1]) +add_colorbar(im2, "bottom") + +im3 = axs[2].imshow(x_plot3[0, :, :], cmap="gray") +axs[2].set_title(f"DRUNet\n (n map={noise_level_3})", fontsize=16) +noaxis(axs[2]) +add_colorbar(im3, "bottom") + +# %% +# Alternative implementation showing the advantage of the :class:`~spyrit.external.drunet.DRUNet` class +# ==================================================================== + +############################################################################## +# First, we consider DCNet without denoising in the image domain (default behaviour) + +dcnet = DCNet(noise_op, prep_op, torch.from_numpy(Cov)) +dcnet = dcnet.to(device) + +with torch.no_grad(): + z_dcnet = dcnet.reconstruct(y.to(device)) + +###################################################################### +# Then, we instantiate DRUNet using the original class :class:`spyrit.external.drunet.UNetRes`. + +from spyrit.external.drunet import UNetRes as drunet + +# Define denoising network +n_channels = 1 # 1 for grayscale image +drunet_den = drunet(in_nc=n_channels + 1, out_nc=n_channels) + +# Load pretrained model +try: + drunet_den.load_state_dict(torch.load(model_drunet_path), strict=True) + print(f"Model {model_drunet_path} loaded.") +except: + print(f"Model {model_drunet_path} not found!") + load_drunet = False +drunet_den = drunet_den.to(device) + +###################################################################### +# To denoise the output of DCNet, we create noise-level map that we concatenate to the output of DCNet that we normalize in [0,1] + +x_sample = 0.5 * (z_dcnet + 1).cpu() + +# +x_sample = torch.cat( + ( + x_sample, + torch.FloatTensor([noise_level / 255.0]).repeat( + 1, 1, x_sample.shape[2], x_sample.shape[3] + ), + ), + dim=1, +) +x_sample = x_sample.to(device) + +with torch.no_grad(): + z_dcnet_den = drunet_den(x_sample) + +############################################################################## +# We plot all results + +x_plot = x.view(-1, h, h).cpu().numpy() +x_plot2 = z_dcnet.view(-1, h, h).cpu().numpy() +x_plot3 = z_dcnet_drunet.view(-1, h, h).cpu().numpy() +x_plot4 = z_dcnet_den.view(-1, h, h).cpu().numpy() + +f, axs = plt.subplots(2, 2, figsize=(10, 10)) +im1 = axs[0, 0].imshow(x_plot[0, :, :], cmap="gray") +axs[0, 0].set_title("Ground-truth image", fontsize=16) +noaxis(axs[0, 0]) +add_colorbar(im1, "bottom") + +im2 = axs[0, 1].imshow(x_plot2[0, :, :], cmap="gray") +axs[0, 1].set_title("No denoising", fontsize=16) +noaxis(axs[0, 1]) +add_colorbar(im2, "bottom") + +im3 = axs[1, 0].imshow(x_plot3[0, :, :], cmap="gray") +axs[1, 1].set_title(f"Using DRUNet with n map={noise_level}", fontsize=16) +noaxis(axs[1, 0]) +add_colorbar(im3, "bottom") + +im4 = axs[1, 1].imshow(x_plot4[0, :, :], cmap="gray") +axs[1, 0].set_title(f"Using UNetRes with n map={noise_level}", fontsize=16) +noaxis(axs[1, 1]) +add_colorbar(im4, "bottom") + +plt.show() + +############################################################################### The results are identical to those obtained using :class:`~spyrit.external.drunet.DRUNet`. + +############################################################################### +# .. note:: +# +# In this tutorial, we have used DRUNet with a DCNet but it can be used any other network, such as pinvNet. In addition, we have considered pretrained weights, leading to a plug-and-play strategy that does not require training. However, the DCNet-DRUNet network can be trained end-to-end to improve the reconstruction performance in a specific setting (where training is done for all noise levels at once). For more details, refer to the paper [ZhLZ21]_. + +############################################################################### +# .. note:: +# +# We refer to `spyrit-examples tutorials `_ for a comparison of different solutions (pinvNet, DCNet and DRUNet) that can be run in colab. + +###################################################################### +# .. rubric:: References for DRUNet +# +# .. [ZhLZ21] Zhang, K.; Li, Y.; Zuo, W.; Zhang, L.; Van Gool, L.; Timofte, R..: Plug-and-Play Image Restoration with Deep Denoiser Prior. In: IEEE Transactions on Pattern Analysis and Machine Intelligence, 44(10), 6360-6376, 2021. +# .. [ZhZG17] Zhang, K.; Zuo, W.; Gu, S.; Zhang, L..: Learning Deep CNN Denoiser Prior for Image Restoration. In: IEEE Conference on Computer Vision and Pattern Recognition, 3929-3938, 2017.