-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
79 lines (63 loc) · 3.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import argparse
import numpy as np
import torch
# load model
from model.densenet import DenseNet
from data import load_data
from traning.train import train_model
class Args():
def __init__(self, args_dict):
# dense connection rate
self.cross_block_rate = args_dict['cross_block_rate'] if 'cross_block_rate' in args_dict else 0.5 # default = 0.5; range = [0, 1]
self.end_block_reduction_rate = args_dict['end_block_reduction_rate'] if 'end_block_reduction_rate' in args_dict else 0.5 # default = 0.5; range = (0, 1]
self.seed = args_dict['seed'] if 'seed' in args_dict else 1 # seed for selecting random connections
# model hyperparameter
self.stages = list(map(int, args_dict['stages'].split('-'))) if 'stages' in args_dict else [10,10,10]
self.growth = list(map(int, args_dict['growth'].split('-'))) if 'growth' in args_dict else [12,12,12]
self.group_1x1 = args_dict['group_1x1'] if 'group_1x1' in args_dict else 4
self.group_3x3 = args_dict['group_3x3'] if 'group_3x3' in args_dict else 4
self.bottleneck = args_dict['bottleneck'] if 'bottleneck' in args_dict else 4
self.lr = args_dict['lr'] if 'lr' in args_dict else 1e-1
self.ep = args_dict['ep'] if 'ep' in args_dict else 100
self.optimizer = args_dict['optimizer'] if 'optimizer' in args_dict else 'adam' # default = 'adam'; options = {'sgd', 'adam'}
self.scheduler = args_dict['scheduler'] if 'scheduler' in args_dict else 'cos' # default = 'cos'; options = {'none', 'clr', 'exp', 'mlr', 'cos'}
# training batch
self.bsize = args_dict['bsize'] if 'bsize' in args_dict else 512
self.one_batch = args_dict['one_batch'] if 'one_batch' in args_dict else False
# folder name for saving result (default: default)
self.save_folder = args_dict['save_folder'] if 'save_folder' in args_dict else 'default'
# data validation
if len(self.stages) != len(self.growth):
raise RunTimeError("Stages and growth must have the same length")
def __str__(self):
print_str = ("{{cross_block_rate={:.1f},".format(self.cross_block_rate) +
"end_block_reduction_rate={:.1f},".format(self.end_block_reduction_rate) +
"stages=[{:s}],".format(','.join(map(str, self.stages))) +
"growth=[{:s}],".format(','.join(map(str, self.growth))) +
"group_1x1={:d},".format(self.group_1x1) +
"group_3x3={:d},".format(self.group_3x3) +
"bottleneck={:d},".format(self.bottleneck) +
"lr={:.3f},".format(self.lr) +
"ep={:d},".format(self.ep) +
"bottleneck={:d},".format(self.bottleneck) +
"optimizer=\'{:s}\',".format(self.optimizer) +
"scheduler=\'{:s}\',".format(self.scheduler) +
"bsize={:d},".format(self.bsize) +
"one_batch=\'{:s}\',".format(str(self.one_batch)) +
"save_folder=\'{:s}\'}}".format(self.save_folder))
return print_str
def main(args_dict):
args = Args(args_dict)
print (args)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args.device = device
print ('Using device: {:s}'.format(str(device)))
# load data
trainloader, validateloader, testloader, classes = load_data(args.bsize)
# model training
densenet = DenseNet(args)
save_data = train_model(densenet, trainloader, validateloader, testloader, device, args)
return save_data
if __name__ == '__main__':
main({})