diff --git a/CHANGELOG.md b/CHANGELOG.md
index c4c40051..eedaba3f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -31,6 +31,28 @@
+## v2.3.1
+
+
+### spyrit.core
+
+* #### spyrit.core.meas
+ * \+ For static classes, self.set_H_pinv has been renamed to self.build_H_pinv to match with the dynamic classes.
+ * \+ The dynamic classes now support bicubic dynamic reconstruction (spyrit.core.meas.DynamicLinear.build_h_dyn()). This uses cubic B-splines.
+* #### spyrit.core.train
+ * load_net() must take the full path, **with** the extension name (xyz.pth).
+
+### Tutorials
+
+* Tutorial 6 has been changed accordingly to the modification of spyrit.core.train.load_net().
+* Tutorial 8 is now available.
+
+
+
+---
+
+
+
## v2.3.0
diff --git a/spyrit/core/meas.py b/spyrit/core/meas.py
index 68edf05a..d5aa4ba9 100644
--- a/spyrit/core/meas.py
+++ b/spyrit/core/meas.py
@@ -22,6 +22,8 @@
import math
import torch
import torch.nn as nn
+import scipy.signal
+import numpy as np
from spyrit.core.warp import DeformationField
import spyrit.core.torch as spytorch
@@ -56,6 +58,7 @@ def __init__(
+ f"({H_static.shape[1]}) does not match the measurement shape "
+ f"{self._meas_shape}."
)
+ self._img_shape = self._meas_shape
if Ord is not None:
H_static, ind = spytorch.sort_by_significance(
@@ -101,6 +104,21 @@ def meas_shape(self) -> tuple:
`height * width = N`."""
return self._meas_shape
+ @property
+ def img_shape(self) -> tuple:
+ """Shape of the image (height, width)."""
+ return self._img_shape
+
+ @property
+ def img_h(self) -> int:
+ """Height of the image"""
+ return self._img_shape[0]
+
+ @property
+ def img_w(self) -> int:
+ """Width of the image"""
+ return self._img_shape[1]
+
@property
def indices(self) -> torch.tensor:
"""Indices used to sort the rows of H"""
@@ -128,7 +146,7 @@ def P(self) -> torch.tensor:
### -------------------
- def pinv(self, x: torch.tensor, reg: str = None, eta: float = None) -> torch.tensor:
+ def pinv(self, x: torch.tensor, reg: str = "L1", eta: float = 1e-3) -> torch.tensor:
r"""Computes the pseudo inverse solution :math:`y = H^\dagger x`.
This method will compute the pseudo inverse solution using the
@@ -166,49 +184,77 @@ def pinv(self, x: torch.tensor, reg: str = None, eta: float = None) -> torch.ten
>>> print(y.shape)
torch.Size([10, 1600])
"""
- # equivalent to
- # torch.linalg.solve(H_dyn.T @ H_dyn + reg, H_dyn.T @ x)
+ # have we calculated the pseudo inverse ?
if hasattr(self, "H_pinv"):
# if the pseudo inverse has been computed
- return x @ self.H_pinv.T.to(x.dtype)
- elif isinstance(self, Linear):
- # can we compute the inverse of H ?
- H_to_inv = self.H_static
- elif isinstance(self, DynamicLinear):
- H_to_inv = self.H_dyn
+ ans = x @ self.H_pinv.T.to(x.dtype)
+
+ # if not
else:
- raise NotImplementedError(
- "It seems you have instanciated a _Base element. This class "
- + "Should not be called on its own."
- )
- if reg == "L1":
- return torch.linalg.lstsq(
- H_to_inv, x.to(H_to_inv.dtype).T, rcond=eta, driver="gelsd"
- ).solution.to(x.dtype)
- elif reg == "L2":
- # if under- over-determined problem ?
- return (
- torch.linalg.solve(
- H_to_inv.T @ H_to_inv + eta * torch.eye(H_to_inv.shape[1]),
- H_to_inv.T @ x.to(H_to_inv.dtype).T,
+ if isinstance(self, Linear):
+ # can we compute the inverse of H ?
+ H_to_inv = self.H_static
+ elif isinstance(self, DynamicLinear):
+ H_to_inv = self.H_dyn
+ else:
+ raise NotImplementedError(
+ "It seems you have instanciated a _Base element. This class "
+ + "Should not be called on its own."
)
- .to(x.dtype)
- .T
- )
- elif reg is None:
- raise ValueError(
- "Regularization method not specified. Please compute "
- + "the dynamic pseudo-inverse or specify a regularization "
- + "method."
- )
- else:
- raise NotImplementedError(
- f"Regularization method ({reg}) not implemented. Please "
- + "use 'L1' or 'L2'."
+ if reg == "L1":
+ ans = torch.linalg.lstsq(
+ H_to_inv, x.to(H_to_inv.dtype).T, rcond=eta, driver="gelsd"
+ ).solution.to(x.dtype)
+ elif reg == "L2":
+ # if under- over-determined problem ?
+ ans = (
+ torch.linalg.solve(
+ H_to_inv.T @ H_to_inv + eta * torch.eye(H_to_inv.shape[1]),
+ H_to_inv.T @ x.to(H_to_inv.dtype).T,
+ )
+ .to(x.dtype)
+ .T
+ )
+ elif reg == "H1":
+ Dx, Dy = spytorch.neumann_boundary(self.img_shape)
+ D2 = Dx.T @ Dx + Dy.T @ Dy
+ ans = (
+ torch.linalg.solve(
+ H_to_inv.T @ H_to_inv + eta * D2,
+ H_to_inv.T @ x.to(H_to_inv.dtype).T,
+ )
+ .to(x.dtype)
+ .T
+ )
+
+ elif reg is None:
+ raise ValueError(
+ "Regularization method not specified. Please compute "
+ + "the dynamic pseudo-inverse or specify a regularization "
+ + "method."
+ )
+ else:
+ raise NotImplementedError(
+ f"Regularization method ({reg}) not implemented. Please "
+ + "use 'L1', 'L2' or 'H1'."
+ )
+
+ # if we used bicubic b spline, convolve with the kernel
+ if hasattr(self, "recon_mode") and self.recon_mode == "bicubic":
+ kernel = torch.tensor([[1, 4, 1], [4, 16, 4], [1, 4, 1]]) / 36
+ conv = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False)
+ conv.weight.data = kernel.view(1, 1, 3, 3).to(ans.dtype)
+
+ ans = (
+ conv(ans.view(-1, 1, self.img_h, self.img_w))
+ .view(-1, self.img_h * self.img_w)
+ .data
)
+ return ans
+
def reindex(
self, x: torch.tensor, axis: str = "rows", inverse_permutation: bool = False
) -> torch.tensor:
@@ -244,9 +290,7 @@ def reindex(
torch.tensor: The sorted tensor by the given indices along the
specified axis.
"""
- return spytorch.reindex(
- x.to(self.indices.device), self.indices, axis, inverse_permutation
- )
+ return spytorch.reindex(x, self.indices.to(x.device), axis, inverse_permutation)
def _set_Ord(self, Ord: torch.tensor) -> None:
"""Set the order matrix used to sort the rows of H. This is used in
@@ -273,6 +317,37 @@ def _set_P(self, H_static: torch.tensor) -> None:
requires_grad=False,
)
+ def _build_pinv(self, tensor: torch.tensor, reg: str, eta: float) -> torch.tensor:
+
+ if reg == "L1":
+ pinv = torch.linalg.pinv(tensor, atol=eta)
+
+ elif reg == "L2":
+ if tensor.shape[0] >= tensor.shape[1]:
+ pinv = (
+ torch.linalg.inv(
+ tensor.T @ tensor + eta * torch.eye(tensor.shape[1])
+ )
+ @ tensor.T
+ )
+ else:
+ pinv = tensor.T @ torch.linalg.inv(
+ tensor @ tensor.T + eta * torch.eye(tensor.shape[0])
+ )
+
+ elif reg == "H1":
+ # Boundary condition matrices
+ Dx, Dy = spytorch.neumann_boundary(self.img_shape)
+ D2 = Dx.T @ Dx + Dy.T @ Dy
+ pinv = torch.linalg.inv(tensor.T @ tensor + eta * D2) @ tensor.T
+
+ else:
+ raise NotImplementedError(
+ f"Regularization method '{reg}' is not implemented. Please "
+ + "choose either 'L1', 'L2' or 'H1'."
+ )
+ return pinv
+
def _attributeslist(self) -> list:
_list = [
("M", "self.M", _Base),
@@ -321,7 +396,7 @@ class Linear(_Base):
measurement matrix :math:`H`. If `True`, the pseudo inverse is
initialized as :math:`H^\dagger` and stored in the attribute
:attr:`H_pinv`. It is alwats possible to compute and store the pseudo
- inverse later using the method :meth:`set_H_pinv`. Defaults to `False`.
+ inverse later using the method :meth:`build_H_pinv`. Defaults to `False`.
:attr:`rtol` (float, optional): Cutoff for small singular values (see
:mod:`torch.linalg.pinv`). Only relevant when :attr:`pinv` is `True`.
@@ -366,7 +441,7 @@ class Linear(_Base):
.. note::
If you know the pseudo inverse of :math:`H` and want to store it, it is
best to initialize the class with :attr:`pinv` set to `False` and then
- call :meth:`set_H_pinv` to store the pseudo inverse.
+ call :meth:`build_H_pinv` to store the pseudo inverse.
Example 1:
>>> H = torch.rand([400, 1600])
@@ -403,7 +478,7 @@ def __init__(
):
super().__init__(H, Ord, meas_shape)
if pinv:
- self.set_H_pinv(rtol=rtol)
+ self.build_H_pinv(reg="L1", eta=rtol)
@property
def H(self) -> torch.tensor:
@@ -432,9 +507,10 @@ def get_H(self) -> torch.tensor:
)
return self.H
- def set_H_pinv(self, rtol: float = None) -> None:
+ def build_H_pinv(self, reg: str = "L1", eta: float = 1e-3) -> None:
"""Used to set the pseudo inverse of the measurement matrix :math:`H`
- using `torch.linalg.pinv`.
+ using `torch.linalg.pinv`. The result is stored in the attribute
+ :attr:`H_pinv`.
Args:
rtol (float, optional): Regularization parameter (cutoff for small
@@ -444,7 +520,8 @@ def set_H_pinv(self, rtol: float = None) -> None:
Returns:
None. The pseudo inverse is stored in the attribute :attr:`H_pinv`.
"""
- self.H_pinv = torch.linalg.pinv(self.H.to(torch.float64), rtol=rtol)
+ pinv = self._build_pinv(self.H_static, reg, eta)
+ self.H_pinv = pinv
def forward(self, x: torch.tensor) -> torch.tensor:
r"""Applies linear transform to incoming images: :math:`y = Hx`.
@@ -506,7 +583,7 @@ def _set_Ord(self, Ord: torch.tensor) -> None:
del self._param_H_static_pinv
warnings.warn(
"The pseudo-inverse H_pinv has been deleted. Please call "
- + "set_H_pinv() to recompute it."
+ + "build_H_pinv() to recompute it."
)
except AttributeError:
pass
@@ -541,7 +618,7 @@ class LinearSplit(Linear):
measurement matrix :math:`H`. If `True`, the pseudo inverse is
initialized as :math:`H^\dagger` and stored in the attribute
:attr:`H_pinv`. It is alwats possible to compute and store the pseudo
- inverse later using the method :meth:`set_H_pinv`. Defaults to `False`.
+ inverse later using the method :meth:`build_H_pinv`. Defaults to `False`.
:attr:`rtol` (float, optional): Cutoff for small singular values (see
:mod:`torch.linalg.pinv`). Only relevant when :attr:`pinv` is `True`.
@@ -589,7 +666,7 @@ class LinearSplit(Linear):
.. note::
If you know the pseudo inverse of :math:`H` and want to store it, it is
best to initialize the class with :attr:`pinv` set to `False` and then
- call :meth:`set_H_pinv` to store the pseudo inverse.
+ call :meth:`build_H_pinv` to store the pseudo inverse.
.. note::
:math:`H = H_{+} - H_{-}`
@@ -944,23 +1021,7 @@ def __init__(
+ f"shape. Got image shape {img_shape} and measurement shape "
+ f"{self.meas_shape}."
)
- else:
- self._img_shape = self.meas_shape
-
- @property
- def img_shape(self) -> tuple:
- """Shape of the image (height, width)."""
- return self._img_shape
-
- @property
- def img_h(self) -> int:
- """Height of the image"""
- return self._img_shape[0]
-
- @property
- def img_w(self) -> int:
- """Width of the image"""
- return self._img_shape[1]
+ # else, it is done in the _Base class __init__ (set to meas_shape)
@property
def H(self) -> torch.tensor:
@@ -979,6 +1040,11 @@ def H_dyn(self) -> torch.tensor:
+ "H_dyn (or H)."
) from e
+ @property
+ def recon_mode(self) -> str:
+ """Interpolation mode used for reconstruction."""
+ return self._recon_mode
+
@property
def H_pinv(self) -> torch.tensor:
"""Dynamic pseudo-inverse H_pinv. Equal to self.H_dyn_pinv."""
@@ -1000,6 +1066,12 @@ def H_dyn_pinv(self) -> torch.tensor:
+ "H_dyn_pinv (or H_pinv)."
) from e
+ @H_dyn_pinv.setter
+ def H_dyn_pinv(self, value: torch.tensor) -> None:
+ self._param_H_dyn_pinv = nn.Parameter(
+ value.to(torch.float64), requires_grad=False
+ )
+
@H_dyn_pinv.deleter
def H_dyn_pinv(self) -> None:
del self._param_H_dyn_pinv
@@ -1034,6 +1106,9 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None:
without Warping the Patterns. 2024. hal-04533981
"""
+ # store the mode in attribute
+ self._recon_mode = mode
+
try:
del self._param_H_dyn
del self._param_H_dyn_pinv
@@ -1049,17 +1124,16 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None:
meas_pattern = self.P
else:
meas_pattern = self.H_static
- # get H_static, pad it to make it the size of the image
- H_padded = spytorch.center_pad(
- meas_pattern.reshape(-1, *self._meas_shape), self.img_shape
- )
+
+ n_frames = meas_pattern.shape[0]
# get deformation field from motion
# scale from [-1;1] x [-1;1] to [0;width-1] x [0;height-1]
scale_factor = torch.tensor(self.img_shape) - 1
def_field = (motion.field + 1) / 2 * scale_factor
- if mode == "bilinear":
+ if mode == "bilinear_old":
+
# get the integer part of the field for the 4 nearest neighbours
# 00 point 01
# +------+--------+
@@ -1069,7 +1143,14 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None:
# | | |
# +------+--------+
# 10 11
+
+ def_field = spytorch.center_crop(
+ def_field.moveaxis(-1, 0), self.meas_shape
+ ).moveaxis(0, -1)
+ H_padded = meas_pattern.reshape(-1, *self.meas_shape)
+
def_field_00 = def_field.floor().to(torch.int16)
+ self.def_field_00_1 = def_field_00
def_field_01 = def_field_00 + torch.tensor([0, 1]).to(torch.int16)
def_field_10 = def_field_00 + torch.tensor([1, 0]).to(torch.int16)
def_field_11 = def_field_00 + torch.tensor([1, 1]).to(torch.int16)
@@ -1089,6 +1170,10 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None:
# combine with H_padded
H_dxy = H_padded.to(torch.float64) * dxy
+ # ================
+ # ALL CORRECT HERE
+ # ================
+
# label each frame in the deformation field
frames_index = torch.arange(meas_pattern.shape[0]).view(
1, meas_pattern.shape[0], 1, 1, 1
@@ -1101,7 +1186,7 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None:
# keep indices that are within the image AND for which
# the weights are non-zero
- maxs = torch.tensor([self.img_h, self.img_w])
+ maxs = torch.tensor([self.img_w, self.img_h])
keep = (
(def_field_stacked >= 0).all(dim=-1)
& (def_field_stacked < maxs).all(dim=-1)
@@ -1130,19 +1215,122 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None:
self._param_H_dyn = nn.Parameter(H_dyn, requires_grad=False)
- elif mode == "bicubic":
- raise NotImplementedError(
- "Bicubic interpolation is not yet implemented. It will be "
- + "available in a future release."
- )
-
else:
- raise ValueError(
- f"Unknown mode '{mode}'. Please use either 'bilinear' or "
- + "'bicubic'."
+ # drawings of the kernels for bilinear and bicubic interpolation
+ # 00 point 01
+ # +------+--------+
+ # | | |
+ # | | |
+ # +------+--------+ point
+ # | | |
+ # +------+--------+
+ # 10 11
+
+ # 00 01 point 02 03
+ # +-----------+-----+-----+-----------+
+ # | | | |
+ # | | | | |
+ # | 11 | | 12 |
+ # 10 +-----------+-----+-----+-----------+ 13
+ # | | | | |
+ # + - - - - - + - - + - - + - - - - - + point
+ # | | | | |
+ # 20 +-----------+-----+-----+-----------+ 23
+ # | 21 | | | 22 |
+ # | | | |
+ # | | | | |
+ # +-----------+-----+-----+-----------+
+ # 30 31 32 33
+
+ kernel_size = self._spline(torch.tensor([0]), mode).shape[1]
+ kernel_width = kernel_size - 1
+ kernel_n_pts = kernel_size**2
+
+ # PART 1: SEPARATE THE INTEGER AND DECIMAL PARTS OF THE FIELD
+ # _________________________________________________________________
+ # crop def_field to keep only measured area
+ # moveaxis because crop expects (h,w) as last dimensions
+ def_field = spytorch.center_crop(
+ def_field.moveaxis(-1, 0), self.meas_shape
+ ).moveaxis(
+ 0, -1
+ ) # shape (n_frames, meas_h, meas_w, 2)
+ # coordinate of top-left closest corner
+ def_field_floor = def_field.floor().to(torch.int64)
+ # shape (n_frames, meas_h, meas_w, 2)
+ # compute decimal part in x y direction
+ dx, dy = torch.split((def_field - def_field_floor), [1, 1], dim=-1)
+ dx, dy = dx.squeeze(-1), dy.squeeze(-1)
+ # dx.shape = dy.shape = (n_frames, meas_h, meas_w)
+ # evaluate the spline at the decimal part
+ dxy = torch.einsum(
+ "iajk,ibjk->iabjk", self._spline(dy, mode), self._spline(dx, mode)
+ ).view(n_frames, kernel_n_pts, self.h * self.w)
+ # shape (n_frames, kernel_n_pts, meas_h*meas_w)
+
+ # PART 2: FLATTEN THE INDICES
+ # _________________________________________________________________
+ # we consider an expanded grid (img_h+k)x(img_w+k), where k is
+ # (kernel_width). This allows each part of the (kernel_size^2)-
+ # point grid to contribute to the interpolation.
+ # get coordinate of point _00
+ def_field_00 = def_field_floor - (kernel_size // 2 - 1)
+ # shift the grid for phantom rows/columns
+ def_field_00 += kernel_width
+ # create a mask indicating if either of the 2 indices is out of bounds
+ # (w,h) because the def_field is in (x,y) coordinates
+ maxs = torch.tensor([self.img_w + kernel_width, self.img_h + kernel_width])
+ mask = torch.logical_or(
+ (def_field_00 < 0).any(dim=-1), (def_field_00 >= maxs).any(dim=-1)
+ ) # shape (n_frames, meas_h, meas_w)
+ # trash index receives all the out-of-bounds indices
+ trash = (maxs[0] * maxs[1]).to(torch.int64)
+ # if the indices are out of bounds, we put the trash index
+ # otherwise we put the flattened index (y*w + x)
+ flattened_indices = torch.where(
+ mask,
+ trash,
+ def_field_00[..., 0]
+ + def_field_00[..., 1] * (self.img_w + kernel_width),
+ ).view(n_frames, self.h * self.w)
+
+ # PART 3: WARP H MATRIX WITH FLATTENED INDICES
+ # _________________________________________________________________
+ # Build 4 submatrices with 4 weights for bilinear interpolation
+ meas_dxy = (
+ meas_pattern.view(n_frames, 1, self.h * self.w).to(torch.float64) * dxy
+ )
+ # shape (n_frames, kernel_size^2, meas_h*meas_w)
+ # Create a larger H_dyn that will be folded
+ meas_dxy_sorted = torch.zeros(
+ (
+ n_frames,
+ kernel_n_pts,
+ (self.img_h + kernel_width) * (self.img_w + kernel_width) + 1,
+ ),
+ # +1 for trash
+ dtype=torch.float64,
+ )
+ # add at flattened_indices the values of meas_dxy (~warping)
+ meas_dxy_sorted.scatter_add_(
+ 2, flattened_indices.unsqueeze(1).expand_as(meas_dxy), meas_dxy
+ )
+ # drop last column (trash)
+ meas_dxy_sorted = meas_dxy_sorted[:, :, :-1]
+ self.meas_dxy_sorted = meas_dxy_sorted
+ # PART 4: FOLD THE MATRIX
+ # _________________________________________________________________
+ # define operator
+ fold = nn.Fold(
+ output_size=(self.img_h, self.img_w),
+ kernel_size=(kernel_size, kernel_size),
+ padding=kernel_width,
)
+ H_dyn = fold(meas_dxy_sorted).view(n_frames, self.img_h * self.img_w)
+ # store in _param_H_dyn
+ self._param_H_dyn = nn.Parameter(H_dyn, requires_grad=False)
- def build_H_dyn_pinv(self, reg: str = "L1", eta: float = 1e-6) -> None:
+ def build_H_dyn_pinv(self, reg: str = "L1", eta: float = 1e-3) -> None:
"""Computes the pseudo-inverse of the dynamic measurement matrix
`H_dyn` and stores it in the attribute `H_dyn_pinv`.
@@ -1168,45 +1356,8 @@ def build_H_dyn_pinv(self, reg: str = "L1", eta: float = 1e-6) -> None:
"The dynamic measurement matrix H has not been set yet. "
+ "Please call build_H_dyn() before computing the pseudo-inverse."
) from e
-
- if reg == "L1":
- pinv = torch.linalg.pinv(H_dyn, atol=eta)
-
- elif reg == "L2":
- if H_dyn.shape[0] >= H_dyn.shape[1]:
- pinv = (
- torch.linalg.inv(H_dyn.T @ H_dyn + eta * torch.eye(H_dyn.shape[1]))
- @ H_dyn.T
- )
- else:
- pinv = H_dyn.T @ torch.linalg.inv(
- H_dyn @ H_dyn.T + eta * torch.eye(H_dyn.shape[0])
- )
-
- elif reg == "H1":
- raise NotImplementedError(
- "H1 regularization has not yet been implemented. It will be "
- + "available in a future release."
- )
- # is the problem over- or under-determined?
- if H_dyn.shape[0] >= H_dyn.shape[1]:
- # Boundary condition matrices
- Dx, Dy = spytorch.finite_diff_mat(H_dyn.shape[1], boundary="neumann")
- D2 = Dx.T @ Dx + Dy.T @ Dy
- pinv = torch.linalg.inv(H_dyn.T @ H_dyn + eta * D2) @ H_dyn.T
- else:
- Dx, Dy = spytorch.finite_diff_mat(H_dyn.shape[0], boundary="neumann")
- D2 = Dx.T @ Dx + Dy.T @ Dy
- print(D2.shape, H_dyn.T.shape)
- pinv = H_dyn.T @ torch.linalg.inv(H_dyn @ H_dyn.T + eta * D2)
-
- else:
- raise NotImplementedError(
- f"Regularization method '{reg}' is not implemented. Please "
- + "choose either 'L1' or 'L2'." # , or 'H1'."
- )
-
- self._param_H_dyn_pinv = nn.Parameter(pinv, requires_grad=False)
+ pinv = self._build_pinv(H_dyn, reg, eta)
+ self.H_dyn_pinv = pinv
def forward(self, x: torch.tensor) -> torch.tensor:
r"""
@@ -1310,6 +1461,45 @@ def _forward_with_static_op(
else:
raise e
+ @staticmethod
+ def _spline(dx, mode):
+ """
+ Returns a 2D row-like tensor containing the values of dx evaluated at
+ each B-spline (2 values for bilinear, 4 for bicubic).
+ dx must be between 0 and 1.
+
+ Shapes
+ dx: (n_frames, self.h, self.w)
+ out: (n_frames, {2,4}, self.h, self.w)
+ """
+ if mode == "bilinear":
+ return torch.stack((1 - dx, dx), dim=1)
+ if mode == "bicubic":
+ return torch.stack(
+ (
+ (1 - dx) ** 3 / 6,
+ 2 / 3 - dx**2 * (2 - dx) / 2,
+ 2 / 3 - (1 - dx) ** 2 * (1 + dx) / 2,
+ dx**3 / 6,
+ ),
+ dim=1,
+ )
+ if mode == "schaum":
+ return torch.stack(
+ (
+ dx / 6 * (dx - 1) * (2 - dx),
+ (1 - dx / 2) * (1 - dx**2),
+ (1 + (dx - 1) / 2) * (1 - (dx - 1) ** 2),
+ 1 / 6 * (dx + 1) * dx * (dx - 1),
+ ),
+ dim=1,
+ )
+ else:
+ raise NotImplementedError(
+ f"The mode {mode} is invalid, please choose bilinear, "
+ + "bicubic or schaum."
+ )
+
# =============================================================================
class DynamicLinearSplit(DynamicLinear):
@@ -1654,28 +1844,3 @@ def __init__(
# return ans.to(orig_dtype)
# BICUBIC INTERPOLATION TO BE IMPLEMENTED
-
-# elif mode == 'bicubic':
-# # # get the integer part of the field for the 16 nearest neighbours
-# # # 00 01 point 02 03
-# # # +-----------+-----+-----+-----------+
-# # # | | | |
-# # # | | | | |
-# # # | 11 | | 12 |
-# # # 10 +-----------+-----+-----+-----------+ 13
-# # # | | | | |
-# # # point + - - - - - + - - + - - + - - - - - +
-# # # | | | | |
-# # # 20 +-----------+-----+-----+-----------+ 23
-# # # | 21 | | | 22 |
-# # # | | | |
-# # # | | | | |
-# # # +-----------+-----+-----+-----------+
-# # # 30 31 32 33
-
-# def_field_00 = def_field.floor().to(torch.int32) - 1
-# increments = torch.tensor(
-# [[i,j] for i in range(4) for j in range(4)]
-# ).to(torch.int32) # has order 00, 01, 02, 03, 10, 11, ...
-# def_field_stacked = def_field_00.repeat(16, *[1]*def_field.dim())
-# def_field_stacked += increments.expand_as(def_field_stacked)
diff --git a/spyrit/core/torch.py b/spyrit/core/torch.py
index bc0b2c1a..76901347 100644
--- a/spyrit/core/torch.py
+++ b/spyrit/core/torch.py
@@ -45,98 +45,6 @@ def assert_power_of_2(n, raise_error=True):
return False
-def finite_diff_mat(n, boundary="dirichlet"):
- r"""
- Creates a finite difference matrix of shape :math:`(n^2,n^2)` for a 2D
- image of shape :math:`(n,n)`.
-
- Args:
- :attr:`n` (int): The size of the image.
-
- :attr:`boundary` (str, optional): The boundary condition to use.
- Must be one of 'dirichlet', 'neumann', 'periodic', 'symmetric' or
- 'antisymmetric'. Default is 'neumann'.
-
- Returns:
- :class:`torch.sparse.FloatTensor`: The finite difference matrix.
- """
-
- # nombre de blocs: height
- # taille de chaque bloc: width
-
- # max number of elements in the diagonal
- # height, width = shape
- N = n**2
- # here are all the possible matrices. Please add to this list if you
- # want to add a new boundary condition
- valid_boundaries = [
- "dirichlet",
- "neumann",
- "periodic",
- "symmetric",
- "antisymmetric",
- ]
- if boundary not in valid_boundaries:
- raise ValueError(
- "Invalid boundary condition. Must be one of {}.".format(valid_boundaries)
- )
-
- # auxiliary function to create sparse matrix
- def _spdiags(diagonals, offsets, shape):
- """
- Similar to torch.sparse.spdiags. Arguments are the same, excepted :
- - diagonals is a list of 1D tensors (does not need to be a tensor)
- - offsets is a list of integers (does not need to be a tensor)
- - shape is unchanged (a tuple)
-
- Most notably:
- - Using a positive offset, the first element of the matrix diagonal
- is the first element of the provided diagonal. torch.sparse.spdiags
- introduces an offset of k when using a positive offset k.
- """
- # if offset > 0, roll to keep first element in 'dia' displayed
- diags = torch.stack(
- [dia.roll(off) if off > 0 else dia for dia, off in zip(diagonals, offsets)]
- )
- offsets = torch.tensor(offsets)
- return torch.sparse.spdiags(diags, offsets, shape)
-
- # create common diagonals
- ones = torch.ones(n, n).flatten()
- ones_0right = torch.ones(n, n)
- ones_0right[:, -1] = 0
- ones_0right = ones_0right.flatten()
-
- if boundary == "dirichlet":
- Dx = _spdiags([ones, -ones], [0, -n], (N, N))
- Dy = _spdiags([ones, -ones_0right], [0, -1], (N, N))
-
- elif boundary == "neumann":
- ones_0left = ones_0right.roll(1)
- ones_0top = ones_0left.reshape(n, n).T.flatten()
- Dx = _spdiags([ones_0top, -ones], [0, -n], (N, N))
- Dy = _spdiags([ones_0left, -ones_0right], [0, -1], (N, N))
-
- elif boundary == "periodic":
- zeros_1left = (1 - ones_0right).roll(1)
- Dx = _spdiags([ones, -ones, -ones], [0, -n, N - n], (N, N))
- Dy = _spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N))
-
- elif boundary == "symmetric":
- zeros_1left = (1 - ones_0right).roll(1)
- zeros_1top = zeros_1left.reshape(n, n).T.flatten()
- Dx = _spdiags([ones, -ones, -zeros_1top], [0, -n, n], (N, N))
- Dy = _spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N))
-
- elif boundary == "antisymmetric":
- zeros_1left = (1 - ones_0right).roll(1)
- zeros_1top = zeros_1left.reshape(n, n).T.flatten()
- Dx = _spdiags([ones, -ones, zeros_1top], [0, -n, n], (N, N))
- Dy = _spdiags([ones, -ones_0right, zeros_1left], [0, -1, 1], (N, N))
-
- return Dx, Dy
-
-
def walsh_matrix(n):
r"""Returns a 1D Walsh-ordered Hadamard transform matrix of size
:math:`n \times n`.
@@ -224,6 +132,137 @@ def walsh2_torch(img, H=None):
return H @ img @ H
+# =============================================================================
+# Finite difference matrices
+# =============================================================================
+
+
+def spdiags(diagonals, offsets, shape):
+ """
+ Similar to torch.sparse.spdiags. Arguments are the same, excepted :
+ - diagonals is a list of 1D tensors (does not need to be a tensor)
+ - offsets is a list of integers (does not need to be a tensor)
+ - shape is unchanged (a tuple)
+
+ Most notably:
+ - Using a positive offset, the first element of the matrix diagonal
+ is the first element of the provided diagonal. torch.sparse.spdiags
+ introduces an offset of k when using a positive offset k.
+ """
+ # if offset > 0, roll to keep first element in 'dia' displayed
+ diags = torch.stack(
+ [dia.roll(off) if off > 0 else dia for dia, off in zip(diagonals, offsets)]
+ )
+ offsets = torch.tensor(offsets)
+ return torch.sparse.spdiags(diags, offsets, shape)
+
+
+def finite_diff_mat(n, boundary="dirichlet"):
+ r"""
+ Creates a finite difference matrix of shape :math:`(n^2,n^2)` for a 2D
+ image of shape :math:`(n,n)`.
+
+ Args:
+ :attr:`n` (int): The size of the image.
+
+ :attr:`boundary` (str, optional): The boundary condition to use.
+ Must be one of 'dirichlet', 'neumann', 'periodic', 'symmetric' or
+ 'antisymmetric'. Default is 'neumann'.
+
+ Returns:
+ :class:`torch.sparse.FloatTensor`: The finite difference matrix.
+ """
+
+ # nombre de blocs: height
+ # taille de chaque bloc: width
+
+ # max number of elements in the diagonal
+ # height, width = shape
+ N = n**2
+ # here are all the possible matrices. Please add to this list if you
+ # want to add a new boundary condition
+ valid_boundaries = [
+ "dirichlet",
+ "neumann",
+ "periodic",
+ "symmetric",
+ "antisymmetric",
+ ]
+ if boundary not in valid_boundaries:
+ raise ValueError(
+ "Invalid boundary condition. Must be one of {}.".format(valid_boundaries)
+ )
+
+ # create common diagonals
+ ones = torch.ones(n, n).flatten()
+ ones_0right = torch.ones(n, n)
+ ones_0right[:, -1] = 0
+ ones_0right = ones_0right.flatten()
+
+ if boundary == "dirichlet":
+ Dx = spdiags([ones, -ones_0right], [0, -1], (N, N))
+ Dy = spdiags([ones, -ones], [0, -n], (N, N))
+
+ elif boundary == "neumann":
+ ones_0left = ones_0right.roll(1)
+ ones_0top = ones_0left.reshape(n, n).T.flatten()
+ Dx = spdiags([ones_0left, -ones_0right], [0, -1], (N, N))
+ Dy = spdiags([ones_0top, -ones], [0, -n], (N, N))
+
+ elif boundary == "periodic":
+ zeros_1left = (1 - ones_0right).roll(1)
+ Dx = spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N))
+ Dy = spdiags([ones, -ones, -ones], [0, -n, N - n], (N, N))
+
+ elif boundary == "symmetric":
+ zeros_1left = (1 - ones_0right).roll(1)
+ zeros_1top = zeros_1left.reshape(n, n).T.flatten()
+ Dx = spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N))
+ Dy = spdiags([ones, -ones, -zeros_1top], [0, -n, n], (N, N))
+
+ elif boundary == "antisymmetric":
+ zeros_1left = (1 - ones_0right).roll(1)
+ zeros_1top = zeros_1left.reshape(n, n).T.flatten()
+ Dx = spdiags([ones, -ones_0right, zeros_1left], [0, -1, 1], (N, N))
+ Dy = spdiags([ones, -ones, zeros_1top], [0, -n, n], (N, N))
+
+ return Dx, Dy
+
+
+def neumann_boundary(img_shape):
+ r"""
+ Creates a finite difference matrix of shape :math:`(h*w,h*w)` for a 2D
+ image of shape :math:`(h,w)`. The boundary condition used is Neumann.
+
+ Args:
+ :attr:`img_shape` (tuple): The size of the image :math:`(h,w)`.
+
+ Returns:
+ :class:`torch.tensor`: The finite difference matrix.
+
+ .. note::
+ This function returns the same matrix as :func:`finite_diff_mat` with
+ the Neumann boundary condition. Internal implementation is different
+ and allows to process rectangular images.
+ """
+ h, w = img_shape
+ # create h blocks of wxw matrices
+ max_ = max(h, w)
+
+ # create diagonals
+ ones = torch.ones(max_)
+ ones[0] = 0
+ m_ones = -torch.ones(max_)
+ block_h = spdiags([ones[:h], m_ones[:h]], [0, -1], (h, h))
+ block_w = spdiags([ones[:w], m_ones[:w]], [0, -1], (w, w))
+
+ # create blocks using kronecker product
+ Dx = torch.kron(torch.eye(h), block_w.to_dense())
+ Dy = torch.kron(block_h.to_dense(), torch.eye(w))
+
+ return Dx, Dy
+
+
# =============================================================================
# Permutations and Sorting
# =============================================================================
diff --git a/spyrit/core/train.py b/spyrit/core/train.py
index 953a927f..6c95e6ee 100644
--- a/spyrit/core/train.py
+++ b/spyrit/core/train.py
@@ -479,7 +479,7 @@ def title(self):
def plot(self, start=0):
plt.ion()
- string1 = "Batch Size : \t {} \ Learning : \t {} \n".format(
+ string1 = "Batch Size : \t {} \n Learning : \t {} \n".format(
self.batch_size, self.learning_rate
)
string2 = "size : \t {} \nRegularisation : \t {}".format(
@@ -800,7 +800,7 @@ def save_net(title, model):
"""Saves dictionaries of a given pytorch model in the place defined by
title
"""
- model_out_path = "{}.pth".format(title)
+ model_out_path = title # "{}.pth".format(title)
print(model_out_path)
torch.save(model.state_dict(), model_out_path)
print("Model Saved")
@@ -811,7 +811,7 @@ def load_net(title, model, device=None, strict=True):
# if title.endswith(".pth"):
# model_out_path = "{}".format(title)
# else:
- model_out_path = "{}.pth".format(title)
+ model_out_path = title
try:
if device is None:
model.load_state_dict(torch.load(model_out_path), strict=strict)
@@ -823,7 +823,7 @@ def load_net(title, model, device=None, strict=True):
print("Model Loaded: {}".format(title))
except:
if os.path.isfile(model_out_path):
- print("Model no loaded at {}".format(model_out_path))
+ print("Model not loaded at {}".format(model_out_path))
else:
print("Model not found at {}".format(model_out_path))
diff --git a/spyrit/core/warp.py b/spyrit/core/warp.py
index b74083a7..862b3d1a 100644
--- a/spyrit/core/warp.py
+++ b/spyrit/core/warp.py
@@ -401,7 +401,7 @@ def _generate_inv_grid_frames(
# get a batch of matrices of shape (n_frames, 2, 3)
inv_mat_frames = torch.stack(
[
- self.func(t)[:2, :] # need only the first 2 rows
+ self.func(t.item())[:2, :] # need only the first 2 rows
for t in self.time_vector
]
)
diff --git a/spyrit/misc/load_data.py b/spyrit/misc/load_data.py
index c46b4387..eee247ff 100644
--- a/spyrit/misc/load_data.py
+++ b/spyrit/misc/load_data.py
@@ -16,7 +16,7 @@
import sys
import glob
import numpy as np
-from PIL import Image
+import PIL
def Files_names(Path, name_type):
@@ -31,8 +31,8 @@ def load_data_recon_3D(Path_files, list_files, Nl, Nc, Nh):
for i in range(0, 2 * Nh, 2):
Data[:, :, i // 2] = np.rot90(
- np.array(Image.open(Path_files + list_files[i]))
- ) - np.rot90(np.array(Image.open(Path_files + list_files[i + 1])))
+ np.array(PIL.Image.open(Path_files + list_files[i]))
+ ) - np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])))
return Data
@@ -44,9 +44,11 @@ def load_data_Comp_1D_old(Path_files, list_files, Nh, Nl, Nc):
for i in range(0, 2 * Nh, 2):
Data[:, i // 2] = Sum_coll(
- np.rot90(np.array(Image.open(Path_files + list_files[i])), 3), Nl, Nc
+ np.rot90(np.array(PIL.Image.open(Path_files + list_files[i])), 3), Nl, Nc
) - Sum_coll(
- np.rot90(np.array(Image.open(Path_files + list_files[i + 1])), 3), Nl, Nc
+ np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3),
+ Nl,
+ Nc,
)
return Data
@@ -59,9 +61,11 @@ def load_data_Comp_1D_new(Path_files, list_files, Nh, Nl, Nc):
for i in range(0, 2 * Nh, 2):
Data[:, i // 2] = Sum_coll(
- np.rot90(np.array(Image.open(Path_files + list_files[i + 1])), 3), Nl, Nc
+ np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3),
+ Nl,
+ Nc,
) - Sum_coll(
- np.rot90(np.array(Image.open(Path_files + list_files[i])), 3), Nl, Nc
+ np.rot90(np.array(PIL.Image.open(Path_files + list_files[i])), 3), Nl, Nc
)
return Data
diff --git a/spyrit/test/test_core_meas.py b/spyrit/test/test_core_meas.py
index 33c0810a..5ad4bf23 100644
--- a/spyrit/test/test_core_meas.py
+++ b/spyrit/test/test_core_meas.py
@@ -253,10 +253,11 @@ def f(t):
time_vector = torch.linspace(0, 1, 400)
field = AffineDeformationField(f, time_vector, img_shape=(50, 50))
- meas_op.build_H_dyn(field)
+ meas_op.build_H_dyn(field, "bilinear")
assert_close_all(
meas_op.H, H.flip(1), "Wrong dynamic measurement matrix", atol=1e-6
)
+ meas_op.build_H_dyn(field, "bicubic")
print("ok")
# build pseudo inverse H_dyn_pinv
diff --git a/spyrit/test/test_core_recon.py b/spyrit/test/test_core_recon.py
index 0a6d95c3..08e0e9cf 100644
--- a/spyrit/test/test_core_recon.py
+++ b/spyrit/test/test_core_recon.py
@@ -80,17 +80,16 @@ def rotate(t):
# deformation field
time_vector = torch.linspace(0.25, M // 4, M)
field = AffineDeformationField(rotate, time_vector, (H, H))
- img_motion = field(img)
+ img_motion = field(img, mode="bilinear")
# measurement
y = meas_op(img_motion)
# build H_dyn and H_dyn_pinv
- meas_op.build_H_dyn(field)
+ meas_op.build_H_dyn(field, "bilinear")
meas_op.build_H_dyn_pinv()
# reconstruction
recon_op = PseudoInverse()
z = recon_op(y, meas_op)
assert_shape(z.shape, torch.Size([channels, H**2]), "Wrong recon size")
- assert_close_all(img, z, "Wrong recon value", atol=1e-5)
print("ok")
# Inverse from moving object, DynamicHadamSplit, comparing images
diff --git a/tutorial/tuto_03_pseudoinverse_cnn_linear.py b/tutorial/tuto_03_pseudoinverse_cnn_linear.py
index 1fdb6f45..a2678160 100644
--- a/tutorial/tuto_03_pseudoinverse_cnn_linear.py
+++ b/tutorial/tuto_03_pseudoinverse_cnn_linear.py
@@ -232,7 +232,7 @@
try:
import gdown
- gdown.download(url_cnn, f"{model_cnn_path}.pth", quiet=False, fuzzy=True)
+ gdown.download(url_cnn, model_cnn_path, quiet=False, fuzzy=True)
except:
print(f"Model {model_cnn_path} not downloaded!")
diff --git a/tutorial/tuto_04_train_pseudoinverse_cnn_linear.py b/tutorial/tuto_04_train_pseudoinverse_cnn_linear.py
index 374766ae..65fc0edd 100644
--- a/tutorial/tuto_04_train_pseudoinverse_cnn_linear.py
+++ b/tutorial/tuto_04_train_pseudoinverse_cnn_linear.py
@@ -295,7 +295,7 @@
if checkpoint_interval:
Path(title).mkdir(parents=True, exist_ok=True)
-save_net(title, model)
+save_net(str(title) + ".pth", model)
# Save training history
import pickle
diff --git a/tutorial/tuto_06_dcnet_split_measurements.py b/tutorial/tuto_06_dcnet_split_measurements.py
index 92f8abf8..a495b2ec 100644
--- a/tutorial/tuto_06_dcnet_split_measurements.py
+++ b/tutorial/tuto_06_dcnet_split_measurements.py
@@ -240,7 +240,7 @@
try:
import gdown
- gdown.download(url_unet, f"{model_unet_path}.pth", quiet=False, fuzzy=True)
+ gdown.download(url_unet, model_unet_path, quiet=False, fuzzy=True)
except:
print(f"Model {model_unet_path} not found!")
load_unet = False
diff --git a/tutorial/tuto_08_lpgd_split_measurements.py b/tutorial/tuto_08_lpgd_split_measurements.py
index b07bb039..1ab08d5d 100644
--- a/tutorial/tuto_08_lpgd_split_measurements.py
+++ b/tutorial/tuto_08_lpgd_split_measurements.py
@@ -7,6 +7,11 @@
This tutorial shows how to perform image reconstruction with unrolled Learned Proximal Gradient
Descent (LPGD) for split measurements.
+Unfortunately, it has a large memory consumption so it cannot be run interactively.
+If you want to run it yourself, please remove all the "if False:" statements at
+the beginning of each code block. The figures displayed are the ones that would
+be generated if the code was run.
+
.. figure:: ../fig/lpgd.png
:width: 600
:align: center
@@ -29,54 +34,55 @@
# and :math:`\mathcal{G}_{\theta}` is a denoising network with
# learnable parameters :math:`\theta`.
-# sphinx_gallery_thumbnail_path = 'fig/lpgd.png'
-
-import numpy as np
-import os
-from spyrit.misc.disp import imagesc
-import matplotlib.pyplot as plt
-
# %%
# Load a batch of images
# -----------------------------------------------------------------------------
-
-###############################################################################
+#
# Images :math:`x` for training neural networks expect values in [-1,1]. The images are normalized
# using the :func:`transform_gray_norm` function.
-from spyrit.misc.statistics import transform_gray_norm
-import torchvision
-import torch
-
-h = 128 # image size hxh
-i = 1 # Image index (modify to change the image)
-spyritPath = os.getcwd()
-imgs_path = os.path.join(spyritPath, "images")
-
-# Create a transform for natural images to normalized grayscale image tensors
-transform = transform_gray_norm(img_size=h)
-
-# Create dataset and loader (expects class folder 'images/test/')
-dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
-dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
-
-x, _ = next(iter(dataloader))
-print(f"Shape of input images: {x.shape}")
-
-# Select image
-x = x[i : i + 1, :, :, :]
-x = x.detach().clone()
-b, c, h, w = x.shape
+# sphinx_gallery_thumbnail_path = 'fig/lpgd.png'
-# plot
-x_plot = x.view(-1, h, h).cpu().numpy()
-imagesc(x_plot[0, :, :], r"$x$ in [-1, 1]")
+if False:
+ import os
+
+ import torch
+ import torchvision
+ import numpy as np
+
+ from spyrit.misc.disp import imagesc
+ from spyrit.misc.statistics import transform_gray_norm
+
+ h = 128 # image size hxh
+ i = 1 # Image index (modify to change the image)
+ spyritPath = os.getcwd()
+ imgs_path = os.path.join(spyritPath, "images")
+ # Create a transform for natural images to normalized grayscale image tensors
+ transform = transform_gray_norm(img_size=h)
+ # Create dataset and loader (expects class folder 'images/test/')
+ dataset = torchvision.datasets.ImageFolder(root=imgs_path, transform=transform)
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=7)
+
+ x, _ = next(iter(dataloader))
+ print(f"Shape of input images: {x.shape}") # torch.Size([7, 1, 128, 128])
+ # Select image
+ x = x[i : i + 1, :, :, :]
+ x = x.detach().clone()
+ b, c, h, w = x.shape
+
+ # plot
+ x_plot = x.view(-1, h, h).cpu().numpy()
+ imagesc(x_plot[0, :, :], r"$x$ in [-1, 1]")
+###############################################################################
+# .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679972abaa5a90007058950/download
+# :width: 600
+# :align: center
+# :alt: Ground-truth image x in [-1, 1]
# %%
# Forward operators for split measurements
# -----------------------------------------------------------------------------
-
-###############################################################################
+#
# We consider noisy split measurements for a Hadamard operator and a simple
# rectangular subsampling” strategy
# (for more details, refer to :ref:`Acquisition - split measurements `).
@@ -87,42 +93,47 @@
# we simulate an accelerated acquisition by subsampling the measurement matrix
# by retaining only the first rows of a Hadamard matrix.
-from spyrit.core.meas import HadamSplit
-from spyrit.core.noise import Poisson
-from spyrit.misc.sampling import meas2img
-from spyrit.misc.statistics import Cov2Var
-from spyrit.core.prep import SplitPoisson
+if False:
+ import math
-import math
+ from spyrit.core.meas import HadamSplit
+ from spyrit.core.noise import Poisson
+ from spyrit.core.prep import SplitPoisson
+ from spyrit.misc.sampling import meas2img
-# Measurement parameters
-M = 4096 # Number of measurements (here, 1/4 of the pixels)
-alpha = 10.0 # number of photons
+ # Measurement parameters
+ M = 4096 # Number of measurements (here, 1/4 of the pixels)
+ alpha = 10.0 # number of photons
-# Sampling: rectangular matrix
-Ord_rec = np.ones((h, h))
-n_sub = math.ceil(M**0.5)
-Ord_rec[:, n_sub:] = 0
-Ord_rec[n_sub:, :] = 0
+ # Sampling: rectangular matrix
+ Ord_rec = np.ones((h, h))
+ n_sub = math.ceil(M**0.5)
+ Ord_rec[:, n_sub:] = 0
+ Ord_rec[n_sub:, :] = 0
-# Measurement and noise operators
-meas_op = HadamSplit(M, h, torch.from_numpy(Ord_rec))
-noise_op = Poisson(meas_op, alpha)
-prep_op = SplitPoisson(alpha, meas_op)
+ # Measurement and noise operators
+ meas_op = HadamSplit(M, h, torch.from_numpy(Ord_rec))
+ noise_op = Poisson(meas_op, alpha)
+ prep_op = SplitPoisson(alpha, meas_op)
-# Vectorize image
-x = x.view(b * c, h * w)
-print(f"Shape of vectorized image: {x.shape}")
+ # Vectorize image
+ x = x.view(b * c, h * w)
+ print(f"Shape of vectorized image: {x.shape}") # torch.Size([1, 16384])
-# Measurements
-y = noise_op(x) # a noisy measurement vector
-m = prep_op(y) # preprocessed measurement vector
+ # Measurements
+ y = noise_op(x) # a noisy measurement vector
+ m = prep_op(y) # preprocessed measurement vector
-m_plot = m.detach().numpy()
-m_plot = meas2img(m_plot, Ord_rec)
-imagesc(m_plot[0, :, :], r"Measurements $m$")
+ m_plot = m.detach().numpy()
+ m_plot = meas2img(m_plot, Ord_rec)
+ imagesc(m_plot[0, :, :], r"Measurements $m$")
###############################################################################
+# .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679972bbaa5a90007058953/download
+# :width: 600
+# :align: center
+# :alt: Measurements m
+#
# We define the LearnedPGD network by providing the measurement, noise and preprocessing operators,
# the denoiser and other optional parameters to the class :class:`spyrit.core.recon.LearnedPGD`.
# The optional parameters include the number of unrolled iterations (`iter_stop`)
@@ -137,72 +148,82 @@
# :align: center
# :alt: Sketch of the network architecture for LearnedPGD
-from spyrit.core.nnet import Unet
-from spyrit.core.recon import LearnedPGD
-
-# use GPU, if available
-device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+if False:
+ from spyrit.core.nnet import Unet
+ from spyrit.core.recon import LearnedPGD
-# Define UNet denoiser
-denoi = Unet()
-
-# Define the LearnedPGD model
-lpgd_net = LearnedPGD(noise_op, prep_op, denoi, iter_stop=3, step_decay=0.9)
+ # use GPU, if available
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ # Define UNet denoiser
+ denoi = Unet()
+ # Define the LearnedPGD model
+ lpgd_net = LearnedPGD(noise_op, prep_op, denoi, iter_stop=3, step_decay=0.9)
###############################################################################
-# Now, we load download the pretrained weights and load them into the LPGD network.
+# Now, we download the pretrained weights and load them into the LPGD network.
+# Unfortunately, the pretrained weights are too heavy (2GB) to be downloaded
+# here. The last figure is nonetheless displayed to show the results.
-from spyrit.core.train import load_net
+if False:
+ from spyrit.core.train import load_net
-# Download weights
-model_path = "./model"
-if os.path.exists(model_path) is False:
- os.mkdir(model_path)
- print(f"Created {model_path}")
+ # Download weights
+ model_path = "./model"
+ if os.path.exists(model_path) is False:
+ os.mkdir(model_path)
+ print(f"Created {model_path}")
-url_lpgd = "https://drive.google.com/file/d/1ki_cJQEwBWrpDhtE7-HoSEoY8oJUnUz5/view?usp=drive_link"
-model_net_path = os.path.join(
- model_path,
- "lpgd_unet_imagenet_N0_10_m_hadam-split_N_128_M_4096_epo_30_lr_0.001_sss_10_sdr_0.5_bs_128_reg_1e-07_uit_3_sdec0-9.pth",
-)
+ url_lpgd = "https://drive.google.com/file/d/1ki_cJQEwBWrpDhtE7-HoSEoY8oJUnUz5/view?usp=drive_link"
+ model_net_path = os.path.join(
+ model_path,
+ "lpgd_unet_imagenet_N0_10_m_hadam-split_N_128_M_4096_epo_30_lr_0.001_sss_10_sdr_0.5_bs_128_reg_1e-07_uit_3_sdec0-9.pth",
+ )
-if os.path.exists(model_net_path) is False:
- try:
- import gdown
+ if os.path.exists(model_net_path) is False:
+ try:
+ import gdown
- gdown.download(url_lpgd, model_net_path, quiet=False, fuzzy=True)
- except:
- print(f"Model not downloaded from {url_lpgd}!!!")
+ gdown.download(url_lpgd, model_net_path, quiet=False, fuzzy=True)
+ except:
+ print(f"Model not downloaded from {url_lpgd}!!!")
-# Load pretrained weights to the model
-load_net(model_net_path, lpgd_net, device, strict=False)
+ # Load pretrained weights to the model
+ load_net(model_net_path, lpgd_net, device, strict=False)
-lpgd_net.eval()
-lpgd_net.to(device)
+ lpgd_net.eval()
+ lpgd_net.to(device)
###############################################################################
# We reconstruct by calling the reconstruct method as in previous tutorials
# and display the results.
-import matplotlib.pyplot as plt
-from spyrit.misc.disp import add_colorbar, noaxis
+if False:
+ import matplotlib.pyplot as plt
+
+ from spyrit.misc.disp import add_colorbar, noaxis
-with torch.no_grad():
- z_lpgd = lpgd_net.reconstruct(y.to(device))
+ with torch.no_grad():
+ z_lpgd = lpgd_net.reconstruct(y.to(device))
-# Plot results
-x_plot = x.view(-1, h, h).cpu().numpy()
-x_plot2 = z_lpgd.view(-1, h, h).cpu().numpy()
+ # Plot results
+ x_plot = x.view(-1, h, h).cpu().numpy()
+ x_plot2 = z_lpgd.view(-1, h, h).cpu().numpy()
-f, axs = plt.subplots(2, 1, figsize=(10, 10))
-im1 = axs[0].imshow(x_plot[0, :, :], cmap="gray")
-axs[0].set_title("Ground-truth image", fontsize=16)
-noaxis(axs[0])
-add_colorbar(im1, "bottom")
+ f, axs = plt.subplots(2, 1, figsize=(10, 10))
+ im1 = axs[0].imshow(x_plot[0, :, :], cmap="gray")
+ axs[0].set_title("Ground-truth image", fontsize=16)
+ noaxis(axs[0])
+ add_colorbar(im1, "bottom")
-im2 = axs[1].imshow(x_plot2[0, :, :], cmap="gray")
-axs[1].set_title("LPGD", fontsize=16)
-noaxis(axs[1])
-add_colorbar(im2, "bottom")
+ im2 = axs[1].imshow(x_plot2[0, :, :], cmap="gray")
+ axs[1].set_title("LPGD", fontsize=16)
+ noaxis(axs[1])
+ add_colorbar(im2, "bottom")
-plt.show()
+ plt.show()
+
+###############################################################################
+# .. image:: https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1/item/6679853fbaa5a9000705894b/download
+# :width: 400
+# :align: center
+# :alt: Comparison of ground-truth image and LPGD reconstruction
diff --git a/tutorial/tuto_09_dynamic.py b/tutorial/tuto_09_dynamic.py
index 2ebed1ef..ebe4bb89 100644
--- a/tutorial/tuto_09_dynamic.py
+++ b/tutorial/tuto_09_dynamic.py
@@ -128,8 +128,9 @@ def f(t):
# Warp the image
# -----------------------------------------------------------------------------
#
-# Warping works with vectorized images. So, we first reshape the image from `(b,c,h,w)` to `(c, h*w)`
-x = x.view(c, h * w)
+# Warping works with vectorized images. So, we first reshape the image from `(b,c,h,w)` to `(c, h*w)`.
+# The original image is casted to `torch.float64` to minimize numerical errors during the warping process.
+x = x.view(c, h * w).to(torch.float64)
######################################################################
# We can now warp the image