From 40615ebf3f9007420f3662927d11ca844e3039ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Jun 2024 12:50:03 +0000 Subject: [PATCH] [pre-commit.ci] Automatic python formatting --- spyrit/core/meas.py | 168 ++++++++++++++++++---------------- spyrit/core/torch.py | 8 +- spyrit/misc/load_data.py | 8 +- spyrit/test/test_core_meas.py | 4 +- 4 files changed, 102 insertions(+), 86 deletions(-) diff --git a/spyrit/core/meas.py b/spyrit/core/meas.py index fddf8fe6..22584a8d 100644 --- a/spyrit/core/meas.py +++ b/spyrit/core/meas.py @@ -172,10 +172,10 @@ def pinv(self, x: torch.tensor, reg: str = None, eta: float = None) -> torch.ten if hasattr(self, "H_pinv"): # if the pseudo inverse has been computed ans = x @ self.H_pinv.T.to(x.dtype) - + # if not else: - + if isinstance(self, Linear): # can we compute the inverse of H ? H_to_inv = self.H_static @@ -201,14 +201,18 @@ def pinv(self, x: torch.tensor, reg: str = None, eta: float = None) -> torch.ten .to(x.dtype) .T ) - elif reg == 'H1': + elif reg == "H1": Dx, Dy = spytorch.neumann_boundary(self.img_shape) D2 = Dx.T @ Dx + Dy.T @ Dy - ans = torch.linalg.solve( - H_to_inv.T @ H_to_inv + eta * D2, - H_to_inv.T @ x.to(H_to_inv.dtype).T, - ).to(x.dtype).T - + ans = ( + torch.linalg.solve( + H_to_inv.T @ H_to_inv + eta * D2, + H_to_inv.T @ x.to(H_to_inv.dtype).T, + ) + .to(x.dtype) + .T + ) + elif reg is None: raise ValueError( "Regularization method not specified. Please compute " @@ -220,15 +224,18 @@ def pinv(self, x: torch.tensor, reg: str = None, eta: float = None) -> torch.ten f"Regularization method ({reg}) not implemented. Please " + "use 'L1' or 'L2'." ) - + # if we used bicubic b spline, convolve with the kernel - if hasattr(self, "recon_mode") and self.recon_mode == 'bicubic': + if hasattr(self, "recon_mode") and self.recon_mode == "bicubic": kernel = torch.tensor([[1, 4, 1], [4, 16, 4], [1, 4, 1]]) / 36 conv = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False) conv.weight.data = kernel.view(1, 1, 3, 3).to(ans.dtype) - - ans = conv(ans.view(-1, 1, self.img_h, self.img_w)) \ - .view(-1, self.img_h * self.img_w).data + + ans = ( + conv(ans.view(-1, 1, self.img_h, self.img_w)) + .view(-1, self.img_h * self.img_w) + .data + ) return ans @@ -1080,16 +1087,16 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None: meas_pattern = self.P else: meas_pattern = self.H_static - + n_frames = meas_pattern.shape[0] - + # get deformation field from motion # scale from [-1;1] x [-1;1] to [0;width-1] x [0;height-1] scale_factor = torch.tensor(self.img_shape) - 1 def_field = (motion.field + 1) / 2 * scale_factor if mode == "bilinear_old": - + # get the integer part of the field for the 4 nearest neighbours # 00 point 01 # +------+--------+ @@ -1099,12 +1106,12 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None: # | | | # +------+--------+ # 10 11 - + def_field = spytorch.center_crop( def_field.moveaxis(-1, 0), self.meas_shape - ).moveaxis(0, -1) + ).moveaxis(0, -1) H_padded = meas_pattern.reshape(-1, *self.meas_shape) - + def_field_00 = def_field.floor().to(torch.int16) self.def_field_00_1 = def_field_00 def_field_01 = def_field_00 + torch.tensor([0, 1]).to(torch.int16) @@ -1125,7 +1132,7 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None: # combine with H_padded H_dxy = H_padded.to(torch.float64) * dxy - + # ================ # ALL CORRECT HERE # ================ @@ -1181,7 +1188,7 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None: # | | | # +------+--------+ # 10 11 - + # 00 01 point 02 03 # +-----------+-----+-----+-----------+ # | | | | @@ -1197,18 +1204,20 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None: # | | | | | # +-----------+-----+-----+-----------+ # 30 31 32 33 - + kernel_size = self._spline(torch.tensor([0]), mode).shape[1] kernel_width = kernel_size - 1 kernel_n_pts = kernel_size**2 - + # PART 1: SEPARATE THE INTEGER AND DECIMAL PARTS OF THE FIELD # _________________________________________________________________ # crop def_field to keep only measured area # moveaxis because crop expects (h,w) as last dimensions def_field = spytorch.center_crop( def_field.moveaxis(-1, 0), self.meas_shape - ).moveaxis(0, -1) # shape (n_frames, meas_h, meas_w, 2) + ).moveaxis( + 0, -1 + ) # shape (n_frames, meas_h, meas_w, 2) # coordinate of top-left closest corner def_field_floor = def_field.floor().to(torch.int64) # shape (n_frames, meas_h, meas_w, 2) @@ -1218,29 +1227,25 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None: # dx.shape = dy.shape = (n_frames, meas_h, meas_w) # evaluate the spline at the decimal part dxy = torch.einsum( - 'iajk,ibjk->iabjk', - self._spline(dy, mode), - self._spline(dx, mode) - ).view(n_frames, kernel_n_pts, self.h*self.w) + "iajk,ibjk->iabjk", self._spline(dy, mode), self._spline(dx, mode) + ).view(n_frames, kernel_n_pts, self.h * self.w) # shape (n_frames, kernel_n_pts, meas_h*meas_w) - + # PART 2: FLATTEN THE INDICES # _________________________________________________________________ # we consider an expanded grid (img_h+k)x(img_w+k), where k is # (kernel_width). This allows each part of the (kernel_size^2)- # point grid to contribute to the interpolation. # get coordinate of point _00 - def_field_00 = def_field_floor - (kernel_size//2 - 1) + def_field_00 = def_field_floor - (kernel_size // 2 - 1) # shift the grid for phantom rows/columns def_field_00 += kernel_width # create a mask indicating if either of the 2 indices is out of bounds # (w,h) because the def_field is in (x,y) coordinates - maxs = torch.tensor([self.img_w + kernel_width, - self.img_h + kernel_width]) + maxs = torch.tensor([self.img_w + kernel_width, self.img_h + kernel_width]) mask = torch.logical_or( - (def_field_00 < 0).any(dim=-1), - (def_field_00 >= maxs).any(dim=-1) - ) # shape (n_frames, meas_h, meas_w) + (def_field_00 < 0).any(dim=-1), (def_field_00 >= maxs).any(dim=-1) + ) # shape (n_frames, meas_h, meas_w) # trash index receives all the out-of-bounds indices trash = (maxs[0] * maxs[1]).to(torch.int64) # if the indices are out of bounds, we put the trash index @@ -1248,29 +1253,30 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None: flattened_indices = torch.where( mask, trash, - def_field_00[..., 0] \ - + def_field_00[..., 1] * (self.img_w + kernel_width) - ).view(n_frames, self.h*self.w) - + def_field_00[..., 0] + + def_field_00[..., 1] * (self.img_w + kernel_width), + ).view(n_frames, self.h * self.w) + # PART 3: WARP H MATRIX WITH FLATTENED INDICES # _________________________________________________________________ - # Build 4 submatrices with 4 weights for bilinear interpolation - meas_dxy = meas_pattern.view(n_frames, 1, self.h*self.w) \ - .to(torch.float64) * dxy + # Build 4 submatrices with 4 weights for bilinear interpolation + meas_dxy = ( + meas_pattern.view(n_frames, 1, self.h * self.w).to(torch.float64) * dxy + ) # shape (n_frames, kernel_size^2, meas_h*meas_w) # Create a larger H_dyn that will be folded meas_dxy_sorted = torch.zeros( - (n_frames, - kernel_n_pts, - (self.img_h + kernel_width) * (self.img_w + kernel_width)+ 1), - # +1 for trash - dtype=torch.float64 + ( + n_frames, + kernel_n_pts, + (self.img_h + kernel_width) * (self.img_w + kernel_width) + 1, + ), + # +1 for trash + dtype=torch.float64, ) # add at flattened_indices the values of meas_dxy (~warping) meas_dxy_sorted.scatter_add_( - 2, - flattened_indices.unsqueeze(1).expand_as(meas_dxy), - meas_dxy + 2, flattened_indices.unsqueeze(1).expand_as(meas_dxy), meas_dxy ) # drop last column (trash) meas_dxy_sorted = meas_dxy_sorted[:, :, :-1] @@ -1278,13 +1284,15 @@ def build_H_dyn(self, motion: DeformationField, mode: str = "bilinear") -> None: # PART 4: FOLD THE MATRIX # _________________________________________________________________ # define operator - fold = nn.Fold(output_size=(self.img_h, self.img_w), - kernel_size=(kernel_size, kernel_size), - padding=kernel_width ) - H_dyn = fold(meas_dxy_sorted).view(n_frames, self.img_h*self.img_w) + fold = nn.Fold( + output_size=(self.img_h, self.img_w), + kernel_size=(kernel_size, kernel_size), + padding=kernel_width, + ) + H_dyn = fold(meas_dxy_sorted).view(n_frames, self.img_h * self.img_w) # store in _param_H_dyn self._param_H_dyn = nn.Parameter(H_dyn, requires_grad=False) - + def build_H_dyn_pinv(self, reg: str = "L1", eta: float = 1e-3) -> None: """Computes the pseudo-inverse of the dynamic measurement matrix `H_dyn` and stores it in the attribute `H_dyn_pinv`. @@ -1331,13 +1339,13 @@ def build_H_dyn_pinv(self, reg: str = "L1", eta: float = 1e-3) -> None: Dx, Dy = spytorch.neumann_boundary(self.img_shape) D2 = Dx.T @ Dx + Dy.T @ Dy pinv = torch.linalg.inv(H_dyn.T @ H_dyn + eta * D2) @ H_dyn.T - + else: raise NotImplementedError( f"Regularization method '{reg}' is not implemented. Please " + "choose either 'L1' or 'L2'." # , or 'H1'." ) - + self._param_H_dyn_pinv = nn.Parameter(pinv, requires_grad=False) def forward(self, x: torch.tensor) -> torch.tensor: @@ -1441,39 +1449,45 @@ def _forward_with_static_op( ) from e else: raise e - + @staticmethod def _spline(dx, mode): """ Returns a 2D row-like tensor containing the values of dx evaluated at each B-spline (2 values for bilinear, 4 for bicubic). dx must be between 0 and 1. - + Shapes dx: (n_frames, self.h, self.w) out: (n_frames, {2,4}, self.h, self.w) """ - if mode == 'bilinear': - return torch.stack((1-dx, dx), dim=1) - if mode == 'bicubic': - return torch.stack(( - (1-dx)**3 /6, - 2/3 - dx**2 * (2-dx)/2, - 2/3 - (1-dx)**2 * (1+dx)/2, - dx**3 /6 - ), dim=1) - if mode == 'schaum': - return torch.stack(( - dx/6 * (dx - 1) * (2 - dx), - (1 - dx/2) * (1 - dx**2), - (1 + (dx-1)/2) * (1 - (dx-1)**2), - 1/6 * (dx + 1) * dx * (dx - 1), - ), dim=1) + if mode == "bilinear": + return torch.stack((1 - dx, dx), dim=1) + if mode == "bicubic": + return torch.stack( + ( + (1 - dx) ** 3 / 6, + 2 / 3 - dx**2 * (2 - dx) / 2, + 2 / 3 - (1 - dx) ** 2 * (1 + dx) / 2, + dx**3 / 6, + ), + dim=1, + ) + if mode == "schaum": + return torch.stack( + ( + dx / 6 * (dx - 1) * (2 - dx), + (1 - dx / 2) * (1 - dx**2), + (1 + (dx - 1) / 2) * (1 - (dx - 1) ** 2), + 1 / 6 * (dx + 1) * dx * (dx - 1), + ), + dim=1, + ) else: raise NotImplementedError( f"The mode {mode} is invalid, please choose bilinear or bicubic" ) - + # ============================================================================= class DynamicLinearSplit(DynamicLinear): @@ -1818,5 +1832,3 @@ def __init__( # return ans.to(orig_dtype) # BICUBIC INTERPOLATION TO BE IMPLEMENTED - - diff --git a/spyrit/core/torch.py b/spyrit/core/torch.py index 301a9401..a45d0e3a 100644 --- a/spyrit/core/torch.py +++ b/spyrit/core/torch.py @@ -234,20 +234,20 @@ def neumann_boundary(img_shape): h, w = img_shape # create h blocks of wxw matrices max_ = max(h, w) - + # create diagonals ones = torch.ones(max_) ones[0] = 0 m_ones = -torch.ones(max_) block_h = spdiags([ones[:h], m_ones[:h]], [0, -1], (h, h)) block_w = spdiags([ones[:w], m_ones[:w]], [0, -1], (w, w)) - + # create blocks using kronecker product Dx = torch.kron(block_h.to_dense(), torch.eye(w)) Dy = torch.kron(torch.eye(h), block_w.to_dense()) - + return Dx, Dy - + # ============================================================================= # Permutations and Sorting diff --git a/spyrit/misc/load_data.py b/spyrit/misc/load_data.py index e6f33743..eee247ff 100644 --- a/spyrit/misc/load_data.py +++ b/spyrit/misc/load_data.py @@ -46,7 +46,9 @@ def load_data_Comp_1D_old(Path_files, list_files, Nh, Nl, Nc): Data[:, i // 2] = Sum_coll( np.rot90(np.array(PIL.Image.open(Path_files + list_files[i])), 3), Nl, Nc ) - Sum_coll( - np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3), Nl, Nc + np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3), + Nl, + Nc, ) return Data @@ -59,7 +61,9 @@ def load_data_Comp_1D_new(Path_files, list_files, Nh, Nl, Nc): for i in range(0, 2 * Nh, 2): Data[:, i // 2] = Sum_coll( - np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3), Nl, Nc + np.rot90(np.array(PIL.Image.open(Path_files + list_files[i + 1])), 3), + Nl, + Nc, ) - Sum_coll( np.rot90(np.array(PIL.Image.open(Path_files + list_files[i])), 3), Nl, Nc ) diff --git a/spyrit/test/test_core_meas.py b/spyrit/test/test_core_meas.py index 6566ad9e..5ad4bf23 100644 --- a/spyrit/test/test_core_meas.py +++ b/spyrit/test/test_core_meas.py @@ -253,11 +253,11 @@ def f(t): time_vector = torch.linspace(0, 1, 400) field = AffineDeformationField(f, time_vector, img_shape=(50, 50)) - meas_op.build_H_dyn(field, 'bilinear') + meas_op.build_H_dyn(field, "bilinear") assert_close_all( meas_op.H, H.flip(1), "Wrong dynamic measurement matrix", atol=1e-6 ) - meas_op.build_H_dyn(field, 'bicubic') + meas_op.build_H_dyn(field, "bicubic") print("ok") # build pseudo inverse H_dyn_pinv