From ef9bdea5d7cf263886d7459a090c4f185e851bfa Mon Sep 17 00:00:00 2001 From: romainphan Date: Tue, 20 Feb 2024 17:15:57 +0100 Subject: [PATCH 01/11] commented skip_member_handler in conf.py --- docs/source/conf.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 7534d29b..99a65c06 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -116,20 +116,18 @@ # autodoc_mock_imports = "numpy matplotlib mpl_toolkits scipy torch torchvision Pillow opencv-python imutils PyWavelets pywt wget imageio".split() -# exclude all torch.nn.Module members from the documentation -# except forward and __init__ methods +# exclude all torch.nn.Module members (except forward method) from the docs: import torch - def skip_member_handler(app, what, name, obj, skip, options): - if name in [ - "forward", - ]: - return False + always_document = [ # complete this list if needed by adding methods + "forward", # you *always* want to see documented + ] + if name in always_document: + return None if name in dir(torch.nn.Module): return True return None - def setup(app): - app.connect("autodoc-skip-member", skip_member_handler) + app.connect("autodoc-skip-member", skip_member_handler) \ No newline at end of file From 928ea3a8a6b7928c9b78cc5a38135a91ef8bdd46 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:18:04 +0000 Subject: [PATCH 02/11] [pre-commit.ci] Automatic python formatting --- docs/source/conf.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 99a65c06..3517c1e8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -119,15 +119,17 @@ # exclude all torch.nn.Module members (except forward method) from the docs: import torch + def skip_member_handler(app, what, name, obj, skip, options): - always_document = [ # complete this list if needed by adding methods - "forward", # you *always* want to see documented + always_document = [ # complete this list if needed by adding methods + "forward", # you *always* want to see documented ] - if name in always_document: + if name in always_document: return None if name in dir(torch.nn.Module): return True return None + def setup(app): - app.connect("autodoc-skip-member", skip_member_handler) \ No newline at end of file + app.connect("autodoc-skip-member", skip_member_handler) From 1d6c0130b807396d85ecfe924f7d1da179414664 Mon Sep 17 00:00:00 2001 From: romainphan Date: Wed, 21 Feb 2024 16:23:03 +0100 Subject: [PATCH 03/11] proudly a new author on the docs website --- docs/source/conf.py | 4 ++-- docs/source/index.rst | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 3517c1e8..03d5f09d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,9 +19,9 @@ # -- Project information ----------------------------------------------------- project = "spyrit" -copyright = "2021, Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier" +copyright = "2021, Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan" author = ( - "Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier" + "Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan" ) # The full version, including alpha/beta/rc tags diff --git a/docs/source/index.rst b/docs/source/index.rst index e771d8c2..0949b529 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -135,6 +135,6 @@ When using SPyRiT specifically for the denoised completion network, please cite Join the project ================================== -Feel free to contact us by `e-mail ` for any question. Active developers are currently `Nicolas Ducros `_, Thomas Baudier and `Juan Abascal `_. Direct contributions via pull requests (PRs) are welcome. +Feel free to contact us by `e-mail ` for any question. Active developers are currently `Nicolas Ducros `_, Thomas Baudier, `Juan Abascal `_ and Romain Phan. Direct contributions via pull requests (PRs) are welcome. The full list of contributors can be found `here `_. From eb6155dd7e9fe2c387b82e3c9713f259cf805cd1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 21 Feb 2024 15:23:17 +0000 Subject: [PATCH 04/11] [pre-commit.ci] Automatic python formatting --- docs/source/conf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 03d5f09d..86c99181 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -20,9 +20,7 @@ # -- Project information ----------------------------------------------------- project = "spyrit" copyright = "2021, Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan" -author = ( - "Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan" -) +author = "Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan" # The full version, including alpha/beta/rc tags release = "2.1.0" From 3ec0591fa46597658e113141f6cfd93e229d9bcd Mon Sep 17 00:00:00 2001 From: romainphan Date: Thu, 22 Feb 2024 13:27:20 +0100 Subject: [PATCH 05/11] Added dynamic measurement operators / changed inheritance architecture --- spyrit/core/meas.py | 671 +++++++++++++++++++++++++++----------------- 1 file changed, 410 insertions(+), 261 deletions(-) diff --git a/spyrit/core/meas.py b/spyrit/core/meas.py index 54cb49c6..12a1d636 100644 --- a/spyrit/core/meas.py +++ b/spyrit/core/meas.py @@ -2,22 +2,30 @@ import torch import torch.nn as nn import numpy as np -from typing import Union from spyrit.misc.walsh_hadamard import walsh2_torch, walsh2_matrix from spyrit.misc.sampling import Permutation_Matrix -# ================================================================================== -class Linear(nn.Module): - # ================================================================================== +# ============================================================================= +class DynamicLinear(nn.Module): + # ========================================================================= r""" + Simulates the measurement of a moving object using the positive and + negative components of the measurement matrix. + Computes linear measurements from incoming images: :math:`y = Hx`, where :math:`H` is a linear operator (matrix) and :math:`x` is a - vectorized image. + batch of vectorized images representing a motion picture. - The class is constructed from a :math:`M` by :math:`N` matrix :math:`H`, + The class is constructed from a matrix :math:`H` of shape :math:`(M, N)`, where :math:`N` represents the number of pixels in the image and - :math:`M` the number of measurements. + :math:`M` the number of measurements and the number of frames in the + animated object. + + .. warning:: + For each call, there must be **exactly** as many images in :math:`x` as + there are measurements in the linear operator used to initialize the class. + If not, an error will be raised. Args: :attr:`H`: measurement matrix (linear operator) with shape :math:`(M, N)`. @@ -29,70 +37,389 @@ class Linear(nn.Module): singular values, see :mod:`numpy.linal.pinv`). Only relevant when :attr:`pinv` is not `None`. - Attributes: :attr:`H`: The learnable measurement matrix of shape :math:`(M,N)` initialized as :math:`H` - :attr:`H_adjoint`: The learnable adjoint measurement matrix - of shape :math:`(N,M)` initialized as :math:`H^\top` - :attr:`H_pinv` (optional): The learnable adjoint measurement matrix of shape :math:`(N,M)` initialized as :math:`H^\dagger`. Only relevant when :attr:`pinv` is not `None`. - Example: + Example 1: >>> H = np.random.random([400, 1000]) - >>> meas_op = Linear(H) + >>> meas_op = LinearDynamic(H) >>> print(meas_op) - Linear( + LinearDynamic( (H): Linear(in_features=1000, out_features=400, bias=False) - (H_adjoint): Linear(in_features=400, out_features=1000, bias=False) - ) + ) Example 2: >>> H = np.random.random([400, 1000]) - >>> meas_op = Linear(H, True) + >>> meas_op = LinearDynamic(H, True) >>> print(meas_op) - Linear( + LinearDynamic( (H): Linear(in_features=1000, out_features=400, bias=False) - (H_adjoint): Linear(in_features=400, out_features=1000, bias=False) (H_pinv): Linear(in_features=400, out_features=1000, bias=False) ) """ - - def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): + + def __init__(self, H: np.ndarray | torch.tensor, pinv=None, reg: float=1e-15): super().__init__() - # instancier nn.linear + + # nn.Parameter are sent to the device when using .to(device), + # contrary to attributes + H = torch.tensor(H, dtype=torch.float32) + self.H = nn.Parameter(H, requires_grad=False) + self.M = H.shape[0] self.N = H.shape[1] - self.h = int(self.N**0.5) self.w = int(self.N**0.5) if self.h * self.w != self.N: warnings.warn( "N is not a square. Please assign self.h and self.w manually." ) + if pinv is not None: + H_pinv = torch.linalg.pinv(H, rcond=reg) + self.H_pinv = nn.Parameter(H_pinv, requires_grad=False) + else: + print("Pseudo inverse will not be instanciated") + + def get_H(self) -> torch.tensor: + r"""Returns the measurement matrix :math:`H`. + + Shape: + Output: :math:`(M, N)` + + Example: + >>> H1 = np.random.random([400, 1000]) + >>> meas_op = Linear(H1) + >>> H2 = meas_op.get_H() + >>> print('Matrix shape:', H2.shape) + Matrix shape: torch.Size([400, 1000]) + """ + return self.H.data + + def get_H_T(self) -> torch.tensor: + r""" + Returns the transpose of the measurement matrix :math:`H`. + + Shape: + Output: :math:`(N, M)` + + Example: + >>> H1 = np.random.random([400, 1000]) + >>> meas_op = Linear(H1) + >>> H2 = meas_op.get_H_T() + >>> print('Transpose shape:', H2.shape) + Transpose shape: torch.Size([400, 1000]) + """ + return self.H.T + + def get_H_pinv(self) -> torch.tensor: + r"""Returns the pseudo inverse of the measurement matrix :math:`H`. + + Shape: + Output: :math:`(N, M)` + + Example: + >>> H1 = np.random.random([400, 1000]) + >>> meas_op = Linear(H1, True) + >>> H2 = meas_op.get_H_pinv() + >>> print('Pseudo inverse shape:', H2.shape) + Pseudo inverse shape: torch.Size([1000, 400]) + """ + return self.H_pinv.data + + def forward(self, x: torch.tensor) -> torch.tensor: + r""" + Simulates the measurement of a motion picture. + + The output :math:`y` is computed as :math:`y = Hx`, where :math:`H` is + the measurement matrix and :math:`x` is a batch of vectorized (flattened) + images. + + .. warning:: + There must be **exactly** as many images as there are measurements + in the linear operator used to initialize the class, i.e. + `H.shape[-2:] == x.shape[-2:] + + Args: + :math:`x`: Batch of vectorized (flattened) images. + + Shape: + :math:`x`: :math:`(*, M, N)` + :math:`output`: :math:`(*, M)` + + Example: + >>> x = torch.rand([10, 400, 1000], dtype=torch.float) + >>> H = np.random.random([400, 1000]) + >>> meas_op = LinearDynamic(H) + >>> y = meas_op(x) + >>> print(y.shape) + torch.Size([10, 400]) + """ + return torch.einsum('ij,...ij->...i', self.get_H(), x) - self.H = nn.Linear(self.N, self.M, False) - self.H.weight.data = torch.from_numpy(H).float() - # Data must be of type float (or double) rather than the default float64 when creating torch tensor - self.H.weight.requires_grad = False - # adjoint (Remove?) - self.H_adjoint = nn.Linear(self.M, self.N, False) - self.H_adjoint.weight.data = torch.from_numpy(H.transpose()).float() - self.H_adjoint.weight.requires_grad = False +# ============================================================================= +class DynamicLinearSplit(DynamicLinear): + # ========================================================================= + r""" + Used to simulate the measurement of a moving object using the positive and + negative components of the measurement matrix. + + Computes linear measurements from incoming images: :math:`y = Px`, + where :math:`P` is a linear operator (matrix) and :math:`x` is a batch of + vectorized images representing a motion picture. - if pinv is None: - H_pinv = pinv - print("Pseudo inverse will not be instanciated") + The matrix :math:`P` contains only positive values and is obtained by + splitting a measurement matrix :math:`H` such that + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where + :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. - else: - H_pinv = np.linalg.pinv(H, rcond=reg) - self.H_pinv = nn.Linear(self.M, self.N, False) - self.H_pinv.weight.data = torch.from_numpy(H_pinv).float() - self.H_pinv.weight.requires_grad = False + The class is constructed from the :math:`M` by :math:`N` matrix :math:`H`, + where :math:`N` represents the number of pixels in the image and + :math:`M` the number of measurements. + + .. warning:: + For each call, there must be **exactly** as many images in :math:`x` as + there are measurements in the linear operator used to initialize the class. + If not, an error will be raised. + + Args: + :math:`H` (np.ndarray): measurement matrix (linear operator) with + shape :math:`(M, N)`. + + Example: + >>> H = np.array(np.random.random([400,1000])) + >>> meas_op = LinearDynamicSplit(H) + """ + + def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): + # initialize self.H and self.H_pinv + super().__init__(H, pinv, reg) + # initialize self.P = [ H^+ ] + # [ H^- ] + zero = torch.zeros(1) + H_pos = torch.maximum(zero, H) + H_neg = torch.maximum(zero, -H) + # concatenate side by side, then reshape vertically + P = torch.cat([H_pos, H_neg], 1).view(2 * self.M, self.N) + self.P = nn.Parameter(P, requires_grad=False) + + def get_P(self) -> torch.tensor: + r"""Returns the measurement matrix :math:`P`. + + Shape: + Output: :math:`(2M, N)` + + Example: + >>> P = meas_op.get_P() + >>> print('Matrix shape:', P.shape) + Matrix shape: torch.Size([800, 1000]) + """ + return self.P.data + + def forward(self, x: torch.tensor) -> torch.tensor: + r""" + Simulates the measurement of a motion picture using :math:`P`. + + The output :math:`y` is computed as :math:`y = Px`, where :math:`P` is + the measurement matrix and :math:`x` is a batch of vectorized (flattened) + images. + + :math:`P` contains only positive values and is obtained by + splitting a measurement matrix :math:`H` such that + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where + :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + + .. warning:: + There must be **exactly** as many images as there are measurements + in the linear operator used to initialize the class, i.e. + `P.shape[-2:] == x.shape[-2:] + + Args: + :math:`x`: Batch of vectorized (flatten) images. + + Shape: + :math:`P` has a shape of :math:`(2M, N)` where :math:`M` is the + number of measurements as defined by the first dimension of :math:`H` + and :math:`N` is the number of pixels in the image. + + :math:`x`: :math:`(*, 2M, N)` + + :math:`output`: :math:`(*, 2M)` + + Example: + >>> x = torch.rand([10, 400, 1000], dtype=torch.float) + >>> H = np.random.random([400, 1000]) + >>> meas_op = LinearDynamicSplit(H) + >>> y = meas_op(x) + >>> print(y.shape) + torch.Size([10, 800]) + """ + return torch.einsum('ij,...ij->...i', self.get_P(), x) + + def forward_H(self, x: torch.tensor) -> torch.tensor: + r""" + Simulates the measurement of a motion picture using :math:`H`. + + The output :math:`y` is computed as :math:`y = Hx`, where :math:`H` is + the measurement matrix and :math:`x` is a batch of vectorized (flattened) + images. The positive and negative components of the measurement matrix + are **not** used in this method. + + .. warning:: + There must be **exactly** as many images as there are measurements + in the linear operator used to initialize the class, i.e. + `H.shape[-2:] == x.shape[-2:] + + Args: + :math:`x`: Batch of vectorized (flatten) images. + + Shape: + :math:`H` has a shape of :math:`(M, N)` where :math:`M` is the + number of measurements and :math:`N` is the number of pixels in the + image. + + :math:`x`: :math:`(*, M, N)` + + :math:`output`: :math:`(*, M)` + + Example: + >>> x = torch.rand([10, 400, 1000], dtype=torch.float) + >>> H = np.random.random([400, 1000]) + >>> meas_op = LinearDynamicSplit(H) + >>> y = meas_op.forward_H(x) + >>> print(y.shape) + torch.Size([10, 400]) + """ + return super.forward(x) + + +# ============================================================================= +class DynamicHadamSplit(DynamicLinearSplit): + # ========================================================================= + r""" + Simulates the measurement of a moving object using the positive and + negative components of a Hadamard matrix. + + Computes linear measurements from incoming images: :math:`y = Px`, + where :math:`P` is a linear operator (matrix) with positive entries and + :math:`x` is a batch of vectorized images representing a motion picture. + + The class relies on a matrix :math:`H` with + shape :math:`(M,N)` where :math:`N` represents the number of pixels in the + image and :math:`M \le N` the number of measurements. The matrix :math:`P` + is obtained by splitting the matrix :math:`H` such that + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where + :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + + The matrix :math:`H` is obtained by retaining the first :math:`M` rows of + a permuted Hadamard matrix :math:`GF`, where :math:`G` is a + permutation matrix with shape with shape :math:`(M,N)` and :math:`F` is a + "full" Hadamard matrix with shape :math:`(N,N)`. The computation of a + Hadamard transform :math:`Fx` benefits a fast algorithm, as well as the + computation of inverse Hadamard transforms. + + .. warning:: + For each call, there must be **exactly** as many images in :math:`x` as + there are measurements in the linear operator used to initialize the class. + If not, an error will be raised. + + Args: + :attr:`M` (int): Number of measurements + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be square. + + :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h,h)` used to + compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` + (see the :mod:`~spyrit.misc.sampling` submodule) + + .. note:: + The matrix H has shape :math:`(M,N)` with :math:`N = h^2`. + + .. note:: + :math:`H = H_{+} - H_{-}` + + Example: + >>> Ord = np.random.random([32,32]) + >>> meas_op = HadamSplitDynamic(400, 32, Ord) + """ + + def __init__(self, M: int, h: int, Ord: np.ndarray): + F = walsh2_matrix(h) # full matrix + Perm = Permutation_Matrix(Ord) + F = Perm @ F # If Perm is not learnt, could be computed mush faster + H = F[:M, :] + w = h # we assume a square image + + super().__init__(H) + + Perm = torch.tensor(Perm, dtype=torch.float32) + self.Perm = nn.Parameter(Perm, requires_grad=False) + # overwrite self.h and self.w + self.h = h + self.w = w + + +# ============================================================================= +class Linear(DynamicLinear): + # ========================================================================= + r""" + Simulates the measurement of an image using a measurement operator. + + Computes linear measurements from incoming images: :math:`y = Hx`, + where :math:`H` is a linear operator (matrix) and :math:`x` is a + vectorized image or a batch of images. + + The class is constructed from a :math:`M` by :math:`N` matrix :math:`H`, + where :math:`N` represents the number of pixels in the image and + :math:`M` the number of measurements. + + Args: + :attr:`H`: measurement matrix (linear operator) with shape :math:`(M, N)`. + + :attr:`pinv`: Option to have access to pseudo inverse solutions. + Defaults to `None` (the pseudo inverse is not initiliazed). + + :attr:`reg` (optional): Regularization parameter (cutoff for small + singular values, see :mod:`numpy.linal.pinv`). Only relevant when + :attr:`pinv` is not `None`. + + + Attributes: + :attr:`H`: The learnable measurement matrix of shape + :math:`(M,N)` initialized as :math:`H` + + :attr:`H_adjoint`: The learnable adjoint measurement matrix + of shape :math:`(N,M)` initialized as :math:`H^\top` + + :attr:`H_pinv` (optional): The learnable adjoint measurement + matrix of shape :math:`(N,M)` initialized as :math:`H^\dagger`. + Only relevant when :attr:`pinv` is not `None`. + + Example 1: + >>> H = np.random.random([400, 1000]) + >>> meas_op = Linear(H) + >>> print(meas_op) + Linear( + (H): Linear(in_features=1000, out_features=400, bias=False) + ) + + Example 2: + >>> H = np.random.random([400, 1000]) + >>> meas_op = Linear(H, True) + >>> print(meas_op) + Linear( + (H): Linear(in_features=1000, out_features=400, bias=False) + (H_pinv): Linear(in_features=400, out_features=1000, bias=False) + ) + """ + + def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): + super().__init__(H, pinv, reg) def forward(self, x: torch.tensor) -> torch.tensor: r"""Applies linear transform to incoming images: :math:`y = Hx`. @@ -114,9 +441,8 @@ def forward(self, x: torch.tensor) -> torch.tensor: forward: torch.Size([10, 400]) """ - # x.shape[b*c,N] - x = self.H(x) - return x + # left multiplication with transpose is equivalent to right mult + return x @ self.get_H_T() def adjoint(self, x: torch.tensor) -> torch.tensor: r"""Applies adjoint transform to incoming measurements :math:`y = H^{T}x` @@ -135,26 +461,11 @@ def adjoint(self, x: torch.tensor) -> torch.tensor: >>> print('adjoint:', y.shape) adjoint: torch.Size([10, 1000]) """ - # Pmat.transpose()*f - x = self.H_adjoint(x) - return x - - def get_H(self) -> torch.tensor: - r"""Returns the measurement matrix :math:`H`. - - Shape: - Output: :math:`(M, N)` - - Example: - >>> H = meas_op.get_H() - >>> print('get_mat:', H.shape) - get_mat: torch.Size([400, 1000]) - - """ - return self.H.weight.data + # left multiplication is equivalent to right mult with transpose + return x @ self.get_H() def pinv(self, x: torch.tensor) -> torch.tensor: - r"""Computer pseudo inverse solution :math:`y = H^\dagger x` + r"""Computes the pseudo inverse solution :math:`y = H^\dagger x` Args: :math:`x`: batch of measurement vectors. @@ -171,17 +482,19 @@ def pinv(self, x: torch.tensor) -> torch.tensor: adjoint: torch.Size([10, 1000]) """ # Pmat.transpose()*f - x = self.H_pinv(x) - return x + return x @ self.get_H_pinv().T + - -# ================================================================================== -class LinearSplit(Linear): - # ================================================================================== +# ============================================================================= +class LinearSplit(Linear, DynamicLinearSplit): + # ========================================================================= r""" + Simulates the measurement of an image using the computed positive and + negative components of the measurement matrix. + Computes linear measurements from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) and :math:`x` is a - vectorized image. + vectorized image or batch of vectorized images. The matrix :math:`P` contains only positive values and is obtained by splitting a measurement matrix :math:`H` such that @@ -202,27 +515,8 @@ class LinearSplit(Linear): """ def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): - super().__init__(H, pinv, reg) - - # [H^+, H^-] - - even_index = range(0, 2 * self.M, 2) - odd_index = range(1, 2 * self.M, 2) - - H_pos = np.zeros(H.shape) - H_neg = np.zeros(H.shape) - H_pos[H > 0] = H[H > 0] - H_neg[H < 0] = -H[H < 0] - - # pourquoi 2 *M ? - P = np.zeros((2 * self.M, self.N)) - P[even_index, :] = H_pos - P[odd_index, :] = H_neg - - self.P = nn.Linear(self.N, 2 * self.M, False) - self.P.weight.data = torch.from_numpy(P) - self.P.weight.data = self.P.weight.data.float() - self.P.weight.requires_grad = False + # initialize from DynamicLinearSplit __init__ + super(Linear, self).__init__(H, pinv, reg) def forward(self, x: torch.tensor) -> torch.tensor: r"""Applies linear transform to incoming images: :math:`y = Px`. @@ -245,8 +539,7 @@ def forward(self, x: torch.tensor) -> torch.tensor: """ # x.shape[b*c,N] # output shape : [b*c, 2*M] - x = self.P(x) - return x + return x @ self.get_P().T def forward_H(self, x: torch.tensor) -> torch.tensor: r"""Applies linear transform to incoming images: :math:`m = Hx`. @@ -268,18 +561,22 @@ def forward_H(self, x: torch.tensor) -> torch.tensor: output shape: torch.Size([10, 400]) """ - x = self.H(x) - return x - + # call Linear.forward() method + return super(LinearSplit, self).forward(x) -# ================================================================================== -class HadamSplit(LinearSplit): + +# ============================================================================= +class HadamSplit(LinearSplit, DynamicHadamSplit): + # ========================================================================= r""" + Simulates the measurement of a moving object using the positive and + negative components of a Hadamard matrix. + Computes linear measurements from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) with positive entries and - :math:`x` is a vectorized image. + :math:`x` is a vectorized image or a batch of images. - The class is relies on a matrix :math:`H` with + The class relies on a matrix :math:`H` with shape :math:`(M,N)` where :math:`N` represents the number of pixels in the image and :math:`M \le N` the number of measurements. The matrix :math:`P` is obtained by splitting the matrix :math:`H` such that @@ -293,16 +590,20 @@ class HadamSplit(LinearSplit): Hadamard transform :math:`Fx` benefits a fast algorithm, as well as the computation of inverse Hadamard transforms. - .. note:: - :math:`H = H_{+} - H_{-}` - Args: - - :attr:`M`: Number of measurements - - :attr:`h`: Image height :math:`h`. The image is assumed to be square. - - :attr:`Ord`: Order matrix with shape :math:`(h,h)` used to compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` (see the :mod:`~spyrit.misc.sampling` submodule) + :attr:`M` (int): Number of measurements + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be square. + + :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h,h)` used to + compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` + (see the :mod:`~spyrit.misc.sampling` submodule) .. note:: The matrix H has shape :math:`(M,N)` with :math:`N = h^2`. + + .. note:: + :math:`H = H_{+} - H_{-}` Example: >>> Ord = np.random.random([32,32]) @@ -310,21 +611,9 @@ class HadamSplit(LinearSplit): """ def __init__(self, M: int, h: int, Ord: np.ndarray): - F = walsh2_matrix(h) # full matrix - Perm = Permutation_Matrix(Ord) - F = Perm @ F # If Perm is not learnt, could be computed mush faster - H = F[:M, :] - w = h # we assume a square image - - super().__init__(H) - - self.Perm = nn.Linear(self.N, self.N, False) - self.Perm.weight.data = torch.from_numpy(Perm.T) - self.Perm.weight.data = self.Perm.weight.data.float() - self.Perm.weight.requires_grad = False - self.h = h - self.w = w - + # initialize from DynamicHadamSplit __init__ + super(LinearSplit, self).__init__(M, h, Ord) + def inverse(self, x: torch.tensor) -> torch.tensor: r"""Inverse transform of Hadamard-domain images :math:`x = H_{had}^{-1}G y` is a Hadamard matrix. @@ -354,8 +643,7 @@ def inverse(self, x: torch.tensor) -> torch.tensor: # inverse of full transform # todo: initialize with 1D transform to speed up x = 1 / self.N * walsh2_torch(x) - x = x.view(b, N) - return x + return x.view(b, N) def pinv(self, x: torch.tensor) -> torch.tensor: r"""Pseudo inverse transform of incoming mesurement vectors :math:`x` @@ -374,145 +662,6 @@ def pinv(self, x: torch.tensor) -> torch.tensor: >>> print(x.shape) torch.Size([85, 1024]) """ - x = self.adjoint(x) / self.N - return x - - -# ================================================================================== -class LinearRowSplit(nn.Module): - # ================================================================================== - r"""Compute linear measurement of incoming images :math:`y = Px`, where - :math:`P` is a linear operator and :math:`x` is an image. Note that - the same transform applies to each of the rows of the image :math:`x`. - - The class is constructed from the positive and negative components of - the measurement operator :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` - - Args: - - :attr:`H_pos`: Positive component of the measurement matrix :math:`H_{+}` - - :attr:`H_neg`: Negative component of the measurement matrix :math:`H_{-}` - - Shape: - :math:`H_{+}`: :math:`(M, N)`, where :math:`M` is the number of - patterns and :math:`N` is the length of the patterns. - - :math:`H_{-}`: :math:`(M, N)`, where :math:`M` is the number of - patterns and :math:`N` is the length of the patterns. - - .. note:: - The class assumes the existence of the measurement operator - :math:`H = H_{+}-H_{-}` that contains negative values that cannot be - implemented in practice (harware constraints). - - Example: - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> linop = LinearRowSplit(H_pos,H_neg) - - """ - - def __init__(self, H_pos: np.ndarray, H_neg: np.ndarray): - super().__init__() - - self.M = H_pos.shape[0] - self.N = H_pos.shape[1] - - # Split patterns ? - # N.B.: Data must be of type float (or double) rather than the default - # float64 when creating torch tensor - even_index = range(0, 2 * self.M, 2) - odd_index = range(1, 2 * self.M, 2) - P = np.zeros((2 * self.M, self.N)) - P[even_index, :] = H_pos - P[odd_index, :] = H_neg - self.P = nn.Linear(self.N, 2 * self.M, False) - self.P.weight.data = torch.from_numpy(P).float() - self.P.weight.requires_grad = False - - # "Unsplit" patterns - H = H_pos - H_neg - self.H = nn.Linear(self.N, self.M, False) - self.H.weight.data = torch.from_numpy(H).float() - self.H.weight.requires_grad = False - - def forward(self, x: torch.tensor) -> torch.tensor: - r"""Applies linear transform to incoming images: :math:`y = Px` - - Args: - x: a batch of images + # + return self.adjoint(x) / self.N - Shape: - x: :math:`(b*c, h, w)` with :math:`b` the batch size, :math:`c` the - number of channels, :math:`h` is the image height, and :math:`w` is the image - width. - - Output: :math:`(b*c, 2M, w)` with :math:`b` the batch size, - :math:`c` the number of channels, :math:`2M` is twice the number of - patterns (as it includes both positive and negative components), and - :math:`w` is the image width. - - .. warning:: - The image height :math:`h` should match the length of the patterns - :math:`N` - - Example: - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> linop = LinearRowSplit(H_pos,H_neg) - >>> x = torch.rand(10,64,92) - >>> y = linop(x) - >>> print(y.shape) - torch.Size([10,48,92]) - - """ - x = torch.transpose(x, 1, 2) # swap last two dimensions - x = self.P(x) - x = torch.transpose(x, 1, 2) # swap last two dimensions - return x - - def forward_H(self, x: torch.tensor) -> torch.tensor: - r"""Applies linear transform to incoming images: :math:`m = Hx` - - Args: - x: a batch of images - - Shape: - x: :math:`(b*c, h, w)` with :math:`b` the batch size, :math:`c` the - number of channels, :math:`h` is the image height, and :math:`w` is the image - width. - - Output: :math:`(b*c, M, w)` with :math:`b` the batch size, - :math:`c` the number of channels, :math:`M` is the number of - patterns, and :math:`w` is the image width. - - .. warning:: - The image height :math:`h` should match the length of the patterns - :math:`N` - - Example: - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> meas_op = LinearRowSplit(H_pos,H_neg) - >>> x = torch.rand(10,64,92) - >>> y = meas_op.forward_H(x) - >>> print(y.shape) - torch.Size([10,24,92]) - - """ - x = torch.transpose(x, 1, 2) # swap last two dimensions - x = self.H(x) - x = torch.transpose(x, 1, 2) # swap last two dimensions - return x - - def get_H(self) -> torch.tensor: - r"""Returns the measurement matrix :math:`H`. - - Shape: - Output: :math:`(M, N)` - - Example: - >>> H = meas_op.get_H() - >>> print(H.shape) - torch.Size([24, 64]) - """ - return self.H.weight.data From 755f159878382f97718f95b4ebcb5f1c3c45cf32 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Feb 2024 12:27:40 +0000 Subject: [PATCH 06/11] [pre-commit.ci] Automatic python formatting --- spyrit/core/meas.py | 129 ++++++++++++++++++++++---------------------- 1 file changed, 64 insertions(+), 65 deletions(-) diff --git a/spyrit/core/meas.py b/spyrit/core/meas.py index 12a1d636..539f2749 100644 --- a/spyrit/core/meas.py +++ b/spyrit/core/meas.py @@ -12,10 +12,10 @@ class DynamicLinear(nn.Module): r""" Simulates the measurement of a moving object using the positive and negative components of the measurement matrix. - + Computes linear measurements from incoming images: :math:`y = Hx`, where :math:`H` is a linear operator (matrix) and :math:`x` is a - batch of vectorized images representing a motion picture. + batch of vectorized images representing a motion picture. The class is constructed from a matrix :math:`H` of shape :math:`(M, N)`, where :math:`N` represents the number of pixels in the image and @@ -62,15 +62,15 @@ class DynamicLinear(nn.Module): (H_pinv): Linear(in_features=400, out_features=1000, bias=False) ) """ - - def __init__(self, H: np.ndarray | torch.tensor, pinv=None, reg: float=1e-15): + + def __init__(self, H: np.ndarray | torch.tensor, pinv=None, reg: float = 1e-15): super().__init__() - + # nn.Parameter are sent to the device when using .to(device), # contrary to attributes - H = torch.tensor(H, dtype=torch.float32) + H = torch.tensor(H, dtype=torch.float32) self.H = nn.Parameter(H, requires_grad=False) - + self.M = H.shape[0] self.N = H.shape[1] self.h = int(self.N**0.5) @@ -84,13 +84,13 @@ def __init__(self, H: np.ndarray | torch.tensor, pinv=None, reg: float=1e-15): self.H_pinv = nn.Parameter(H_pinv, requires_grad=False) else: print("Pseudo inverse will not be instanciated") - + def get_H(self) -> torch.tensor: r"""Returns the measurement matrix :math:`H`. - + Shape: Output: :math:`(M, N)` - + Example: >>> H1 = np.random.random([400, 1000]) >>> meas_op = Linear(H1) @@ -99,14 +99,14 @@ def get_H(self) -> torch.tensor: Matrix shape: torch.Size([400, 1000]) """ return self.H.data - + def get_H_T(self) -> torch.tensor: r""" Returns the transpose of the measurement matrix :math:`H`. - + Shape: Output: :math:`(N, M)` - + Example: >>> H1 = np.random.random([400, 1000]) >>> meas_op = Linear(H1) @@ -115,13 +115,13 @@ def get_H_T(self) -> torch.tensor: Transpose shape: torch.Size([400, 1000]) """ return self.H.T - + def get_H_pinv(self) -> torch.tensor: r"""Returns the pseudo inverse of the measurement matrix :math:`H`. - + Shape: Output: :math:`(N, M)` - + Example: >>> H1 = np.random.random([400, 1000]) >>> meas_op = Linear(H1, True) @@ -130,27 +130,27 @@ def get_H_pinv(self) -> torch.tensor: Pseudo inverse shape: torch.Size([1000, 400]) """ return self.H_pinv.data - + def forward(self, x: torch.tensor) -> torch.tensor: r""" Simulates the measurement of a motion picture. - + The output :math:`y` is computed as :math:`y = Hx`, where :math:`H` is the measurement matrix and :math:`x` is a batch of vectorized (flattened) images. - + .. warning:: There must be **exactly** as many images as there are measurements in the linear operator used to initialize the class, i.e. `H.shape[-2:] == x.shape[-2:] - + Args: :math:`x`: Batch of vectorized (flattened) images. - + Shape: :math:`x`: :math:`(*, M, N)` :math:`output`: :math:`(*, M)` - + Example: >>> x = torch.rand([10, 400, 1000], dtype=torch.float) >>> H = np.random.random([400, 1000]) @@ -159,7 +159,7 @@ def forward(self, x: torch.tensor) -> torch.tensor: >>> print(y.shape) torch.Size([10, 400]) """ - return torch.einsum('ij,...ij->...i', self.get_H(), x) + return torch.einsum("ij,...ij->...i", self.get_H(), x) # ============================================================================= @@ -168,7 +168,7 @@ class DynamicLinearSplit(DynamicLinear): r""" Used to simulate the measurement of a moving object using the positive and negative components of the measurement matrix. - + Computes linear measurements from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) and :math:`x` is a batch of vectorized images representing a motion picture. @@ -195,7 +195,7 @@ class DynamicLinearSplit(DynamicLinear): >>> H = np.array(np.random.random([400,1000])) >>> meas_op = LinearDynamicSplit(H) """ - + def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): # initialize self.H and self.H_pinv super().__init__(H, pinv, reg) @@ -207,50 +207,50 @@ def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): # concatenate side by side, then reshape vertically P = torch.cat([H_pos, H_neg], 1).view(2 * self.M, self.N) self.P = nn.Parameter(P, requires_grad=False) - + def get_P(self) -> torch.tensor: r"""Returns the measurement matrix :math:`P`. - + Shape: Output: :math:`(2M, N)` - + Example: >>> P = meas_op.get_P() >>> print('Matrix shape:', P.shape) Matrix shape: torch.Size([800, 1000]) """ return self.P.data - + def forward(self, x: torch.tensor) -> torch.tensor: r""" Simulates the measurement of a motion picture using :math:`P`. - + The output :math:`y` is computed as :math:`y = Px`, where :math:`P` is the measurement matrix and :math:`x` is a batch of vectorized (flattened) images. - + :math:`P` contains only positive values and is obtained by splitting a measurement matrix :math:`H` such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. - + .. warning:: There must be **exactly** as many images as there are measurements in the linear operator used to initialize the class, i.e. `P.shape[-2:] == x.shape[-2:] - + Args: :math:`x`: Batch of vectorized (flatten) images. - + Shape: :math:`P` has a shape of :math:`(2M, N)` where :math:`M` is the number of measurements as defined by the first dimension of :math:`H` - and :math:`N` is the number of pixels in the image. - + and :math:`N` is the number of pixels in the image. + :math:`x`: :math:`(*, 2M, N)` - + :math:`output`: :math:`(*, 2M)` - + Example: >>> x = torch.rand([10, 400, 1000], dtype=torch.float) >>> H = np.random.random([400, 1000]) @@ -259,34 +259,34 @@ def forward(self, x: torch.tensor) -> torch.tensor: >>> print(y.shape) torch.Size([10, 800]) """ - return torch.einsum('ij,...ij->...i', self.get_P(), x) - + return torch.einsum("ij,...ij->...i", self.get_P(), x) + def forward_H(self, x: torch.tensor) -> torch.tensor: r""" Simulates the measurement of a motion picture using :math:`H`. - + The output :math:`y` is computed as :math:`y = Hx`, where :math:`H` is the measurement matrix and :math:`x` is a batch of vectorized (flattened) images. The positive and negative components of the measurement matrix are **not** used in this method. - + .. warning:: There must be **exactly** as many images as there are measurements in the linear operator used to initialize the class, i.e. `H.shape[-2:] == x.shape[-2:] - + Args: :math:`x`: Batch of vectorized (flatten) images. - + Shape: :math:`H` has a shape of :math:`(M, N)` where :math:`M` is the number of measurements and :math:`N` is the number of pixels in the - image. - + image. + :math:`x`: :math:`(*, M, N)` - + :math:`output`: :math:`(*, M)` - + Example: >>> x = torch.rand([10, 400, 1000], dtype=torch.float) >>> H = np.random.random([400, 1000]) @@ -296,7 +296,7 @@ def forward_H(self, x: torch.tensor) -> torch.tensor: torch.Size([10, 400]) """ return super.forward(x) - + # ============================================================================= class DynamicHadamSplit(DynamicLinearSplit): @@ -304,7 +304,7 @@ class DynamicHadamSplit(DynamicLinearSplit): r""" Simulates the measurement of a moving object using the positive and negative components of a Hadamard matrix. - + Computes linear measurements from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) with positive entries and :math:`x` is a batch of vectorized images representing a motion picture. @@ -330,13 +330,13 @@ class DynamicHadamSplit(DynamicLinearSplit): Args: :attr:`M` (int): Number of measurements - + :attr:`h` (int): Image height :math:`h`. The image is assumed to be square. - + :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h,h)` used to compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` (see the :mod:`~spyrit.misc.sampling` submodule) - + .. note:: The matrix H has shape :math:`(M,N)` with :math:`N = h^2`. @@ -347,7 +347,7 @@ class DynamicHadamSplit(DynamicLinearSplit): >>> Ord = np.random.random([32,32]) >>> meas_op = HadamSplitDynamic(400, 32, Ord) """ - + def __init__(self, M: int, h: int, Ord: np.ndarray): F = walsh2_matrix(h) # full matrix Perm = Permutation_Matrix(Ord) @@ -369,7 +369,7 @@ class Linear(DynamicLinear): # ========================================================================= r""" Simulates the measurement of an image using a measurement operator. - + Computes linear measurements from incoming images: :math:`y = Hx`, where :math:`H` is a linear operator (matrix) and :math:`x` is a vectorized image or a batch of images. @@ -483,7 +483,7 @@ def pinv(self, x: torch.tensor) -> torch.tensor: """ # Pmat.transpose()*f return x @ self.get_H_pinv().T - + # ============================================================================= class LinearSplit(Linear, DynamicLinearSplit): @@ -491,7 +491,7 @@ class LinearSplit(Linear, DynamicLinearSplit): r""" Simulates the measurement of an image using the computed positive and negative components of the measurement matrix. - + Computes linear measurements from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) and :math:`x` is a vectorized image or batch of vectorized images. @@ -564,14 +564,14 @@ def forward_H(self, x: torch.tensor) -> torch.tensor: # call Linear.forward() method return super(LinearSplit, self).forward(x) - + # ============================================================================= class HadamSplit(LinearSplit, DynamicHadamSplit): # ========================================================================= r""" Simulates the measurement of a moving object using the positive and negative components of a Hadamard matrix. - + Computes linear measurements from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) with positive entries and :math:`x` is a vectorized image or a batch of images. @@ -592,16 +592,16 @@ class HadamSplit(LinearSplit, DynamicHadamSplit): Args: :attr:`M` (int): Number of measurements - + :attr:`h` (int): Image height :math:`h`. The image is assumed to be square. - + :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h,h)` used to compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` (see the :mod:`~spyrit.misc.sampling` submodule) .. note:: The matrix H has shape :math:`(M,N)` with :math:`N = h^2`. - + .. note:: :math:`H = H_{+} - H_{-}` @@ -613,7 +613,7 @@ class HadamSplit(LinearSplit, DynamicHadamSplit): def __init__(self, M: int, h: int, Ord: np.ndarray): # initialize from DynamicHadamSplit __init__ super(LinearSplit, self).__init__(M, h, Ord) - + def inverse(self, x: torch.tensor) -> torch.tensor: r"""Inverse transform of Hadamard-domain images :math:`x = H_{had}^{-1}G y` is a Hadamard matrix. @@ -662,6 +662,5 @@ def pinv(self, x: torch.tensor) -> torch.tensor: >>> print(x.shape) torch.Size([85, 1024]) """ - # + # return self.adjoint(x) / self.N - From d2de6e295918dc54c4e84c926644d4a718067d72 Mon Sep 17 00:00:00 2001 From: romainphan Date: Fri, 23 Feb 2024 17:17:31 +0100 Subject: [PATCH 07/11] improved attributes management / docstring large update --- spyrit/core/meas.py | 777 ++++++++++++++++++++++++++++++-------------- 1 file changed, 531 insertions(+), 246 deletions(-) diff --git a/spyrit/core/meas.py b/spyrit/core/meas.py index 12a1d636..65fb6813 100644 --- a/spyrit/core/meas.py +++ b/spyrit/core/meas.py @@ -1,7 +1,9 @@ import warnings + import torch import torch.nn as nn import numpy as np + from spyrit.misc.walsh_hadamard import walsh2_torch, walsh2_matrix from spyrit.misc.sampling import Permutation_Matrix @@ -10,10 +12,9 @@ class DynamicLinear(nn.Module): # ========================================================================= r""" - Simulates the measurement of a moving object using the positive and - negative components of the measurement matrix. + Simulates the measurement of a moving object using a measurement matrix. - Computes linear measurements from incoming images: :math:`y = Hx`, + Computes linear measurements :math:`y` from incoming images: :math:`y = Hx`, where :math:`H` is a linear operator (matrix) and :math:`x` is a batch of vectorized images representing a motion picture. @@ -25,50 +26,45 @@ class DynamicLinear(nn.Module): .. warning:: For each call, there must be **exactly** as many images in :math:`x` as there are measurements in the linear operator used to initialize the class. - If not, an error will be raised. Args: - :attr:`H`: measurement matrix (linear operator) with shape :math:`(M, N)`. - - :attr:`pinv`: Option to have access to pseudo inverse solutions. - Defaults to `None` (the pseudo inverse is not initiliazed). - - :attr:`reg` (optional): Regularization parameter (cutoff for small - singular values, see :mod:`numpy.linal.pinv`). Only relevant when - :attr:`pinv` is not `None`. + :attr:`H` (torch.tensor): measurement matrix (linear operator) with + shape :math:`(M, N)`. Attributes: - :attr:`H`: The learnable measurement matrix of shape - :math:`(M,N)` initialized as :math:`H` - - :attr:`H_pinv` (optional): The learnable adjoint measurement - matrix of shape :math:`(N,M)` initialized as :math:`H^\dagger`. - Only relevant when :attr:`pinv` is not `None`. + :attr:`H` (torch.nn.Parameter): The learnable measurement matrix of + shape :math:`(M,N)` initialized as :math:`H`. + + :attr:`M` (int): Number of measurements performed by the linear operator. + It is initialized as the first dimension of :math:`H`. + + :attr:`N` (int): Number of pixels in the image. It is initialized as the + second dimension of :math:`H`. + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be + square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. + + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. - Example 1: - >>> H = np.random.random([400, 1000]) - >>> meas_op = LinearDynamic(H) + Example: + >>> H = np.random.random([400, 1600]) + >>> meas_op = DynamicLinear(H) >>> print(meas_op) - LinearDynamic( - (H): Linear(in_features=1000, out_features=400, bias=False) + DynamicLinear( + (Image pixels): 1600 + (H): torch.Size([400, 1600]) ) - - Example 2: - >>> H = np.random.random([400, 1000]) - >>> meas_op = LinearDynamic(H, True) - >>> print(meas_op) - LinearDynamic( - (H): Linear(in_features=1000, out_features=400, bias=False) - (H_pinv): Linear(in_features=400, out_features=1000, bias=False) - ) """ - def __init__(self, H: np.ndarray | torch.tensor, pinv=None, reg: float=1e-15): + def __init__(self, H: torch.tensor): super().__init__() # nn.Parameter are sent to the device when using .to(device), - # contrary to attributes - H = torch.tensor(H, dtype=torch.float32) + # convert to float 32 for memory efficiency + H = H.type(torch.FloatTensor) self.H = nn.Parameter(H, requires_grad=False) self.M = H.shape[0] @@ -77,60 +73,24 @@ def __init__(self, H: np.ndarray | torch.tensor, pinv=None, reg: float=1e-15): self.w = int(self.N**0.5) if self.h * self.w != self.N: warnings.warn( - "N is not a square. Please assign self.h and self.w manually." + f"N ({H.shape[1]}) is not a square. Please assign self.h and self.w manually." ) - if pinv is not None: - H_pinv = torch.linalg.pinv(H, rcond=reg) - self.H_pinv = nn.Parameter(H_pinv, requires_grad=False) - else: - print("Pseudo inverse will not be instanciated") - + def get_H(self) -> torch.tensor: - r"""Returns the measurement matrix :math:`H`. + r"""Returns the attribute measurement matrix :math:`H`. Shape: Output: :math:`(M, N)` Example: - >>> H1 = np.random.random([400, 1000]) + >>> H1 = np.random.random([400, 1600]) >>> meas_op = Linear(H1) >>> H2 = meas_op.get_H() - >>> print('Matrix shape:', H2.shape) - Matrix shape: torch.Size([400, 1000]) + >>> print(H2.shape) + torch.Size([400, 1600]) """ return self.H.data - - def get_H_T(self) -> torch.tensor: - r""" - Returns the transpose of the measurement matrix :math:`H`. - - Shape: - Output: :math:`(N, M)` - - Example: - >>> H1 = np.random.random([400, 1000]) - >>> meas_op = Linear(H1) - >>> H2 = meas_op.get_H_T() - >>> print('Transpose shape:', H2.shape) - Transpose shape: torch.Size([400, 1000]) - """ - return self.H.T - - def get_H_pinv(self) -> torch.tensor: - r"""Returns the pseudo inverse of the measurement matrix :math:`H`. - - Shape: - Output: :math:`(N, M)` - - Example: - >>> H1 = np.random.random([400, 1000]) - >>> meas_op = Linear(H1, True) - >>> H2 = meas_op.get_H_pinv() - >>> print('Pseudo inverse shape:', H2.shape) - Pseudo inverse shape: torch.Size([1000, 400]) - """ - return self.H_pinv.data - + def forward(self, x: torch.tensor) -> torch.tensor: r""" Simulates the measurement of a motion picture. @@ -142,39 +102,61 @@ def forward(self, x: torch.tensor) -> torch.tensor: .. warning:: There must be **exactly** as many images as there are measurements in the linear operator used to initialize the class, i.e. - `H.shape[-2:] == x.shape[-2:] + `H.shape[-2] == x.shape[-2]` Args: :math:`x`: Batch of vectorized (flattened) images. Shape: - :math:`x`: :math:`(*, M, N)` + :math:`x`: :math:`(*, M, N)`, where * denotes the batch size and + :math:`(M, N)` is the shape of the measurement matrix :math:`H`. + :math:`M` is the number of measurements (and frames) and :math:`N` + the number of pixels in the image. + :math:`output`: :math:`(*, M)` Example: - >>> x = torch.rand([10, 400, 1000], dtype=torch.float) - >>> H = np.random.random([400, 1000]) - >>> meas_op = LinearDynamic(H) + >>> x = torch.rand([10, 400, 1600]) + >>> H = np.random.random([400, 1600]) + >>> meas_op = DynamicLinear(H) >>> y = meas_op(x) >>> print(y.shape) torch.Size([10, 400]) """ - return torch.einsum('ij,...ij->...i', self.get_H(), x) + try : + return torch.einsum('ij,...ij->...i', self.get_H(), x) + except RuntimeError as e: + if "which does not broadcast with previously seen size" in str(e): + raise ValueError( + f"The shape of the input x ({x.shape}) does not match the " + + f"shape of the measurement matrix H ({self.get_H().shape})." + ) + else: + raise e + + def __str__(self): + s_begin = f"{self.__class__.__name__}(\n " + s_fill = "\n ".join([f"({k}): {v}" for k, v in self.__attributeslist__()]) + s_end = "\n )" + return s_begin + s_fill + s_end + + def __attributeslist__(self): + return [('Image pixels', self.N), ('H', self.H.shape)] # ============================================================================= class DynamicLinearSplit(DynamicLinear): # ========================================================================= r""" - Used to simulate the measurement of a moving object using the positive and + Simulates the measurement of a moving object using the positive and negative components of the measurement matrix. - Computes linear measurements from incoming images: :math:`y = Px`, + Computes linear measurements :math:`y` from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) and :math:`x` is a batch of vectorized images representing a motion picture. The matrix :math:`P` contains only positive values and is obtained by - splitting a measurement matrix :math:`H` such that + splitting a given measurement matrix :math:`H` such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. @@ -182,23 +164,53 @@ class DynamicLinearSplit(DynamicLinear): where :math:`N` represents the number of pixels in the image and :math:`M` the number of measurements. + Args: + :math:`H` (np.ndarray): measurement matrix (linear operator) with + shape :math:`(M, N)` where :math:`M` is the number of measurements and + :math:`N` the number of pixels in the image. + + Attributes: + :attr:`H` (torch.nn.Parameter): The learnable measurement matrix of + shape :math:`(M,N)`. + + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of + shape :math:`(2M, N)` initialized as + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)` + + :attr:`M` (int): Number of measurements performed by the linear operator. + It is initialized as the first dimension of :math:`H`. + + :attr:`N` (int): Number of pixels in the image. It is initialized as the + second dimension of :math:`H`. + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be + square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. + + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. + .. warning:: For each call, there must be **exactly** as many images in :math:`x` as there are measurements in the linear operator used to initialize the class. - If not, an error will be raised. - - Args: - :math:`H` (np.ndarray): measurement matrix (linear operator) with - shape :math:`(M, N)`. Example: - >>> H = np.array(np.random.random([400,1000])) - >>> meas_op = LinearDynamicSplit(H) + >>> H = np.array(np.random.random([400,1600])) + >>> meas_op = DynamicLinearSplit(H) + >>> print(meas_op) + DynamicLinearSplit( + (Image pixels): 1600 + (H): torch.Size([400, 1600]) + (P): torch.Size([800, 1600]) + ) """ - def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): + def __init__(self, H: np.ndarray): # initialize self.H and self.H_pinv - super().__init__(H, pinv, reg) + super().__init__(H) + # initialize self.P = [ H^+ ] # [ H^- ] zero = torch.zeros(1) @@ -206,18 +218,22 @@ def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): H_neg = torch.maximum(zero, -H) # concatenate side by side, then reshape vertically P = torch.cat([H_pos, H_neg], 1).view(2 * self.M, self.N) + P = P.type(torch.FloatTensor) # cast to float 32 self.P = nn.Parameter(P, requires_grad=False) def get_P(self) -> torch.tensor: - r"""Returns the measurement matrix :math:`P`. + r"""Returns the attribute measurement matrix :math:`P`. Shape: - Output: :math:`(2M, N)` + Output: :math:`(2M, N)`, where :math:`(M, N)` is the shape of the + measurement matrix :math:`H` given at initialization. Example: + >>> H = np.random.random([400, 1600]) + >>> meas_op = LinearDynamicSplit(H) >>> P = meas_op.get_P() - >>> print('Matrix shape:', P.shape) - Matrix shape: torch.Size([800, 1000]) + >>> print(P.shape) + torch.Size([800, 1600]) """ return self.P.data @@ -230,36 +246,51 @@ def forward(self, x: torch.tensor) -> torch.tensor: images. :math:`P` contains only positive values and is obtained by - splitting a measurement matrix :math:`H` such that + splitting a given measurement matrix :math:`H` such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + The matrix :math:`H` can contain positive and negative values and is + given by the user at initialization. + .. warning:: There must be **exactly** as many images as there are measurements in the linear operator used to initialize the class, i.e. - `P.shape[-2:] == x.shape[-2:] + `P.shape[-2] == x.shape[-2]` Args: - :math:`x`: Batch of vectorized (flatten) images. + :math:`x`: Batch of vectorized (flattened) images of shape + :math:`(*, 2M, N)` where * denotes the batch size, :math:`2M` the + number of measurements in the measurement matrix :math:`P` and + :math:`N` the number of pixels in the image. Shape: + :math:`x`: :math:`(*, 2M, N)` + :math:`P` has a shape of :math:`(2M, N)` where :math:`M` is the number of measurements as defined by the first dimension of :math:`H` and :math:`N` is the number of pixels in the image. - :math:`x`: :math:`(*, 2M, N)` - :math:`output`: :math:`(*, 2M)` Example: - >>> x = torch.rand([10, 400, 1000], dtype=torch.float) - >>> H = np.random.random([400, 1000]) - >>> meas_op = LinearDynamicSplit(H) + >>> x = torch.rand([10, 800, 1600]) + >>> H = np.random.random([400, 1600]) + >>> meas_op = DynamicLinearSplit(H) >>> y = meas_op(x) >>> print(y.shape) torch.Size([10, 800]) """ - return torch.einsum('ij,...ij->...i', self.get_P(), x) + try : + return torch.einsum('ij,...ij->...i', self.get_P(), x) + except RuntimeError as e: + if "which does not broadcast with previously seen size" in str(e): + raise ValueError( + f"The shape of the input x ({x.shape}) does not match the " + + f"shape of the measurement matrix P ({self.get_P().shape})." + ) + else: + raise e def forward_H(self, x: torch.tensor) -> torch.tensor: r""" @@ -270,33 +301,41 @@ def forward_H(self, x: torch.tensor) -> torch.tensor: images. The positive and negative components of the measurement matrix are **not** used in this method. + The matrix :math:`H` can contain positive and negative values and is + given by the user at initialization. + .. warning:: There must be **exactly** as many images as there are measurements in the linear operator used to initialize the class, i.e. - `H.shape[-2:] == x.shape[-2:] + `H.shape[-2:] == x.shape[-2:]` Args: - :math:`x`: Batch of vectorized (flatten) images. + :math:`x`: Batch of vectorized (flatten) images of shape + :math:`(*, M, N)` where * denotes the batch size, and :math:`(M, N)` + is the shape of the measurement matrix :math:`H`. Shape: + :math:`x`: :math:`(*, M, N)` + :math:`H` has a shape of :math:`(M, N)` where :math:`M` is the number of measurements and :math:`N` is the number of pixels in the image. - :math:`x`: :math:`(*, M, N)` - :math:`output`: :math:`(*, M)` Example: - >>> x = torch.rand([10, 400, 1000], dtype=torch.float) - >>> H = np.random.random([400, 1000]) + >>> x = torch.rand([10, 400, 1600]) + >>> H = np.random.random([400, 1600]) >>> meas_op = LinearDynamicSplit(H) >>> y = meas_op.forward_H(x) >>> print(y.shape) torch.Size([10, 400]) """ - return super.forward(x) + return super().forward(x) + def __attributeslist__(self): + return super().__attributeslist__() + [('P', self.P.shape)] + # ============================================================================= class DynamicHadamSplit(DynamicLinearSplit): @@ -309,36 +348,68 @@ class DynamicHadamSplit(DynamicLinearSplit): where :math:`P` is a linear operator (matrix) with positive entries and :math:`x` is a batch of vectorized images representing a motion picture. - The class relies on a matrix :math:`H` with - shape :math:`(M,N)` where :math:`N` represents the number of pixels in the - image and :math:`M \le N` the number of measurements. The matrix :math:`P` - is obtained by splitting the matrix :math:`H` such that - :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where + The class relies on a Hadamard-based matrix :math:`H` with shape :math:`(M,N)` + where :math:`N` represents the number of pixels in the image and + :math:`M \le N` the number of measurements. :math:`H` is obtained by + selecting a re-ordered subsample of :math:`M` rows of a "full" Hadamard + matrix :math:`F` with shape :math:`(N^2, N^2)`. :math:`N` must be a power + of 2. + + The matrix :math:`P` is then obtained by splitting the matrix :math:`H` + such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. - The matrix :math:`H` is obtained by retaining the first :math:`M` rows of - a permuted Hadamard matrix :math:`GF`, where :math:`G` is a - permutation matrix with shape with shape :math:`(M,N)` and :math:`F` is a - "full" Hadamard matrix with shape :math:`(N,N)`. The computation of a - Hadamard transform :math:`Fx` benefits a fast algorithm, as well as the - computation of inverse Hadamard transforms. - - .. warning:: - For each call, there must be **exactly** as many images in :math:`x` as - there are measurements in the linear operator used to initialize the class. - If not, an error will be raised. - Args: :attr:`M` (int): Number of measurements - :attr:`h` (int): Image height :math:`h`. The image is assumed to be square. + :attr:`h` (int): Image height :math:`h`, must be a power of 2. The + image is assumed to be square, so the number of pixels in the image is + :math:`N = h^2`. - :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h,h)` used to + :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h, h)` used to + select the rows of the full Hadamard matrix :math:`F` compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` (see the :mod:`~spyrit.misc.sampling` submodule) + + Attributes: + :attr:`H` (torch.nn.Parameter): The measurement matrix of shape + :math:`(M, h^2)`. It is initialized as a re-ordered subsample of the + rows of the "full" Hadamard matrix :math:`F` with shape :math:`(N^2, N^2)`. + + :attr:`H_pinv` (torch.nn.Parameter): The pseudo inverse of the measurement + matrix of shape :math:`(h^2, M)`. It is initialized as + :math:`H^\dagger = \frac{1}{N}H^{T}` where :math:`N = h^2`. + + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of + shape :math:`(2M, h^2)` initialized as + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + + :attr:`Perm` (torch.nn.Parameter): The permutation matrix :math:`G^{T}` + that is used to re-order the subsample of rows of the "full" Hadamard + matrix :math:`F` according to descreasing value of the order matrix + :math:`Ord`. It has shape :math:`(N, N)` where :math:`N = h^2`. + + :attr:`M` (int): Number of measurements performed by the linear operator. + + :attr:`N` (int): Number of pixels in the image. It is initialized as + :math:`h^2`. + + :attr:`h` (int): Image height :math:`h`. + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = h`. + + .. warning:: + For each call, there must be **exactly** as many images in :math:`x` as + there are measurements in the linear operator used to initialize the class. + .. note:: - The matrix H has shape :math:`(M,N)` with :math:`N = h^2`. + The computation of a Hadamard transform :math:`Fx` benefits a fast + algorithm, as well as the computation of inverse Hadamard transforms. + + .. note:: + The matrix :math:`H` has shape :math:`(M, N)` with :math:`N = h^2`. .. note:: :math:`H = H_{+} - H_{-}` @@ -346,86 +417,200 @@ class DynamicHadamSplit(DynamicLinearSplit): Example: >>> Ord = np.random.random([32,32]) >>> meas_op = HadamSplitDynamic(400, 32, Ord) + >>> print(meas_op) + HadamSplitDynamic( + (Image pixels): 1024 + (H): torch.Size([400, 1024]) + (P): torch.Size([800, 1024]) + (Perm): torch.Size([1024, 1024]) + ) """ - + # ========================================================================= change this ????? ^ def __init__(self, M: int, h: int, Ord: np.ndarray): F = walsh2_matrix(h) # full matrix Perm = Permutation_Matrix(Ord) - F = Perm @ F # If Perm is not learnt, could be computed mush faster + F = Perm @ F # If Perm is not learnt, could be computed much faster H = F[:M, :] w = h # we assume a square image - super().__init__(H) + super().__init__(torch.from_numpy(H)) + print("h before", self.h) - Perm = torch.tensor(Perm, dtype=torch.float32) + Perm = torch.from_numpy(Perm).float() # float32 self.Perm = nn.Parameter(Perm, requires_grad=False) # overwrite self.h and self.w self.h = h self.w = w + print("h after", self.h) + + def __attributeslist__(self): + return super().__attributeslist__() + [('Perm', self.Perm.shape)] # ============================================================================= class Linear(DynamicLinear): # ========================================================================= r""" - Simulates the measurement of an image using a measurement operator. + Simulates the measurement of an still image using a measurement matrix. Computes linear measurements from incoming images: :math:`y = Hx`, - where :math:`H` is a linear operator (matrix) and :math:`x` is a - vectorized image or a batch of images. + where :math:`H` is a given linear operator (matrix) and :math:`x` is a + vectorized image or batch of images. The class is constructed from a :math:`M` by :math:`N` matrix :math:`H`, where :math:`N` represents the number of pixels in the image and :math:`M` the number of measurements. Args: - :attr:`H`: measurement matrix (linear operator) with shape :math:`(M, N)`. + :attr:`H` (:type:`torch.tensor`): measurement matrix (linear operator) with shape :math:`(M, N)`. - :attr:`pinv`: Option to have access to pseudo inverse solutions. - Defaults to `None` (the pseudo inverse is not initiliazed). + :attr:`pinv` (Any): Option to have access to pseudo inverse solutions. If not + `None`, the pseudo inverse is initialized as :math:`H^\dagger` and + stored in the attribute :attr:`H_pinv`. Defaults to `None` (the pseudo + inverse is not initiliazed). - :attr:`reg` (optional): Regularization parameter (cutoff for small + :attr:`reg` (float, optional): Regularization parameter (cutoff for small singular values, see :mod:`numpy.linal.pinv`). Only relevant when :attr:`pinv` is not `None`. - Attributes: - :attr:`H`: The learnable measurement matrix of shape - :math:`(M,N)` initialized as :math:`H` + :attr:`H` (torch.tensor): The learnable measurement matrix of shape + :math:`(M, N)` initialized as :math:`H` + + :attr:`H_pinv` (torch.tensor, optional): The learnable adjoint measurement + matrix of shape :math:`(N, M)` initialized as :math:`H^\dagger`. + Only relevant when :attr:`pinv` is not `None`. - :attr:`H_adjoint`: The learnable adjoint measurement matrix - of shape :math:`(N,M)` initialized as :math:`H^\top` + :attr:`M` (int): Number of measurements performed by the linear operator. + It is initialized as the first dimension of :math:`H`. + + :attr:`N` (int): Number of pixels in the image. It is initialized as the + second dimension of :math:`H`. + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be + square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. + + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. - :attr:`H_pinv` (optional): The learnable adjoint measurement - matrix of shape :math:`(N,M)` initialized as :math:`H^\dagger`. - Only relevant when :attr:`pinv` is not `None`. + .. note:: + If you know the pseudo inverse of :math:`H` and want to store it, it is + best to initialize the class with :attr:`pinv` set to `None` and then + call :meth:`set_H_pinv` to store the pseudo inverse. Example 1: - >>> H = np.random.random([400, 1000]) - >>> meas_op = Linear(H) + >>> H = np.random.random([400, 1600]) + >>> meas_op = Linear(H, pinv=None) >>> print(meas_op) Linear( - (H): Linear(in_features=1000, out_features=400, bias=False) + (Image pixels): 1600 + (H): torch.Size([400, 1600]) + (H_pinv): None ) Example 2: - >>> H = np.random.random([400, 1000]) + >>> H = np.random.random([400, 1600]) >>> meas_op = Linear(H, True) >>> print(meas_op) Linear( - (H): Linear(in_features=1000, out_features=400, bias=False) - (H_pinv): Linear(in_features=400, out_features=1000, bias=False) - ) + (Image pixels): 1600 + (H): torch.Size([400, 1600]) + (H_pinv): torch.Size([1600, 400]) + ) """ - def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): - super().__init__(H, pinv, reg) + def __init__(self, H: np.ndarray, pinv=None, reg: float=1e-15): + super().__init__(H) + if pinv is not None: + self.set_H_pinv(reg=reg) + + def get_H_T(self) -> torch.tensor: + r""" + Returns the transpose of the measurement matrix :math:`H`. + + Shape: + Output: :math:`(N, M)`, where :math:`N` is the number of pixels in + the image and :math:`M` the number of measurements. + + Example: + >>> H1 = np.random.random([400, 1600]) + >>> meas_op = Linear(H1) + >>> H2 = meas_op.get_H_T() + >>> print(H2.shape) + torch.Size([400, 1600]) + """ + return self.H.T + + def get_H_pinv(self) -> torch.tensor: + r"""Returns the pseudo inverse of the measurement matrix :math:`H`. + + Shape: + Output: :math:`(N, M)` + + Example: + >>> H1 = np.random.random([400, 1600]) + >>> meas_op = Linear(H1, True) + >>> H2 = meas_op.get_H_pinv() + >>> print(H2.shape) + torch.Size([1600, 400]) + """ + try: + return self.H_pinv.data + except AttributeError as e: + if "has no attribute 'H_pinv'" in str(e): + raise AttributeError( + "The pseudo inverse has not been initialized. Please set it using self.set_H_pinv()." + ) + else: + raise e + + def set_H_pinv(self, reg: float=1e-15, pinv: torch.tensor=None) -> None: + r""" + Stores in self.H_pinv the pseudo inverse of the measurement matrix :math:`H`. + + If :attr:`pinv` is given, it is directly stored as the pseudo inverse. + The validity of the pseudo inverse is not checked. If :attr:`pinv` is + :obj:`None`, the pseudo inverse is computed from the existing + measurement matrix :math:`H` with regularization parameter :attr:`reg`. + + Args: + :attr:`reg` (float, optional): Cutoff for small singular values. + + :attr:`H_pinv` (torch.tensor, optional): If given, the tensor is + directly stored as the pseudo inverse. No checks are performed. + Otherwise, the pseudo inverse is computed from the existing + measurement matrix :math:`H`. + + .. note: + Only one of :math:`H_pinv` and :math:`reg` should be given. If both + are given, :math:`H_pinv` is used and :math:`reg` is ignored. + + Shape: + :attr:`H_pinv`: :math:`(N, M)`, where :math:`N` is the number of + pixels in the image and :math:`M` the number of measurements. + + Example: + >>> H1 = torch.rand([400, 1600]) + >>> H2 = torch.linalg.pinv(H1) + >>> meas_op = Linear(H1) + >>> meas_op.set_H_pinv(H2) + """ + if pinv is not None: + H_pinv = pinv.type(torch.FloatTensor) # to float32 + else: + H_pinv = torch.linalg.pinv(self.get_H(), rcond=reg) + self.H_pinv = nn.Parameter(H_pinv, requires_grad=False) + def forward(self, x: torch.tensor) -> torch.tensor: r"""Applies linear transform to incoming images: :math:`y = Hx`. Args: - :math:`x`: Batch of vectorized (flatten) images. + :math:`x` (torch.tensor): Batch of vectorized (flattened) images. + If x has more than 1 dimension, the linear measurement is applied + to each image in the batch. Shape: :math:`x`: :math:`(*, N)` where * denotes the batch size and `N` @@ -435,20 +620,23 @@ def forward(self, x: torch.tensor) -> torch.tensor: the number of measurements. Example: - >>> x = torch.rand([10,1000], dtype=torch.float) + >>> H = torch.randn([400, 1600]) + >>> meas_op = Linear(H) + >>> x = torch.randn([10, 1600]) >>> y = meas_op(x) - >>> print('forward:', y.shape) - forward: torch.Size([10, 400]) - + >>> print(y.shape) + torch.Size([10, 400]) """ # left multiplication with transpose is equivalent to right mult - return x @ self.get_H_T() + return x @ self.get_H().T def adjoint(self, x: torch.tensor) -> torch.tensor: r"""Applies adjoint transform to incoming measurements :math:`y = H^{T}x` Args: - :math:`x`: batch of measurement vectors. + :math:`x` (torch.tensor): batch of measurement vectors. If x has + more than 1 dimension, the adjoint measurement is applied to each + measurement in the batch. Shape: :math:`x`: :math:`(*, M)` @@ -456,19 +644,23 @@ def adjoint(self, x: torch.tensor) -> torch.tensor: Output: :math:`(*, N)` Example: - >>> x = torch.rand([10,400], dtype=torch.float) + >>> H = torch.randn([400, 1600]) + >>> meas_op = Linear(H) + >>> x = torch.randn([10, 400] >>> y = meas_op.adjoint(x) - >>> print('adjoint:', y.shape) - adjoint: torch.Size([10, 1000]) + >>> print(y.shape) + torch.Size([10, 1600]) """ # left multiplication is equivalent to right mult with transpose - return x @ self.get_H() + return x @ self.get_H_T().T def pinv(self, x: torch.tensor) -> torch.tensor: r"""Computes the pseudo inverse solution :math:`y = H^\dagger x` Args: - :math:`x`: batch of measurement vectors. + :math:`x` (torch.tensor): batch of measurement vectors. If x has + more than 1 dimension, the pseudo inverse is applied to each + image in the batch. Shape: :math:`x`: :math:`(*, M)` @@ -476,20 +668,27 @@ def pinv(self, x: torch.tensor) -> torch.tensor: Output: :math:`(*, N)` Example: - >>> x = torch.rand([10,400], dtype=torch.float) + >>> H = torch.randn([400, 1600]) + >>> meas_op = Linear(H, True) + >>> x = torch.randn([10, 400]) >>> y = meas_op.pinv(x) - >>> print('pinv:', y.shape) - adjoint: torch.Size([10, 1000]) + >>> print(y.shape) + torch.Size([10, 1600]) """ # Pmat.transpose()*f return x @ self.get_H_pinv().T - + + def __attributeslist__(self): + return super().__attributeslist__() \ + + [('H_pinv', self.H_pinv.shape if hasattr(self, 'H_pinv') + else None)] + # ============================================================================= class LinearSplit(Linear, DynamicLinearSplit): # ========================================================================= r""" - Simulates the measurement of an image using the computed positive and + Simulates the measurement of a still image using the computed positive and negative components of the measurement matrix. Computes linear measurements from incoming images: :math:`y = Px`, @@ -506,23 +705,79 @@ class LinearSplit(Linear, DynamicLinearSplit): :math:`M` the number of measurements. Args: - :math:`H` (np.ndarray): measurement matrix (linear operator) with - shape :math:`(M, N)`. + :attr:`H` (torch.tensor): measurement matrix (linear operator) with + shape :math:`(M, N)`, where :math:`M` is the number of measurements and + :math:`N` the number of pixels in the image. + + :attr:`pinv` (Any): Option to have access to pseudo inverse solutions. If not + `None`, the pseudo inverse is initialized as :math:`H^\dagger` and + stored in the attribute :attr:`H_pinv`. Defaults to `None` (the pseudo + inverse is not initiliazed). + + :attr:`reg` (float, optional): Regularization parameter (cutoff for small + singular values, see :mod:`torch.linalg.pinv`). Only relevant when + :attr:`pinv` is not `None`. + + Attributes: + :attr:`H` (torch.nn.Parameter): The learnable measurement matrix of + shape :math:`(M,N)`. + + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of + shape :math:`(2M, N)` initialized as + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)` + + :attr:`M` (int): Number of measurements performed by the linear operator. + It is initialized as the first dimension of :math:`H`. + + :attr:`N` (int): Number of pixels in the image. It is initialized as the + second dimension of :math:`H`. + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be + square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. + + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. + + .. note:: + If you know the pseudo inverse of :math:`H` and want to store it, it is + best to initialize the class with :attr:`pinv` set to `None` and then + call :meth:`set_H_pinv` to store the pseudo inverse. Example: - >>> H = np.array(np.random.random([400,1000])) - >>> meas_op = LinearSplit(H) + >>> H = torch.randn(400, 1600) + >>> meas_op = LinearSplit(H, None) + >>> print(meas_op) + LinearSplit( + (Image pixels): 1600 + (H): torch.Size([400, 1600]) + (P): torch.Size([800, 1600]) + (H_pinv): None + ) """ - def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): + def __init__(self, H: np.ndarray, pinv=None, reg: float=1e-15): + print("initializing LinearSplit") # initialize from DynamicLinearSplit __init__ - super(Linear, self).__init__(H, pinv, reg) + super(Linear, self).__init__(H) + if pinv is not None: + self.set_H_pinv(reg) def forward(self, x: torch.tensor) -> torch.tensor: r"""Applies linear transform to incoming images: :math:`y = Px`. + This method uses the splitted measurement matrix :math:`P` to compute + the linear measurements from incoming images. :math:`P` contains only + positive values and is obtained by splitting a given measurement matrix + :math:`H` such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + Args: - :math:`x`: Batch of vectorized (flatten) images. + :math:`x` (torch.tensor): Batch of vectorized (flattened) images. If + x has more than 1 dimension, the linear measurement is applied to + each image in the batch. Shape: :math:`x`: :math:`(*, N)` where * denotes the batch size and `N` @@ -532,34 +787,40 @@ def forward(self, x: torch.tensor) -> torch.tensor: the number of measurements. Example: - >>> x = torch.rand([10,1000], dtype=torch.float) + >>> H = torch.randn(400, 1600) + >>> meas_op = LinearSplit(H) + >>> x = torch.randn(10, 1600) >>> y = meas_op(x) - >>> print('Output:', y.shape) - Output: torch.Size([10, 800]) + >>> print(y.shape) + torch.Size([10, 800]) """ - # x.shape[b*c,N] - # output shape : [b*c, 2*M] return x @ self.get_P().T def forward_H(self, x: torch.tensor) -> torch.tensor: r"""Applies linear transform to incoming images: :math:`m = Hx`. + This method uses the measurement matrix :math:`H` to compute the linear + measurements from incoming images. + Args: - :math:`x`: Batch of vectorized (flatten) images. + :attr:`x` (torch.tensor): Batch of vectorized (flatten) images. If + x has more than 1 dimension, the linear measurement is applied to + each image in the batch. Shape: - :math:`x`: :math:`(*, N)` where * denotes the batch size and `N` + :attr:`x`: :math:`(*, N)` where * denotes the batch size and `N` the total number of pixels in the image. Output: :math:`(*, M)` where * denotes the batch size and `M` the number of measurements. Example: - >>> x = torch.rand([10,1000], dtype=torch.float) + >>> H = torch.randn(400, 1600) + >>> meas_op = LinearSplit(H) + >>> x = torch.randn(10, 1600) >>> y = meas_op.forward_H(x) - >>> print('Output:', y.shape) - output shape: torch.Size([10, 400]) - + >>> print(y.shape) + torch.Size([10, 400]) """ # call Linear.forward() method return super(LinearSplit, self).forward(x) @@ -569,36 +830,68 @@ def forward_H(self, x: torch.tensor) -> torch.tensor: class HadamSplit(LinearSplit, DynamicHadamSplit): # ========================================================================= r""" - Simulates the measurement of a moving object using the positive and + Simulates the measurement of a still image using the positive and negative components of a Hadamard matrix. Computes linear measurements from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) with positive entries and :math:`x` is a vectorized image or a batch of images. - The class relies on a matrix :math:`H` with - shape :math:`(M,N)` where :math:`N` represents the number of pixels in the - image and :math:`M \le N` the number of measurements. The matrix :math:`P` - is obtained by splitting the matrix :math:`H` such that - :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where + The class relies on a Hadamard-based matrix :math:`H` with shape :math:`(M,N)` + where :math:`N` represents the number of pixels in the image and + :math:`M \le N` the number of measurements. :math:`H` is obtained by + selecting a re-ordered subsample of :math:`M` rows of a "full" Hadamard + matrix :math:`F` with shape :math:`(N^2, N^2)`. :math:`N` must be a power + of 2. + + The matrix :math:`P` is then obtained by splitting the matrix :math:`H` + such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. - The matrix :math:`H` is obtained by retaining the first :math:`M` rows of - a permuted Hadamard matrix :math:`GF`, where :math:`G` is a - permutation matrix with shape with shape :math:`(M,N)` and :math:`F` is a - "full" Hadamard matrix with shape :math:`(N,N)`. The computation of a - Hadamard transform :math:`Fx` benefits a fast algorithm, as well as the - computation of inverse Hadamard transforms. - Args: :attr:`M` (int): Number of measurements - :attr:`h` (int): Image height :math:`h`. The image is assumed to be square. + :attr:`h` (int): Image height :math:`h`, must be a power of 2. The + image is assumed to be square, so the number of pixels in the image is + :math:`N = h^2`. - :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h,h)` used to + :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h, h)` used to compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` (see the :mod:`~spyrit.misc.sampling` submodule) + Attributes: + :attr:`H` (torch.nn.Parameter): The measurement matrix of shape + :math:`(M, h^2)`. It is initialized as a re-ordered subsample of the + rows of the "full" Hadamard matrix :math:`F` with shape :math:`(N^2, N^2)`. + + :attr:`H_pinv` (torch.nn.Parameter): The pseudo inverse of the measurement + matrix of shape :math:`(h^2, M)`. It is initialized as + :math:`H^\dagger = \frac{1}{N}H^{T}` where :math:`N = h^2`. + + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of + shape :math:`(2M, h^2)` initialized as + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + + :attr:`Perm` (torch.nn.Parameter): The permutation matrix :math:`G^{T}` + that is used to re-order the subsample of rows of the "full" Hadamard + matrix :math:`F` according to descreasing value of the order matrix + :math:`Ord`. It has shape :math:`(N, N)` where :math:`N = h^2`. + + :attr:`M` (int): Number of measurements performed by the linear operator. + + :attr:`N` (int): Number of pixels in the image. It is initialized as + :math:`h^2`. + + :attr:`h` (int): Image height :math:`h`. + + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = h`. + + .. note:: + The computation of a Hadamard transform :math:`Fx` benefits a fast + algorithm, as well as the computation of inverse Hadamard transforms. + .. note:: The matrix H has shape :math:`(M,N)` with :math:`N = h^2`. @@ -606,14 +899,25 @@ class HadamSplit(LinearSplit, DynamicHadamSplit): :math:`H = H_{+} - H_{-}` Example: - >>> Ord = np.random.random([32,32]) - >>> meas_op = HadamSplit(400, 32, Ord) + >>> h = 32 + >>> Ord = torch.randn(h, h) + >>> meas_op = HadamSplit(400, h, Ord) + >>> print(meas_op) + HadamSplit( + (Image pixels): 1024 + (H): torch.Size([400, 1024]) + (P): torch.Size([800, 1024]) + (Perm): torch.Size([1024, 1024]) + (H_pinv): torch.Size([1024, 400]) + ) """ def __init__(self, M: int, h: int, Ord: np.ndarray): - # initialize from DynamicHadamSplit __init__ - super(LinearSplit, self).__init__(M, h, Ord) - + print("initializing HadamSplit") + # initialize from DynamicHadamSplit (the MRO is not trivial here) + super(Linear, self).__init__(M, h, Ord) + self.set_H_pinv(pinv = 1 / self.N * self.get_H_T()) + def inverse(self, x: torch.tensor) -> torch.tensor: r"""Inverse transform of Hadamard-domain images :math:`x = H_{had}^{-1}G y` is a Hadamard matrix. @@ -629,39 +933,20 @@ def inverse(self, x: torch.tensor) -> torch.tensor: Output: math:`(b*c, N)` Example: - - >>> y = torch.rand([85,32*32], dtype=torch.float) + >>> h = 32 + >>> Ord = torch.randn(h, h) + >>> meas_op = HadamSplit(400, h, Ord) + >>> y = torch.randn(10, h**2) >>> x = meas_op.inverse(y) - >>> print('Inverse:', x.shape) - Inverse: torch.Size([85, 1024]) + >>> print(x.shape) + torch.Size([10, 1024]) """ # permutations # todo: check walsh2_S_fold_torch to speed up b, N = x.shape - x = self.Perm(x) + x = x @ self.Perm.T x = x.view(b, 1, self.h, self.w) # inverse of full transform # todo: initialize with 1D transform to speed up x = 1 / self.N * walsh2_torch(x) - return x.view(b, N) - - def pinv(self, x: torch.tensor) -> torch.tensor: - r"""Pseudo inverse transform of incoming mesurement vectors :math:`x` - - Args: - :attr:`x`: batch of measurement vectors. - - Shape: - x: :math:`(*, M)` - - Output: :math:`(*, N)` - - Example: - >>> y = torch.rand([85,400], dtype=torch.float) - >>> x = meas_op.pinv(y) - >>> print(x.shape) - torch.Size([85, 1024]) - """ - # - return self.adjoint(x) / self.N - + return x.view(b, N) \ No newline at end of file From 37de02e77a664141dceb4a3a52bd1557b2d0615f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Feb 2024 16:19:44 +0000 Subject: [PATCH 08/11] [pre-commit.ci] Automatic python formatting --- spyrit/core/meas.py | 230 ++++++++++++++++++++++---------------------- 1 file changed, 115 insertions(+), 115 deletions(-) diff --git a/spyrit/core/meas.py b/spyrit/core/meas.py index d05c13e8..d1b31646 100644 --- a/spyrit/core/meas.py +++ b/spyrit/core/meas.py @@ -13,7 +13,7 @@ class DynamicLinear(nn.Module): # ========================================================================= r""" Simulates the measurement of a moving object using a measurement matrix. - + Computes linear measurements :math:`y` from incoming images: :math:`y = Hx`, where :math:`H` is a linear operator (matrix) and :math:`x` is a batch of vectorized images representing a motion picture. @@ -34,17 +34,17 @@ class DynamicLinear(nn.Module): Attributes: :attr:`H` (torch.nn.Parameter): The learnable measurement matrix of shape :math:`(M,N)` initialized as :math:`H`. - + :attr:`M` (int): Number of measurements performed by the linear operator. It is initialized as the first dimension of :math:`H`. - + :attr:`N` (int): Number of pixels in the image. It is initialized as the second dimension of :math:`H`. - - :attr:`h` (int): Image height :math:`h`. The image is assumed to be + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign :attr:`h` and :attr:`w` manually. - + :attr:`w` (int): Image width :math:`w`. The image is assumed to be square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign :attr:`h` and :attr:`w` manually. @@ -58,7 +58,7 @@ class DynamicLinear(nn.Module): (H): torch.Size([400, 1600]) ) """ - + def __init__(self, H: torch.tensor): super().__init__() @@ -75,10 +75,10 @@ def __init__(self, H: torch.tensor): warnings.warn( f"N ({H.shape[1]}) is not a square. Please assign self.h and self.w manually." ) - + def get_H(self) -> torch.tensor: r"""Returns the attribute measurement matrix :math:`H`. - + Shape: Output: :math:`(M, N)` @@ -103,7 +103,7 @@ def forward(self, x: torch.tensor) -> torch.tensor: There must be **exactly** as many images as there are measurements in the linear operator used to initialize the class, i.e. `H.shape[-2] == x.shape[-2]` - + Args: :math:`x`: Batch of vectorized (flattened) images. @@ -112,7 +112,7 @@ def forward(self, x: torch.tensor) -> torch.tensor: :math:`(M, N)` is the shape of the measurement matrix :math:`H`. :math:`M` is the number of measurements (and frames) and :math:`N` the number of pixels in the image. - + :math:`output`: :math:`(*, M)` Example: @@ -123,8 +123,8 @@ def forward(self, x: torch.tensor) -> torch.tensor: >>> print(y.shape) torch.Size([10, 400]) """ - try : - return torch.einsum('ij,...ij->...i', self.get_H(), x) + try: + return torch.einsum("ij,...ij->...i", self.get_H(), x) except RuntimeError as e: if "which does not broadcast with previously seen size" in str(e): raise ValueError( @@ -141,7 +141,7 @@ def __str__(self): return s_begin + s_fill + s_end def __attributeslist__(self): - return [('Image pixels', self.N), ('H', self.H.shape)] + return [("Image pixels", self.N), ("H", self.H.shape)] # ============================================================================= @@ -150,7 +150,7 @@ class DynamicLinearSplit(DynamicLinear): r""" Simulates the measurement of a moving object using the positive and negative components of the measurement matrix. - + Computes linear measurements :math:`y` from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) and :math:`x` is a batch of vectorized images representing a motion picture. @@ -172,7 +172,7 @@ class DynamicLinearSplit(DynamicLinear): Attributes: :attr:`H` (torch.nn.Parameter): The learnable measurement matrix of shape :math:`(M,N)`. - + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of shape :math:`(2M, N)` initialized as :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` @@ -180,14 +180,14 @@ class DynamicLinearSplit(DynamicLinear): :attr:`M` (int): Number of measurements performed by the linear operator. It is initialized as the first dimension of :math:`H`. - + :attr:`N` (int): Number of pixels in the image. It is initialized as the second dimension of :math:`H`. - - :attr:`h` (int): Image height :math:`h`. The image is assumed to be + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign :attr:`h` and :attr:`w` manually. - + :attr:`w` (int): Image width :math:`w`. The image is assumed to be square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign :attr:`h` and :attr:`w` manually. @@ -206,11 +206,11 @@ class DynamicLinearSplit(DynamicLinear): (P): torch.Size([800, 1600]) ) """ - + def __init__(self, H: np.ndarray): # initialize self.H and self.H_pinv super().__init__(H) - + # initialize self.P = [ H^+ ] # [ H^- ] zero = torch.zeros(1) @@ -218,16 +218,16 @@ def __init__(self, H: np.ndarray): H_neg = torch.maximum(zero, -H) # concatenate side by side, then reshape vertically P = torch.cat([H_pos, H_neg], 1).view(2 * self.M, self.N) - P = P.type(torch.FloatTensor) # cast to float 32 + P = P.type(torch.FloatTensor) # cast to float 32 self.P = nn.Parameter(P, requires_grad=False) def get_P(self) -> torch.tensor: r"""Returns the attribute measurement matrix :math:`P`. - + Shape: Output: :math:`(2M, N)`, where :math:`(M, N)` is the shape of the measurement matrix :math:`H` given at initialization. - + Example: >>> H = np.random.random([400, 1600]) >>> meas_op = LinearDynamicSplit(H) @@ -249,28 +249,28 @@ def forward(self, x: torch.tensor) -> torch.tensor: splitting a given measurement matrix :math:`H` such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. - + The matrix :math:`H` can contain positive and negative values and is given by the user at initialization. - + .. warning:: There must be **exactly** as many images as there are measurements in the linear operator used to initialize the class, i.e. `P.shape[-2] == x.shape[-2]` - + Args: - :math:`x`: Batch of vectorized (flattened) images of shape + :math:`x`: Batch of vectorized (flattened) images of shape :math:`(*, 2M, N)` where * denotes the batch size, :math:`2M` the number of measurements in the measurement matrix :math:`P` and :math:`N` the number of pixels in the image. - + Shape: :math:`x`: :math:`(*, 2M, N)` - + :math:`P` has a shape of :math:`(2M, N)` where :math:`M` is the number of measurements as defined by the first dimension of :math:`H` - and :math:`N` is the number of pixels in the image. - + and :math:`N` is the number of pixels in the image. + :math:`output`: :math:`(*, 2M)` Example: @@ -281,8 +281,8 @@ def forward(self, x: torch.tensor) -> torch.tensor: >>> print(y.shape) torch.Size([10, 800]) """ - try : - return torch.einsum('ij,...ij->...i', self.get_P(), x) + try: + return torch.einsum("ij,...ij->...i", self.get_P(), x) except RuntimeError as e: if "which does not broadcast with previously seen size" in str(e): raise ValueError( @@ -291,7 +291,7 @@ def forward(self, x: torch.tensor) -> torch.tensor: ) else: raise e - + def forward_H(self, x: torch.tensor) -> torch.tensor: r""" Simulates the measurement of a motion picture using :math:`H`. @@ -300,27 +300,27 @@ def forward_H(self, x: torch.tensor) -> torch.tensor: the measurement matrix and :math:`x` is a batch of vectorized (flattened) images. The positive and negative components of the measurement matrix are **not** used in this method. - + The matrix :math:`H` can contain positive and negative values and is given by the user at initialization. - + .. warning:: There must be **exactly** as many images as there are measurements in the linear operator used to initialize the class, i.e. `H.shape[-2:] == x.shape[-2:]` - + Args: :math:`x`: Batch of vectorized (flatten) images of shape :math:`(*, M, N)` where * denotes the batch size, and :math:`(M, N)` - is the shape of the measurement matrix :math:`H`. - + is the shape of the measurement matrix :math:`H`. + Shape: :math:`x`: :math:`(*, M, N)` - + :math:`H` has a shape of :math:`(M, N)` where :math:`M` is the number of measurements and :math:`N` is the number of pixels in the - image. - + image. + :math:`output`: :math:`(*, M)` Example: @@ -332,9 +332,9 @@ def forward_H(self, x: torch.tensor) -> torch.tensor: torch.Size([10, 400]) """ return super().forward(x) - + def __attributeslist__(self): - return super().__attributeslist__() + [('P', self.P.shape)] + return super().__attributeslist__() + [("P", self.P.shape)] # ============================================================================= @@ -354,20 +354,20 @@ class DynamicHadamSplit(DynamicLinearSplit): selecting a re-ordered subsample of :math:`M` rows of a "full" Hadamard matrix :math:`F` with shape :math:`(N^2, N^2)`. :math:`N` must be a power of 2. - + The matrix :math:`P` is then obtained by splitting the matrix :math:`H` such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. Args: :attr:`M` (int): Number of measurements - + :attr:`h` (int): Image height :math:`h`, must be a power of 2. The image is assumed to be square, so the number of pixels in the image is :math:`N = h^2`. - + :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h, h)` used to - select the rows of the full Hadamard matrix :math:`F` + select the rows of the full Hadamard matrix :math:`F` compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` (see the :mod:`~spyrit.misc.sampling` submodule) @@ -375,28 +375,28 @@ class DynamicHadamSplit(DynamicLinearSplit): :attr:`H` (torch.nn.Parameter): The measurement matrix of shape :math:`(M, h^2)`. It is initialized as a re-ordered subsample of the rows of the "full" Hadamard matrix :math:`F` with shape :math:`(N^2, N^2)`. - + :attr:`H_pinv` (torch.nn.Parameter): The pseudo inverse of the measurement - matrix of shape :math:`(h^2, M)`. It is initialized as + matrix of shape :math:`(h^2, M)`. It is initialized as :math:`H^\dagger = \frac{1}{N}H^{T}` where :math:`N = h^2`. - + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of shape :math:`(2M, h^2)` initialized as :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` - where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. - + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + :attr:`Perm` (torch.nn.Parameter): The permutation matrix :math:`G^{T}` that is used to re-order the subsample of rows of the "full" Hadamard matrix :math:`F` according to descreasing value of the order matrix :math:`Ord`. It has shape :math:`(N, N)` where :math:`N = h^2`. - + :attr:`M` (int): Number of measurements performed by the linear operator. - - :attr:`N` (int): Number of pixels in the image. It is initialized as + + :attr:`N` (int): Number of pixels in the image. It is initialized as :math:`h^2`. - + :attr:`h` (int): Image height :math:`h`. - + :attr:`w` (int): Image width :math:`w`. The image is assumed to be square, i.e. :math:`w = h`. @@ -407,7 +407,7 @@ class DynamicHadamSplit(DynamicLinearSplit): .. note:: The computation of a Hadamard transform :math:`Fx` benefits a fast algorithm, as well as the computation of inverse Hadamard transforms. - + .. note:: The matrix :math:`H` has shape :math:`(M, N)` with :math:`N = h^2`. @@ -425,6 +425,7 @@ class DynamicHadamSplit(DynamicLinearSplit): (Perm): torch.Size([1024, 1024]) ) """ + # ========================================================================= change this ????? ^ def __init__(self, M: int, h: int, Ord: np.ndarray): F = walsh2_matrix(h) # full matrix @@ -436,7 +437,7 @@ def __init__(self, M: int, h: int, Ord: np.ndarray): super().__init__(torch.from_numpy(H)) print("h before", self.h) - Perm = torch.from_numpy(Perm).float() # float32 + Perm = torch.from_numpy(Perm).float() # float32 self.Perm = nn.Parameter(Perm, requires_grad=False) # overwrite self.h and self.w self.h = h @@ -444,7 +445,7 @@ def __init__(self, M: int, h: int, Ord: np.ndarray): print("h after", self.h) def __attributeslist__(self): - return super().__attributeslist__() + [('Perm', self.Perm.shape)] + return super().__attributeslist__() + [("Perm", self.Perm.shape)] # ============================================================================= @@ -452,7 +453,7 @@ class Linear(DynamicLinear): # ========================================================================= r""" Simulates the measurement of an still image using a measurement matrix. - + Computes linear measurements from incoming images: :math:`y = Hx`, where :math:`H` is a given linear operator (matrix) and :math:`x` is a vectorized image or batch of images. @@ -483,14 +484,14 @@ class Linear(DynamicLinear): :attr:`M` (int): Number of measurements performed by the linear operator. It is initialized as the first dimension of :math:`H`. - + :attr:`N` (int): Number of pixels in the image. It is initialized as the second dimension of :math:`H`. - - :attr:`h` (int): Image height :math:`h`. The image is assumed to be + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign :attr:`h` and :attr:`w` manually. - + :attr:`w` (int): Image width :math:`w`. The image is assumed to be square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign :attr:`h` and :attr:`w` manually. @@ -521,19 +522,19 @@ class Linear(DynamicLinear): ) """ - def __init__(self, H: np.ndarray, pinv=None, reg: float=1e-15): + def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): super().__init__(H) if pinv is not None: self.set_H_pinv(reg=reg) - + def get_H_T(self) -> torch.tensor: r""" Returns the transpose of the measurement matrix :math:`H`. - + Shape: Output: :math:`(N, M)`, where :math:`N` is the number of pixels in the image and :math:`M` the number of measurements. - + Example: >>> H1 = np.random.random([400, 1600]) >>> meas_op = Linear(H1) @@ -542,13 +543,13 @@ def get_H_T(self) -> torch.tensor: torch.Size([400, 1600]) """ return self.H.T - + def get_H_pinv(self) -> torch.tensor: r"""Returns the pseudo inverse of the measurement matrix :math:`H`. - + Shape: Output: :math:`(N, M)` - + Example: >>> H1 = np.random.random([400, 1600]) >>> meas_op = Linear(H1, True) @@ -565,32 +566,32 @@ def get_H_pinv(self) -> torch.tensor: ) else: raise e - - def set_H_pinv(self, reg: float=1e-15, pinv: torch.tensor=None) -> None: + + def set_H_pinv(self, reg: float = 1e-15, pinv: torch.tensor = None) -> None: r""" Stores in self.H_pinv the pseudo inverse of the measurement matrix :math:`H`. - - If :attr:`pinv` is given, it is directly stored as the pseudo inverse. + + If :attr:`pinv` is given, it is directly stored as the pseudo inverse. The validity of the pseudo inverse is not checked. If :attr:`pinv` is :obj:`None`, the pseudo inverse is computed from the existing measurement matrix :math:`H` with regularization parameter :attr:`reg`. - - Args: + + Args: :attr:`reg` (float, optional): Cutoff for small singular values. :attr:`H_pinv` (torch.tensor, optional): If given, the tensor is directly stored as the pseudo inverse. No checks are performed. Otherwise, the pseudo inverse is computed from the existing measurement matrix :math:`H`. - + .. note: Only one of :math:`H_pinv` and :math:`reg` should be given. If both are given, :math:`H_pinv` is used and :math:`reg` is ignored. - + Shape: :attr:`H_pinv`: :math:`(N, M)`, where :math:`N` is the number of pixels in the image and :math:`M` the number of measurements. - + Example: >>> H1 = torch.rand([400, 1600]) >>> H2 = torch.linalg.pinv(H1) @@ -598,12 +599,11 @@ def set_H_pinv(self, reg: float=1e-15, pinv: torch.tensor=None) -> None: >>> meas_op.set_H_pinv(H2) """ if pinv is not None: - H_pinv = pinv.type(torch.FloatTensor) # to float32 + H_pinv = pinv.type(torch.FloatTensor) # to float32 else: H_pinv = torch.linalg.pinv(self.get_H(), rcond=reg) self.H_pinv = nn.Parameter(H_pinv, requires_grad=False) - def forward(self, x: torch.tensor) -> torch.tensor: r"""Applies linear transform to incoming images: :math:`y = Hx`. @@ -679,10 +679,10 @@ def pinv(self, x: torch.tensor) -> torch.tensor: return x @ self.get_H_pinv().T def __attributeslist__(self): - return super().__attributeslist__() \ - + [('H_pinv', self.H_pinv.shape if hasattr(self, 'H_pinv') - else None)] - + return super().__attributeslist__() + [ + ("H_pinv", self.H_pinv.shape if hasattr(self, "H_pinv") else None) + ] + # ============================================================================= class LinearSplit(Linear, DynamicLinearSplit): @@ -708,12 +708,12 @@ class LinearSplit(Linear, DynamicLinearSplit): :attr:`H` (torch.tensor): measurement matrix (linear operator) with shape :math:`(M, N)`, where :math:`M` is the number of measurements and :math:`N` the number of pixels in the image. - + :attr:`pinv` (Any): Option to have access to pseudo inverse solutions. If not `None`, the pseudo inverse is initialized as :math:`H^\dagger` and stored in the attribute :attr:`H_pinv`. Defaults to `None` (the pseudo inverse is not initiliazed). - + :attr:`reg` (float, optional): Regularization parameter (cutoff for small singular values, see :mod:`torch.linalg.pinv`). Only relevant when :attr:`pinv` is not `None`. @@ -721,7 +721,7 @@ class LinearSplit(Linear, DynamicLinearSplit): Attributes: :attr:`H` (torch.nn.Parameter): The learnable measurement matrix of shape :math:`(M,N)`. - + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of shape :math:`(2M, N)` initialized as :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` @@ -729,18 +729,18 @@ class LinearSplit(Linear, DynamicLinearSplit): :attr:`M` (int): Number of measurements performed by the linear operator. It is initialized as the first dimension of :math:`H`. - + :attr:`N` (int): Number of pixels in the image. It is initialized as the second dimension of :math:`H`. - - :attr:`h` (int): Image height :math:`h`. The image is assumed to be + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign :attr:`h` and :attr:`w` manually. - + :attr:`w` (int): Image width :math:`w`. The image is assumed to be square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign :attr:`h` and :attr:`w` manually. - + .. note:: If you know the pseudo inverse of :math:`H` and want to store it, it is best to initialize the class with :attr:`pinv` set to `None` and then @@ -758,7 +758,7 @@ class LinearSplit(Linear, DynamicLinearSplit): ) """ - def __init__(self, H: np.ndarray, pinv=None, reg: float=1e-15): + def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): print("initializing LinearSplit") # initialize from DynamicLinearSplit __init__ super(Linear, self).__init__(H) @@ -843,18 +843,18 @@ class HadamSplit(LinearSplit, DynamicHadamSplit): selecting a re-ordered subsample of :math:`M` rows of a "full" Hadamard matrix :math:`F` with shape :math:`(N^2, N^2)`. :math:`N` must be a power of 2. - + The matrix :math:`P` is then obtained by splitting the matrix :math:`H` such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. Args: :attr:`M` (int): Number of measurements - + :attr:`h` (int): Image height :math:`h`, must be a power of 2. The image is assumed to be square, so the number of pixels in the image is :math:`N = h^2`. - + :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h, h)` used to compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` (see the :mod:`~spyrit.misc.sampling` submodule) @@ -863,35 +863,35 @@ class HadamSplit(LinearSplit, DynamicHadamSplit): :attr:`H` (torch.nn.Parameter): The measurement matrix of shape :math:`(M, h^2)`. It is initialized as a re-ordered subsample of the rows of the "full" Hadamard matrix :math:`F` with shape :math:`(N^2, N^2)`. - + :attr:`H_pinv` (torch.nn.Parameter): The pseudo inverse of the measurement - matrix of shape :math:`(h^2, M)`. It is initialized as + matrix of shape :math:`(h^2, M)`. It is initialized as :math:`H^\dagger = \frac{1}{N}H^{T}` where :math:`N = h^2`. - + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of shape :math:`(2M, h^2)` initialized as :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` - where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. - + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + :attr:`Perm` (torch.nn.Parameter): The permutation matrix :math:`G^{T}` that is used to re-order the subsample of rows of the "full" Hadamard matrix :math:`F` according to descreasing value of the order matrix :math:`Ord`. It has shape :math:`(N, N)` where :math:`N = h^2`. - + :attr:`M` (int): Number of measurements performed by the linear operator. - - :attr:`N` (int): Number of pixels in the image. It is initialized as + + :attr:`N` (int): Number of pixels in the image. It is initialized as :math:`h^2`. - + :attr:`h` (int): Image height :math:`h`. - + :attr:`w` (int): Image width :math:`w`. The image is assumed to be square, i.e. :math:`w = h`. .. note:: The computation of a Hadamard transform :math:`Fx` benefits a fast algorithm, as well as the computation of inverse Hadamard transforms. - + .. note:: The matrix H has shape :math:`(M,N)` with :math:`N = h^2`. @@ -916,7 +916,7 @@ def __init__(self, M: int, h: int, Ord: np.ndarray): print("initializing HadamSplit") # initialize from DynamicHadamSplit (the MRO is not trivial here) super(Linear, self).__init__(M, h, Ord) - self.set_H_pinv(pinv = 1 / self.N * self.get_H_T()) + self.set_H_pinv(pinv=1 / self.N * self.get_H_T()) def inverse(self, x: torch.tensor) -> torch.tensor: r"""Inverse transform of Hadamard-domain images @@ -949,4 +949,4 @@ def inverse(self, x: torch.tensor) -> torch.tensor: # inverse of full transform # todo: initialize with 1D transform to speed up x = 1 / self.N * walsh2_torch(x) - return x.view(b, N) \ No newline at end of file + return x.view(b, N) From 415f24aae0e7e42cc4143e6099c7969b5e3d8306 Mon Sep 17 00:00:00 2001 From: romainphan Date: Fri, 23 Feb 2024 17:22:07 +0100 Subject: [PATCH 09/11] recon.py > reconstruct clarified documentation and called reconstruct_pinv --- spyrit/core/recon.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/spyrit/core/recon.py b/spyrit/core/recon.py index 9aa86da3..dd680601 100644 --- a/spyrit/core/recon.py +++ b/spyrit/core/recon.py @@ -443,7 +443,7 @@ def meas2img(self, y): return z.view(-1, 1, self.acqu.meas_op.h, self.acqu.meas_op.w) def reconstruct(self, x): - r"""Reconstruction step of a reconstruction network + r"""Preprocesses, reconstructs, and denoises raw measurement vectors. Args: :attr:`x`: raw measurement vectors @@ -465,23 +465,11 @@ def reconstruct(self, x): >>> print(z.shape) torch.Size([10, 1, 64, 64]) """ - # Measurement to image domain mapping - bc, _ = x.shape - - # Preprocessing in the measurement domain - x = self.prep(x) # shape x = [b*c, M] - - # measurements to image-domain processing - x = self.pinv(x, self.acqu.meas_op) # shape x = [b*c,N] - - # Image-domain denoising - x = x.view( - bc, 1, self.acqu.meas_op.h, self.acqu.meas_op.w - ) # shape x = [b*c,1,h,w] - return self.denoi(x) + # Denoise image-domain + return self.denoi(self.reconstruct_pinv(x)) def reconstruct_pinv(self, x): - r"""Reconstruction step of a reconstruction network + r"""Preprocesses and reconstructs raw measurement vectors. Args: :attr:`x`: raw measurement vectors From 422f3003aebe47947f59e604cc3722f9a9506bd8 Mon Sep 17 00:00:00 2001 From: romainphan Date: Fri, 23 Feb 2024 17:22:39 +0100 Subject: [PATCH 10/11] tiny readability improvement --- spyrit/misc/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spyrit/misc/sampling.py b/spyrit/misc/sampling.py index 54d919d7..78adf66a 100644 --- a/spyrit/misc/sampling.py +++ b/spyrit/misc/sampling.py @@ -96,7 +96,7 @@ def Permutation_Matrix(Mat: np.ndarray) -> np.ndarray: N-by-N sampling matrix, where high values indicate high significance. Returns: - P (np.ndarray): N*N-by-N*N permutation matrix (boolean) + P (np.ndarray): N^2-by-N^2 permutation matrix (boolean) """ (nx, ny) = Mat.shape Reorder = rankdata(-Mat, method="ordinal") From b0bd0ba965924a1b056338ffcddc59cdd8e3bd25 Mon Sep 17 00:00:00 2001 From: romainphan Date: Fri, 23 Feb 2024 17:23:36 +0100 Subject: [PATCH 11/11] minor changes to walsh2_matrix --- spyrit/misc/walsh_hadamard.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/spyrit/misc/walsh_hadamard.py b/spyrit/misc/walsh_hadamard.py index 84ab5662..5da3a587 100644 --- a/spyrit/misc/walsh_hadamard.py +++ b/spyrit/misc/walsh_hadamard.py @@ -646,10 +646,8 @@ def walsh2_matrix(n): Returns: H (np.ndarray): A n*n-by-n*n matrix """ - H = np.zeros((n**2, n**2)) H1d = walsh_matrix(n) - H = np.kron(H1d, H1d) - return H + return np.kron(H1d, H1d) def walsh2(X, H=None): @@ -1278,11 +1276,12 @@ def walsh2_torch(im, H=None): """Return 2D Walsh-ordered Hadamard transform of an image Args: - im (torch.Tensor): Image, typically a B-by-C-by-W-by-H Tensor - H (torch.Tensor, optional): 1D Walsh-ordered Hadamard transformation matrix. A 2-D tensor of size W-by-H. + im (torch.tensor): Image, typically a B-by-C-by-W-by-H Tensor + H (torch.tensor, optional): 1D Walsh-ordered Hadamard transformation + matrix. A 2-D tensor of size W-by-H. Returns: - torch.Tensor: Hadamard transformed image. Same size as im + torch.tensor: Hadamard transformed image. Same size as im Examples: >>> im = torch.randn(256, 1, 64, 64)