Skip to content

Commit

Permalink
tuto8 correction
Browse files Browse the repository at this point in the history
  • Loading branch information
romainphan committed Jun 11, 2024
1 parent 7cbf9e3 commit db555ca
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions tutorial/tuto_08_lpgd_split_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@

from spyrit.core.meas import HadamSplit
from spyrit.core.noise import Poisson
from spyrit.misc.sampling import meas2img2
from spyrit.misc.sampling import meas2img
from spyrit.misc.statistics import Cov2Var
from spyrit.core.prep import SplitPoisson

Expand All @@ -106,7 +106,7 @@
Ord_rec[n_sub:, :] = 0

# Measurement and noise operators
meas_op = HadamSplit(M, h, Ord_rec)
meas_op = HadamSplit(M, h, torch.from_numpy(Ord_rec))
noise_op = Poisson(meas_op, alpha)
prep_op = SplitPoisson(alpha, meas_op)

Expand All @@ -119,8 +119,8 @@
m = prep_op(y) # preprocessed measurement vector

m_plot = m.detach().numpy()
m_plot = meas2img2(m_plot.T, Ord_rec)
imagesc(m_plot, r"Measurements $m$")
m_plot = meas2img(m_plot, Ord_rec)
imagesc(m_plot[0, :, :], r"Measurements $m$")

###############################################################################
# We define the LearnedPGD network by providing the measurement, noise and preprocessing operators,
Expand Down Expand Up @@ -163,15 +163,15 @@
url_lpgd = "https://drive.google.com/file/d/1ki_cJQEwBWrpDhtE7-HoSEoY8oJUnUz5/view?usp=drive_link"
model_net_path = os.path.join(
model_path,
"lpgd_unet_imagenet_N0_10_m_hadam-split_N_128_M_4096_epo_30_lr_0.001_sss_10_sdr_0.5_bs_128_reg_1e-07_uit_3_sdec0-9",
"lpgd_unet_imagenet_N0_10_m_hadam-split_N_128_M_4096_epo_30_lr_0.001_sss_10_sdr_0.5_bs_128_reg_1e-07_uit_3_sdec0-9.pth",
)

try:
import gdown

gdown.download(url_lpgd, model_path, quiet=False, fuzzy=True)
except:
print(f"Model not downloaded from {url_lpgd}!!!")
if os.path.exists(model_net_path) is False:
try:
import gdown
gdown.download(url_lpgd, model_net_path, quiet=False, fuzzy=True)
except:
print(f"Model not downloaded from {url_lpgd}!!!")

# Load pretrained weights to the model
load_net(model_net_path, lpgd_net, device, strict=False)
Expand Down

0 comments on commit db555ca

Please sign in to comment.