diff --git a/docs/source/conf.py b/docs/source/conf.py index 7534d29b..86c99181 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,10 +19,8 @@ # -- Project information ----------------------------------------------------- project = "spyrit" -copyright = "2021, Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier" -author = ( - "Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier" -) +copyright = "2021, Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan" +author = "Antonio Tomas Lorente Mur - Nicolas Ducros - Sebastien Crombez - Thomas Baudier - Romain Phan" # The full version, including alpha/beta/rc tags release = "2.1.0" @@ -116,16 +114,16 @@ # autodoc_mock_imports = "numpy matplotlib mpl_toolkits scipy torch torchvision Pillow opencv-python imutils PyWavelets pywt wget imageio".split() -# exclude all torch.nn.Module members from the documentation -# except forward and __init__ methods +# exclude all torch.nn.Module members (except forward method) from the docs: import torch def skip_member_handler(app, what, name, obj, skip, options): - if name in [ - "forward", - ]: - return False + always_document = [ # complete this list if needed by adding methods + "forward", # you *always* want to see documented + ] + if name in always_document: + return None if name in dir(torch.nn.Module): return True return None diff --git a/docs/source/index.rst b/docs/source/index.rst index e771d8c2..0949b529 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -135,6 +135,6 @@ When using SPyRiT specifically for the denoised completion network, please cite Join the project ================================== -Feel free to contact us by `e-mail ` for any question. Active developers are currently `Nicolas Ducros `_, Thomas Baudier and `Juan Abascal `_. Direct contributions via pull requests (PRs) are welcome. +Feel free to contact us by `e-mail ` for any question. Active developers are currently `Nicolas Ducros `_, Thomas Baudier, `Juan Abascal `_ and Romain Phan. Direct contributions via pull requests (PRs) are welcome. The full list of contributors can be found `here `_. diff --git a/spyrit/core/meas.py b/spyrit/core/meas.py index 54cb49c6..ec87cb18 100644 --- a/spyrit/core/meas.py +++ b/spyrit/core/meas.py @@ -1,190 +1,169 @@ import warnings + import torch import torch.nn as nn import numpy as np -from typing import Union + from spyrit.misc.walsh_hadamard import walsh2_torch, walsh2_matrix from spyrit.misc.sampling import Permutation_Matrix -# ================================================================================== -class Linear(nn.Module): - # ================================================================================== +# ============================================================================= +class DynamicLinear(nn.Module): + # ========================================================================= r""" - Computes linear measurements from incoming images: :math:`y = Hx`, + Simulates the measurement of a moving object using a measurement matrix. + + Computes linear measurements :math:`y` from incoming images: :math:`y = Hx`, where :math:`H` is a linear operator (matrix) and :math:`x` is a - vectorized image. + batch of vectorized images representing a motion picture. - The class is constructed from a :math:`M` by :math:`N` matrix :math:`H`, + The class is constructed from a matrix :math:`H` of shape :math:`(M, N)`, where :math:`N` represents the number of pixels in the image and - :math:`M` the number of measurements. + :math:`M` the number of measurements and the number of frames in the + animated object. - Args: - :attr:`H`: measurement matrix (linear operator) with shape :math:`(M, N)`. + .. warning:: + For each call, there must be **exactly** as many images in :math:`x` as + there are measurements in the linear operator used to initialize the class. - :attr:`pinv`: Option to have access to pseudo inverse solutions. - Defaults to `None` (the pseudo inverse is not initiliazed). + Args: + :attr:`H` (torch.tensor): measurement matrix (linear operator) with + shape :math:`(M, N)`. - :attr:`reg` (optional): Regularization parameter (cutoff for small - singular values, see :mod:`numpy.linal.pinv`). Only relevant when - :attr:`pinv` is not `None`. + Attributes: + :attr:`H` (torch.nn.Parameter): The learnable measurement matrix of + shape :math:`(M,N)` initialized as :math:`H`. + :attr:`M` (int): Number of measurements performed by the linear operator. + It is initialized as the first dimension of :math:`H`. - Attributes: - :attr:`H`: The learnable measurement matrix of shape - :math:`(M,N)` initialized as :math:`H` + :attr:`N` (int): Number of pixels in the image. It is initialized as the + second dimension of :math:`H`. - :attr:`H_adjoint`: The learnable adjoint measurement matrix - of shape :math:`(N,M)` initialized as :math:`H^\top` + :attr:`h` (int): Image height :math:`h`. The image is assumed to be + square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. - :attr:`H_pinv` (optional): The learnable adjoint measurement - matrix of shape :math:`(N,M)` initialized as :math:`H^\dagger`. - Only relevant when :attr:`pinv` is not `None`. + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. Example: - >>> H = np.random.random([400, 1000]) - >>> meas_op = Linear(H) + >>> H = np.random.random([400, 1600]) + >>> meas_op = DynamicLinear(H) >>> print(meas_op) - Linear( - (H): Linear(in_features=1000, out_features=400, bias=False) - (H_adjoint): Linear(in_features=400, out_features=1000, bias=False) - ) - - Example 2: - >>> H = np.random.random([400, 1000]) - >>> meas_op = Linear(H, True) - >>> print(meas_op) - Linear( - (H): Linear(in_features=1000, out_features=400, bias=False) - (H_adjoint): Linear(in_features=400, out_features=1000, bias=False) - (H_pinv): Linear(in_features=400, out_features=1000, bias=False) - ) + DynamicLinear( + (Image pixels): 1600 + (H): torch.Size([400, 1600]) + ) """ - def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): + def __init__(self, H: torch.tensor): super().__init__() - # instancier nn.linear + + # convert H from numpy to torch tensor if needed + # convert to float 32 for memory efficiency + if isinstance(H, np.ndarray): + H = torch.from_numpy(H) + warnings.warn( + "Using a numpy array is deprecated. Please use a torch tensor instead.", + DeprecationWarning + ) + H = H.type(torch.float32) + # nn.Parameter are sent to the device when using .to(device), + self.H = nn.Parameter(H, requires_grad=False) + self.M = H.shape[0] self.N = H.shape[1] - self.h = int(self.N**0.5) self.w = int(self.N**0.5) if self.h * self.w != self.N: warnings.warn( - "N is not a square. Please assign self.h and self.w manually." + f"N ({H.shape[1]}) is not a square. Please assign self.h and self.w manually." ) - self.H = nn.Linear(self.N, self.M, False) - self.H.weight.data = torch.from_numpy(H).float() - # Data must be of type float (or double) rather than the default float64 when creating torch tensor - self.H.weight.requires_grad = False - - # adjoint (Remove?) - self.H_adjoint = nn.Linear(self.M, self.N, False) - self.H_adjoint.weight.data = torch.from_numpy(H.transpose()).float() - self.H_adjoint.weight.requires_grad = False - - if pinv is None: - H_pinv = pinv - print("Pseudo inverse will not be instanciated") - - else: - H_pinv = np.linalg.pinv(H, rcond=reg) - self.H_pinv = nn.Linear(self.M, self.N, False) - self.H_pinv.weight.data = torch.from_numpy(H_pinv).float() - self.H_pinv.weight.requires_grad = False - - def forward(self, x: torch.tensor) -> torch.tensor: - r"""Applies linear transform to incoming images: :math:`y = Hx`. - - Args: - :math:`x`: Batch of vectorized (flatten) images. - - Shape: - :math:`x`: :math:`(*, N)` where * denotes the batch size and `N` - the total number of pixels in the image. - - Output: :math:`(*, M)` where * denotes the batch size and `M` - the number of measurements. - - Example: - >>> x = torch.rand([10,1000], dtype=torch.float) - >>> y = meas_op(x) - >>> print('forward:', y.shape) - forward: torch.Size([10, 400]) - - """ - # x.shape[b*c,N] - x = self.H(x) - return x - - def adjoint(self, x: torch.tensor) -> torch.tensor: - r"""Applies adjoint transform to incoming measurements :math:`y = H^{T}x` - - Args: - :math:`x`: batch of measurement vectors. - - Shape: - :math:`x`: :math:`(*, M)` - - Output: :math:`(*, N)` - - Example: - >>> x = torch.rand([10,400], dtype=torch.float) - >>> y = meas_op.adjoint(x) - >>> print('adjoint:', y.shape) - adjoint: torch.Size([10, 1000]) - """ - # Pmat.transpose()*f - x = self.H_adjoint(x) - return x - def get_H(self) -> torch.tensor: - r"""Returns the measurement matrix :math:`H`. + r"""Returns the attribute measurement matrix :math:`H`. Shape: Output: :math:`(M, N)` Example: - >>> H = meas_op.get_H() - >>> print('get_mat:', H.shape) - get_mat: torch.Size([400, 1000]) - + >>> H1 = np.random.random([400, 1600]) + >>> meas_op = Linear(H1) + >>> H2 = meas_op.get_H() + >>> print(H2.shape) + torch.Size([400, 1600]) """ - return self.H.weight.data + return self.H.data - def pinv(self, x: torch.tensor) -> torch.tensor: - r"""Computer pseudo inverse solution :math:`y = H^\dagger x` + def forward(self, x: torch.tensor) -> torch.tensor: + r""" + Simulates the measurement of a motion picture. + + The output :math:`y` is computed as :math:`y = Hx`, where :math:`H` is + the measurement matrix and :math:`x` is a batch of vectorized (flattened) + images. + + .. warning:: + There must be **exactly** as many images as there are measurements + in the linear operator used to initialize the class, i.e. + `H.shape[-2] == x.shape[-2]` Args: - :math:`x`: batch of measurement vectors. + :math:`x`: Batch of vectorized (flattened) images. Shape: - :math:`x`: :math:`(*, M)` + :math:`x`: :math:`(*, M, N)`, where * denotes the batch size and + :math:`(M, N)` is the shape of the measurement matrix :math:`H`. + :math:`M` is the number of measurements (and frames) and :math:`N` + the number of pixels in the image. - Output: :math:`(*, N)` + :math:`output`: :math:`(*, M)` Example: - >>> x = torch.rand([10,400], dtype=torch.float) - >>> y = meas_op.pinv(x) - >>> print('pinv:', y.shape) - adjoint: torch.Size([10, 1000]) + >>> x = torch.rand([10, 400, 1600]) + >>> H = np.random.random([400, 1600]) + >>> meas_op = DynamicLinear(H) + >>> y = meas_op(x) + >>> print(y.shape) + torch.Size([10, 400]) """ - # Pmat.transpose()*f - x = self.H_pinv(x) - return x - - -# ================================================================================== -class LinearSplit(Linear): - # ================================================================================== + try: + return torch.einsum("ij,...ij->...i", self.get_H(), x) + except RuntimeError as e: + if "which does not broadcast with previously seen size" in str(e): + raise ValueError( + f"The shape of the input x ({x.shape}) does not match the " + + f"shape of the measurement matrix H ({self.get_H().shape})." + ) + else: + raise e + + def __str__(self): + s_begin = f"{self.__class__.__name__}(\n " + s_fill = "\n ".join([f"({k}): {v}" for k, v in self.__attributeslist__()]) + s_end = "\n )" + return s_begin + s_fill + s_end + + def __attributeslist__(self): + return [("Image pixels", self.N), ("H", self.H.shape)] + + +# ============================================================================= +class DynamicLinearSplit(DynamicLinear): + # ========================================================================= r""" - Computes linear measurements from incoming images: :math:`y = Px`, - where :math:`P` is a linear operator (matrix) and :math:`x` is a - vectorized image. + Simulates the measurement of a moving object using the positive and + negative components of the measurement matrix. + + Computes linear measurements :math:`y` from incoming images: :math:`y = Px`, + where :math:`P` is a linear operator (matrix) and :math:`x` is a batch of + vectorized images representing a motion picture. The matrix :math:`P` contains only positive values and is obtained by - splitting a measurement matrix :math:`H` such that + splitting a given measurement matrix :math:`H` such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. @@ -194,325 +173,795 @@ class LinearSplit(Linear): Args: :math:`H` (np.ndarray): measurement matrix (linear operator) with - shape :math:`(M, N)`. + shape :math:`(M, N)` where :math:`M` is the number of measurements and + :math:`N` the number of pixels in the image. + + Attributes: + :attr:`H` (torch.nn.Parameter): The learnable measurement matrix of + shape :math:`(M,N)`. + + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of + shape :math:`(2M, N)` initialized as + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)` + + :attr:`M` (int): Number of measurements performed by the linear operator. + It is initialized as the first dimension of :math:`H`. + + :attr:`N` (int): Number of pixels in the image. It is initialized as the + second dimension of :math:`H`. + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be + square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. + + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. + + .. warning:: + For each call, there must be **exactly** as many images in :math:`x` as + there are measurements in the linear operator used to initialize the class. Example: - >>> H = np.array(np.random.random([400,1000])) - >>> meas_op = LinearSplit(H) + >>> H = np.array(np.random.random([400,1600])) + >>> meas_op = DynamicLinearSplit(H) + >>> print(meas_op) + DynamicLinearSplit( + (Image pixels): 1600 + (H): torch.Size([400, 1600]) + (P): torch.Size([800, 1600]) + ) """ - def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): - super().__init__(H, pinv, reg) + def __init__(self, H: torch.tensor): + # initialize self.H + if isinstance(H, np.ndarray): + H = torch.from_numpy(H) + warnings.warn( + "Using a numpy array is deprecated. Please use a torch tensor instead.", + DeprecationWarning + ) + H = H.type(torch.float32) - # [H^+, H^-] + super().__init__(H) - even_index = range(0, 2 * self.M, 2) - odd_index = range(1, 2 * self.M, 2) + # initialize self.P = [ H^+ ] + # [ H^- ] + zero = torch.zeros(1) + H_pos = torch.maximum(zero, H) + H_neg = torch.maximum(zero, -H) + # concatenate side by side, then reshape vertically + P = torch.cat([H_pos, H_neg], 1).view(2 * self.M, self.N) + P = P.type(torch.FloatTensor) # cast to float 32 + self.P = nn.Parameter(P, requires_grad=False) - H_pos = np.zeros(H.shape) - H_neg = np.zeros(H.shape) - H_pos[H > 0] = H[H > 0] - H_neg[H < 0] = -H[H < 0] + def get_P(self) -> torch.tensor: + r"""Returns the attribute measurement matrix :math:`P`. - # pourquoi 2 *M ? - P = np.zeros((2 * self.M, self.N)) - P[even_index, :] = H_pos - P[odd_index, :] = H_neg + Shape: + Output: :math:`(2M, N)`, where :math:`(M, N)` is the shape of the + measurement matrix :math:`H` given at initialization. - self.P = nn.Linear(self.N, 2 * self.M, False) - self.P.weight.data = torch.from_numpy(P) - self.P.weight.data = self.P.weight.data.float() - self.P.weight.requires_grad = False + Example: + >>> H = np.random.random([400, 1600]) + >>> meas_op = LinearDynamicSplit(H) + >>> P = meas_op.get_P() + >>> print(P.shape) + torch.Size([800, 1600]) + """ + return self.P.data def forward(self, x: torch.tensor) -> torch.tensor: - r"""Applies linear transform to incoming images: :math:`y = Px`. + r""" + Simulates the measurement of a motion picture using :math:`P`. + + The output :math:`y` is computed as :math:`y = Px`, where :math:`P` is + the measurement matrix and :math:`x` is a batch of vectorized (flattened) + images. + + :math:`P` contains only positive values and is obtained by + splitting a given measurement matrix :math:`H` such that + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where + :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + + The matrix :math:`H` can contain positive and negative values and is + given by the user at initialization. + + .. warning:: + There must be **exactly** as many images as there are measurements + in the linear operator used to initialize the class, i.e. + `P.shape[-2] == x.shape[-2]` Args: - :math:`x`: Batch of vectorized (flatten) images. + :math:`x`: Batch of vectorized (flattened) images of shape + :math:`(*, 2M, N)` where * denotes the batch size, :math:`2M` the + number of measurements in the measurement matrix :math:`P` and + :math:`N` the number of pixels in the image. Shape: - :math:`x`: :math:`(*, N)` where * denotes the batch size and `N` - the total number of pixels in the image. + :math:`x`: :math:`(*, 2M, N)` - Output: :math:`(*, 2M)` where * denotes the batch size and `M` - the number of measurements. + :math:`P` has a shape of :math:`(2M, N)` where :math:`M` is the + number of measurements as defined by the first dimension of :math:`H` + and :math:`N` is the number of pixels in the image. + + :math:`output`: :math:`(*, 2M)` Example: - >>> x = torch.rand([10,1000], dtype=torch.float) + >>> x = torch.rand([10, 800, 1600]) + >>> H = np.random.random([400, 1600]) + >>> meas_op = DynamicLinearSplit(H) >>> y = meas_op(x) - >>> print('Output:', y.shape) - Output: torch.Size([10, 800]) + >>> print(y.shape) + torch.Size([10, 800]) """ - # x.shape[b*c,N] - # output shape : [b*c, 2*M] - x = self.P(x) - return x + try: + return torch.einsum("ij,...ij->...i", self.get_P(), x) + except RuntimeError as e: + if "which does not broadcast with previously seen size" in str(e): + raise ValueError( + f"The shape of the input x ({x.shape}) does not match the " + + f"shape of the measurement matrix P ({self.get_P().shape})." + ) + else: + raise e def forward_H(self, x: torch.tensor) -> torch.tensor: - r"""Applies linear transform to incoming images: :math:`m = Hx`. + r""" + Simulates the measurement of a motion picture using :math:`H`. + + The output :math:`y` is computed as :math:`y = Hx`, where :math:`H` is + the measurement matrix and :math:`x` is a batch of vectorized (flattened) + images. The positive and negative components of the measurement matrix + are **not** used in this method. + + The matrix :math:`H` can contain positive and negative values and is + given by the user at initialization. + + .. warning:: + There must be **exactly** as many images as there are measurements + in the linear operator used to initialize the class, i.e. + `H.shape[-2:] == x.shape[-2:]` Args: - :math:`x`: Batch of vectorized (flatten) images. + :math:`x`: Batch of vectorized (flatten) images of shape + :math:`(*, M, N)` where * denotes the batch size, and :math:`(M, N)` + is the shape of the measurement matrix :math:`H`. Shape: - :math:`x`: :math:`(*, N)` where * denotes the batch size and `N` - the total number of pixels in the image. + :math:`x`: :math:`(*, M, N)` - Output: :math:`(*, M)` where * denotes the batch size and `M` - the number of measurements. + :math:`H` has a shape of :math:`(M, N)` where :math:`M` is the + number of measurements and :math:`N` is the number of pixels in the + image. + + :math:`output`: :math:`(*, M)` Example: - >>> x = torch.rand([10,1000], dtype=torch.float) + >>> x = torch.rand([10, 400, 1600]) + >>> H = np.random.random([400, 1600]) + >>> meas_op = LinearDynamicSplit(H) >>> y = meas_op.forward_H(x) - >>> print('Output:', y.shape) - output shape: torch.Size([10, 400]) - + >>> print(y.shape) + torch.Size([10, 400]) """ - x = self.H(x) - return x + return super().forward(x) + + def __attributeslist__(self): + return super().__attributeslist__() + [("P", self.P.shape)] -# ================================================================================== -class HadamSplit(LinearSplit): +# ============================================================================= +class DynamicHadamSplit(DynamicLinearSplit): + # ========================================================================= r""" + Simulates the measurement of a moving object using the positive and + negative components of a Hadamard matrix. + Computes linear measurements from incoming images: :math:`y = Px`, where :math:`P` is a linear operator (matrix) with positive entries and - :math:`x` is a vectorized image. + :math:`x` is a batch of vectorized images representing a motion picture. - The class is relies on a matrix :math:`H` with - shape :math:`(M,N)` where :math:`N` represents the number of pixels in the - image and :math:`M \le N` the number of measurements. The matrix :math:`P` - is obtained by splitting the matrix :math:`H` such that - :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where + The class relies on a Hadamard-based matrix :math:`H` with shape :math:`(M,N)` + where :math:`N` represents the number of pixels in the image and + :math:`M \le N` the number of measurements. :math:`H` is obtained by + selecting a re-ordered subsample of :math:`M` rows of a "full" Hadamard + matrix :math:`F` with shape :math:`(N^2, N^2)`. :math:`N` must be a power + of 2. + + The matrix :math:`P` is then obtained by splitting the matrix :math:`H` + such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. - The matrix :math:`H` is obtained by retaining the first :math:`M` rows of - a permuted Hadamard matrix :math:`GF`, where :math:`G` is a - permutation matrix with shape with shape :math:`(M,N)` and :math:`F` is a - "full" Hadamard matrix with shape :math:`(N,N)`. The computation of a - Hadamard transform :math:`Fx` benefits a fast algorithm, as well as the - computation of inverse Hadamard transforms. + Args: + :attr:`M` (int): Number of measurements + + :attr:`h` (int): Image height :math:`h`, must be a power of 2. The + image is assumed to be square, so the number of pixels in the image is + :math:`N = h^2`. + + :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h, h)` used to + select the rows of the full Hadamard matrix :math:`F` + compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` + (see the :mod:`~spyrit.misc.sampling` submodule) + + Attributes: + :attr:`H` (torch.nn.Parameter): The measurement matrix of shape + :math:`(M, h^2)`. It is initialized as a re-ordered subsample of the + rows of the "full" Hadamard matrix :math:`F` with shape :math:`(N^2, N^2)`. + + :attr:`H_pinv` (torch.nn.Parameter): The pseudo inverse of the measurement + matrix of shape :math:`(h^2, M)`. It is initialized as + :math:`H^\dagger = \frac{1}{N}H^{T}` where :math:`N = h^2`. + + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of + shape :math:`(2M, h^2)` initialized as + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + + :attr:`Perm` (torch.nn.Parameter): The permutation matrix :math:`G^{T}` + that is used to re-order the subsample of rows of the "full" Hadamard + matrix :math:`F` according to descreasing value of the order matrix + :math:`Ord`. It has shape :math:`(N, N)` where :math:`N = h^2`. + + :attr:`M` (int): Number of measurements performed by the linear operator. + + :attr:`N` (int): Number of pixels in the image. It is initialized as + :math:`h^2`. + + :attr:`h` (int): Image height :math:`h`. + + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = h`. + + .. warning:: + For each call, there must be **exactly** as many images in :math:`x` as + there are measurements in the linear operator used to initialize the class. .. note:: - :math:`H = H_{+} - H_{-}` + The computation of a Hadamard transform :math:`Fx` benefits a fast + algorithm, as well as the computation of inverse Hadamard transforms. - Args: - - :attr:`M`: Number of measurements - - :attr:`h`: Image height :math:`h`. The image is assumed to be square. - - :attr:`Ord`: Order matrix with shape :math:`(h,h)` used to compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` (see the :mod:`~spyrit.misc.sampling` submodule) + .. note:: + The matrix :math:`H` has shape :math:`(M, N)` with :math:`N = h^2`. .. note:: - The matrix H has shape :math:`(M,N)` with :math:`N = h^2`. + :math:`H = H_{+} - H_{-}` Example: >>> Ord = np.random.random([32,32]) - >>> meas_op = HadamSplit(400, 32, Ord) + >>> meas_op = HadamSplitDynamic(400, 32, Ord) + >>> print(meas_op) + HadamSplitDynamic( + (Image pixels): 1024 + (H): torch.Size([400, 1024]) + (P): torch.Size([800, 1024]) + (Perm): torch.Size([1024, 1024]) + ) """ + # ========================================================================= change this ????? ^ def __init__(self, M: int, h: int, Ord: np.ndarray): F = walsh2_matrix(h) # full matrix Perm = Permutation_Matrix(Ord) - F = Perm @ F # If Perm is not learnt, could be computed mush faster + F = Perm @ F # If Perm is not learnt, could be computed much faster H = F[:M, :] w = h # we assume a square image - super().__init__(H) + super().__init__(torch.from_numpy(H)) + print("h before", self.h) - self.Perm = nn.Linear(self.N, self.N, False) - self.Perm.weight.data = torch.from_numpy(Perm.T) - self.Perm.weight.data = self.Perm.weight.data.float() - self.Perm.weight.requires_grad = False + Perm = torch.from_numpy(Perm).float() # float32 + self.Perm = nn.Parameter(Perm, requires_grad=False) + # overwrite self.h and self.w self.h = h self.w = w + print("h after", self.h) - def inverse(self, x: torch.tensor) -> torch.tensor: - r"""Inverse transform of Hadamard-domain images - :math:`x = H_{had}^{-1}G y` is a Hadamard matrix. + def __attributeslist__(self): + return super().__attributeslist__() + [("Perm", self.Perm.shape)] + + +# ============================================================================= +class Linear(DynamicLinear): + # ========================================================================= + r""" + Simulates the measurement of an still image using a measurement matrix. + + Computes linear measurements from incoming images: :math:`y = Hx`, + where :math:`H` is a given linear operator (matrix) and :math:`x` is a + vectorized image or batch of images. + + The class is constructed from a :math:`M` by :math:`N` matrix :math:`H`, + where :math:`N` represents the number of pixels in the image and + :math:`M` the number of measurements. + + Args: + :attr:`H` (:type:`torch.tensor`): measurement matrix (linear operator) with shape :math:`(M, N)`. + + :attr:`pinv` (Any): Option to have access to pseudo inverse solutions. If not + `None`, the pseudo inverse is initialized as :math:`H^\dagger` and + stored in the attribute :attr:`H_pinv`. Defaults to `None` (the pseudo + inverse is not initiliazed). + + :attr:`reg` (float, optional): Regularization parameter (cutoff for small + singular values, see :mod:`numpy.linal.pinv`). Only relevant when + :attr:`pinv` is not `None`. + + Attributes: + :attr:`H` (torch.tensor): The learnable measurement matrix of shape + :math:`(M, N)` initialized as :math:`H` + + :attr:`H_pinv` (torch.tensor, optional): The learnable adjoint measurement + matrix of shape :math:`(N, M)` initialized as :math:`H^\dagger`. + Only relevant when :attr:`pinv` is not `None`. + + :attr:`M` (int): Number of measurements performed by the linear operator. + It is initialized as the first dimension of :math:`H`. + + :attr:`N` (int): Number of pixels in the image. It is initialized as the + second dimension of :math:`H`. + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be + square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. + + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. + + .. note:: + If you know the pseudo inverse of :math:`H` and want to store it, it is + best to initialize the class with :attr:`pinv` set to `None` and then + call :meth:`set_H_pinv` to store the pseudo inverse. + + Example 1: + >>> H = np.random.random([400, 1600]) + >>> meas_op = Linear(H, pinv=None) + >>> print(meas_op) + Linear( + (Image pixels): 1600 + (H): torch.Size([400, 1600]) + (H_pinv): None + ) + + Example 2: + >>> H = np.random.random([400, 1600]) + >>> meas_op = Linear(H, True) + >>> print(meas_op) + Linear( + (Image pixels): 1600 + (H): torch.Size([400, 1600]) + (H_pinv): torch.Size([1600, 400]) + ) + """ + + def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): + super().__init__(H) + if pinv is not None: + self.set_H_pinv(reg=reg) + + def get_H_T(self) -> torch.tensor: + r""" + Returns the transpose of the measurement matrix :math:`H`. + + Shape: + Output: :math:`(N, M)`, where :math:`N` is the number of pixels in + the image and :math:`M` the number of measurements. + + Example: + >>> H1 = np.random.random([400, 1600]) + >>> meas_op = Linear(H1) + >>> H2 = meas_op.get_H_T() + >>> print(H2.shape) + torch.Size([400, 1600]) + """ + return self.H.T + + def get_H_pinv(self) -> torch.tensor: + r"""Returns the pseudo inverse of the measurement matrix :math:`H`. + + Shape: + Output: :math:`(N, M)` + + Example: + >>> H1 = np.random.random([400, 1600]) + >>> meas_op = Linear(H1, True) + >>> H2 = meas_op.get_H_pinv() + >>> print(H2.shape) + torch.Size([1600, 400]) + """ + try: + return self.H_pinv.data + except AttributeError as e: + if "has no attribute 'H_pinv'" in str(e): + raise AttributeError( + "The pseudo inverse has not been initialized. Please set it using self.set_H_pinv()." + ) + else: + raise e + + def set_H_pinv(self, reg: float = 1e-15, pinv: torch.tensor = None) -> None: + r""" + Stores in self.H_pinv the pseudo inverse of the measurement matrix :math:`H`. + + If :attr:`pinv` is given, it is directly stored as the pseudo inverse. + The validity of the pseudo inverse is not checked. If :attr:`pinv` is + :obj:`None`, the pseudo inverse is computed from the existing + measurement matrix :math:`H` with regularization parameter :attr:`reg`. Args: - :math:`x`: batch of images in the Hadamard domain + :attr:`reg` (float, optional): Cutoff for small singular values. + + :attr:`H_pinv` (torch.tensor, optional): If given, the tensor is + directly stored as the pseudo inverse. No checks are performed. + Otherwise, the pseudo inverse is computed from the existing + measurement matrix :math:`H`. + + .. note: + Only one of :math:`H_pinv` and :math:`reg` should be given. If both + are given, :math:`H_pinv` is used and :math:`reg` is ignored. Shape: - :math:`x`: :math:`(b*c, N)` with :math:`b` the batch size, - :math:`c` the number of channels, and :math:`N` the number of - pixels in the image. + :attr:`H_pinv`: :math:`(N, M)`, where :math:`N` is the number of + pixels in the image and :math:`M` the number of measurements. - Output: math:`(b*c, N)` + Example: + >>> H1 = torch.rand([400, 1600]) + >>> H2 = torch.linalg.pinv(H1) + >>> meas_op = Linear(H1) + >>> meas_op.set_H_pinv(H2) + """ + if pinv is not None: + H_pinv = pinv.type(torch.FloatTensor) # to float32 + else: + H_pinv = torch.linalg.pinv(self.get_H(), rcond=reg) + self.H_pinv = nn.Parameter(H_pinv, requires_grad=False) + + def forward(self, x: torch.tensor) -> torch.tensor: + r"""Applies linear transform to incoming images: :math:`y = Hx`. + + Args: + :math:`x` (torch.tensor): Batch of vectorized (flattened) images. + If x has more than 1 dimension, the linear measurement is applied + to each image in the batch. + + Shape: + :math:`x`: :math:`(*, N)` where * denotes the batch size and `N` + the total number of pixels in the image. + + Output: :math:`(*, M)` where * denotes the batch size and `M` + the number of measurements. Example: + >>> H = torch.randn([400, 1600]) + >>> meas_op = Linear(H) + >>> x = torch.randn([10, 1600]) + >>> y = meas_op(x) + >>> print(y.shape) + torch.Size([10, 400]) + """ + # left multiplication with transpose is equivalent to right mult + return x @ self.get_H().T - >>> y = torch.rand([85,32*32], dtype=torch.float) - >>> x = meas_op.inverse(y) - >>> print('Inverse:', x.shape) - Inverse: torch.Size([85, 1024]) + def adjoint(self, x: torch.tensor) -> torch.tensor: + r"""Applies adjoint transform to incoming measurements :math:`y = H^{T}x` + + Args: + :math:`x` (torch.tensor): batch of measurement vectors. If x has + more than 1 dimension, the adjoint measurement is applied to each + measurement in the batch. + + Shape: + :math:`x`: :math:`(*, M)` + + Output: :math:`(*, N)` + + Example: + >>> H = torch.randn([400, 1600]) + >>> meas_op = Linear(H) + >>> x = torch.randn([10, 400] + >>> y = meas_op.adjoint(x) + >>> print(y.shape) + torch.Size([10, 1600]) """ - # permutations - # todo: check walsh2_S_fold_torch to speed up - b, N = x.shape - x = self.Perm(x) - x = x.view(b, 1, self.h, self.w) - # inverse of full transform - # todo: initialize with 1D transform to speed up - x = 1 / self.N * walsh2_torch(x) - x = x.view(b, N) - return x + # left multiplication is equivalent to right mult with transpose + return x @ self.get_H_T().T def pinv(self, x: torch.tensor) -> torch.tensor: - r"""Pseudo inverse transform of incoming mesurement vectors :math:`x` + r"""Computes the pseudo inverse solution :math:`y = H^\dagger x` Args: - :attr:`x`: batch of measurement vectors. + :math:`x` (torch.tensor): batch of measurement vectors. If x has + more than 1 dimension, the pseudo inverse is applied to each + image in the batch. Shape: - x: :math:`(*, M)` + :math:`x`: :math:`(*, M)` Output: :math:`(*, N)` Example: - >>> y = torch.rand([85,400], dtype=torch.float) - >>> x = meas_op.pinv(y) - >>> print(x.shape) - torch.Size([85, 1024]) + >>> H = torch.randn([400, 1600]) + >>> meas_op = Linear(H, True) + >>> x = torch.randn([10, 400]) + >>> y = meas_op.pinv(x) + >>> print(y.shape) + torch.Size([10, 1600]) """ - x = self.adjoint(x) / self.N - return x + # Pmat.transpose()*f + return x @ self.get_H_pinv().T + def __attributeslist__(self): + return super().__attributeslist__() + [ + ("H_pinv", self.H_pinv.shape if hasattr(self, "H_pinv") else None) + ] -# ================================================================================== -class LinearRowSplit(nn.Module): - # ================================================================================== - r"""Compute linear measurement of incoming images :math:`y = Px`, where - :math:`P` is a linear operator and :math:`x` is an image. Note that - the same transform applies to each of the rows of the image :math:`x`. - The class is constructed from the positive and negative components of - the measurement operator :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` +# ============================================================================= +class LinearSplit(Linear, DynamicLinearSplit): + # ========================================================================= + r""" + Simulates the measurement of a still image using the computed positive and + negative components of the measurement matrix. + + Computes linear measurements from incoming images: :math:`y = Px`, + where :math:`P` is a linear operator (matrix) and :math:`x` is a + vectorized image or batch of vectorized images. + + The matrix :math:`P` contains only positive values and is obtained by + splitting a measurement matrix :math:`H` such that + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where + :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + + The class is constructed from the :math:`M` by :math:`N` matrix :math:`H`, + where :math:`N` represents the number of pixels in the image and + :math:`M` the number of measurements. Args: - - :attr:`H_pos`: Positive component of the measurement matrix :math:`H_{+}` - - :attr:`H_neg`: Negative component of the measurement matrix :math:`H_{-}` + :attr:`H` (torch.tensor): measurement matrix (linear operator) with + shape :math:`(M, N)`, where :math:`M` is the number of measurements and + :math:`N` the number of pixels in the image. - Shape: - :math:`H_{+}`: :math:`(M, N)`, where :math:`M` is the number of - patterns and :math:`N` is the length of the patterns. + :attr:`pinv` (Any): Option to have access to pseudo inverse solutions. If not + `None`, the pseudo inverse is initialized as :math:`H^\dagger` and + stored in the attribute :attr:`H_pinv`. Defaults to `None` (the pseudo + inverse is not initiliazed). - :math:`H_{-}`: :math:`(M, N)`, where :math:`M` is the number of - patterns and :math:`N` is the length of the patterns. + :attr:`reg` (float, optional): Regularization parameter (cutoff for small + singular values, see :mod:`torch.linalg.pinv`). Only relevant when + :attr:`pinv` is not `None`. + + Attributes: + :attr:`H` (torch.nn.Parameter): The learnable measurement matrix of + shape :math:`(M,N)`. + + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of + shape :math:`(2M, N)` initialized as + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)` + + :attr:`M` (int): Number of measurements performed by the linear operator. + It is initialized as the first dimension of :math:`H`. + + :attr:`N` (int): Number of pixels in the image. It is initialized as the + second dimension of :math:`H`. + + :attr:`h` (int): Image height :math:`h`. The image is assumed to be + square, i.e. :math:`h = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. + + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = \text{floor}(\sqrt{N})`. If not, please assign + :attr:`h` and :attr:`w` manually. .. note:: - The class assumes the existence of the measurement operator - :math:`H = H_{+}-H_{-}` that contains negative values that cannot be - implemented in practice (harware constraints). + If you know the pseudo inverse of :math:`H` and want to store it, it is + best to initialize the class with :attr:`pinv` set to `None` and then + call :meth:`set_H_pinv` to store the pseudo inverse. Example: - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> linop = LinearRowSplit(H_pos,H_neg) - + >>> H = torch.randn(400, 1600) + >>> meas_op = LinearSplit(H, None) + >>> print(meas_op) + LinearSplit( + (Image pixels): 1600 + (H): torch.Size([400, 1600]) + (P): torch.Size([800, 1600]) + (H_pinv): None + ) """ - def __init__(self, H_pos: np.ndarray, H_neg: np.ndarray): - super().__init__() - - self.M = H_pos.shape[0] - self.N = H_pos.shape[1] - - # Split patterns ? - # N.B.: Data must be of type float (or double) rather than the default - # float64 when creating torch tensor - even_index = range(0, 2 * self.M, 2) - odd_index = range(1, 2 * self.M, 2) - P = np.zeros((2 * self.M, self.N)) - P[even_index, :] = H_pos - P[odd_index, :] = H_neg - self.P = nn.Linear(self.N, 2 * self.M, False) - self.P.weight.data = torch.from_numpy(P).float() - self.P.weight.requires_grad = False - - # "Unsplit" patterns - H = H_pos - H_neg - self.H = nn.Linear(self.N, self.M, False) - self.H.weight.data = torch.from_numpy(H).float() - self.H.weight.requires_grad = False + def __init__(self, H: np.ndarray, pinv=None, reg: float = 1e-15): + print("initializing LinearSplit") + # initialize from DynamicLinearSplit __init__ + super(Linear, self).__init__(H) + if pinv is not None: + self.set_H_pinv(reg) def forward(self, x: torch.tensor) -> torch.tensor: - r"""Applies linear transform to incoming images: :math:`y = Px` + r"""Applies linear transform to incoming images: :math:`y = Px`. + + This method uses the splitted measurement matrix :math:`P` to compute + the linear measurements from incoming images. :math:`P` contains only + positive values and is obtained by splitting a given measurement matrix + :math:`H` such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. Args: - x: a batch of images + :math:`x` (torch.tensor): Batch of vectorized (flattened) images. If + x has more than 1 dimension, the linear measurement is applied to + each image in the batch. Shape: - x: :math:`(b*c, h, w)` with :math:`b` the batch size, :math:`c` the - number of channels, :math:`h` is the image height, and :math:`w` is the image - width. - - Output: :math:`(b*c, 2M, w)` with :math:`b` the batch size, - :math:`c` the number of channels, :math:`2M` is twice the number of - patterns (as it includes both positive and negative components), and - :math:`w` is the image width. + :math:`x`: :math:`(*, N)` where * denotes the batch size and `N` + the total number of pixels in the image. - .. warning:: - The image height :math:`h` should match the length of the patterns - :math:`N` + Output: :math:`(*, 2M)` where * denotes the batch size and `M` + the number of measurements. Example: - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> linop = LinearRowSplit(H_pos,H_neg) - >>> x = torch.rand(10,64,92) - >>> y = linop(x) + >>> H = torch.randn(400, 1600) + >>> meas_op = LinearSplit(H) + >>> x = torch.randn(10, 1600) + >>> y = meas_op(x) >>> print(y.shape) - torch.Size([10,48,92]) - + torch.Size([10, 800]) """ - x = torch.transpose(x, 1, 2) # swap last two dimensions - x = self.P(x) - x = torch.transpose(x, 1, 2) # swap last two dimensions - return x + return x @ self.get_P().T def forward_H(self, x: torch.tensor) -> torch.tensor: - r"""Applies linear transform to incoming images: :math:`m = Hx` + r"""Applies linear transform to incoming images: :math:`m = Hx`. + + This method uses the measurement matrix :math:`H` to compute the linear + measurements from incoming images. Args: - x: a batch of images + :attr:`x` (torch.tensor): Batch of vectorized (flatten) images. If + x has more than 1 dimension, the linear measurement is applied to + each image in the batch. Shape: - x: :math:`(b*c, h, w)` with :math:`b` the batch size, :math:`c` the - number of channels, :math:`h` is the image height, and :math:`w` is the image - width. - - Output: :math:`(b*c, M, w)` with :math:`b` the batch size, - :math:`c` the number of channels, :math:`M` is the number of - patterns, and :math:`w` is the image width. + :attr:`x`: :math:`(*, N)` where * denotes the batch size and `N` + the total number of pixels in the image. - .. warning:: - The image height :math:`h` should match the length of the patterns - :math:`N` + Output: :math:`(*, M)` where * denotes the batch size and `M` + the number of measurements. Example: - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> meas_op = LinearRowSplit(H_pos,H_neg) - >>> x = torch.rand(10,64,92) + >>> H = torch.randn(400, 1600) + >>> meas_op = LinearSplit(H) + >>> x = torch.randn(10, 1600) >>> y = meas_op.forward_H(x) >>> print(y.shape) - torch.Size([10,24,92]) - + torch.Size([10, 400]) """ - x = torch.transpose(x, 1, 2) # swap last two dimensions - x = self.H(x) - x = torch.transpose(x, 1, 2) # swap last two dimensions - return x + # call Linear.forward() method + return super(LinearSplit, self).forward(x) - def get_H(self) -> torch.tensor: - r"""Returns the measurement matrix :math:`H`. + +# ============================================================================= +class HadamSplit(LinearSplit, DynamicHadamSplit): + # ========================================================================= + r""" + Simulates the measurement of a still image using the positive and + negative components of a Hadamard matrix. + + Computes linear measurements from incoming images: :math:`y = Px`, + where :math:`P` is a linear operator (matrix) with positive entries and + :math:`x` is a vectorized image or a batch of images. + + The class relies on a Hadamard-based matrix :math:`H` with shape :math:`(M,N)` + where :math:`N` represents the number of pixels in the image and + :math:`M \le N` the number of measurements. :math:`H` is obtained by + selecting a re-ordered subsample of :math:`M` rows of a "full" Hadamard + matrix :math:`F` with shape :math:`(N^2, N^2)`. :math:`N` must be a power + of 2. + + The matrix :math:`P` is then obtained by splitting the matrix :math:`H` + such that :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}`, where + :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + + Args: + :attr:`M` (int): Number of measurements + + :attr:`h` (int): Image height :math:`h`, must be a power of 2. The + image is assumed to be square, so the number of pixels in the image is + :math:`N = h^2`. + + :attr:`Ord` (np.ndarray): Order matrix with shape :math:`(h, h)` used to + compute the permutation matrix :math:`G^{T}` with shape :math:`(N, N)` + (see the :mod:`~spyrit.misc.sampling` submodule) + + Attributes: + :attr:`H` (torch.nn.Parameter): The measurement matrix of shape + :math:`(M, h^2)`. It is initialized as a re-ordered subsample of the + rows of the "full" Hadamard matrix :math:`F` with shape :math:`(N^2, N^2)`. + + :attr:`H_pinv` (torch.nn.Parameter): The pseudo inverse of the measurement + matrix of shape :math:`(h^2, M)`. It is initialized as + :math:`H^\dagger = \frac{1}{N}H^{T}` where :math:`N = h^2`. + + :attr:`P` (torch.nn.Parameter): The splitted measurement matrix of + shape :math:`(2M, h^2)` initialized as + :math:`P = \begin{bmatrix}{H_{+}}\\{H_{-}}\end{bmatrix}` + where :math:`H_{+} = \max(0,H)` and :math:`H_{-} = \max(0,-H)`. + + :attr:`Perm` (torch.nn.Parameter): The permutation matrix :math:`G^{T}` + that is used to re-order the subsample of rows of the "full" Hadamard + matrix :math:`F` according to descreasing value of the order matrix + :math:`Ord`. It has shape :math:`(N, N)` where :math:`N = h^2`. + + :attr:`M` (int): Number of measurements performed by the linear operator. + + :attr:`N` (int): Number of pixels in the image. It is initialized as + :math:`h^2`. + + :attr:`h` (int): Image height :math:`h`. + + :attr:`w` (int): Image width :math:`w`. The image is assumed to be + square, i.e. :math:`w = h`. + + .. note:: + The computation of a Hadamard transform :math:`Fx` benefits a fast + algorithm, as well as the computation of inverse Hadamard transforms. + + .. note:: + The matrix H has shape :math:`(M,N)` with :math:`N = h^2`. + + .. note:: + :math:`H = H_{+} - H_{-}` + + Example: + >>> h = 32 + >>> Ord = torch.randn(h, h) + >>> meas_op = HadamSplit(400, h, Ord) + >>> print(meas_op) + HadamSplit( + (Image pixels): 1024 + (H): torch.Size([400, 1024]) + (P): torch.Size([800, 1024]) + (Perm): torch.Size([1024, 1024]) + (H_pinv): torch.Size([1024, 400]) + ) + """ + + def __init__(self, M: int, h: int, Ord: np.ndarray): + print("initializing HadamSplit") + # initialize from DynamicHadamSplit (the MRO is not trivial here) + super(Linear, self).__init__(M, h, Ord) + self.set_H_pinv(pinv=1 / self.N * self.get_H_T()) + + def inverse(self, x: torch.tensor) -> torch.tensor: + r"""Inverse transform of Hadamard-domain images + :math:`x = H_{had}^{-1}G y` is a Hadamard matrix. + + Args: + :math:`x`: batch of images in the Hadamard domain Shape: - Output: :math:`(M, N)` + :math:`x`: :math:`(b*c, N)` with :math:`b` the batch size, + :math:`c` the number of channels, and :math:`N` the number of + pixels in the image. + + Output: math:`(b*c, N)` Example: - >>> H = meas_op.get_H() - >>> print(H.shape) - torch.Size([24, 64]) + >>> h = 32 + >>> Ord = torch.randn(h, h) + >>> meas_op = HadamSplit(400, h, Ord) + >>> y = torch.randn(10, h**2) + >>> x = meas_op.inverse(y) + >>> print(x.shape) + torch.Size([10, 1024]) """ - return self.H.weight.data + # permutations + # todo: check walsh2_S_fold_torch to speed up + b, N = x.shape + x = x @ self.Perm.T + x = x.view(b, 1, self.h, self.w) + # inverse of full transform + # todo: initialize with 1D transform to speed up + x = 1 / self.N * walsh2_torch(x) + return x.view(b, N) diff --git a/spyrit/core/noise.py b/spyrit/core/noise.py index 3443d41b..a57fd696 100644 --- a/spyrit/core/noise.py +++ b/spyrit/core/noise.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import poisson -from spyrit.core.meas import Linear, LinearSplit, LinearRowSplit, HadamSplit +from spyrit.core.meas import Linear, LinearSplit, HadamSplit # , LinearRowSplit from typing import Union @@ -36,7 +36,7 @@ class NoNoise(nn.Module): >>> split_acq = NoNoise(split_op) """ - def __init__(self, meas_op: Union[Linear, LinearSplit, HadamSplit, LinearRowSplit]): + def __init__(self, meas_op: Union[Linear, LinearSplit, HadamSplit]): super().__init__() self.meas_op = meas_op @@ -101,16 +101,15 @@ class Poisson(NoNoise): >>> split_op = HadamSplit(H, Perm, 32, 32) >>> split_acq = Poisson(split_op, 200.0) - Example 3: Using a :class:`~spyrit.core.meas.LinearRowSplit` measurement operator - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> split_row_op = LinearRowSplit(H_pos,H_neg) + Example 3: Using a :class:`~spyrit.core.meas.LinearSplit` measurement operator + >>> H = np.random.rand(24,64) + >>> split_row_op = LinearSplit(H) >>> split_acq = Poisson(split_row_op, 50.0) """ def __init__( self, - meas_op: Union[Linear, LinearSplit, HadamSplit, LinearRowSplit], + meas_op: Union[Linear, LinearSplit, HadamSplit], alpha=50.0, ): super().__init__(meas_op) @@ -155,10 +154,9 @@ def forward(self, x): Measurements in (0.00 , 55338.00) Measurements in (0.00 , 55077.00) - Example 3: Two noisy measurement vectors from a :class:`~spyrit.core.meas.LinearRowSplit` operator - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> meas_op = LinearRowSplit(H_pos,H_neg) + Example 3: Two noisy measurement vectors from a :class:`~spyrit.core.meas.LinearSplit` operator + >>> H = np.random.rand(24,64) + >>> meas_op = LinearSplit(H) >>> noise_op = Poisson(meas_op, 50.0) >>> x = torch.FloatTensor(10, 64, 92).uniform_(-1, 1) >>> y = noise_op(x) @@ -213,16 +211,15 @@ class PoissonApproxGauss(NoNoise): >>> meas_op = HadamSplit(H, Perm, 32, 32) >>> noise_op = PoissonApproxGauss(meas_op, 200.0) - Example 3: Using a :class:`~spyrit.core.meas.LinearRowSplit` operator - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> meas_op = LinearRowSplit(H_pos,H_neg) + Example 3: Using a :class:`~spyrit.core.meas.LinearSplit` operator + >>> H = np.random.rand(24,64) + >>> meas_op = LinearSplit(H) >>> noise_op = PoissonApproxGauss(meas_op, 50.0) """ def __init__( self, - meas_op: Union[Linear, LinearSplit, HadamSplit, LinearRowSplit], + meas_op: Union[Linear, LinearSplit, HadamSplit], alpha: float, ): super().__init__(meas_op) @@ -267,10 +264,9 @@ def forward(self, x: torch.tensor) -> torch.tensor: Measurements in (0.00 , 55951.41) Measurements in (0.00 , 56216.86) - Example 3: Two noisy measurement vectors from a :class:`~spyrit.core.meas.LinearRowSplit` operator - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> meas_op = LinearRowSplit(H_pos,H_neg) + Example 3: Two noisy measurement vectors from a :class:`~spyrit.core.meas.LinearSplit` operator + >>> H = np.random.rand(24,64) + >>> meas_op = LinearSplit(H) >>> noise_op = PoissonApproxGauss(meas_op, 50.0) >>> x = torch.FloatTensor(10, 64, 92).uniform_(-1, 1) >>> y = noise_op(x) diff --git a/spyrit/core/prep.py b/spyrit/core/prep.py index 3c9acaab..835708a1 100644 --- a/spyrit/core/prep.py +++ b/spyrit/core/prep.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from spyrit.core.meas import Linear, LinearSplit, LinearRowSplit, HadamSplit +from spyrit.core.meas import Linear, LinearSplit, HadamSplit # , LinearRowSplit from typing import Union, Tuple import math @@ -456,7 +456,7 @@ class SplitRowPoisson(nn.Module): It computes :math:`m = \frac{y_{+}-y_{-}}{\alpha}` and the variance :math:`\sigma^2 = \frac{2(y_{+} + y_{-})}{\alpha^{2}}`, where :math:`y_{+} = H_{+}x` and :math:`y_{-} = H_{-}x` are obtained using - a split measurement operator such as :class:`spyrit.core.LinearRowSplit`. + a split measurement operator such as :class:`spyrit.core.LinearSplit`. Args: - :math:`\alpha` (float): maximun image intensity (in counts) @@ -481,7 +481,7 @@ def __init__(self, alpha: float, M: int, h: int): def forward( self, x: torch.tensor, - meas_op: LinearRowSplit, + meas_op: LinearSplit, ) -> torch.tensor: """ Args: @@ -504,7 +504,7 @@ def forward( >>> x = torch.rand([10,48,64], dtype=torch.float) >>> H_pos = np.random.random([24,64]) >>> H_neg = np.random.random([24,64]) - >>> meas_op = LinearRowSplit(H_pos, H_neg) + >>> meas_op = LinearSplit(H_pos, H_neg) >>> m = split_op(x, meas_op) >>> print(m.shape) torch.Size([10, 24, 64]) diff --git a/spyrit/core/recon.py b/spyrit/core/recon.py index 9aa86da3..dd680601 100644 --- a/spyrit/core/recon.py +++ b/spyrit/core/recon.py @@ -443,7 +443,7 @@ def meas2img(self, y): return z.view(-1, 1, self.acqu.meas_op.h, self.acqu.meas_op.w) def reconstruct(self, x): - r"""Reconstruction step of a reconstruction network + r"""Preprocesses, reconstructs, and denoises raw measurement vectors. Args: :attr:`x`: raw measurement vectors @@ -465,23 +465,11 @@ def reconstruct(self, x): >>> print(z.shape) torch.Size([10, 1, 64, 64]) """ - # Measurement to image domain mapping - bc, _ = x.shape - - # Preprocessing in the measurement domain - x = self.prep(x) # shape x = [b*c, M] - - # measurements to image-domain processing - x = self.pinv(x, self.acqu.meas_op) # shape x = [b*c,N] - - # Image-domain denoising - x = x.view( - bc, 1, self.acqu.meas_op.h, self.acqu.meas_op.w - ) # shape x = [b*c,1,h,w] - return self.denoi(x) + # Denoise image-domain + return self.denoi(self.reconstruct_pinv(x)) def reconstruct_pinv(self, x): - r"""Reconstruction step of a reconstruction network + r"""Preprocesses and reconstructs raw measurement vectors. Args: :attr:`x`: raw measurement vectors diff --git a/spyrit/dev/recon.py b/spyrit/dev/recon.py index a068a866..da3ff2be 100644 --- a/spyrit/dev/recon.py +++ b/spyrit/dev/recon.py @@ -85,7 +85,7 @@ class PseudoInverseStore2(nn.Module): Args: :attr:`meas_op`: Measurement operator that defines :math:`H`. Any class that implements a :meth:`get_H` method can be used, e.g., - :class:`~spyrit.core.forwop.LinearRowSplit`. + :class:`~spyrit.core.forwop.LinearSplit`. :attr:`reg` (optional): Regularization parameter (cutoff for small singular values, see :mod:`numpy.linal.pinv`). @@ -102,9 +102,8 @@ class PseudoInverseStore2(nn.Module): matrix is stored and therefore learnable Example 1: - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> meas_op = LinearRowSplit(H_pos,H_neg) + >>> H = np.random.rand(24,64) + >>> meas_op = LinearSplit(H) >>> recon_op = PseudoInverseStore2(meas_op) Example 2: @@ -112,16 +111,12 @@ class PseudoInverseStore2(nn.Module): >>> N = 64 >>> B = 1 >>> H = walsh_matrix(N) - >>> H_pos = np.where(H>0,H,0)[:M,:] - >>> H_neg = np.where(H<0,-H,0)[:M,:] - >>> meas_op = LinearRowSplit(H_pos,H_neg) + >>> meas_op = LinearSplit(H) >>> recon_op = PseudoInverseStore2(meas_op) """ - def __init__( - self, meas_op: LinearRowSplit, reg: float = 1e-15, learn: bool = False - ): + def __init__(self, meas_op: LinearSplit, reg: float = 1e-15, learn: bool = False): H = meas_op.get_H() M, N = H.shape H_pinv = np.linalg.pinv(H, rcond=reg) @@ -143,9 +138,8 @@ def forward(self, x: torch.tensor) -> torch.tensor: - :attr:`output`: :math:`(*, N)` Example 1: - >>> H_pos = np.random.rand(24,64) - >>> H_neg = np.random.rand(24,64) - >>> meas_op = LinearRowSplit(H_pos,H_neg) + >>> H = np.random.rand(24,64) + >>> meas_op = LinearSplit(H) >>> recon_op = PseudoInverseStore2(meas_op) >>> x = torch.rand([10,24,92], dtype=torch.float) >>> y = recon_op(x) @@ -157,9 +151,7 @@ def forward(self, x: torch.tensor) -> torch.tensor: >>> N = 64 >>> B = 1 >>> H = walsh_matrix(N) - >>> H_pos = np.where(H>0,H,0)[:M,:] - >>> H_neg = np.where(H<0,-H,0)[:M,:] - >>> meas_op = LinearRowSplit(H_pos,H_neg) + >>> meas_op = LinearSplit(H) >>> noise_op = NoNoise(meas_op) >>> split_op = SplitRowPoisson(1.0, M, 92) >>> recon_op = PseudoInverseStore2(meas_op) diff --git a/spyrit/misc/sampling.py b/spyrit/misc/sampling.py index 54d919d7..78adf66a 100644 --- a/spyrit/misc/sampling.py +++ b/spyrit/misc/sampling.py @@ -96,7 +96,7 @@ def Permutation_Matrix(Mat: np.ndarray) -> np.ndarray: N-by-N sampling matrix, where high values indicate high significance. Returns: - P (np.ndarray): N*N-by-N*N permutation matrix (boolean) + P (np.ndarray): N^2-by-N^2 permutation matrix (boolean) """ (nx, ny) = Mat.shape Reorder = rankdata(-Mat, method="ordinal") diff --git a/spyrit/misc/walsh_hadamard.py b/spyrit/misc/walsh_hadamard.py index 84ab5662..5da3a587 100644 --- a/spyrit/misc/walsh_hadamard.py +++ b/spyrit/misc/walsh_hadamard.py @@ -646,10 +646,8 @@ def walsh2_matrix(n): Returns: H (np.ndarray): A n*n-by-n*n matrix """ - H = np.zeros((n**2, n**2)) H1d = walsh_matrix(n) - H = np.kron(H1d, H1d) - return H + return np.kron(H1d, H1d) def walsh2(X, H=None): @@ -1278,11 +1276,12 @@ def walsh2_torch(im, H=None): """Return 2D Walsh-ordered Hadamard transform of an image Args: - im (torch.Tensor): Image, typically a B-by-C-by-W-by-H Tensor - H (torch.Tensor, optional): 1D Walsh-ordered Hadamard transformation matrix. A 2-D tensor of size W-by-H. + im (torch.tensor): Image, typically a B-by-C-by-W-by-H Tensor + H (torch.tensor, optional): 1D Walsh-ordered Hadamard transformation + matrix. A 2-D tensor of size W-by-H. Returns: - torch.Tensor: Hadamard transformed image. Same size as im + torch.tensor: Hadamard transformed image. Same size as im Examples: >>> im = torch.randn(256, 1, 64, 64) diff --git a/spyrit/test/test_core_meas.py b/spyrit/test/test_core_meas.py index 6bb2c2ec..86b485b4 100644 --- a/spyrit/test/test_core_meas.py +++ b/spyrit/test/test_core_meas.py @@ -109,31 +109,31 @@ def test_core_meas(): assert_test(x.shape, torch.Size([85, 1024]), "Wrong inverse size") # %% Test LinearRowSplit - from spyrit.core.meas import LinearRowSplit - - # constructor - H_pos = np.random.rand(24, 64) - H_neg = np.random.rand(24, 64) - meas_op = LinearRowSplit(H_pos, H_neg) - - # forward - x = torch.rand([10, 64, 92], dtype=torch.float) - y = meas_op(x) - print(y.shape) - assert_test(y.shape, torch.Size([10, 48, 92]), "Wrong forward size") - - # forward_H - x = torch.rand([10, 64, 92], dtype=torch.float) - y = meas_op(x) - print(y.shape) - assert_test(y.shape, torch.Size([10, 48, 92]), "Wrong forward size") - - # get_H - H = meas_op.get_H() - print(H.shape) - assert_test(H.shape, torch.Size([24, 64]), "Wrong measurement matrix size") - - return True + # from spyrit.core.meas import LinearRowSplit + + # # constructor + # H_pos = np.random.rand(24, 64) + # H_neg = np.random.rand(24, 64) + # meas_op = LinearRowSplit(H_pos, H_neg) + + # # forward + # x = torch.rand([10, 64, 92], dtype=torch.float) + # y = meas_op(x) + # print(y.shape) + # assert_test(y.shape, torch.Size([10, 48, 92]), "Wrong forward size") + + # # forward_H + # x = torch.rand([10, 64, 92], dtype=torch.float) + # y = meas_op(x) + # print(y.shape) + # assert_test(y.shape, torch.Size([10, 48, 92]), "Wrong forward size") + + # # get_H + # H = meas_op.get_H() + # print(H.shape) + # assert_test(H.shape, torch.Size([24, 64]), "Wrong measurement matrix size") + + # return True if __name__ == "__main__": diff --git a/spyrit/test/test_core_noise.py b/spyrit/test/test_core_noise.py index dfc64860..12dbbbe1 100644 --- a/spyrit/test/test_core_noise.py +++ b/spyrit/test/test_core_noise.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import numpy as np import torch -from spyrit.core.meas import Linear, LinearSplit, LinearRowSplit, HadamSplit +from spyrit.core.meas import Linear, LinearSplit, HadamSplit # , LinearRowSplit from test_helpers import assert_test @@ -67,9 +67,8 @@ def test_core_noise(): print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})") # EXAMPLE 3 - H_pos = np.random.rand(24, 64) - H_neg = np.random.rand(24, 64) - meas_op = LinearRowSplit(H_pos, H_neg) + H = np.random.rand(24, 64) + meas_op = LinearSplit(H) noise_op = Poisson(meas_op, 50.0) x = torch.FloatTensor(10, 64, 92).uniform_(-1, 1) @@ -117,9 +116,8 @@ def test_core_noise(): print(f"Measurements in ({torch.min(y):.2f} , {torch.max(y):.2f})") # EXAMPLE 3 - H_pos = np.random.rand(24, 64) - H_neg = np.random.rand(24, 64) - meas_op = LinearRowSplit(H_pos, H_neg) + H = np.random.rand(24, 64) + meas_op = LinearSplit(H) noise_op = PoissonApproxGauss(meas_op, 50.0) x = torch.FloatTensor(10, 64, 92).uniform_(-1, 1) diff --git a/spyrit/test/test_core_prep.py b/spyrit/test/test_core_prep.py index 2363c919..55123e54 100644 --- a/spyrit/test/test_core_prep.py +++ b/spyrit/test/test_core_prep.py @@ -94,17 +94,16 @@ def test_core_prep(): assert_test(y.shape, torch.Size([10, 1, 32, 32]), "Wrong matrix size") # %% Test SplitRowPoisson - from spyrit.core.meas import LinearRowSplit + from spyrit.core.meas import LinearSplit from spyrit.core.prep import SplitRowPoisson # constructor split_op = SplitRowPoisson(2.0, 24, 64) - # forward with LinearRowSplit + # forward with LinearSplit x = torch.rand([10, 48, 64], dtype=torch.float) - H_pos = np.random.random([24, 64]) - H_neg = np.random.random([24, 64]) - meas_op = LinearRowSplit(H_pos, H_neg) + H = np.random.random([24, 64]) + meas_op = LinearSplit(H) # forward m = split_op(x, meas_op)