-
Notifications
You must be signed in to change notification settings - Fork 9
/
train.py
153 lines (120 loc) · 5.85 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import numpy as np
import os
import argparse
import tensorflow as tf
import cv2
import random
from predictor import resfcn256
import math
from datetime import datetime
class TrainData(object):
def __init__(self, train_data_file):
super(TrainData, self).__init__()
self.train_data_file = train_data_file
self.train_data_list = []
self.readTrainData()
self.index = 0
self.num_data = len(self.train_data_list)
def readTrainData(self):
with open(self.train_data_file) as fp:
temp = fp.readlines()
for item in temp:
item = item.strip().split()
self.train_data_list.append(item)
random.shuffle(self.train_data_list)
def getBatch(self, batch_list):
batch = []
imgs = []
labels = []
for item in batch_list:
img = cv2.imread(item[0])
label = np.load(item[1])
img_array = np.array(img, dtype=np.float32)
imgs.append(img_array / 256.0 / 1.1)
label_array = np.array(label, dtype=np.float32)
labels.append(label_array / 256 / 1.1)
batch.append(imgs)
batch.append(labels)
return batch
def __call__(self, batch_num):
if (self.index + batch_num) <= self.num_data:
batch_list = self.train_data_list[self.index:(self.index + batch_num)]
batch_data = self.getBatch(batch_list)
self.index += batch_num
return batch_data
else:
self.index = 0
random.shuffle(self.train_data_list)
batch_list = self.train_data_list[self.index:(self.index + batch_num)]
batch_data = self.getBatch(batch_list)
self.index += batch_num
return batch_data
def main(args):
# Some arguments
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
batch_size = args.batch_size
epochs = args.epochs
train_data_file = args.train_data_file
model_path = args.model_path
save_dir = args.checkpoint
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Training data
data = TrainData(train_data_file)
begin_epoch = 0
if os.path.exists(model_path + '.data-00000-of-00001'):
begin_epoch = int(model_path.split('_')[-1]) + 1
epoch_iters = data.num_data / batch_size
global_step = tf.Variable(epoch_iters * begin_epoch, trainable=False)
# Declay learning rate half every 5 epochs
decay_steps = 5 * epoch_iters
# learning_rate = learning_rate * 0.5 ^ (global_step / decay_steps)
learning_rate = tf.train.exponential_decay(args.learning_rate, global_step,
decay_steps, 0.5, staircase=True)
x = tf.placeholder(tf.float32, shape=[None, 256, 256, 3])
label = tf.placeholder(tf.float32, shape=[None, 256, 256, 3])
# Train net
net = resfcn256(256, 256)
x_op = net(x, is_training=True)
# Loss
weights = cv2.imread("Data/uv-data/weight_mask_final.jpg") # [256, 256, 3]
weights_data = np.zeros([1, 256, 256, 3], dtype=np.float32)
weights_data[0, :, :, :] = weights # / 16.0
loss = tf.losses.mean_squared_error(label, x_op, weights_data)
# This is for batch norm layer
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_step = tf.train.AdamOptimizer(learning_rate=learning_rate,
beta1=0.9, beta2=0.999, epsilon=1e-08,
use_locking=False).minimize(loss, global_step=global_step)
sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
sess.run(tf.global_variables_initializer())
if os.path.exists(model_path + '.data-00000-of-00001'):
tf.train.Saver(net.vars).restore(sess, model_path)
saver = tf.train.Saver(var_list=tf.global_variables())
save_path = model_path
# Begining train
time_now = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
fp_log = open("log_" + time_now + ".txt","w")
iters_total_each_epoch = int(math.ceil(1.0 * data.num_data / batch_size))
for epoch in range(begin_epoch, epochs):
for iters in range(iters_total_each_epoch):
batch = data(batch_size)
loss_res, _, global_step_res, learning_rate_res = sess.run(
[loss, train_step, global_step, learning_rate], feed_dict={x: batch[0], label: batch[1]})
time_now_tmp = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
log_line = '[' + time_now_tmp + ']:' + 'global_step:%d:iters:%d/epoch:%d,learning rate:%f,loss:%f' % (global_step_res, iters, epoch, learning_rate_res, loss_res)
print(log_line)
fp_log.writelines(log_line + "\n")
saver.save(sess=sess, save_path=save_path + '_' + str(epoch))
fp_log.close()
if __name__ == '__main__':
par = argparse.ArgumentParser(description='Joint 3D Face Reconstruction and Dense Alignment with Position Map Regression Network')
par.add_argument('--train_data_file', default='face3d/examples/trainDataLabel.txt', type=str, help='The training data file')
par.add_argument('--learning_rate', default=0.0002, type=float, help='The learning rate')
par.add_argument('--epochs', default=50, type=int, help='Total epochs')
par.add_argument('--batch_size', default=16, type=int, help='Batch sizes')
par.add_argument('--checkpoint', default='checkpoint/', type=str, help='The path of checkpoint')
par.add_argument('--model_path', default='checkpoint/256_256_resfcn256_weight', type=str, help='The path of pretrained model')
par.add_argument('--gpu', default='0', type=str, help='The GPU ID')
main(par.parse_args())