-
Notifications
You must be signed in to change notification settings - Fork 0
/
load.py
48 lines (38 loc) · 1.54 KB
/
load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import pyro
import pyro.distributions as dist
import functools
import os
import tyxe
import network
def CNN():
path = "output/cnn/model.pth"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = network.network(in_channels=1, output_size=24, device=device)
# load the pretrained model
net.load_state_dict(torch.load(path))
net.eval()
return net
def BCNN():
path = "output/bcnn"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = network.network(in_channels=1, output_size=24, device=device)
# load the pretrained model =========================================================
prior_kwargs = dict() # expose_all=False, hide_module_types=(nn.BatchNorm2d,))
likelihood = tyxe.likelihoods.Categorical(27455)
prior = tyxe.priors.IIDPrior(
dist.Normal(torch.zeros(1, device=device), torch.ones(1, device=device)),
**prior_kwargs
)
guide = functools.partial(
tyxe.guides.AutoNormal,
init_loc_fn=tyxe.guides.PretrainedInitializer.from_net(net, prefix="net"),
init_scale=1e-4,
max_guide_scale=1
) #, train_loc=not scale_only)
bnn = tyxe.VariationalBNN(net, prior, likelihood, guide)
# load pre-trained model ============================================================
pyro.clear_param_store()
bnn.net.load_state_dict(torch.load(os.path.join(path, "state_dict.pt")))
pyro.get_param_store().load(os.path.join(path, "param_store.pt"), map_location=device)
return bnn