We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 294175e commit d6fa808Copy full SHA for d6fa808
examples/cpp/aot_inductor/bert/aot_compile_export.py
@@ -12,6 +12,7 @@
12
13
set_seed(1)
14
MAX_BATCH_SIZE = 15
15
+MAX_LENGTH = 1024
16
17
18
def transformers_model_dowloader(
@@ -75,8 +76,8 @@ def transformers_model_dowloader(
75
76
attention_mask = torch.cat([inputs["attention_mask"]] * batch_size, 0).to(
77
device
78
)
- batch_dim = torch.export.Dim("batch", min=1, max=8)
79
- seq_len_dim = torch.export.Dim("seq_len", min=1, max=max_length)
+ batch_dim = torch.export.Dim("batch", min=1, max=MAX_BATCH_SIZE)
80
+ seq_len_dim = torch.export.Dim("seq_len", min=1, max=MAX_LENGTH)
81
torch._C._GLIBCXX_USE_CXX11_ABI = True
82
model_so_path = torch._export.aot_compile(
83
model,
0 commit comments