-
Notifications
You must be signed in to change notification settings - Fork 1
/
pretrain_transformer.py
122 lines (102 loc) · 4.75 KB
/
pretrain_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
from torch.utils.data import (DataLoader, Dataset)
from torchmetrics import MeanAbsoluteError, MeanAbsolutePercentageError
from CNN import PretrainingDataset, cnn_multi_dim, output_2
from transformer import MultiViewTransformer
from training_utils import TrainingModel
import os
import os.path as osp
class PretrainTransformerDataset(Dataset):
def __init__(
self,
nets: nn.Module,
image_dataset: Dataset,
cache_root: str = './transformer_pretraining_dataset/') -> None:
"""Dataset class for pretraining the transformer given cnn
Args:
cnn_model (nn.Module): The frozen CNN model
image_loader (nn.DataLoader): The dataloader of image to feed into CNN
cache_root (str): The directory to store the activations
"""
super().__init__()
# Keep track of data
self.lst_data_dir = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# If the folder doesn't exist, create it
if not osp.isdir(cache_root):
os.makedirs(cache_root)
nets.to(device)
rets = output_2(nets, image_dataset, device)
for x in rets:
# Store the representations
x = x.permute(1,0,2)
name = osp.join(cache_root, str(len(self.lst_data_dir)))
torch.save(x, name)
self.lst_data_dir.append(name)
else:
for f in os.listdir(cache_root):
name = os.path.join(cache_root, f)
self.lst_data_dir.append(name)
def __getitem__(self, idx):
return torch.load(self.lst_data_dir[idx])
def __len__(self):
return len(self.lst_data_dir)
if __name__ == '__main__':
from argparse import ArgumentParser
from pytorch_lightning import Trainer
torch.multiprocessing.set_start_method('forkserver', force=True)
parser = ArgumentParser()
parser.add_argument("--n_hidden", type=int, default=10)
parser.add_argument("--batch_size_generate", type=int, default=32)
parser.add_argument("--n_worker_generate", type=int, default=32)
parser.add_argument("--batch_size_pretrain", type=int, default=6)
parser.add_argument("--n_worker_pretrain", type=int, default=6)
parser.add_argument("--cache_root",
type=str,
default="./pretrain_transformer_dataset_cache/")
parser.add_argument("--imaging_dataset_dir", type=str, default="./data/")
parser.add_argument("--imaging_dataset_cache_dir", type=str, default="./cached_mri")
parser.add_argument("--cnn_checkpoint_path", type=str, default="./cnn_checkpoints/checkpointat5.pth")
parser.add_argument("--store_checkpoint_path", type=str, default="./transformer_checkpoints/")
parser.add_argument("--pretrain", type=int, default=1)
parser = Trainer.add_argparse_args(parser)
parser = TrainingModel.add_model_specific_args(parser)
args = parser.parse_args()
if not args.pretrain:
model = MultiViewTransformer(args)
torch.save(model.state_dict(), args.store_checkpoint_path)
else:
# CNN model
cnn_models = nn.ModuleList([cnn_multi_dim(i, args.n_hidden) for i in range(3)])
loaded = torch.load(args.cnn_checkpoint_path)
for model_idx, model in enumerate(cnn_models):
model.load_state_dict(loaded[model_idx])
# Load image dataset
image_dataset = PretrainingDataset(
path=args.imaging_dataset_dir, cache_path=args.imaging_dataset_cache_dir
)
# Instantiate pretraining dataset
dataset = PretrainTransformerDataset(nets=cnn_models,
image_dataset=image_dataset,
cache_root=args.cache_root)
loader = DataLoader(dataset,
batch_size=args.batch_size_pretrain,
shuffle=True,
num_workers=args.n_worker_pretrain)
# Instantiate training model
trainer_model = TrainingModel(
args=args,
model_args={"Transformer": {"args": args}},
models={"Transformer": MultiViewTransformer},
model_forward_args={"Transformer": {
"mask": True
}},
model_order=["Transformer"],
metrics={"MAE": MeanAbsoluteError(), "MAPE": MeanAbsolutePercentageError()}
)
# Instantiate trainer
trainer = Trainer.from_argparse_args(args, strategy='ddp', gpus=2)
trainer.fit(trainer_model, loader)
# Save the transformer model weight
torch.save(trainer_model.Transformer.state_dict(), args.store_checkpoint_path)