Skip to content

Commit 1364126

Browse files
authored
s2s CPU
1 parent 3fde4a2 commit 1364126

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

s2s-ft/s2s_ft/modeling_decoding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1656,7 +1656,7 @@ def get_dup_ngram_candidates(seq, n):
16561656
forbid_word_mask = torch.tensor(
16571657
buf_matrix, dtype=log_scores.dtype)
16581658
forbid_word_mask = torch.reshape(
1659-
forbid_word_mask, [batch_size * K, 1, vocab_size]).cuda()
1659+
forbid_word_mask, [batch_size * K, 1, vocab_size]).to(input_ids.device)
16601660
else:
16611661
forbid_word_mask = None
16621662
next_pos += 1

0 commit comments

Comments
 (0)