Skip to content

Commit

Permalink
Merge pull request #227 from openspyrit/spyrit_2.3.2
Browse files Browse the repository at this point in the history
patched tuto7
  • Loading branch information
romainphan committed Jul 10, 2024
2 parents 91ce609 + da48050 commit 2d9b65c
Showing 1 changed file with 7 additions and 25 deletions.
32 changes: 7 additions & 25 deletions tutorial/tuto_07_drunet_split_measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,35 +143,17 @@
###############################################################################
# We download the pretrained weights of the DRUNet and load them.

local_folder = "./model/"
# Create model folder
if os.path.exists(local_folder):
print(f"{local_folder} found")
else:
os.mkdir(local_folder)
print(f"Created {local_folder}")
from spyrit.misc.load_data import download_girder

# Load pretrained model
url = "https://tomoradio-warehouse.creatis.insa-lyon.fr/api/v1"
dataID = "667ebf9ebaa5a9000705895e" # unique ID of the file
local_folder = "./model/"
data_name = "tuto7_drunet_gray.pth"
model_drunet_path = os.path.join(local_folder, data_name)

if os.path.exists(model_drunet_path):
print(f"Model found : {data_name}")

else:
print(f"Model not found : {data_name}")
print(f"Downloading model... ", end="")
try:
gc = girder_client.GirderClient(apiUrl=url)
gc.downloadFile(dataID, model_drunet_path)
print("Done")
except Exception as e:
print("Failed with error: ", e)
model_drunet_abs_path = download_girder(url, dataID, local_folder, data_name)

# Load pretrained weights
denoi_drunet.load_state_dict(torch.load(model_drunet_path), strict=False)
denoi_drunet.load_state_dict(torch.load(model_drunet_abs_path), strict=False)

# %%
# Pluggind the DRUnet in a DCNet
Expand Down Expand Up @@ -257,10 +239,10 @@

# Load pretrained model
try:
drunet_den.load_state_dict(torch.load(model_drunet_path), strict=True)
print(f"Model {model_drunet_path} loaded.")
drunet_den.load_state_dict(torch.load(model_drunet_abs_path), strict=True)
print(f"Model {model_drunet_abs_path} loaded.")
except:
print(f"Model {model_drunet_path} not found!")
print(f"Model {model_drunet_abs_path} not found!")
load_drunet = False
drunet_den = drunet_den.to(device)

Expand Down

0 comments on commit 2d9b65c

Please sign in to comment.