-
Notifications
You must be signed in to change notification settings - Fork 0
/
lang.py
112 lines (89 loc) · 3.24 KB
/
lang.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# coding=utf-8
import re
import string
import unicodedata
from abc import ABC, abstractmethod
from typing import List
PAD_token = 0
SOS_token = 1
EOS_token = 2
class AbstractVocabulary(ABC):
def __init__(self):
self.token2index = {}
self.token2count = {}
self.index2token = {
PAD_token: '<PAD>',
SOS_token: '<SOS>',
EOS_token: '<EOS>'
}
def __len__(self):
return len(self.index2token)
def _add_token(self, token: str) -> int:
if token not in self.token2index:
self.token2index[token] = len(self)
self.token2count[token] = 1
self.index2token[len(self)] = token
else:
self.token2count[token] += 1
return self.token2index[token]
@abstractmethod
def add_sentence(self, sentence: str) -> List[int]:
"""Return the encoded sentence"""
pass
@abstractmethod
def to_list(self, sequence: str) -> List[int]:
pass
@abstractmethod
def to_string(self, sequence: List[int]) -> str:
pass
class WordVocabulary(AbstractVocabulary):
def add_sentence(self, sentence: str) -> List[int]:
sentence = _unicode_to_ascii(sentence)
words = _split_sentence(sentence)
sentence_enc = []
for word in words:
sentence_enc.append(self._add_token(word))
sentence_enc.append(EOS_token)
return sentence_enc
def to_list(self, sentence: str) -> List[int]:
sentence = _unicode_to_ascii(sentence)
words = _split_sentence(sentence)
return [self.token2index[w] for w in words] + [EOS_token]
def to_string(self, sequence: List[int]) -> str:
if EOS_token in sequence:
eos_position = sequence.index(EOS_token)
else:
eos_position = len(sequence)
return ' '.join([self.index2token[i] for i in sequence[:eos_position]])
class CharVocabulary(AbstractVocabulary):
def __init__(self):
super().__init__()
ascii_vocabulary = string.printable
for c in ascii_vocabulary:
self._add_token(c)
def __str__(self):
return ''.join(self.token2index.keys())
def add_sentence(self, sentence: str) -> List[int]:
sentence = _unicode_to_ascii(sentence)
for c in sentence:
self._add_token(c)
return self.to_list(sentence)
def to_list(self, sentence: str) -> List[int]:
sentence = _unicode_to_ascii(sentence)
return [self.token2index[c] for c in sentence] + [EOS_token]
def to_string(self, sequence: List[int]) -> str:
if EOS_token in sequence:
eos_position = sequence.index(EOS_token)
else:
eos_position = len(sequence)
return ''.join([self.index2token[i] for i in sequence[:eos_position]])
def _split_sentence(sentence: str) -> List[str]:
sentence = re.sub(r'([^A-Za-z ])', r' \1 ', sentence)
words = filter(lambda s: s != '', sentence.split(' '))
return list(words)
def _unicode_to_ascii(s: str):
"""Turn a Unicode string to plain ASCII, thanks to http://stackoverflow.com/a/518232/2809427"""
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)