-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
136 lines (114 loc) · 4.33 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import torch.nn as nn
from torch.utils.data import Dataset
class BilingualDataset(Dataset):
"""Dataset for a bilingual translation task."""
def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
"""Initialize the BilingualDataset.
Args:
ds: The dataset containing bilingual translation pairs.
tokenizer_src: Tokenizer for source language.
tokenizer_tgt: Tokenizer for target language.
src_lang: Source language code.
tgt_lang: Target language code.
seq_len (int): Maximum sequence length.
"""
super().__init__()
self.seq_len = seq_len
self.ds = ds
self.tokenizer_src = tokenizer_src
self.tokenizer_tgt = tokenizer_tgt
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.sos_token = torch.tensor(
[tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64
)
self.eos_token = torch.tensor(
[tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64
)
self.pad_token = torch.tensor(
[tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64
)
def __len__(self):
"""Get the length of the dataset."""
return len(self.ds)
def __getitem__(self, index: any):
"""Get an item from the dataset.
Args:
index: Index of the item to retrieve.
Returns:
dict: Dictionary containing encoder input, decoder input, masks,
label, and source/target texts.
"""
src_target_pair = self.ds[index]
src_text = src_target_pair["translation"][self.src_lang]
tgt_text = src_target_pair["translation"][self.tgt_lang]
enc_input_tokens = self.tokenizer_src.encode(src_text).ids
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
raise ValueError("Sentence is too long")
# Add SOS and EOS to the source text
encoder_input = torch.cat(
[
self.sos_token,
torch.tensor(enc_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor(
[self.pad_token] * enc_num_padding_tokens, dtype=torch.int64
),
],
dim=0,
)
# Add SOS to the decoder input
decoder_input = torch.cat(
[
self.sos_token,
torch.tensor(dec_input_tokens, dtype=torch.int64),
torch.tensor(
[self.pad_token] * dec_num_padding_tokens, dtype=torch.int64
),
],
dim=0,
)
# Add EOS to the label( what we expect as output from the decoder)
label = torch.cat(
[
torch.tensor(dec_input_tokens, dtype=torch.int64),
self.eos_token,
torch.tensor(
[self.pad_token] * dec_num_padding_tokens, dtype=torch.int64
),
],
dim=0,
)
assert encoder_input.size(0) == self.seq_len
assert decoder_input.size(0) == self.seq_len
assert label.size(0) == self.seq_len
return {
"encoder_input": encoder_input, # (Seq_len)
"decoder_input": decoder_input, # (Seq_len)
"encoder_mask": (encoder_input != self.pad_token)
.unsqueeze(0)
.unsqueeze(0)
.int(), # (1, 1, Seq_len)
"decoder_mask": (decoder_input != self.pad_token)
.unsqueeze(0)
.int()
& casual_mask(
decoder_input.size(0)
), # (1, Seq_Len) & (1, Seq_Len, Seq_Len)
"label": label, # (Seq_Len),
"src_text": src_text,
"tgt_text": tgt_text,
}
def casual_mask(size):
"""Generate a causal mask for self-attention.
Args:
size (int): The size of the mask.
Returns:
torch.Tensor: A causal mask tensor.
"""
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
return mask == 0