Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Romain dev from clean #203

Merged
merged 20 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/fig/tuto9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 9 additions & 2 deletions spyrit/core/meas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
"""

import warnings
from typing import Union

import math
import torch
import torch.nn as nn

from spyrit.core.time import DeformationField
from spyrit.core.warp import DeformationField
import spyrit.core.torch as spytorch


Expand Down Expand Up @@ -383,6 +382,10 @@ def H_pinv(self, value: torch.tensor) -> None:
value.to(torch.float64), requires_grad=False
)

@H_pinv.deleter
def H_pinv(self) -> None:
del self._param_H_static_pinv

def set_H_pinv(self, rtol: float = None) -> None:
"""Used to set the pseudo inverse of the measurement matrix :math:`H`
using `torch.linalg.pinv`.
Expand Down Expand Up @@ -866,6 +869,10 @@ def H_pinv(self) -> torch.tensor:
"""Dynamic pseudo-inverse H_pinv. Equal to self.H_dyn_pinv."""
return self.H_dyn_pinv

@H_pinv.deleter
def H_pinv(self) -> None:
del self._param_H_dyn_pinv

@property
def H_dyn_pinv(self) -> torch.tensor:
"""Dynamic pseudo-inverse H_pinv."""
Expand Down
2 changes: 1 addition & 1 deletion spyrit/core/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import torch.nn as nn

from spyrit.core.meas import Linear, LinearSplit, HadamSplit # , LinearRowSplit
from spyrit.core.meas import LinearSplit, HadamSplit # , Linear


# ==============================================================================
Expand Down
8 changes: 4 additions & 4 deletions spyrit/core/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@

import math
import torch
import torchvision

# import torchvision
import torch.nn as nn
import numpy as np

import spyrit.core.torch as spytorch
from spyrit.core.meas import Linear, DynamicLinear, HadamSplit
from spyrit.core.time import DeformationField
from spyrit.core.noise import NoNoise
from spyrit.core.prep import DirectPoisson, SplitPoisson

Expand Down Expand Up @@ -48,6 +47,7 @@ def forward(
self,
x: torch.tensor,
meas_op: Union[Linear, DynamicLinear],
**kwargs,
) -> torch.tensor:
r"""Computes pseudo-inverse of measurements.

Expand Down Expand Up @@ -76,7 +76,7 @@ def forward(
>>> print(x.shape)
torch.Size([85, 1024])
"""
return meas_op.pinv(x)
return meas_op.pinv(x, **kwargs)


# =============================================================================
Expand Down
5 changes: 1 addition & 4 deletions spyrit/core/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@
functions, but using pytorch tensors instead of numpy arrays.
"""

import warnings
# import warnings

import torch
import torch.nn as nn
import torchvision
import math
import scipy
import numpy as np


# =============================================================================
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions spyrit/misc/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def imagesc(
cax = plt.axes([0.85, 0.1, 0.075, 0.8])
plt.colorbar(cax=cax, orientation="vertical")

fig.tight_layout()
if show is True:
plt.show()

Expand Down
6 changes: 3 additions & 3 deletions spyrit/test/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from test_core_noise import test_core_noise
from test_core_prep import test_core_prep
from test_core_recon import test_core_recon
from test_core_time import test_core_time
from test_core_warp import test_core_warp


def run_tests():
# order matters ! Please change it if you have failing tests
test_core_meas()
test_core_noise()
test_core_prep()
test_core_time()
test_core_recon() # must be after time
test_core_warp()
test_core_recon() # must be after warp


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion spyrit/test/test_core_meas.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_core_meas():
print("ok")

# Build dynamic measurement matrix
from spyrit.core.time import AffineDeformationField
from spyrit.core.warp import AffineDeformationField

print("\tBuild dynamic measurement matrix... ", end="")
H = torch.rand(400, 2500, dtype=torch.float64)
Expand Down
2 changes: 1 addition & 1 deletion spyrit/test/test_core_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from spyrit.core.meas import HadamSplit, DynamicLinear, DynamicHadamSplit
from spyrit.core.noise import NoNoise
from spyrit.core.prep import SplitPoisson
from spyrit.core.time import AffineDeformationField
from spyrit.core.warp import AffineDeformationField


def test_core_recon():
Expand Down
14 changes: 7 additions & 7 deletions spyrit/test/test_core_time.py → spyrit/test/test_core_warp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Test for module time.py
Test for module warp.py
Author: Romain Phan
"""

Expand All @@ -9,14 +9,14 @@
from test_helpers import *


def test_core_time():
def test_core_warp():

print("\n*** Testing time.py ***")
print("\n*** Testing warp.py ***")

# =========================================================================
## DeformationField
print("DeformationField")
from spyrit.core.time import DeformationField
from spyrit.core.warp import DeformationField

# constructor
print("\tconstructor... ", end="")
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_core_time():
# =========================================================================
## AffineDeformationField
print("AffineDeformationField")
from spyrit.core.time import AffineDeformationField
from spyrit.core.warp import AffineDeformationField

# constructor
print("\tconstructor... ", end="")
Expand Down Expand Up @@ -146,10 +146,10 @@ def f(t):
# print("ok")

# =========================================================================
print("All tests passed for time.py")
print("All tests passed for warp.py")
print("==============================")
return True


if __name__ == "__main__":
test_core_time()
test_core_warp()
30 changes: 13 additions & 17 deletions tutorial/README.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,25 @@ these tutorials can be found on `this page`_ of the Spyrit GitHub.
Below is a diagram of the entire image processing pipeline. Each tutorial
focuseson a specific part of the pipeline.

* :ref:`Tutorial 1 <sphx_glr_gallery_tuto_01_acquisition_operators.py>` focuses
on the measurement operators, with or without noise
* :ref:`Tutorial 1 <sphx_glr_gallery_tuto_01_acquisition_operators.py>` focuses on the measurement operators, with or without noise

* :ref:`Tutorial 2 <sphx_glr_gallery_tuto_02_pseudoinverse_linear.py>` explains
the pseudo-inverse reconstruction process from the (possibly noisy)
measurements
* :ref:`Tutorial 2 <sphx_glr_gallery_tuto_02_pseudoinverse_linear.py>` explains the pseudo-inverse reconstruction process from the (possibly noisy) measurements

* :ref:`Tutorial 3 <sphx_glr_gallery_tuto_03_pseudoinverse_cnn_linear.py>` uses
a CNN to denoise the image if necessary
* :ref:`Tutorial 3 <sphx_glr_gallery_tuto_03_pseudoinverse_cnn_linear.py>` uses a CNN to denoise the image if necessary

* :ref:`Tutorial 4 <sphx_glr_gallery_tuto_04_train_pseudoinverse_cnn_linear.py>`
is used to train the CNN introduced in Tutorial 3
* :ref:`Tutorial 4 <sphx_glr_gallery_tuto_04_train_pseudoinverse_cnn_linear.py>` is used to train the CNN introduced in Tutorial 3

* :ref:`Tutorial 5 <sphx_glr_gallery_tuto_05_acquisition_split_measurements.py>`
introduces a new type of measurement operator ('split') that simulates positive
and negative measurements
* :ref:`Tutorial 5 <sphx_glr_gallery_tuto_05_acquisition_split_measurements.py>` introduces a new type of measurement operator ('split') that simulates positive and negative measurements

* :ref:`Tutorial 6 <sphx_glr_gallery_tuto_06_dcnet_split_measurements.py>` uses
a Denoised Completion Network with a trainable image denoiser to improve the
results obtained in Tutorial 5
* :ref:`Tutorial 6 <sphx_glr_gallery_tuto_06_dcnet_split_measurements.py>` uses a Denoised Completion Network with a trainable image denoiser to improve the results obtained in Tutorial 5

* Explore :ref:`Bonus Tutorial <sphx_glr_gallery_tuto_bonus_advanced_methods_colab.py>`
if you want to go deeper into Spyrit's capabilities
* :ref:`Tutorial 7 <sphx_glr_gallery_tuto_07_drunet_split_measurements.py>` shows how to perform image reconstruction using a pretrained plug-and-play denoising network.

* :ref:`Tutorial 8 <sphx_glr_gallery_tuto_08_lpgd_split_measurements.py>` shows how to perform image reconstruction using a learnt proximal gradient descent (AVAILABLE SOON).

* :ref:`Tutorial 9 <sphx_glr_gallery_tuto_09_dynamic.py>` explains motion simulation from an image, dynamic measurements and reconstruction.

* Explore :ref:`Bonus Tutorial <sphx_glr_gallery_tuto_bonus_advanced_methods_colab.py>` if you want to go deeper into Spyrit's capabilities


.. image:: ../fig/full.png
Expand Down
61 changes: 56 additions & 5 deletions tutorial/tuto_02_pseudoinverse_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,20 +178,71 @@
# -----------------------------------------------------------------------------

###############################################################################
# There are two ways to perform the pseudo inverse reconstruction from the
# measurements :attr:`y`. The first consists of explicitly computing the
# pseudo inverse of the measurement matrix :attr:`H` and applying it to the
# measurements. The second computes a least-squares solution using :func:`torch.linalg.lstsq`
# to compute the pseudo inverse solution.
# The choice is made automatically: if the measurement operator has a pseudo-inverse
# already computed, it is used; otherwise, the least-squares solution is used.
#
# .. note::
# Generally, the second method is preferred because it is faster and more
# numerically stable. However, if you will use the pseudo inverse multiple
# times, it becomes more efficient to compute it explicitly.
#
# First way: explicit computation of the pseudo inverse
# We can use the :class:`spyrit.core.recon.PseudoInverse` class to perform the
# pseudo inverse reconstruction from the measurements :attr:`y`
# pseudo inverse reconstruction from the measurements :attr:`y`.

from spyrit.core.recon import PseudoInverse

# Pseudo-inverse reconstruction operator
recon_op = PseudoInverse()

# Reconstruction
x_rec = recon_op(y, meas_op)
x_rec1 = recon_op(y, meas_op) # equivalent to: meas_op.pinv(y)

###############################################################################
# Second way: calling pinv method from the Linear operator
# The code is very similar to the previous case, but we need to make sure the
# measurement operator has no pseudo-inverse computed. We can also specify
# regularization parameters for the least-squares solution when calling
# `recon_op`.

print(f"Pseudo-inverse computed: {hasattr(meas_op, 'H_pinv')}")
temp = meas_op.H_pinv # save the pseudo-inverse
del meas_op.H_pinv # delete the pseudo-inverse
print(f"Pseudo-inverse computed: {hasattr(meas_op, 'H_pinv')}")

# Reconstruction
x_rec2 = recon_op(y, meas_op, reg="L1", eta=1e-6)

# restore the pseudo-inverse
meas_op.H_pinv = temp

##############################################################################
# .. note::
# This choice is also offered for dynamic measurement operators which are
# explained in :ref:`Tutorial 9 <sphx_glr_gallery_tuto_09_dynamic.py>`.

# plot side by side
import matplotlib.pyplot as plt
from spyrit.misc.disp import add_colorbar

x_plot1 = x_rec1.squeeze().view(h, h).cpu().numpy()
x_plot2 = x_rec2.squeeze().view(h, h).cpu().numpy()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

im1 = ax1.imshow(x_plot1, cmap="gray")
ax1.set_title("Explicit pseudo-inverse reconstruction")
add_colorbar(im1, "right", size="20%")

im2 = ax2.imshow(x_plot2, cmap="gray")
ax2.set_title("Least-squares pseudo-inverse reconstruction")
add_colorbar(im2, "right", size="20%")

# plot
x_plot = x_rec.squeeze().view(h, h).cpu().numpy()
imagesc(x_plot, "Pseudoinverse reconstruction (no noise)", title_fontsize=20)

# %%
# PinvNet Network
Expand Down
Loading