-
Notifications
You must be signed in to change notification settings - Fork 49
/
utils.py
152 lines (120 loc) · 4.48 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import math
import torch
import torch.nn.functional as F
from torch.distributions.bernoulli import Bernoulli
from build_vocab import PAD_TOKEN, UNK_TOKEN
def collate_fn(sign2id, batch):
# filter the pictures that have different weight or height
size = batch[0][0].size()
batch = [img_formula for img_formula in batch
if img_formula[0].size() == size]
# sort by the length of formula
batch.sort(key=lambda img_formula: len(img_formula[1].split()),
reverse=True)
imgs, formulas = zip(*batch)
formulas = [formula.split() for formula in formulas]
# targets for training , begin with START_TOKEN
tgt4training = formulas2tensor(add_start_token(formulas), sign2id)
# targets for calculating loss , end with END_TOKEN
tgt4cal_loss = formulas2tensor(add_end_token(formulas), sign2id)
imgs = torch.stack(imgs, dim=0)
return imgs, tgt4training, tgt4cal_loss
def formulas2tensor(formulas, sign2id):
"""convert formula to tensor"""
batch_size = len(formulas)
max_len = len(formulas[0])
tensors = torch.ones(batch_size, max_len, dtype=torch.long) * PAD_TOKEN
for i, formula in enumerate(formulas):
for j, sign in enumerate(formula):
tensors[i][j] = sign2id.get(sign, UNK_TOKEN)
return tensors
def add_start_token(formulas):
return [['<s>']+formula for formula in formulas]
def add_end_token(formulas):
return [formula+['</s>'] for formula in formulas]
def count_parameters(model):
"""count model parameters"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def tile(x, count, dim=0):
"""
Tiles x on dimension dim count times.
"""
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
x = x.permute(perm).contiguous()
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = x.view(batch, -1) \
.transpose(0, 1) \
.repeat(count, 1) \
.transpose(0, 1) \
.contiguous() \
.view(*out_size)
if dim != 0:
x = x.permute(perm).contiguous()
return x
def load_formulas(filename):
formulas = dict()
with open(filename) as f:
for idx, line in enumerate(f):
formulas[idx] = line.strip()
print("Loaded {} formulas from {}".format(len(formulas), filename))
return formulas
def cal_loss(logits, targets):
"""args:
logits: probability distribution return by model
[B, MAX_LEN, voc_size]
targets: target formulas
[B, MAX_LEN]
"""
padding = torch.ones_like(targets) * PAD_TOKEN
mask = (targets != padding)
targets = targets.masked_select(mask)
logits = logits.masked_select(
mask.unsqueeze(2).expand(-1, -1, logits.size(2))
).contiguous().view(-1, logits.size(2))
logits = torch.log(logits)
assert logits.size(0) == targets.size(0)
loss = F.nll_loss(logits, targets)
return loss
def get_checkpoint(ckpt_dir):
"""return full path if there is ckpt in ckpt_dir else None"""
if not os.path.isdir(ckpt_dir):
raise FileNotFoundError("No checkpoint found in {}".format(ckpt_dir))
ckpts = [f for f in os.listdir(ckpt_dir) if f.startswith('ckpt')]
if not ckpts:
raise FileNotFoundError("No checkpoint found in {}".format(ckpt_dir))
last_ckpt, max_epoch = None, 0
for ckpt in ckpts:
epoch = int(ckpt.split('-')[1])
if epoch > max_epoch:
max_epoch = epoch
last_ckpt = ckpt
full_path = os.path.join(ckpt_dir, last_ckpt)
print("Get checkpoint from {} for training".format(full_path))
return full_path
def schedule_sample(prev_logit, prev_tgt, epsilon):
prev_out = torch.argmax(prev_logit, dim=1, keepdim=True)
prev_choices = torch.cat([prev_out, prev_tgt], dim=1) # [B, 2]
batch_size = prev_choices.size(0)
prob = Bernoulli(torch.tensor([epsilon]*batch_size).unsqueeze(1))
# sampling
sample = prob.sample().long().to(prev_tgt.device)
next_inp = torch.gather(prev_choices, 1, sample)
return next_inp
def cal_epsilon(k, step, method):
"""
Reference:
Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks
See details in https://arxiv.org/pdf/1506.03099.pdf
"""
assert method in ['inv_sigmoid', 'exp', 'teacher_forcing']
if method == 'exp':
return k**step
elif method == 'inv_sigmoid':
return k/(k+math.exp(step/k))
else:
return 1.