Skip to content

Commit

Permalink
core.torch correctedswap Dx Dy
Browse files Browse the repository at this point in the history
  • Loading branch information
romainphan committed Jun 11, 2024
1 parent 52e422a commit 6e5fcf3
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions spyrit/core/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,37 +200,51 @@ def finite_diff_mat(n, boundary="dirichlet"):
ones_0right = ones_0right.flatten()

if boundary == "dirichlet":
Dx = spdiags([ones, -ones], [0, -n], (N, N))
Dy = spdiags([ones, -ones_0right], [0, -1], (N, N))
Dx = spdiags([ones, -ones_0right], [0, -1], (N, N))
Dy = spdiags([ones, -ones], [0, -n], (N, N))

elif boundary == "neumann":
ones_0left = ones_0right.roll(1)
ones_0top = ones_0left.reshape(n, n).T.flatten()
Dx = spdiags([ones_0top, -ones], [0, -n], (N, N))
Dy = spdiags([ones_0left, -ones_0right], [0, -1], (N, N))
Dx = spdiags([ones_0left, -ones_0right], [0, -1], (N, N))
Dy = spdiags([ones_0top, -ones], [0, -n], (N, N))

elif boundary == "periodic":
zeros_1left = (1 - ones_0right).roll(1)
Dx = spdiags([ones, -ones, -ones], [0, -n, N - n], (N, N))
Dy = spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N))
Dx = spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N))
Dy = spdiags([ones, -ones, -ones], [0, -n, N - n], (N, N))

elif boundary == "symmetric":
zeros_1left = (1 - ones_0right).roll(1)
zeros_1top = zeros_1left.reshape(n, n).T.flatten()
Dx = spdiags([ones, -ones, -zeros_1top], [0, -n, n], (N, N))
Dy = spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N))
Dx = spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N))
Dy = spdiags([ones, -ones, -zeros_1top], [0, -n, n], (N, N))

elif boundary == "antisymmetric":
zeros_1left = (1 - ones_0right).roll(1)
zeros_1top = zeros_1left.reshape(n, n).T.flatten()
Dx = spdiags([ones, -ones, zeros_1top], [0, -n, n], (N, N))
Dy = spdiags([ones, -ones_0right, zeros_1left], [0, -1, 1], (N, N))
Dx = spdiags([ones, -ones_0right, zeros_1left], [0, -1, 1], (N, N))
Dy = spdiags([ones, -ones, zeros_1top], [0, -n, n], (N, N))

return Dx, Dy


def neumann_boundary(img_shape):
""""""
r"""
Creates a finite difference matrix of shape :math:`(h*w,h*w)` for a 2D
image of shape :math:`(h,w)`. The boundary condition used is Neumann.
Args:
:attr:`img_shape` (tuple): The size of the image :math:`(h,w)`.
Returns:
:class:`torch.tensor`: The finite difference matrix.
.. note::
This function returns the same matrix as :func:`finite_diff_mat` with
the Neumann boundary condition. Internal implementation is different
and allows to process rectangular images.
"""
h, w = img_shape
# create h blocks of wxw matrices
max_ = max(h, w)
Expand All @@ -243,8 +257,8 @@ def neumann_boundary(img_shape):
block_w = spdiags([ones[:w], m_ones[:w]], [0, -1], (w, w))

# create blocks using kronecker product
Dx = torch.kron(block_h.to_dense(), torch.eye(w))
Dy = torch.kron(torch.eye(h), block_w.to_dense())
Dx = torch.kron(torch.eye(h), block_w.to_dense())
Dy = torch.kron(block_h.to_dense(), torch.eye(w))

return Dx, Dy

Expand Down

0 comments on commit 6e5fcf3

Please sign in to comment.