-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
57 lines (35 loc) · 1.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Import Packages
import torch
from nltk.tokenize import word_tokenize
def translate_sentence(sentence, src_field, trg_field, model, device, max_len = 50):
model.eval()
sentence = ' '.join([str(elem) for elem in sentence])
tokens = [token.lower() for token in word_tokenize(sentence)]
tokens = [src_field.init_token] + tokens + [src_field.eos_token]
src_indexes = [src_field.vocab.stoi[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
src_mask = model.make_src_mask(src_tensor)
with torch.no_grad():
enc_src = model.encoder(src_tensor, src_mask)
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
for i in range(max_len):
trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
trg_mask = model.make_trg_mask(trg_tensor)
with torch.no_grad():
output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
pred_token = output.argmax(2)[:,-1].item()
trg_indexes.append(pred_token)
if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
break
trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
return trg_tokens[1:], attention
def tokenize_src(text):
"""
Tokenizes German text from a string into a list of strings (tokens) and reverses it
"""
return [tok for tok in word_tokenize(text)]
def tokenize_trg(text):
"""
Tokenizes English text from a string into a list of strings (tokens)
"""
return [tok for tok in word_tokenize(text)]