-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
128 lines (112 loc) · 4.55 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from datetime import datetime
import os
import shutil
import torch
from src.args import Args
from src.params import Params
from src.system import DHOSystem, CSTRSystem, HolohoverSystem
from src.model_black import DHOModelBlack, CSTRModelBlack, HolohoverModelBlack
from src.model_grey import HolohoverModelGrey, CorrectModelGrey
from src.model_black_simple import ModelBlackSimple
from src.learn import LearnGreyModel, LearnCorrection, LearnStableModel, LearnBlackSimple
from src.plot import Plot
from src.simulation import Simulation
def main():
# pytorch device and random seed
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
device = torch.device(dev)
torch.manual_seed(0)
# load arguments and parameters
args = Args(model_type="HolohoverBlackSimple")
params = None
if args.model_type == "HolohoverGrey":
params = Params(args=args)
# create directory
t = datetime.now()
dir_name = t.strftime("%Y%m%d") + "_" + t.strftime("%H%M")
args.dir_path = os.path.join("models", args.model_type, dir_name)
if os.path.exists(args.dir_path):
shutil.rmtree(args.dir_path)
os.mkdir(args.dir_path)
# init. system
if args.model_type == "DHO":
sys = DHOSystem(args=args, dev=device)
elif args.model_type == "CSTR":
sys = CSTRSystem(args=args, dev=device)
elif args.model_type == "HolohoverBlack" or args.model_type == "HolohoverGrey" or args.model_type == "HolohoverBlackSimple":
sys = HolohoverSystem(args=args, dev=device)
# init. equilibrium point
if args.model_type == "DHO":
ueq = torch.tensor([0])
xeq = sys.equPoint(ueq, U_hat=False)
ueq = sys.uMap(ueq)
elif args.model_type == "CSTR":
ueq = torch.tensor([14.19])
xeq = sys.equPoint(ueq, U_hat=False)
ueq = sys.uMap(ueq)
elif args.model_type == "HolohoverBlack" or args.model_type == "HolohoverGrey" or args.model_type == "HolohoverBlackSimple":
ueq = torch.zeros(sys.M)
xeq = torch.zeros(sys.D)
ueq = sys.uMap(ueq)
# init. model
model = None
cor_model = None
if args.model_type == "DHO":
model = DHOModelBlack(args=args, dev=device, system=sys, xref=xeq)
elif args.model_type == "CSTR":
model = CSTRModelBlack(args=args, dev=device, system=sys, xref=xeq)
elif args.model_type == "HolohoverBlack":
model = HolohoverModelBlack(args=args, dev=device, system=sys, xref=xeq)
elif args.model_type == "HolohoverGrey":
model = HolohoverModelGrey(args=args, params=params, dev=device)
cor_model = CorrectModelGrey(args=args, dev=dev)
elif args.model_type == "HolohoverBlackSimple":
model = ModelBlackSimple(args=args, dev=device)
# load model to continue learning process
if args.load_model:
model.load_state_dict(torch.load(args.model_path))
# init. base learner
ld = None
lc = None
if args.model_type == "DHO" or args.model_type == "CSTR" or args.model_type == "HolohoverBlack":
ld = LearnStableModel(args=args, dev=device, system=sys, model=model)
elif args.model_type == "HolohoverGrey":
ld = LearnGreyModel(args=args, dev=device, system=sys, model=model)
if args.learn_correction: # init. correction learner
lc = LearnCorrection(args=args, dev=dev, system=sys, model=cor_model, base_model=model)
elif args.model_type == "HolohoverBlackSimple":
ld = LearnBlackSimple(args=args, dev=device, system=sys, model=model)
# learn dynamics
ld.optimize()
if args.learn_correction:
lc.optimize()
# plot results
plot = Plot(args=args, params=params, dev=device, model=model, cor_model=cor_model, system=sys, learn=ld, learn_cor=lc)
if args.model_type == "DHO":
plot.blackDHO()
if args.model_type == "CSTR":
plot.blackCSTR()
elif args.model_type == "HolohoverGrey":
plot.greyModel(ueq)
if args.learn_correction:
plot.corModel()
plot.paramsSig2Thrust()
plot.paramsVec()
plot.dataHistogram()
elif args.model_type == "HolohoverBlack" or args.model_type == "HolohoverBlackSimple":
plot.blackModel(ueq)
plot.dataHistogram()
# # simulate system
# sim = Simulation(sys, model)
# Xreal_seq, Xreal_integ_seq, Xlearn_seq = sim.simGrey()
# plot.simGrey(Xreal_seq, Xreal_integ_seq, Xlearn_seq)
# save model, arguments and parameters
ld.saveModel()
args.save()
if params:
params.save(model)
if __name__ == "__main__":
main()