-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
117 lines (99 loc) · 3.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
from argparse import Namespace
from doctest import testsource
from gc import callbacks
from typing import List
import torch
import torch.nn as nn
from matplotlib import transforms
from torch.utils.data import DataLoader
import DeepNoise.builders as builders
import wandb
from DeepNoise.algorithms.base_trainer import Trainer
from DeepNoise.builders.builders import build_cfg
from DeepNoise.callbacks.statistics import Callback
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--cfg_path", type=str)
parser.add_argument("--noise_type", type=str, default="SymmetricNoise")
parser.add_argument("--noise_prob", type=float, default=0)
parser.add_argument("--allow_equal_flips", type=str2bool, default=True)
args = parser.parse_args()
return args
def main():
args = parse_args()
cfg = build_cfg(args.cfg_path)
if "noise_injector" in cfg["data"]["trainset"]:
print(
"WARNING: the noise_injector setting in the config file will be overwritten"
" by the noise_settings passed from the command line arguments."
)
cfg["data"]["trainset"]["noise_injector"] = dict(
type=args.noise_type,
noise_prob=args.noise_prob,
allow_equal_flips=args.allow_equal_flips,
)
wandb.init(project="DeepNoise", config=cfg)
trainset = builders.build_dataset(cfg["data"]["trainset"])
valset = builders.build_dataset(cfg["data"]["valset"])
testset = builders.build_dataset(cfg["data"]["testset"])
train_loader = DataLoader(
trainset,
shuffle=True,
pin_memory=True,
batch_size=cfg["batch_size"],
num_workers=cfg["num_workers"],
)
val_loader = DataLoader(
valset,
shuffle=False,
pin_memory=True,
batch_size=cfg["batch_size"],
num_workers=cfg["num_workers"],
)
test_loader = DataLoader(
testset,
shuffle=False,
pin_memory=True,
batch_size=cfg["batch_size"],
num_workers=cfg["num_workers"],
)
model: nn.Module = builders.build_model(
cfg["model"], num_classes=cfg["num_classes"]
)
optimizer: torch.optim.Optimizer = builders.build_optimizer(
cfg["optimizer"], model=model
)
loss_fn: nn.Module = builders.build_loss(cfg["loss_fn"])
callbacks: List[Callback] = [
builders.build_callbacks(callback_cfg) for callback_cfg in cfg["callbacks"]
]
callbacks.extend(
[
builders.build_callbacks(callback_cfg, optimizer=optimizer)
for callback_cfg in cfg["optimizer_callbacks"]
]
)
trainer: Trainer = builders.build_trainer(
model=model,
optimizer=optimizer,
loss_fn=loss_fn,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
epochs=cfg["epochs"],
callbacks=callbacks,
cfg=cfg["trainer"],
)
trainer.start()
if __name__ == "__main__":
main()