-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
98 lines (76 loc) · 3.01 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import random
import torch
import numpy as np
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
def token_level_to_char_level(text, offsets, preds):
probas_char = np.zeros(len(text))
for i, offset in enumerate(offsets):
if offset[0] or offset[1]:
probas_char[offset[0]:offset[1]] = preds[i]
return probas_char
def jaccard(str1, str2):
"""Original metric implementation."""
a = set(str1.lower().split())
b = set(str2.lower().split())
c = a.intersection(b)
return float(len(c)) / (len(a) + len(b) - len(c))
def get_best_start_end_idx(start_logits, end_logits,
orig_start, orig_end):
"""Return best start and end indices following BERT paper."""
best_logit = -np.inf
best_idxs = None
start_logits = start_logits[orig_start:orig_end + 1]
end_logits = end_logits[orig_start:orig_end + 1]
for start_idx, start_logit in enumerate(start_logits):
for end_idx, end_logit in enumerate(end_logits[start_idx:]):
logit_sum = start_logit + end_logit
if logit_sum > best_logit:
best_logit = logit_sum
best_idxs = (orig_start + start_idx,
orig_start + start_idx + end_idx)
return best_idxs
def calculate_jaccard(original_tweet, target_string,
start_logits, end_logits,
orig_start, orig_end,
offsets,
verbose=False):
"""Calculates final Jaccard score using predictions."""
start_idx, end_idx = get_best_start_end_idx(
start_logits, end_logits, orig_start, orig_end)
filtered_output = ''
for ix in range(start_idx, end_idx + 1):
filtered_output += original_tweet[offsets[ix][0]:offsets[ix][1]]
if (ix + 1) < len(offsets) and offsets[ix][1] < offsets[ix + 1][0]:
filtered_output += ' '
# Return orig tweet if it has less then 2 words
if len(original_tweet.split()) < 2:
filtered_output = original_tweet
if len(filtered_output.split()) == 1:
filtered_output = filtered_output.replace('!!!!', '!')
filtered_output = filtered_output.replace('..', '.')
filtered_output = filtered_output.replace('...', '.')
filtered_output = filtered_output.replace('ïï', 'ï')
filtered_output = filtered_output.replace('¿¿', '¿')
jac = jaccard(target_string.strip(), filtered_output.strip())
return jac, filtered_output
class AverageMeter:
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count