@@ -60,13 +60,17 @@ def byte_pair_counts(self, words: Iterable[str]) -> Iterable[typing.Counter]:
60
60
"""
61
61
for token , count in self .count_tokens (words ).items ():
62
62
bp_counts = Counter () # type: Counter
63
- for ngram in token .split (' ' ):
63
+ sub_tokens = token .split (' ' )
64
+ joined_tokens = '' .join (sub_tokens )
65
+ token_offsets = [0 ]
66
+ length = 0
67
+ for ngram in sub_tokens :
64
68
bp_counts [ngram ] += count
65
- for ngram_size in range ( self . ngram_min , min ([ self . ngram_max , len (token )]) + 1 ):
66
- ngrams = ['' . join ( ngram ) for ngram in toolz . sliding_window ( ngram_size , token . split ( ' ' )) ]
67
-
68
- for ngram in ngrams :
69
- bp_counts ['' . join ( ngram ) ] += count
69
+ length += len (ngram )
70
+ token_offsets + = [length ]
71
+ for ngram_size in range ( self . ngram_min , min ( self . ngram_max , len ( sub_tokens )) + 1 ):
72
+ for i in range ( len ( sub_tokens ) - ngram_size + 1 ) :
73
+ bp_counts [joined_tokens [ token_offsets [ i ]: token_offsets [ i + ngram_size ]] ] += count
70
74
71
75
yield bp_counts
72
76
@@ -89,9 +93,7 @@ def learn_bpe_vocab(self, words: Iterable[str]) -> Dict[str, int]:
89
93
for token in {self .SOW , self .EOW }:
90
94
vocab [token ] = int (2 ** 63 )
91
95
for idx , byte_pair_count in enumerate (self .byte_pair_counts (words )):
92
- for byte_pair , count in byte_pair_count .items ():
93
- vocab [byte_pair ] += count
94
-
96
+ vocab .update (byte_pair_count )
95
97
if (idx + 1 ) % 10000 == 0 :
96
98
self .trim_vocab (10 * self .bpe_vocab_size , vocab )
97
99
0 commit comments