Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Romain dev : Major fixes #186

Merged
merged 16 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions docs/source/_static/css/sg_README.css
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.sphx-glr-thumbnails {
width: 100%;
margin: 0px 0px 20px 0px;
margin: 0px 0px 0px 0px;

/* align thumbnails on a grid */
justify-content: space-between;
Expand All @@ -13,18 +13,16 @@
}
.sphx-glr-thumbcontainer {
width: 100% !important;
height: 200px !important;
margin: 20px !important;
min-height: 210px !important;
margin: 0px !important;
}
.sphx-glr-thumbcontainer .figure {
min-width: 300px !important;
min-width: 100px !important;
height: 100px !important;
}
.sphx-glr-thumbcontainer img {
display: inline !important;
max-height: 112px !important;
min-width: 200px !important;
max-width: 300px !important;
object-fit: cover !important;
max-height: 150px !important;
min-width: 300px !important;
}
.sphx-glr-thumbcontainer a.internal {
padding: 10px 10px 0 !important;
}
208 changes: 135 additions & 73 deletions spyrit/core/meas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""
Measurement operators, static and dynamic.

There are six classes contained in this module, each representing a different
type of measurement operator. Three of them are static, i.e. they are used to
simulate measurements of still images, and three are dynamic, i.e. they are used
to simulate measurements of moving objects, represented as a sequence of images.
"""

import warnings

import torch
Expand All @@ -7,6 +16,7 @@

from spyrit.misc.walsh_hadamard import walsh2_torch, walsh2_matrix
from spyrit.misc.sampling import Permutation_Matrix, sort_by_significance
from spyrit.core.time import DeformationField, AffineDeformationField


# =============================================================================
Expand Down Expand Up @@ -99,6 +109,66 @@ def get_H(self) -> torch.tensor:
"""
return self.H.data

def get_H_pinv(self) -> torch.tensor:
r"""Returns the pseudo inverse of the measurement matrix :math:`H`.

Shape:
Output: :math:`(N, M)`

Example:
>>> H1 = np.random.random([400, 1600])
>>> meas_op = Linear(H1, True)
>>> H2 = meas_op.get_H_pinv()
>>> print(H2.shape)
torch.Size([1600, 400])
"""
try:
return self.H_pinv.data
except AttributeError as e:
if "has no attribute 'H_pinv'" in str(e):
raise AttributeError(
"The pseudo inverse has not been initialized. Please set it using self.set_H_pinv()."
)
else:
raise e

def set_H_pinv(self, reg: float = 1e-15, pinv: torch.tensor = None) -> None:
r"""
Stores in self.H_pinv the pseudo inverse of the measurement matrix :math:`H`.

If :attr:`pinv` is given, it is directly stored as the pseudo inverse.
The validity of the pseudo inverse is not checked. If :attr:`pinv` is
:obj:`False`, the pseudo inverse is computed from the existing
measurement matrix :math:`H` with regularization parameter :attr:`reg`.

Args:
:attr:`reg` (float, optional): Cutoff for small singular values.

:attr:`H_pinv` (torch.tensor, optional): If given, the tensor is
directly stored as the pseudo inverse. No checks are performed.
Otherwise, the pseudo inverse is computed from the existing
measurement matrix :math:`H`.

.. note:
Only one of :math:`H_pinv` and :math:`reg` should be given. If both
are given, :math:`H_pinv` is used and :math:`reg` is ignored.

Shape:
:attr:`H_pinv`: :math:`(N, M)`, where :math:`N` is the number of
pixels in the image and :math:`M` the number of measurements.

Example:
>>> H1 = torch.rand([400, 1600])
>>> H2 = torch.linalg.pinv(H1)
>>> meas_op = Linear(H1)
>>> meas_op.set_H_pinv(H2)
"""
if pinv is not None:
H_pinv = pinv.type(torch.FloatTensor) # to float32
else:
H_pinv = torch.linalg.pinv(self.get_H(), rcond=reg)
self.H_pinv = nn.Parameter(H_pinv, requires_grad=False)

def forward(self, x: torch.tensor) -> torch.tensor:
r"""
Simulates the measurement of a motion picture.
Expand Down Expand Up @@ -144,11 +214,11 @@ def forward(self, x: torch.tensor) -> torch.tensor:

def __str__(self):
s_begin = f"{self.__class__.__name__}(\n "
s_fill = "\n ".join([f"({k}): {v}" for k, v in self.__attributeslist__()])
s_fill = "\n ".join([f"({k}): {v}" for k, v in self._attributeslist()])
s_end = "\n )"
return s_begin + s_fill + s_end

def __attributeslist__(self):
def _attributeslist(self):
return [("Image pixels", self.N), ("H", self.H.shape)]


Expand Down Expand Up @@ -349,8 +419,8 @@ def forward_H(self, x: torch.tensor) -> torch.tensor:
"""
return super().forward(x)

def __attributeslist__(self):
return super().__attributeslist__() + [("P", self.P.shape)]
def _attributeslist(self):
return super()._attributeslist() + [("P", self.P.shape)]


# =============================================================================
Expand Down Expand Up @@ -444,9 +514,7 @@ class DynamicHadamSplit(DynamicLinearSplit):

def __init__(self, M: int, h: int, Ord: np.ndarray):
F = walsh2_matrix(h) # full matrix
H = sort_by_significance(F, Ord, "rows", False)[
:M, :
] # much faster than previously
H = sort_by_significance(F, Ord, "rows", False)[:M, :] # much faster
w = h # we assume a square image

super().__init__(torch.from_numpy(H))
Expand All @@ -461,7 +529,7 @@ def __init__(self, M: int, h: int, Ord: np.ndarray):
#######################################################################
Perm = Permutation_Matrix(Ord)
Perm = torch.from_numpy(Perm).float() # float32
self.Perm = nn.Parameter(Perm, requires_grad=False)
self.Perm = nn.Parameter(Perm.T, requires_grad=False)

def get_Perm(self) -> torch.tensor:
warnings.warn(
Expand Down Expand Up @@ -568,66 +636,6 @@ def get_H_T(self) -> torch.tensor:
"""
return self.H.T

def get_H_pinv(self) -> torch.tensor:
r"""Returns the pseudo inverse of the measurement matrix :math:`H`.

Shape:
Output: :math:`(N, M)`

Example:
>>> H1 = np.random.random([400, 1600])
>>> meas_op = Linear(H1, True)
>>> H2 = meas_op.get_H_pinv()
>>> print(H2.shape)
torch.Size([1600, 400])
"""
try:
return self.H_pinv.data
except AttributeError as e:
if "has no attribute 'H_pinv'" in str(e):
raise AttributeError(
"The pseudo inverse has not been initialized. Please set it using self.set_H_pinv()."
)
else:
raise e

def set_H_pinv(self, reg: float = 1e-15, pinv: torch.tensor = None) -> None:
r"""
Stores in self.H_pinv the pseudo inverse of the measurement matrix :math:`H`.

If :attr:`pinv` is given, it is directly stored as the pseudo inverse.
The validity of the pseudo inverse is not checked. If :attr:`pinv` is
:obj:`False`, the pseudo inverse is computed from the existing
measurement matrix :math:`H` with regularization parameter :attr:`reg`.

Args:
:attr:`reg` (float, optional): Cutoff for small singular values.

:attr:`H_pinv` (torch.tensor, optional): If given, the tensor is
directly stored as the pseudo inverse. No checks are performed.
Otherwise, the pseudo inverse is computed from the existing
measurement matrix :math:`H`.

.. note:
Only one of :math:`H_pinv` and :math:`reg` should be given. If both
are given, :math:`H_pinv` is used and :math:`reg` is ignored.

Shape:
:attr:`H_pinv`: :math:`(N, M)`, where :math:`N` is the number of
pixels in the image and :math:`M` the number of measurements.

Example:
>>> H1 = torch.rand([400, 1600])
>>> H2 = torch.linalg.pinv(H1)
>>> meas_op = Linear(H1)
>>> meas_op.set_H_pinv(H2)
"""
if pinv is not None:
H_pinv = pinv.type(torch.FloatTensor) # to float32
else:
H_pinv = torch.linalg.pinv(self.get_H(), rcond=reg)
self.H_pinv = nn.Parameter(H_pinv, requires_grad=False)

def forward(self, x: torch.tensor) -> torch.tensor:
r"""Applies linear transform to incoming images: :math:`y = Hx`.

Expand Down Expand Up @@ -702,8 +710,8 @@ def pinv(self, x: torch.tensor) -> torch.tensor:
# Pmat.transpose()*f
return torch.matmul(x, self.get_H_pinv().T)

def __attributeslist__(self):
return super().__attributeslist__() + [
def _attributeslist(self):
return super()._attributeslist() + [
("H_pinv", self.H_pinv.shape if hasattr(self, "H_pinv") else None)
]

Expand Down Expand Up @@ -971,7 +979,8 @@ def inverse(self, x: torch.tensor) -> torch.tensor:
# todo: check walsh2_S_fold_torch to speed up
b, N = x.shape

x = sort_by_significance(x, self.Ord, "cols", True) # new way
# False because self.Perm is already permuted vvvvv
x = sort_by_significance(x, self.Ord, "cols", False) # new way
# x = x @ self.Perm.T # old way

x = x.view(b, 1, self.h, self.w)
Expand All @@ -980,5 +989,58 @@ def inverse(self, x: torch.tensor) -> torch.tensor:
x = 1 / self.N * walsh2_torch(x)
return x.view(b, N)

def __attributeslist__(self):
return super().__attributeslist__() + [("Perm", self.Ord.shape)]
def _attributeslist(self):
return super()._attributeslist() + [("Perm", self.Ord.shape)]


# =============================================================================
# Functions
# =============================================================================


def set_dyn_pinv(
meas_op: DynamicLinear,
motion: DeformationField,
interp_mode: str = "bilinear",
# regularizer: str=None,
reg: float = 1e-15,
) -> None:

# get full image shape, defined by motion
Nx_img, Ny_img = motion.Nx, motion.Ny
n_frames = motion.n_frames

# get measurement matrix shape
Nx_meas = meas_op.w
Ny_meas = meas_op.h
n_meas = meas_op.M

if Nx_meas * Ny_meas != meas_op.N:
raise ValueError(
f"The image size in the measurement operator is not a square. "
+ "Please assign self.h and self.w manually."
)

# get the measurement matrix
H = meas_op.get_H().view(n_meas, 1, Nx_meas, Ny_meas)

# if the image is larger than the measurement matrix, we need to pad the
# measurement matrix with zeros
if (Nx_img > Nx_meas) or (Ny_img > Ny_meas):
pad_left = (Ny_img - Ny_meas) // 2
pad_right = Ny_img - Ny_meas - pad_left
pad_top = (Nx_img - Nx_meas) // 2
pad_bottom = Nx_img - Nx_meas - pad_top

pad = nn.ConstantPad2d((pad_left, pad_right, pad_top, pad_bottom), 0)
H = pad(H)

H_dyn_physical = motion(
H, 0, meas_op.M, mode=interp_mode
) #########################
# is it the reciprocal motion we are looking for ?

# find pseudo inverse according to regularizer
# first no regularizer
H_pinv = torch.linalg.pinv(H_dyn_physical, rcond=reg)
meas_op.set_H_pinv(pinv=H_pinv)
6 changes: 5 additions & 1 deletion spyrit/core/nnet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# ==================================================================================
"""
Neural network models for image denoising.
"""

# from __future__ import print_function, division
import torch
import torch.nn as nn
from collections import OrderedDict
import copy


# =============================================================================
Expand Down
Loading
Loading