Skip to content

Commit e6654ec

Browse files
mresomsaroufim
andauthored
Enable fx_graph_cache in gpt-fast example (#2935)
* Enable fx_graph_cache in gpt-fast example * mention fx_graph_cache in readme * Fix spellcheck * Update README.md --------- Co-authored-by: Mark Saroufim <[email protected]>
1 parent e3f8703 commit e6654ec

File tree

7 files changed

+18
-5
lines changed

7 files changed

+18
-5
lines changed

examples/large_models/gpt_fast/README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ It features:
99
* No dependencies other than PyTorch and sentencepiece
1010
* int8/int4 quantization
1111
* Speculative decoding
12-
* Tensor parallelism
12+
* Supports multi-GPU inference through Tensor parallelism
1313
* Supports Nvidia and AMD GPUs
1414

1515
More details about gpt-fast can be found in this [blog](https://pytorch.org/blog/accelerating-generative-ai-2/).
@@ -81,6 +81,8 @@ cd ..
8181
At this stage we're creating the model archive which includes the configuration of our model in [model_config.yaml](./model_config.yaml).
8282
It's also the point where we need to decide if we want to deploy our model on a single or multiple GPUs.
8383
For the single GPU case we can use the default configuration that can be found in [model_config.yaml](./model_config.yaml).
84+
All configs enable the current prototyping feature FxGraphCache by setting fx_graph_cache to *true*.
85+
This feature stores the TorchInductor output in a cache to speed up torch.compile times when rerunning the handler.
8486

8587
```
8688
torch-model-archiver --model-name gpt_fast --version 1.0 --handler handler.py --config-file model_config.yaml --extra-files "gpt-fast/generate.py,gpt-fast/model.py,gpt-fast/quantize.py,gpt-fast/tp.py" --archive-format no-archive

examples/large_models/gpt_fast/handler.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import logging
3+
import os
34
import time
45
from pathlib import Path
56

@@ -84,6 +85,9 @@ def initialize(self, ctx):
8485
self.tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
8586

8687
if ctx.model_yaml_config["handler"]["compile"]:
88+
if ctx.model_yaml_config["handler"].get("fx_graph_cache", False):
89+
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
90+
8791
if self.is_speculative and use_tp:
8892
torch._inductor.config.triton.cudagraph_trees = (
8993
False # Bug with cudagraph trees in this case

examples/large_models/gpt_fast/model_config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ handler:
88
converted_ckpt_dir: "checkpoints/meta-llama/Llama-2-7b-hf/model.pth"
99
max_new_tokens: 50
1010
compile: true
11+
fx_graph_cache: True

examples/large_models/gpt_fast/model_config_speculative.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ handler:
1414
max_new_tokens: 50
1515
compile: true
1616
stream: false
17+
fx_graph_cache: True

examples/large_models/gpt_fast/model_config_tp.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ handler:
1313
max_new_tokens: 50
1414
compile: true
1515
stream: false
16+
fx_graph_cache: True

test/pytest/test_example_gpt_fast.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
{
5151
"nproc": 1,
5252
"stream": "true",
53-
"compile": "false",
53+
"compile": "true",
5454
},
5555
{
5656
"nproc": 4,
@@ -74,6 +74,7 @@
7474
EXPECTED_RESULTS = [
7575
# ", Paris, is a city of romance, fashion, and art. The city is home to the Eiffel Tower, the Louvre, and the Arc de Triomphe. Paris is also known for its cafes, restaurants",
7676
" is Paris.\nThe capital of Germany is Berlin.\nThe capital of Italy is Rome.\nThe capital of Spain is Madrid.\nThe capital of the United Kingdom is London.\nThe capital of the European Union is Brussels.\n",
77+
" is Paris.\n\nThe capital of Germany is Berlin.\n\nThe capital of Italy is Rome.\n\nThe capital of Spain is Madrid.\n\nThe capital of the United Kingdom is London.\n\nThe capital of the United States is",
7778
]
7879

7980

@@ -116,7 +117,7 @@ def test_handler(tmp_path, add_paths, compile, mocker):
116117
ctx.model_yaml_config = config
117118
ctx.request_ids = {0: "0"}
118119

119-
torch.manual_seed(42 * 42)
120+
torch.manual_seed(42)
120121
handler.initialize(ctx)
121122

122123
assert ("cuda:0" if torch.cuda.is_available() else "cpu") == str(handler.device)
@@ -129,7 +130,7 @@ def test_handler(tmp_path, add_paths, compile, mocker):
129130

130131
result = "".join(c[0][0][0] for c in send_mock.call_args_list)
131132

132-
assert result == EXPECTED_RESULTS[0]
133+
assert result == EXPECTED_RESULTS[1 if compile == "true" else 0]
133134
finally:
134135
# free memory in case of failed test
135136
del handler.model
@@ -241,4 +242,4 @@ def test_gpt_fast_mar(model_name_and_stdout):
241242

242243
assert len(prediction) > 1
243244

244-
assert "".join(prediction) == EXPECTED_RESULTS[0]
245+
assert "".join(prediction) == EXPECTED_RESULTS[1]

ts_scripts/spellcheck_conf/wordlist.txt

+3
Original file line numberDiff line numberDiff line change
@@ -1181,4 +1181,7 @@ Karpathy's
11811181
Maher's
11821182
warmup
11831183
SOTA
1184+
FxGraphCache
1185+
TorchInductor
1186+
fx
11841187
locustapache

0 commit comments

Comments
 (0)