Skip to content

Commit

Permalink
Merge pull request #154 from openspyrit/romain_dev
Browse files Browse the repository at this point in the history
ReadTheDocs documentation update, docstring improved
  • Loading branch information
romainphan committed Feb 20, 2024
2 parents 3802091 + 410b6ed commit e55ed09
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 268 deletions.
14 changes: 0 additions & 14 deletions docs/source/_templates/spyrit-class-template.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,3 @@
{%- endfor %}
{% endif %}
{% endblock %}

{% block attributes %}
{% if attributes %}
.. rubric:: {{ _('Attributes') }}

.. autosummary::
:toctree:
{% for item in attributes %}
{%- if item != "training" %}
~{{ name }}.{{ item }}
{%- endif %}
{%- endfor %}
{% endif %}
{% endblock %}
177 changes: 98 additions & 79 deletions spyrit/core/recon.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,19 @@
# -*- coding: utf-8 -*-
"""
Reconstruction methods
Created on Fri Jan 20 11:03:12 2023
@author: ducros
Reconstruction methods and networks.
"""
import math
import torch
import torch.nn as nn
import numpy as np
from spyrit.core.meas import HadamSplit, LinearRowSplit, Linear
import math

from spyrit.core.meas import HadamSplit

# ==================================================================================

# =============================================================================
class PseudoInverse(nn.Module):
# ==================================================================================
r"""Moore-Penrose Pseudoinverse
# =========================================================================
r"""Moore-Penrose pseudoinverse.
Considering linear measurements :math:`y = Hx`, where :math:`H` is the
measurement matrix and :math:`x` is a vectorized image, it estimates
Expand All @@ -38,7 +35,7 @@ def __init__(self):
super().__init__()

def forward(self, x: torch.tensor, meas_op) -> torch.tensor:
r"""Compute pseudo-inverse of measurements.
r"""Computes pseudo-inverse of measurements.
Args:
:attr:`x`: Batch of measurement vectors.
Expand All @@ -65,15 +62,14 @@ def forward(self, x: torch.tensor, meas_op) -> torch.tensor:
>>> print(x.shape)
torch.Size([85, 1024])
"""
x = meas_op.pinv(x)
return x
return meas_op.pinv(x)


# ===========================================================================================
# =============================================================================
class TikhonovMeasurementPriorDiag(nn.Module):
# ===========================================================================================
# =========================================================================
r"""
Tikhonov regularization with prior in the measurement domain
Tikhonov regularization with prior in the measurement domain.
Considering linear measurements :math:`y = Hx`, where :math:`H = GF` is the
measurement matrix and :math:`x` is a vectorized image, it estimates
Expand Down Expand Up @@ -132,8 +128,9 @@ def forward(
self, x: torch.tensor, x_0: torch.tensor, var: torch.tensor, meas_op: HadamSplit
) -> torch.tensor:
r"""
Computes the Tikhonov regularization with prior in the measurement domain.
We approximate the solution as
We approximate the solution as:
.. math::
\hat{x} = x_0 + F^{-1} \begin{bmatrix} y_1 \\ y_2\end{bmatrix}
Expand Down Expand Up @@ -188,31 +185,32 @@ def forward(
return x


# ===========================================================================================
# =============================================================================
class Denoise_layer(nn.Module):
# ===========================================================================================
r"""Wiener filter that assumes additive white Gaussian noise
# =========================================================================
r"""Wiener filter that assumes additive white Gaussian noise.
.. math::
y = \sigma_\text{prior}^2/(\sigma^2_\text{prior} + \sigma^2_\text{meas}) x,
where :math:`\sigma^2_\text{prior}` is the variance prior and
:math:`\sigma^2_\text{meas}` is the variance of the measurement,
x is the input vector and y is the output vector.
where :math:`\sigma^2_\text{prior}` is the variance prior and
:math:`\sigma^2_\text{meas}` is the variance of the measurement,
x is the input vector and y is the output vector.
Args:
:attr:`M`: size of incoming vector
:attr:`M` (int): size of incoming vector
Shape:
- Input: :math:`(*, M)`.
- Output: :math:`(*, M)`.
Attributes:
:attr:`sigma`:
the learnable standard deviation prior
:math:`\sigma_\text{prior}` of shape :math:`(M, 1)`. The
values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = 1/M`.
:attr:`weight`:
The learnable standard deviation prior :math:`\sigma_\text{prior}` of
shape :math:`(M, 1)`. The values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where :math:`k = 1/M`.
:attr:`in_features`:
The number of input features equal to :math:`M`.
Example:
>>> m = Denoise_layer(30)
Expand All @@ -222,53 +220,88 @@ class Denoise_layer(nn.Module):
torch.Size([128, 30])
"""

def __init__(self, M):
def __init__(self, M: int):
super(Denoise_layer, self).__init__()
self.in_features = M
self.weight = nn.Parameter(torch.Tensor(M))
self.reset_parameters()

def reset_parameters(self):
r"""
Resets the standard deviation prior :math:`\sigma_\text{prior}`.
The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`,
where :math:`k = 1/M`. They are stored in the :attr:`weight` attribute.
"""
nn.init.uniform_(self.weight, 0, 2 / math.sqrt(self.in_features))

def forward(self, inputs):
def forward(self, inputs: torch.tensor) -> torch.tensor:
r"""
Applies a transformation to the incoming data: :math:`y = A^2/(A^2+x)`.
:math:`x` is the input tensor (see :attr:`inputs`) and :math:`A` is the
standard deviation prior (see :attr:`self.weight`).
Args:
:attr:`inputs` (torch.tensor): input tensor :math:`x` of shape
:math:`(N, *, in\_features)`
Returns:
torch.tensor: The transformed data :math:`y` of shape
:math:`(N, in\_features)`
Shape:
"""
return self.tikho(inputs, self.weight)

def extra_repr(self):
return "in_features={}".format(self.in_features)

@staticmethod
def tikho(inputs, weight):
def tikho(inputs: torch.tensor, weight: torch.tensor) -> torch.tensor:
# type: (torch.Tensor, torch.Tensor) -> torch.Tensor
r"""
Applies a transformation to the incoming data: :math:`y = A^2/(A^2+x)`.
:math:`x` is the input tensor (see :attr:`inputs`) and :math:`A` is the
standard deviation prior (see :attr:`weight`).
Args:
:attr:`inputs` (torch.tensor): input tensor :math:`x` of shape
:math:`(N, *, in\_features)`
:attr:`weight` (torch.tensor): standard deviation prior :math:`A` of
shape :math:`(in\_features)`
Returns:
torch.tensor: The transformed data :math:`y` of shape
:math:`(N, in\_features)`
Shape:
- Input: :math:`(N, *, in\_features)` where `*` means any number of
- :attr:`inputs`: :math:`(N, *, in\_features)` where `*` means any number of
additional dimensions - Variance of measurements
- Weight: :math:`(in\_features)` - corresponds to the standard deviation
- :attr:`weight`: :math:`(in\_features)` - corresponds to the standard deviation
of our prior.
- Output: :math:`(N, in\_features)`
- :attr:`output`: :math:`(N, in\_features)`
"""
var = weight**2 # prefer to square it, because when leant, it can got to the
a = weight**2 # prefer to square it, because when learnt, it can go to the
# negative, which we do not want to happen.
# TO BE Potentially done : square inputs.
den = var + inputs
ret = var / den
return ret
b = a + inputs
return a / b


# -----------------------------------------------------------------------------
# | RECONSTRUCTION NETWORKS |
# -----------------------------------------------------------------------------


# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# RECONSTRUCTION NETWORKS
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# =============================================================================
class PinvNet(nn.Module):
# =============================================================================
# =========================================================================
r"""Pseudo inverse reconstruction network
.. math:
Args:
:attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`)
Expand All @@ -280,8 +313,10 @@ class PinvNet(nn.Module):
Input / Output:
:attr:`input`: Ground-truth images with shape :math:`(B,C,H,W)`
corresponding to the batch size, number of channels, height, and width.
:attr:`output`: Reconstructed images with shape :math:`(B,C,H,W)`
corresponding to the batch size, number of channels, height, and width.
Attributes:
:attr:`Acq`: Acquisition operator initialized as :attr:`noise`
Expand Down Expand Up @@ -341,7 +376,6 @@ def forward(self, x):
torch.Size([10, 1, 64, 64])
tensor(5.8912e-06)
"""

b, c, _, _ = x.shape

# Acquisition
Expand All @@ -350,12 +384,10 @@ def forward(self, x):

# Reconstruction
x = self.reconstruct(x) # shape x = [bc, 1, h,w]
x = x.view(b, c, self.acqu.meas_op.h, self.acqu.meas_op.w)

return x
return x.view(b, c, self.acqu.meas_op.h, self.acqu.meas_op.w)

def acquire(self, x):
r"""Simulate data acquisition
r"""Simulates data acquisition
Args:
:attr:`x`: ground-truth images
Expand All @@ -377,17 +409,13 @@ def acquire(self, x):
>>> print(z.shape)
torch.Size([10, 8192])
"""

b, c, _, _ = x.shape

# Acquisition
x = x.view(b * c, self.acqu.meas_op.N) # shape x = [b*c,h*w] = [b*c,N]
x = self.acqu(x) # shape x = [b*c, 2*M]

return x
return self.acqu(x) # shape x = [b*c, 2*M]

def meas2img(self, y):
"""Return images from raw measurement vectors
"""Returns images from raw measurement vectors
Args:
:attr:`x`: raw measurement vectors
Expand All @@ -412,9 +440,7 @@ def meas2img(self, y):
m = self.prep(y)
m = torch.nn.functional.pad(m, (0, self.acqu.meas_op.N - self.acqu.meas_op.M))
z = m @ self.acqu.meas_op.Perm.weight.data.T
z = z.view(-1, 1, self.acqu.meas_op.h, self.acqu.meas_op.w)

return z
return z.view(-1, 1, self.acqu.meas_op.h, self.acqu.meas_op.w)

def reconstruct(self, x):
r"""Reconstruction step of a reconstruction network
Expand Down Expand Up @@ -452,9 +478,7 @@ def reconstruct(self, x):
x = x.view(
bc, 1, self.acqu.meas_op.h, self.acqu.meas_op.w
) # shape x = [b*c,1,h,w]
x = self.denoi(x)

return x
return self.denoi(x)

def reconstruct_pinv(self, x):
r"""Reconstruction step of a reconstruction network
Expand Down Expand Up @@ -492,7 +516,6 @@ def reconstruct_pinv(self, x):
x = x.view(
bc, 1, self.acqu.meas_op.h, self.acqu.meas_op.w
) # shape x = [b*c,1,h,w]

return x

def reconstruct_expe(self, x):
Expand All @@ -511,7 +534,6 @@ def reconstruct_expe(self, x):
:attr:`x`: :math:`(BC,2M)`
:attr:`output`: :math:`(BC,1,H,W)`
"""
# x of shape [b*c, 2M]
bc, _ = x.shape
Expand All @@ -537,14 +559,11 @@ def reconstruct_expe(self, x):
return x


# %%===========================================================================================
# =============================================================================
class DCNet(nn.Module):
# ===========================================================================================
# =========================================================================
r"""Denoised completion reconstruction network
.. math:
Args:
:attr:`noise`: Acquisition operator (see :class:`~spyrit.core.noise`)
Expand Down Expand Up @@ -759,9 +778,9 @@ def reconstruct_expe(self, x):
return x


# %%===========================================================================================
# =============================================================================
class DCDRUNet(DCNet):
# ===========================================================================================
# =========================================================================
r"""Denoised completion reconstruction network based on DRUNet wich concatenates a
noise level map to the input
Expand Down Expand Up @@ -901,9 +920,9 @@ def set_noise_level(self, noise_level):
self.noise_level = torch.FloatTensor([noise_level / 255.0])


# %%===========================================================================================
# =============================================================================
class PositiveParameters(nn.Module):
# ===========================================================================================
# ==========================================================================
def __init__(self, size, val_min=1e-6):
super(PositiveParameters, self).__init__()
self.val_min = torch.tensor(val_min)
Expand All @@ -915,9 +934,9 @@ def forward(self):
return torch.abs(self.params)


# %%===========================================================================================
# =============================================================================
class PositiveMonoIncreaseParameters(PositiveParameters):
# ===========================================================================================
# =========================================================================
def __init__(self, size, val_min=0.000001):
super().__init__(size, val_min)

Expand All @@ -926,9 +945,9 @@ def forward(self):
return super().forward().cumsum(dim=0).flip(dims=[0])


# %%===========================================================================================
# =============================================================================
class UPGD(PinvNet):
# ===========================================================================================
# =========================================================================
def __init__(
self,
noise,
Expand Down
Loading

0 comments on commit e55ed09

Please sign in to comment.