diff --git a/spyrit/core/torch.py b/spyrit/core/torch.py index a45d0e3a..f9d2a053 100644 --- a/spyrit/core/torch.py +++ b/spyrit/core/torch.py @@ -200,37 +200,51 @@ def finite_diff_mat(n, boundary="dirichlet"): ones_0right = ones_0right.flatten() if boundary == "dirichlet": - Dx = spdiags([ones, -ones], [0, -n], (N, N)) - Dy = spdiags([ones, -ones_0right], [0, -1], (N, N)) + Dx = spdiags([ones, -ones_0right], [0, -1], (N, N)) + Dy = spdiags([ones, -ones], [0, -n], (N, N)) elif boundary == "neumann": ones_0left = ones_0right.roll(1) ones_0top = ones_0left.reshape(n, n).T.flatten() - Dx = spdiags([ones_0top, -ones], [0, -n], (N, N)) - Dy = spdiags([ones_0left, -ones_0right], [0, -1], (N, N)) + Dx = spdiags([ones_0left, -ones_0right], [0, -1], (N, N)) + Dy = spdiags([ones_0top, -ones], [0, -n], (N, N)) elif boundary == "periodic": zeros_1left = (1 - ones_0right).roll(1) - Dx = spdiags([ones, -ones, -ones], [0, -n, N - n], (N, N)) - Dy = spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N)) + Dx = spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N)) + Dy = spdiags([ones, -ones, -ones], [0, -n, N - n], (N, N)) elif boundary == "symmetric": zeros_1left = (1 - ones_0right).roll(1) zeros_1top = zeros_1left.reshape(n, n).T.flatten() - Dx = spdiags([ones, -ones, -zeros_1top], [0, -n, n], (N, N)) - Dy = spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N)) + Dx = spdiags([ones, -ones_0right, -zeros_1left], [0, -1, n - 1], (N, N)) + Dy = spdiags([ones, -ones, -zeros_1top], [0, -n, n], (N, N)) elif boundary == "antisymmetric": zeros_1left = (1 - ones_0right).roll(1) zeros_1top = zeros_1left.reshape(n, n).T.flatten() - Dx = spdiags([ones, -ones, zeros_1top], [0, -n, n], (N, N)) - Dy = spdiags([ones, -ones_0right, zeros_1left], [0, -1, 1], (N, N)) + Dx = spdiags([ones, -ones_0right, zeros_1left], [0, -1, 1], (N, N)) + Dy = spdiags([ones, -ones, zeros_1top], [0, -n, n], (N, N)) return Dx, Dy def neumann_boundary(img_shape): - """""" + r""" + Creates a finite difference matrix of shape :math:`(h*w,h*w)` for a 2D + image of shape :math:`(h,w)`. The boundary condition used is Neumann. + + Args: + :attr:`img_shape` (tuple): The size of the image :math:`(h,w)`. + + Returns: + :class:`torch.tensor`: The finite difference matrix. + + .. note:: + This function returns the same matrix as :func:`finite_diff_mat` with + the Neumann boundary condition. Internal implementation is different + and allows to process rectangular images. + """ h, w = img_shape # create h blocks of wxw matrices max_ = max(h, w) @@ -243,8 +257,8 @@ def neumann_boundary(img_shape): block_w = spdiags([ones[:w], m_ones[:w]], [0, -1], (w, w)) # create blocks using kronecker product - Dx = torch.kron(block_h.to_dense(), torch.eye(w)) - Dy = torch.kron(torch.eye(h), block_w.to_dense()) + Dx = torch.kron(torch.eye(h), block_w.to_dense()) + Dy = torch.kron(block_h.to_dense(), torch.eye(w)) return Dx, Dy