Skip to content

Commit 8935d61

Browse files
committed
speed up bpevocabulary build
1 parent 865caf0 commit 8935d61

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/utils/bpevocabulary.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,17 @@ def byte_pair_counts(self, words: Iterable[str]) -> Iterable[typing.Counter]:
6060
"""
6161
for token, count in self.count_tokens(words).items():
6262
bp_counts = Counter() # type: Counter
63-
for ngram in token.split(' '):
63+
sub_tokens = token.split(' ')
64+
joined_tokens = ''.join(sub_tokens)
65+
token_offsets = [0]
66+
length = 0
67+
for ngram in sub_tokens:
6468
bp_counts[ngram] += count
65-
for ngram_size in range(self.ngram_min, min([self.ngram_max, len(token)]) + 1):
66-
ngrams = [''.join(ngram) for ngram in toolz.sliding_window(ngram_size, token.split(' '))]
67-
68-
for ngram in ngrams:
69-
bp_counts[''.join(ngram)] += count
69+
length += len(ngram)
70+
token_offsets += [length]
71+
for ngram_size in range(self.ngram_min, min(self.ngram_max, len(sub_tokens)) + 1):
72+
for i in range(len(sub_tokens) - ngram_size + 1):
73+
bp_counts[joined_tokens[token_offsets[i]:token_offsets[i+ngram_size]]] += count
7074

7175
yield bp_counts
7276

@@ -89,9 +93,7 @@ def learn_bpe_vocab(self, words: Iterable[str]) -> Dict[str, int]:
8993
for token in {self.SOW, self.EOW}:
9094
vocab[token] = int(2**63)
9195
for idx, byte_pair_count in enumerate(self.byte_pair_counts(words)):
92-
for byte_pair, count in byte_pair_count.items():
93-
vocab[byte_pair] += count
94-
96+
vocab.update(byte_pair_count)
9597
if (idx + 1) % 10000 == 0:
9698
self.trim_vocab(10 * self.bpe_vocab_size, vocab)
9799

0 commit comments

Comments
 (0)