diff --git a/docs/source/fig/tuto9.png b/docs/source/fig/tuto9.png new file mode 100644 index 00000000..f1038fd0 Binary files /dev/null and b/docs/source/fig/tuto9.png differ diff --git a/spyrit/core/meas.py b/spyrit/core/meas.py index f218b1e4..f352e309 100644 --- a/spyrit/core/meas.py +++ b/spyrit/core/meas.py @@ -25,13 +25,12 @@ """ import warnings -from typing import Union import math import torch import torch.nn as nn -from spyrit.core.time import DeformationField +from spyrit.core.warp import DeformationField import spyrit.core.torch as spytorch @@ -383,6 +382,10 @@ def H_pinv(self, value: torch.tensor) -> None: value.to(torch.float64), requires_grad=False ) + @H_pinv.deleter + def H_pinv(self) -> None: + del self._param_H_static_pinv + def set_H_pinv(self, rtol: float = None) -> None: """Used to set the pseudo inverse of the measurement matrix :math:`H` using `torch.linalg.pinv`. @@ -866,6 +869,10 @@ def H_pinv(self) -> torch.tensor: """Dynamic pseudo-inverse H_pinv. Equal to self.H_dyn_pinv.""" return self.H_dyn_pinv + @H_pinv.deleter + def H_pinv(self) -> None: + del self._param_H_dyn_pinv + @property def H_dyn_pinv(self) -> torch.tensor: """Dynamic pseudo-inverse H_pinv.""" diff --git a/spyrit/core/prep.py b/spyrit/core/prep.py index 20b34c65..f9f8d367 100644 --- a/spyrit/core/prep.py +++ b/spyrit/core/prep.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn -from spyrit.core.meas import Linear, LinearSplit, HadamSplit # , LinearRowSplit +from spyrit.core.meas import LinearSplit, HadamSplit # , Linear # ============================================================================== diff --git a/spyrit/core/recon.py b/spyrit/core/recon.py index 87ea55b5..0344e315 100644 --- a/spyrit/core/recon.py +++ b/spyrit/core/recon.py @@ -7,13 +7,12 @@ import math import torch -import torchvision + +# import torchvision import torch.nn as nn import numpy as np -import spyrit.core.torch as spytorch from spyrit.core.meas import Linear, DynamicLinear, HadamSplit -from spyrit.core.time import DeformationField from spyrit.core.noise import NoNoise from spyrit.core.prep import DirectPoisson, SplitPoisson @@ -48,6 +47,7 @@ def forward( self, x: torch.tensor, meas_op: Union[Linear, DynamicLinear], + **kwargs, ) -> torch.tensor: r"""Computes pseudo-inverse of measurements. @@ -76,7 +76,7 @@ def forward( >>> print(x.shape) torch.Size([85, 1024]) """ - return meas_op.pinv(x) + return meas_op.pinv(x, **kwargs) # ============================================================================= diff --git a/spyrit/core/torch.py b/spyrit/core/torch.py index d29c3d06..b90c8243 100644 --- a/spyrit/core/torch.py +++ b/spyrit/core/torch.py @@ -7,14 +7,11 @@ functions, but using pytorch tensors instead of numpy arrays. """ -import warnings +# import warnings import torch import torch.nn as nn import torchvision -import math -import scipy -import numpy as np # ============================================================================= diff --git a/spyrit/core/time.py b/spyrit/core/warp.py similarity index 100% rename from spyrit/core/time.py rename to spyrit/core/warp.py diff --git a/spyrit/misc/disp.py b/spyrit/misc/disp.py index c26cfeb3..5d5b43ad 100644 --- a/spyrit/misc/disp.py +++ b/spyrit/misc/disp.py @@ -175,6 +175,7 @@ def imagesc( cax = plt.axes([0.85, 0.1, 0.075, 0.8]) plt.colorbar(cax=cax, orientation="vertical") + fig.tight_layout() if show is True: plt.show() diff --git a/spyrit/test/run_tests.py b/spyrit/test/run_tests.py index d70aa5c6..7db6cef1 100644 --- a/spyrit/test/run_tests.py +++ b/spyrit/test/run_tests.py @@ -2,7 +2,7 @@ from test_core_noise import test_core_noise from test_core_prep import test_core_prep from test_core_recon import test_core_recon -from test_core_time import test_core_time +from test_core_warp import test_core_warp def run_tests(): @@ -10,8 +10,8 @@ def run_tests(): test_core_meas() test_core_noise() test_core_prep() - test_core_time() - test_core_recon() # must be after time + test_core_warp() + test_core_recon() # must be after warp if __name__ == "__main__": diff --git a/spyrit/test/test_core_meas.py b/spyrit/test/test_core_meas.py index 3de7ce21..33c0810a 100644 --- a/spyrit/test/test_core_meas.py +++ b/spyrit/test/test_core_meas.py @@ -241,7 +241,7 @@ def test_core_meas(): print("ok") # Build dynamic measurement matrix - from spyrit.core.time import AffineDeformationField + from spyrit.core.warp import AffineDeformationField print("\tBuild dynamic measurement matrix... ", end="") H = torch.rand(400, 2500, dtype=torch.float64) diff --git a/spyrit/test/test_core_recon.py b/spyrit/test/test_core_recon.py index c818f0fb..0a6d95c3 100644 --- a/spyrit/test/test_core_recon.py +++ b/spyrit/test/test_core_recon.py @@ -13,7 +13,7 @@ from spyrit.core.meas import HadamSplit, DynamicLinear, DynamicHadamSplit from spyrit.core.noise import NoNoise from spyrit.core.prep import SplitPoisson -from spyrit.core.time import AffineDeformationField +from spyrit.core.warp import AffineDeformationField def test_core_recon(): diff --git a/spyrit/test/test_core_time.py b/spyrit/test/test_core_warp.py similarity index 94% rename from spyrit/test/test_core_time.py rename to spyrit/test/test_core_warp.py index 9c003e3a..749e82a3 100644 --- a/spyrit/test/test_core_time.py +++ b/spyrit/test/test_core_warp.py @@ -1,5 +1,5 @@ """ -Test for module time.py +Test for module warp.py Author: Romain Phan """ @@ -9,14 +9,14 @@ from test_helpers import * -def test_core_time(): +def test_core_warp(): - print("\n*** Testing time.py ***") + print("\n*** Testing warp.py ***") # ========================================================================= ## DeformationField print("DeformationField") - from spyrit.core.time import DeformationField + from spyrit.core.warp import DeformationField # constructor print("\tconstructor... ", end="") @@ -88,7 +88,7 @@ def test_core_time(): # ========================================================================= ## AffineDeformationField print("AffineDeformationField") - from spyrit.core.time import AffineDeformationField + from spyrit.core.warp import AffineDeformationField # constructor print("\tconstructor... ", end="") @@ -146,10 +146,10 @@ def f(t): # print("ok") # ========================================================================= - print("All tests passed for time.py") + print("All tests passed for warp.py") print("==============================") return True if __name__ == "__main__": - test_core_time() + test_core_warp() diff --git a/tutorial/README.txt b/tutorial/README.txt index 644cce7a..33defce3 100644 --- a/tutorial/README.txt +++ b/tutorial/README.txt @@ -11,29 +11,25 @@ these tutorials can be found on `this page`_ of the Spyrit GitHub. Below is a diagram of the entire image processing pipeline. Each tutorial focuseson a specific part of the pipeline. -* :ref:`Tutorial 1 ` focuses -on the measurement operators, with or without noise +* :ref:`Tutorial 1 ` focuses on the measurement operators, with or without noise -* :ref:`Tutorial 2 ` explains -the pseudo-inverse reconstruction process from the (possibly noisy) -measurements +* :ref:`Tutorial 2 ` explains the pseudo-inverse reconstruction process from the (possibly noisy) measurements -* :ref:`Tutorial 3 ` uses -a CNN to denoise the image if necessary +* :ref:`Tutorial 3 ` uses a CNN to denoise the image if necessary -* :ref:`Tutorial 4 ` -is used to train the CNN introduced in Tutorial 3 +* :ref:`Tutorial 4 ` is used to train the CNN introduced in Tutorial 3 -* :ref:`Tutorial 5 ` -introduces a new type of measurement operator ('split') that simulates positive -and negative measurements +* :ref:`Tutorial 5 ` introduces a new type of measurement operator ('split') that simulates positive and negative measurements -* :ref:`Tutorial 6 ` uses -a Denoised Completion Network with a trainable image denoiser to improve the -results obtained in Tutorial 5 +* :ref:`Tutorial 6 ` uses a Denoised Completion Network with a trainable image denoiser to improve the results obtained in Tutorial 5 -* Explore :ref:`Bonus Tutorial ` -if you want to go deeper into Spyrit's capabilities +* :ref:`Tutorial 7 ` shows how to perform image reconstruction using a pretrained plug-and-play denoising network. + +* :ref:`Tutorial 8 ` shows how to perform image reconstruction using a learnt proximal gradient descent (AVAILABLE SOON). + +* :ref:`Tutorial 9 ` explains motion simulation from an image, dynamic measurements and reconstruction. + +* Explore :ref:`Bonus Tutorial ` if you want to go deeper into Spyrit's capabilities .. image:: ../fig/full.png diff --git a/tutorial/tuto_02_pseudoinverse_linear.py b/tutorial/tuto_02_pseudoinverse_linear.py index 4fa93134..f0b0e627 100644 --- a/tutorial/tuto_02_pseudoinverse_linear.py +++ b/tutorial/tuto_02_pseudoinverse_linear.py @@ -178,8 +178,22 @@ # ----------------------------------------------------------------------------- ############################################################################### +# There are two ways to perform the pseudo inverse reconstruction from the +# measurements :attr:`y`. The first consists of explicitly computing the +# pseudo inverse of the measurement matrix :attr:`H` and applying it to the +# measurements. The second computes a least-squares solution using :func:`torch.linalg.lstsq` +# to compute the pseudo inverse solution. +# The choice is made automatically: if the measurement operator has a pseudo-inverse +# already computed, it is used; otherwise, the least-squares solution is used. +# +# .. note:: +# Generally, the second method is preferred because it is faster and more +# numerically stable. However, if you will use the pseudo inverse multiple +# times, it becomes more efficient to compute it explicitly. +# +# First way: explicit computation of the pseudo inverse # We can use the :class:`spyrit.core.recon.PseudoInverse` class to perform the -# pseudo inverse reconstruction from the measurements :attr:`y` +# pseudo inverse reconstruction from the measurements :attr:`y`. from spyrit.core.recon import PseudoInverse @@ -187,11 +201,48 @@ recon_op = PseudoInverse() # Reconstruction -x_rec = recon_op(y, meas_op) +x_rec1 = recon_op(y, meas_op) # equivalent to: meas_op.pinv(y) + +############################################################################### +# Second way: calling pinv method from the Linear operator +# The code is very similar to the previous case, but we need to make sure the +# measurement operator has no pseudo-inverse computed. We can also specify +# regularization parameters for the least-squares solution when calling +# `recon_op`. + +print(f"Pseudo-inverse computed: {hasattr(meas_op, 'H_pinv')}") +temp = meas_op.H_pinv # save the pseudo-inverse +del meas_op.H_pinv # delete the pseudo-inverse +print(f"Pseudo-inverse computed: {hasattr(meas_op, 'H_pinv')}") + +# Reconstruction +x_rec2 = recon_op(y, meas_op, reg="L1", eta=1e-6) + +# restore the pseudo-inverse +meas_op.H_pinv = temp + +############################################################################## +# .. note:: +# This choice is also offered for dynamic measurement operators which are +# explained in :ref:`Tutorial 9 `. + +# plot side by side +import matplotlib.pyplot as plt +from spyrit.misc.disp import add_colorbar + +x_plot1 = x_rec1.squeeze().view(h, h).cpu().numpy() +x_plot2 = x_rec2.squeeze().view(h, h).cpu().numpy() + +fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) + +im1 = ax1.imshow(x_plot1, cmap="gray") +ax1.set_title("Explicit pseudo-inverse reconstruction") +add_colorbar(im1, "right", size="20%") + +im2 = ax2.imshow(x_plot2, cmap="gray") +ax2.set_title("Least-squares pseudo-inverse reconstruction") +add_colorbar(im2, "right", size="20%") -# plot -x_plot = x_rec.squeeze().view(h, h).cpu().numpy() -imagesc(x_plot, "Pseudoinverse reconstruction (no noise)", title_fontsize=20) # %% # PinvNet Network diff --git a/tutorial/tuto_06_dcnet_split_measurements.py b/tutorial/tuto_06_dcnet_split_measurements.py index 56b6b1de..4950d2e2 100644 --- a/tutorial/tuto_06_dcnet_split_measurements.py +++ b/tutorial/tuto_06_dcnet_split_measurements.py @@ -1,89 +1,85 @@ r""" -06. DCNet solution for split measurements +========================================= +06. Denoised Completion Network (DCNet) ========================================= .. _tuto_dcnet_split_measurements: - -This tutorial shows how to perform image reconstruction using DCNet (denoised -completion network) with -and without a trainable image denoiser. In the previous tutorial -:ref:`Acquisition - split measurements ` -we showed how to handle split measurements for a Hadamard operator -and how to perform a pseudo-inverse reconstruction with PinvNet. - -.. image:: ../fig/tuto6.png +This tutorial shows how to perform image reconstruction using the denoised +completion network (DCNet) with a trainable image denoiser. In the next +tutorial, we will plug a denoiser into a DCNet, which requires no training. +.. figure:: ../fig/tuto6.png :width: 600 :align: center - :alt: Reconstruction and neural network denoising architecture sketch using split measurements - -These tutorials load image samples from `/images/`. + :alt: Reconstruction and neural network denoising architecture sketch using + split measurements """ +###################################################################### +# .. note:: +# As in the previous tutorials, we consider a split Hadamard operator and +# measurements corrupted by Poisson noise (see :ref:`Tutorial 5 `). + # %% # Load a batch of images -# ----------------------------------------------------------------------------- - -############################################################################### -# Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized -# using the :func:`transform_gray_norm` function. - -import os - -import torch -import torchvision -import numpy as np -import matplotlib.pyplot as plt +# ========================================= -from spyrit.misc.disp import imagesc -from spyrit.misc.statistics import transform_gray_norm +###################################################################### +# Update search path # sphinx_gallery_thumbnail_path = 'fig/tuto6.png' +import os -h = 64 # image size hxh -i = 1 # Image index (modify to change the image) spyritPath = os.getcwd() imgs_path = os.path.join(spyritPath, "images/") +###################################################################### +# Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized and resized using the :func:`transform_gray_norm` function. +from spyrit.misc.statistics import transform_gray_norm -# Create a transform for natural images to normalized grayscale image tensors +h = 64 # image is resized to h x h transform = transform_gray_norm(img_size=h) -# Create dataset and loader (expects class folder 'images/test/') +###################################################################### +# Create a data loader from some dataset (images must be in the folder `images/test/`) +import torch +import torchvision + dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=7) x, _ = next(iter(dataloader)) print(f"Shape of input images: {x.shape}") -# Select image +###################################################################### +# Select the `i`-th image in the batch +i = 1 # Image index (modify to change the image) x = x[i : i + 1, :, :, :] x = x.detach().clone() b, c, h, w = x.shape -# plot +###################################################################### +# Plot the selected image +from spyrit.misc.disp import imagesc + x_plot = x.view(-1, h, h).cpu().numpy() imagesc(x_plot[0, :, :], r"$x$ in [-1, 1]") # %% # Forward operators for split measurements -# ----------------------------------------------------------------------------- +# ========================================= -############################################################################### -# We consider noisy split measurements for a Hadamard operator and a -# "variance subsampling" strategy that preserves the coefficients with the largest variance, -# obtained from a previously estimated covariance matrix (for more details, -# refer to :ref:`Acquisition - split measurements `). +###################################################################### +# We consider noisy measurements obtained from a split Hadamard operator, and a subsampling strategy that retaines the coefficients with the largest variance (for more details, refer to :ref:`Tutorial 5 `). -############################################################################### -# First, we download the covariance matrix and load it. +###################################################################### +# First, we download the covariance matrix from our warehouse. import girder_client +import numpy as np # api Rest url of the warehouse url = "https://pilot-warehouse.creatis.insa-lyon.fr/api/v1" - # Generate the warehouse client gc = girder_client.GirderClient(apiUrl=url) - # Download the covariance matrix and mean image data_folder = "./stat/" dataId_list = [ @@ -91,7 +87,6 @@ "63935a224d15dd536f048496", # for reconstruction (imageNet, 64) ] cov_name = "./stat/Cov_64x64.npy" - try: Cov = np.load(cov_name) print(f"Cov matrix {cov_name} loaded") @@ -99,20 +94,15 @@ for dataId in dataId_list: myfile = gc.getFile(dataId) gc.downloadFile(dataId, data_folder + myfile["name"]) - print(f"Created {data_folder}") - Cov = np.load(cov_name) print(f"Cov matrix {cov_name} loaded") except: Cov = np.eye(h * h) print(f"Cov matrix {cov_name} not found! Set to the identity") -############################################################################### -# We define the measurement, noise and preprocessing operators and then -# simulate a noiseless measurement vector :math:`y`. As in the previous tutorial, -# we simulate an accelerated acquisition by subsampling the measurement matrix -# by retaining only the first :math:`M` rows of a Hadamard matrix :math:`\textrm{Perm} H`. +###################################################################### +# We define the measurement, noise and preprocessing operators and then simulate a measurement vector corrupted by Poisson noise. As in the previous tutorials, we simulate an accelerated acquisition by subsampling the measurement matrix by retaining only the first rows of a Hadamard matrix that is permuted looking at the diagonal of the covariance matrix. from spyrit.core.meas import HadamSplit from spyrit.core.noise import Poisson @@ -124,10 +114,8 @@ M = 64 * 64 // 4 # Number of measurements (here, 1/4 of the pixels) alpha = 100.0 # number of photons -# Ordering matrix -Ord = Cov2Var(Cov) - # Measurement and noise operators +Ord = Cov2Var(Cov) meas_op = HadamSplit(M, h, torch.from_numpy(Ord)) noise_op = Poisson(meas_op, alpha) prep_op = SplitPoisson(alpha, meas_op) @@ -135,98 +123,108 @@ # Vectorize image x = x.view(b * c, h * w) print(f"Shape of vectorized image: {x.shape}") - # Measurements y = noise_op(x) # a noisy measurement vector m = prep_op(y) # preprocessed measurement vector m_plot = m.detach().numpy() m_plot = meas2img(m_plot, Ord) -imagesc(np.moveaxis(m_plot, 0, -1), r"Measurements $m$") +imagesc(m_plot[0, :, :], r"Measurements $m$") # %% -# PinvNet network -# ----------------------------------------------------------------------------- +# Pseudo inverse solution +# ========================================= -############################################################################### -# We reconstruct with the pseudo inverse using :class:`spyrit.core.recon.PinvNet` class -# as in the previous tutorial. For this, we define the neural network and then perform the reconstruction. +###################################################################### +# We compute the pseudo inverse solution using :class:`spyrit.core.recon.PinvNet` class as in the previous tutorial. + +# Instantiate a PinvNet (with no denoising by default) from spyrit.core.recon import PinvNet -from spyrit.misc.disp import add_colorbar, noaxis -# Reconstruction with for Core module (linear net) pinvnet = PinvNet(noise_op, prep_op) -# use GPU, if available +# Use GPU, if available device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - -# Pseudo-inverse net pinvnet = pinvnet.to(device) +y = y.to(device) # Reconstruction with torch.no_grad(): - z_invnet = pinvnet.reconstruct(y.to(device)) # reconstruct from raw measurements + z_invnet = pinvnet.reconstruct(y) # %% -# DCNet network -# ----------------------------------------------------------------------------- +# Denoised completion network (DCNet) +# ========================================= + +###################################################################### +# .. image:: ../fig/dcnet.png +# :width: 400 +# :align: center +# :alt: Sketch of the DCNet architecture -############################################################################### -# We can improve PinvNet results by using the *denoised* completion network DCNet with the -# :class:`spyrit.core.recon.DCNet` class. It has four sequential steps: +###################################################################### +# The DCNet is based on four sequential steps: # -# i) denoising of the acquired measurements, +# i) Denoising in the measurement domain. # -# ii) estimation of the missing measurements from the denoised ones, +# ii) Estimation of the missing measurements from the denoised ones. # -# iii) mapping them to the image domain, and +# iii) Image-domain mapping. # -# iv) denoising in the image-domain. +# iv) (Learned) Denoising in the image domain. # -# Only the last step involves learnable parameters. +# Typically, only the last step involves learnable parameters. -############################################################################### -# .. image:: ../fig/dcnet.png -# :width: 400 -# :align: center -# :alt: Sketch of the DCNet architecture -############################################################################### -# For the denoiser, we compare the default unit matrix (no denoising) with the UNet denoiser -# with the :class:`spyrit.core.nnet.Unet` class. For the latter, we load the pretrained model -# weights. +# %% +# Denoised completion +# ========================================= + +###################################################################### +# The first three steps implement denoised completion, which corresponds to Tikhonov regularization. Considering linear measurements :math:`y = Hx`, where :math:`H` is the measurement matrix and :math:`x` is the unknown image, it estimates :math:`x` from :math:`y` by minimizing +# +# .. math:: +# \| y - Hx \|^2_{\Sigma^{-1}_\alpha} + \|x\|^2_{\Sigma^{-1}}, +# +# where :math:`\Sigma` is a covariance prior and :math:`\Sigma_\alpha` is the noise covariance. Denoised completation can be performed using the :class:`~spyrit.core.recon.TikhonovMeasurementPriorDiag` class (see documentation for more details). -############################################################################### -# Without *learnable image-domain* denoising +###################################################################### +# In practice, it is more convenient to use the :class:`spyrit.core.recon.DCNet` class, which relies on a forward operator, a preprocessing operator, and a covariance prior. from spyrit.core.recon import DCNet -from spyrit.core.nnet import Unet -from torch import nn -# Reconstruction with for DCNet (linear net) -dcnet = DCNet(noise_op, prep_op, torch.from_numpy(Cov), denoi=nn.Identity()) +dcnet = DCNet(noise_op, prep_op, torch.from_numpy(Cov)) + +# Use GPU, if available dcnet = dcnet.to(device) +y = y.to(device) -# Reconstruction with torch.no_grad(): - z_dcnet = dcnet.reconstruct(y.to(device)) # reconstruct from raw measurements + z_dcnet = dcnet.reconstruct(y) -############################################################################### -# With a UNet denoising layer, we define the denoising network and -# then load the pretrained weights. +###################################################################### +# .. note:: +# In this tutorial, the covariance matrix used to define subsampling is also used as prior knowledge during reconstruction. -from spyrit.core.train import load_net -import matplotlib.pyplot as plt -from spyrit.misc.disp import add_colorbar, noaxis -# Define UNet denoiser -denoi = Unet() +# %% +# (Learned) Denoising in the image domain +# ========================================= + +###################################################################### +# To implement denoising in the image domain, we provide a :class:`spyrit.core.nnet.Unet` denoiser to a :class:`spyrit.core.recon.DCNet`. -# Define DCNet (with UNet denoising) +from spyrit.core.nnet import Unet + +denoi = Unet() dcnet_unet = DCNet(noise_op, prep_op, torch.from_numpy(Cov), denoi) -dcnet_unet = dcnet_unet.to(device) +dcnet_unet = dcnet_unet.to(device) # Use GPU, if available + +######################################################################## +# We load pretrained weights for the UNet + +from spyrit.core.train import load_net -# Load previously trained model # Download weights url_unet = "https://drive.google.com/file/d/15PRRZj5OxKpn1iJw78lGwUUBtTbFco1l/view?usp=drive_link" model_path = "./model" @@ -237,7 +235,6 @@ model_path, "dc-net_unet_stl10_N0_100_N_64_M_1024_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07.pth", ) - load_unet = True if os.path.exists(model_unet_path) is False: try: @@ -247,58 +244,59 @@ except: print(f"Model {model_unet_path} not found!") load_unet = False - if load_unet: # Load pretrained model load_net(model_unet_path, dcnet_unet, device, False) # print(f"Model {model_unet_path} loaded.") - -# Reconstruction +###################################################################### +# We reconstruct the image with torch.no_grad(): - z_dcnet_unet = dcnet_unet.reconstruct( - y.to(device) - ) # reconstruct from raw measurements + z_dcnet_unet = dcnet_unet.reconstruct(y) -############################################################################### -# We plot all results +# %% +# Results +# ========================================= + +import matplotlib.pyplot as plt +from spyrit.misc.disp import add_colorbar, noaxis -# plot reconstruction side by side x_plot = x.view(-1, h, h).cpu().numpy() x_plot2 = z_invnet.view(-1, h, h).cpu().numpy() x_plot3 = z_dcnet.view(-1, h, h).cpu().numpy() x_plot4 = z_dcnet_unet.view(-1, h, h).cpu().numpy() f, axs = plt.subplots(2, 2, figsize=(10, 10)) + +# Plot the ground-truth image im1 = axs[0, 0].imshow(x_plot[0, :, :], cmap="gray") axs[0, 0].set_title("Ground-truth image", fontsize=16) noaxis(axs[0, 0]) add_colorbar(im1, "bottom") +# Plot the pseudo inverse solution im2 = axs[0, 1].imshow(x_plot2[0, :, :], cmap="gray") -axs[0, 1].set_title("PinvNet", fontsize=16) +axs[0, 1].set_title("Pseudo inverse", fontsize=16) noaxis(axs[0, 1]) add_colorbar(im2, "bottom") +# Plot the solution obtained from denoised completion im3 = axs[1, 0].imshow(x_plot3[0, :, :], cmap="gray") -axs[1, 0].set_title(f"DCNet (without denoising)", fontsize=16) +axs[1, 0].set_title(f"Denoised completion", fontsize=16) noaxis(axs[1, 0]) add_colorbar(im3, "bottom") +# Plot the solution obtained from denoised completion with UNet denoising im4 = axs[1, 1].imshow(x_plot4[0, :, :], cmap="gray") -axs[1, 1].set_title(f"DCNet (UNet denoising)", fontsize=16) +axs[1, 1].set_title(f"Denoised completion with UNet denoising", fontsize=16) noaxis(axs[1, 1]) add_colorbar(im4, "bottom") plt.show() -############################################################################### -# Comparing results, PinvNet provides pixelized reconstruction, DCNet with no denoising -# leads to a smoother reconstruction, as expected by a Tikonov regularization, and -# DCNet with UNet denoising provides the best reconstruction. +###################################################################### +# .. note:: +# While the pseudo inverse reconstrcution is pixelized, the solution obtained by denoised completion is smoother. DCNet with UNet denoising in the image domain provides the best reconstruction. -############################################################################### +###################################################################### # .. note:: -# -# In this tutorial, we have used DCNet with a UNet denoising layer for split measurements. -# We refer to `spyrit-examples tutorials `_ -# for a comparison of different solutions for split measurements (pinvNet, DCNet and DRUNet). +# We refer to `spyrit-examples tutorials `_ for a comparison of different solutions (pinvNet, DCNet and DRUNet) that can be run in colab. diff --git a/tutorial/tuto_07_drunet_split_measurements.py b/tutorial/tuto_07_drunet_split_measurements.py index 42169844..07571560 100644 --- a/tutorial/tuto_07_drunet_split_measurements.py +++ b/tutorial/tuto_07_drunet_split_measurements.py @@ -107,7 +107,7 @@ from spyrit.core.meas import HadamSplit from spyrit.core.noise import Poisson -from spyrit.misc.sampling import meas2img2 +from spyrit.misc.sampling import meas2img from spyrit.misc.statistics import Cov2Var from spyrit.core.prep import SplitPoisson @@ -130,7 +130,7 @@ m = prep_op(y) # preprocessed measurement vector m_plot = m.detach().numpy() -m_plot = meas2img2(m_plot.T, Ord) +m_plot = meas2img(m_plot, Ord) imagesc(m_plot[0, :, :], r"Measurements $m$") # %% diff --git a/tutorial/tuto_09_dynamic.py b/tutorial/tuto_09_dynamic.py index 8143fbe2..f3d74bb6 100644 --- a/tutorial/tuto_09_dynamic.py +++ b/tutorial/tuto_09_dynamic.py @@ -7,18 +7,19 @@ of a moving object. There are three steps in this process: 1. First, a still image is deformed to generate multiple frames. This step -simulates movement of the object. The module :mod:`spyrit.core.time` is used +simulates movement of the object. The module :mod:`spyrit.core.warp` is used to warp images. 2. Second, the measurement is performed on the series of frames. The 'Dynamic' classes from :mod:`spyrit.core.meas` are used. 3. Third, the reconstruction from pesudo-inverse matrices is used to reconstruct -the original image. +the motion-compensated image. This tutorial will present an example in which all three steps will be -explained in an example. To understand the specificities of each class, a more -detailed explanation of each class is included at the end of the case study. +explained in an example. To understand the specificities of the module +:mod:`spyrit.core.warp`, a more detailed explanation is included at the end +of the example. .. image:: ../fig/tuto9.png :width: 600 @@ -34,14 +35,10 @@ # This tutorial loads example images from the relative folder `/images/`. # %% -# 1.a Example: load an image from a batch of images +# 1.a Load an image from a batch of images # ----------------------------------------------------------------------------- -# This part is identical to other tutorials, but for the image size. Here, we -# consider a square image of side 50 pixels, and the measurement patterns will -# correspond to a Hadamard matrix of size 32x32. It is the center of the image -# that will be measured with those patterns. This leaves a border of 9 pixels -# on each side of the image, allowing for the object to move in and out of the -# measurement area. +# This part is identical to other tutorials. We consider an image of size +# 32x32 pixels. import os @@ -54,7 +51,7 @@ # sphinx_gallery_thumbnail_path = 'fig/tuto9.png' -img_size = 50 # full image side's size in pixels +img_size = 32 # full image side's size in pixels meas_size = 32 # measurement pattern side's size in pixels (Hadamard matrix) img_shape = (img_size, img_size) meas_shape = (meas_size, meas_size) @@ -79,14 +76,14 @@ b, c, h, w = x.shape # plot -x_plot = x.view(img_size, img_size).cpu() +x_plot = x.view(img_shape).cpu() imagesc(x_plot, r"Original image $x$ in [-1, 1]") # %% -# 1.b Example: defining an affine transformation +# 1.b Define an affine transformation # ----------------------------------------------------------------------------- # Here we will define an affine transformation using a matrix and the class -# :class:`spyrit.core.time.AffineDeformationField`. +# :class:`spyrit.core.warp.AffineDeformationField`. # # This class takes 3 arguments: # a function :math:`f(t) = Mat`, where :math:`t` represents the time @@ -99,7 +96,7 @@ # # Let's first see th construction of the function :math:`f`. -from spyrit.core.time import AffineDeformationField +from spyrit.core.warp import AffineDeformationField # we want to define a deformation similar to that see in [ref to Thomas]. @@ -131,34 +128,44 @@ def f(t): # Next, we will create the time vector and define the image shape. # # The measurement size (the size of the Hadamard patterns applied to the image) -# detemrines the number of measurements - if there is no subsampling. The +# determines the number of measurements - if there is no subsampling. The # number of patterns must match the number of frames of the motion picture. It # is for this reason that the number of frames is set to the square of the # measurement size. -time_vector = torch.linspace(0, 10, 2 * meas_size**2) # *2 because of the splitting +time_vector = torch.linspace(0, 10, (meas_size**2) * 2) # *2 because of the splitting aff_field = AffineDeformationField(f, time_vector, img_shape) # %% -# 1.c Example: warping the image +# 1.c Warp the image # ----------------------------------------------------------------------------- # Now that the field is defined, we can warp the image. Spyrit works mostly # with vectorized images, and warping images is no exception. Currently, the -# classes :class:`spyrit.core.time.AffineDeformationField` and -# :class:`spyrit.core.time.DeformationField` can only warp a single image at a +# classes :class:`spyrit.core.warp.AffineDeformationField` and +# :class:`spyrit.core.warp.DeformationField` can only warp a single image at a # time. +import matplotlib.pyplot as plt +from spyrit.misc.disp import add_colorbar + # Reshape the image from (b,c,h,w) to (c, h*w) x = x.view(c, h * w) -x_motion = aff_field(x, 0, 2 * meas_size**2) +x_motion = aff_field(x, 0, (meas_size**2) * 2) c, n_frames, n_pixels = x_motion.shape # show random frames frames = [100, 300] -for f in frames: - imagesc(x_motion[0, f, :].view(img_shape).cpu().numpy(), f"Frame {f}") + +plot, axes = plt.subplots(1, len(frames), figsize=(10, 5)) + +for i, f in enumerate(frames): + im = axes[i].imshow(x_motion[0, f, :].view(img_shape).cpu().numpy(), cmap="gray") + axes[i].set_title(f"Frame {f}") + add_colorbar(im, "right", size="20%") +plot.tight_layout() +plt.show() # %% @@ -173,7 +180,7 @@ def f(t): # unavailable. # %% -# 2.a Example: defining the dynamic measurement operator +# 2.a Define the measurement operator # ----------------------------------------------------------------------------- # The class :class:`spyrit.core.meas.DynamicHadamSplit` is the mirror class of # :class:`spyrit.core.meas.HadamardSplit`. The difference is that the dynamic @@ -185,19 +192,18 @@ def f(t): # may want to set the number of patterns to the number of frames using the # parameter `M`. -from spyrit.core.noise import NoNoise from spyrit.core.meas import DynamicHadamSplit meas_op = DynamicHadamSplit(M=meas_size**2, h=meas_size, Ord=None, img_shape=img_shape) # show the measurement matrix H -imagesc(meas_op.H_static.cpu().numpy(), "Measurement matrix") - -# if we wanted to apply NoNoise, we would do it here -# noise = NoNoise(meas_op) (not available yet) +print("Shape of the measurement matrix H:", meas_op.H_static.shape) +# as we are using split measurements, it is the matrix P that is effectively +# used when computing the measurements +print("Shape of the measurement matrix P:", meas_op.P.shape) # %% -# 2.b Example: measuring the moving object +# 2.b Measure the moving object # ----------------------------------------------------------------------------- # Now that the measurement operator is defined, we can measure the moving # object. As with the static case, this is done by using the implicit forward @@ -211,49 +217,79 @@ def f(t): # %% -# 3. Example: reconstructing the still image +# 3. Example: reconstructing the motion-compensated image # ***************************************************************************** -# In this section, we will reconstruct the still image from the measurements. +# In this section, we will reconstruct the motion-compensated image from the measurements. # This is done by combining the information contained in the measurement -# patterns and in the deformation field. The class -# :class:`spyrit.core.meas.DynamicHadamSplit` (and the other dynamic classes) -# can handle the dynamic reconstruction through various methods. +# patterns and in the deformation field. This theoretical work has been +# explained in [1]_ and [2]_. The reconstruction follows the physical +# discretization of the problem, thus avoiding to warp the Hadamard patterns. +# The class :class:`spyrit.core.meas.DynamicHadamSplit` (and the other dynamic classes) +# handle the dynamic reconstruction through various methods. # %% -# 3.a Example: computing the dynamic measurement matrix +# 3.a Compute the dynamic measurement matrix # ----------------------------------------------------------------------------- -# The dynamic measurement matrix (:math:`H_dyn`) is defined as the measurement +# The dynamic measurement matrix :math:`H_{dyn}` is defined as the measurement # matrix that would give the same measurement vector :math:`y` as the one -# computed before when applied to a still image :math:`x_ref`. To build the -# dynamic measurement matrix, we need the measurement patterns and the +# computed before when applied to a still image :math:`x_{ref}`: +# +# .. math:: +# y = H_{dyn} x_{ref} +# +# Or, following the notations from [2]_, :math:`m = H_{dyn} f_{ref}`. +# +# To build the # dynamic measurement matrix, we need the measurement patterns and the # deformation field. In this case, the deformation field is known, but in some # cases it might have to be estimated. # -# The dynamic measurement matrix `H_dyn` is built from the measurement operator -# itself. +# The dynamic measurement matrix `H_dyn` is built using the measurement +# operator itself. # compute the dynamic measurement matrix print("H_dyn computed:", hasattr(meas_op, "H_dyn")) meas_op.build_H_dyn(aff_field, mode="bilinear") print("H_dyn computed:", hasattr(meas_op, "H_dyn")) +############################################################################### +# .. important:: +# Because :math:`P` is the actual matrix used for measuring, the attribute +# :attr:`H_dyn` is computed using the matrix :math:`P`. This can be seen in +# their shapes, which are the transpose of each other. + # recommended way print("H_dyn shape:", meas_op.H_dyn.shape) +print("P shape:", meas_op.P.shape) # NOT recommended, can cause confusions print("H shape:", meas_op.H.shape) print("H_dyn is same as H:", (meas_op.H == meas_op.H_dyn).all()) +# show P and H_dyn side by side + +plot, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 6)) + +im1 = ax1.imshow(meas_op.P.cpu().numpy(), vmin=0, vmax=1.5, cmap="gray") +ax1.set_title("Measurement matrix P") +add_colorbar(im1, "right", size="20%") + +im2 = ax2.imshow(meas_op.H_dyn.cpu().numpy(), vmin=0, vmax=1.5, cmap="gray") +ax2.set_title("Dynamic measurement matrix H_dyn") +add_colorbar(im2, "right", size="20%") + +plot.tight_layout() +plt.show() + ############################################################################### # This method adds to the measurement operator a new attribute named # :attr:`H_dyn`. It can also be accessed using the attribute name :attr:`H` for # compatibility reasons, although it is NOT recommended. # %% -# 3.b Example: reconstruct the original undeformed image +# 3.b Reconstruct the motion-compensated image # ----------------------------------------------------------------------------- # Now that the dynamic measurement matrix has been computed, we can reconstruct -# the original image. To do this, we can first compute the pseudo-inverse of -# our dynamic measurement matrix: +# the motion-compensated image. To do this, we can first compute the pseudo-inverse +# of our dynamic measurement matrix: # compute the pseudo-inverse using the requested regularizers print("H_dyn_pinv computed:", hasattr(meas_op, "H_dyn_pinv")) @@ -273,12 +309,14 @@ def f(t): # is not recommended. # # Once the pseudo-inverse has been computed, we can simply call the method -# :meth:`pinv` associated with some measurements to reconstruct the original +# :meth:`pinv` associated with some measurements to reconstruct the motion-compensated # image. As with the static case, this can also be done through the class # :class:`spyrit.core.recon.PseudoInverse` # using self.pinv directly x_hat1 = meas_op.pinv(y) +print("x_hat1 shape:", x_hat1.shape) + # using a PseudoInverse instance, no difference from spyrit.core.recon import PseudoInverse @@ -286,63 +324,59 @@ def f(t): x_hat2 = recon_op(y, meas_op) print("x_hat1 and x_hat2 are equal:", (x_hat1 == x_hat2).all()) -print("x_hat1 shape:", x_hat1.shape) -# show the reconstructed image and the difference with the original image -imagesc(x_hat1.view(img_shape), "Reconstructed image with pinv") -imagesc( - x_plot.view(img_shape) - x_hat1.view(img_shape), - "Difference between original\nand reconstructed image", -) +# show the motion-compensated image and the difference with the original image -############################################################################### -# It is possible to reconstruct the original image without having to compute -# the pseudo-inverse but using the least-squares function that is provided in -# torch: :func:`torch.linalg.lstsq`. -# -# This allows for a much faster reconstruction if you need to reconstruct only -# once, but is much slower if you need to reconstruct an image with the same -# parameters (measurement patterns and deformation field) more than 5-10 times. +plot, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) -# delete the H_dyn_pinv attribute -print("H_dyn_pinv computed:", hasattr(meas_op, "H_dyn_pinv")) -del meas_op.H_dyn_pinv -print("H_dyn_pinv computed:", hasattr(meas_op, "H_dyn_pinv")) +im1 = ax1.imshow(x_hat1.view(img_shape), cmap="gray") +ax1.set_title("Motion-compensated image,\nusing pinv") +add_colorbar(im1, "right", size="20%") -# use the pinv method directly, can specify reg and eta -x_hat3 = meas_op.pinv(y, reg="L1", eta=1e-6) +im2 = ax2.imshow(x_plot.view(img_shape) - x_hat1.view(img_shape), cmap="gray") +ax2.set_title("Difference between original\nand motion-compensated image") +add_colorbar(im2, "right", size="20%") -# show the reconstructed image and the difference with the original image -imagesc(x_hat3.view(img_shape), "Reconstructed image with lstsq") -imagesc( - x_plot.view(img_shape) - x_hat3.view(img_shape), - "Difference between original\nand reconstructed image", -) +# plot.tight_layout() +plt.show() + +# imagesc(x_hat1.view(img_shape), "Motion-compensated image, using pinv") +# imagesc(x_plot.view(img_shape) - x_hat1.view(img_shape), +# "Difference between original\nand motion-compensated image", +# ) + +############################################################################### +# .. important:: +# As with static reconstruction, it is possible to reconstruct the +# motion-compensated image without having to compute the pseudo-inverse +# explicitly. Calling the method :meth:`pinv` while the attribute +# :attr:`H_dyn_pinv` is not defined will result in using the least-squares +# function provided in torch: :func:`torch.linalg.lstsq`. # %% # 4. Warping detailed explanation # ***************************************************************************** -# This tutorial uses the class :class:`spyrit.core.time.AffineDeformationField` +# This tutorial uses the class :class:`spyrit.core.warp.AffineDeformationField` # to simulate the movement of a still image. This class is a subclass of -# :class:`spyrit.core.time.DeformationField`, which can be used to deform an +# :class:`spyrit.core.warp.DeformationField`, which can be used to deform an # image in a more general manner. This is particularly useful for experimental # setups where the deformation is estimated from real measurements. # # Here, we provide an example of how to use the class -# :class:`spyrit.core.time.DeformationField`. The class takes one argument: +# :class:`spyrit.core.warp.DeformationField`. The class takes one argument: # the deformation field itself of shape :math:`(n_frames,h,w,2)`, where # :math:`n_frames` is the number of frames, and :math:`h` and :math:`w` are the # height and width of the image. The last dimension represents the 2D # pixel from where to interpolate the new pixel value at the coordinate # :math:`(h,w)`. # -# We will first use an instance of :class:`spyrit.core.time.AffineDeformationField` +# We will first use an instance of :class:`spyrit.core.warp.AffineDeformationField` # to create the deformation field. Then, a separate instance of -# :class:`spyrit.core.time.DeformationField` will be created using the +# :class:`spyrit.core.warp.DeformationField` will be created using the # deformation field from the affine deformation field. -from spyrit.core.time import DeformationField +from spyrit.core.warp import DeformationField # define a rotation function omega = 2 * math.pi # angular velocity @@ -373,7 +407,7 @@ def rot(t): ############################################################################### # Now that the affine deformation field is created, we can access the # deformation field through the attribute :attr:`field`. Its value can then be -# used to create a new instance of :class:`spyrit.core.time.DeformationField`. +# used to create a new instance of :class:`spyrit.core.warp.DeformationField`. # get the deformation field field = aff_field2.field @@ -384,3 +418,10 @@ def rot(t): # def_field and aff_field2 are the same print("def_field and aff_field2 are the same:", (def_field == aff_field2)) + +############################################################################### +# .. rubric:: References for dynamic reconstruction +# +# .. [1] Thomas Maitre, Elie Bretin, L. Mahieu-Williame, Michaël Sdika, Nicolas Ducros. Hybrid single-pixel camera for dynamic hyperspectral imaging. 2023. hal-04310110 + +# .. [2] (MICCAI 2024 paper #883) Thomas Maitre, Elie Bretin, Romain Phan, Nicolas Ducros, Michaël Sdika. Dynamic Single-Pixel Imaging on an Extended Field of View without Warping the Patterns. 2024. hal-04533981