Skip to content

Commit

Permalink
fix: special_tokens in the encode method to support the special token…
Browse files Browse the repository at this point in the history
…s in the vocab (#13)

* fix: special_tokens in the encode method to support the special tokens in the vocab

* fix: encode_ord to not handle the special_tokens
  • Loading branch information
Hk669 committed Jun 7, 2024
1 parent de22772 commit c66cab7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
8 changes: 6 additions & 2 deletions bpetokenizer/pretrained/wi17k_base/wi17k_base.json
Original file line number Diff line number Diff line change
Expand Up @@ -17065,7 +17065,9 @@
"(17306, 195)": 17307,
"(17307, 163)": 17308,
"(1012, 7365)": 17309,
"(9137, 336)": 17310
"(9137, 336)": 17310,
"(32, 32)": 17320,
"(17320, 32)": 17321
},
"vocab": {
"0": "\\u0000",
Expand Down Expand Up @@ -34380,6 +34382,8 @@
"17309": " differs",
"17311": " def",
"17312": "_stats",
"17313": " get"
"17313": " get",
"17320": " ",
"17321": " "
}
}
25 changes: 17 additions & 8 deletions bpetokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ def __init__(self, pattern=None, special_tokens=None):
self.special_tokens = {} if special_tokens is None else special_tokens
self.inverse_special_tokens = {} if special_tokens is None else {v: k for k, v in special_tokens.items()}
self.vocab_size = len(self.vocab) if self.vocab else 0
self.inverse_merges = {int(v): k for k, v in self.merges.items()} if self.merges else {}

@classmethod
def from_pretrained(cls,
tokenizer_name: str,
verbose=False):
"""Allows you to load the pretrained tokenizers"""
tokenizer = cls()
pretrained_dir = 'bpetokenizer/pretrained'
tokenizer_file = os.path.join(pretrained_dir, tokenizer_name, f'{tokenizer_name}.json')
Expand All @@ -60,6 +62,10 @@ def train(self, texts, vocab_size, verbose=False, min_frequency=1) -> None:
vocab_size: int (the size of the vocab, gpt4 vocab size is around 100k)
verbose: bool (to get extra visibilty and the overview of internal processes)
min_frequency: int (the minimum frequency of the pair to be merged and added into the vocab as a new token)
internal_args:
text_chunks: list[str]
pair: tuple(int, int)
"""
assert vocab_size >= 256
num_merges = vocab_size - 256
Expand Down Expand Up @@ -123,8 +129,6 @@ def encode_ord(self, text) -> list:
for chunk in text_chunks:
if chunk in self.vocab:
ids.append(self.vocab[chunk])
elif chunk in self.special_tokens:
ids.append(self.special_tokens[chunk])
else:
_bytes = chunk.encode("utf-8")
chunk_ids = self._encode(_bytes)
Expand All @@ -150,33 +154,38 @@ def encode(self, text, special_tokens="none") -> list:
else:
raise ValueError(f"invalid special tokens argument: {special_tokens}")


text_chunks = re.findall(self.compiled_pattern, text)
if not special:
# shortcut: if no special tokens, just use the ordinary encoding
return self.encode_ord(text)
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
text_chunks = re.split(special_pattern, text)
ids = []
for chunk in text_chunks:
if chunk in self.inverse_vocab:
ids.append(self.inverse_vocab[chunk])
elif chunk in self.special_tokens:
elif special and chunk in self.special_tokens:
ids.append(self.special_tokens[chunk])
else:
chunk_ids = self._encode(chunk.encode("utf-8"))
ids.extend(chunk_ids)
return ids


def decode(self, ids) -> str:
def decode(self, ids, verbose=False) -> str:
part_bytes = []
for idx in ids:
if idx in self.vocab: #str conversion because vocab keys are strings when loaded from json
part_bytes.append(self.vocab[idx])
elif idx in self.inverse_special_tokens:
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8")) # special tokens are not encoded in vocab
elif idx in self.merges:
pair = self.merges[idx]
elif idx in self.inverse_merges:
pair = self.inverse_merges[idx]
part_bytes.append(self.vocab[pair[0]] + self.vocab[pair[1]])
else:
raise ValueError(f"invalid token id: {idx}")
text_bytes = b"".join(part_bytes)
if verbose:
print("---\nText bytes: ", text_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text

Expand Down

0 comments on commit c66cab7

Please sign in to comment.