Skip to content

Commit d5e10de

Browse files
agunapalmreso
andauthoredSep 17, 2024··
TRT LLM Integration with LORA (#3305)
* TRT LLM Integration with LORA * TRT LLM Integration with LORA * TRT LLM Integration with LORA * TRT LLM Integration with LORA * Added launcher support for trt_llm * updated README * updated README * Using the API that supports async generate * Review comments * Apply suggestions from code review Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com> * addressed review comments * Addressed review comments * Updated the async logic based on review comments * Made max_batch_size and kv_cache size configurable for the launcher * fixing lint --------- Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>
1 parent 2dfbff7 commit d5e10de

File tree

14 files changed

+428
-173
lines changed

14 files changed

+428
-173
lines changed
 

‎README.md

+13-2
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,23 @@ Refer to [torchserve docker](docker/README.md) for details.
6262

6363
### 🤖 Quick Start LLM Deployment
6464

65+
#### VLLM Engine
6566
```bash
6667
# Make sure to install torchserve with pip or conda as described above and login with `huggingface-cli login`
67-
python -m ts.llm_launcher --model_id meta-llama/Meta-Llama-3-8B-Instruct --disable_token_auth
68+
python -m ts.llm_launcher --model_id meta-llama/Meta-Llama-3.1-8B-Instruct --disable_token_auth
6869

6970
# Try it out
70-
curl -X POST -d '{"model":"meta-llama/Meta-Llama-3-8B-Instruct", "prompt":"Hello, my name is", "max_tokens": 200}' --header "Content-Type: application/json" "http://localhost:8080/predictions/model/1.0/v1/completions"
71+
curl -X POST -d '{"model":"meta-llama/Meta-Llama-3.1-8B-Instruct", "prompt":"Hello, my name is", "max_tokens": 200}' --header "Content-Type: application/json" "http://localhost:8080/predictions/model/1.0/v1/completions"
72+
```
73+
74+
#### TRT-LLM Engine
75+
```bash
76+
# Make sure to install torchserve with python venv as described above and login with `huggingface-cli login`
77+
# pip install -U --use-deprecated=legacy-resolver -r requirements/trt_llm.txt
78+
python -m ts.llm_launcher --model_id meta-llama/Meta-Llama-3.1-8B-Instruct --engine trt_llm --disable_token_auth
79+
80+
# Try it out
81+
curl -X POST -d '{"prompt":"count from 1 to 9 in french ", "max_tokens": 100}' --header "Content-Type: application/json" "http://localhost:8080/predictions/model"
7182
```
7283

7384
### 🚢 Quick Start LLM Deployment with Docker

‎examples/large_models/trt_llm/llama/README.md

+18-15
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44

55
## Pre-requisites
66

7-
TRT-LLM requires Python 3.10
7+
- TRT-LLM requires Python 3.10
8+
- TRT-LLM works well with python venv (vs conda)
89
This example is tested with CUDA 12.1
910
Once TorchServe is installed, install TensorRT-LLM using the following.
10-
This will downgrade the versions of PyTorch & Triton but this doesn't cause any issue.
1111

1212
```
13-
pip install tensorrt_llm==0.10.0 --extra-index-url https://pypi.nvidia.com
14-
pip install tensorrt-cu12==10.1.0
13+
pip install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com
14+
pip install transformers>=4.44.2
1515
python -c "import tensorrt_llm"
1616
```
1717
shows
1818
```
19-
[TensorRT-LLM] TensorRT-LLM version: 0.10.0
19+
[TensorRT-LLM] TensorRT-LLM version: 0.13.0.dev2024090300
2020
```
2121

2222
## Download model from HuggingFace
@@ -26,29 +26,32 @@ huggingface-cli login
2626
huggingface-cli login --token $HUGGINGFACE_TOKEN
2727
```
2828
```
29-
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-8B-Instruct
29+
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --use_auth_token True
3030
```
3131

3232
## Create TensorRT-LLM Engine
3333
Clone TensorRT-LLM which will be used to create the TensorRT-LLM Engine
3434

3535
```
36-
git clone -b v0.10.0 https://github.com/NVIDIA/TensorRT-LLM.git
36+
git clone https://github.com/NVIDIA/TensorRT-LLM.git
3737
```
3838

3939
Compile the model into a TensorRT engine with model weights and a model definition written in the TensorRT-LLM Python API.
4040

4141
```
42-
python TensorRT-LLM/examples/llama/convert_checkpoint.py --model_dir model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/ --output_dir ./tllm_checkpoint_1gpu_bf16 --dtype bfloat16
42+
python TensorRT-LLM/examples/llama/convert_checkpoint.py --model_dir model/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f/ --output_dir ./tllm_checkpoint_1gpu_bf16 --dtype bfloat16
4343
```
44+
4445
```
45-
trtllm-build --checkpoint_dir tllm_checkpoint_1gpu_bf16 --gemm_plugin bfloat16 --gpt_attention_plugin bfloat16 --output_dir ./llama-3-8b-engine
46+
trtllm-build --checkpoint_dir tllm_checkpoint_1gpu_bf16 --gemm_plugin bfloat16 --gpt_attention_plugin bfloat16 --max_batch_size 4 --output_dir ./llama-3.1-8b-engine
4647
```
48+
If you have enough GPU memory, you can try increasing the `max_batch_size`
4749

4850
You can test if TensorRT-LLM Engine has been compiled correctly by running the following
4951
```
50-
python TensorRT-LLM/examples/run.py --engine_dir ./llama-3-8b-engine --max_output_len 100 --tokenizer_dir model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/ --input_text "How do I count to nine in French?"
52+
python TensorRT-LLM/examples/run.py --engine_dir ./llama-3.1-8b-engine --max_output_len 100 --tokenizer_dir model/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f/ --input_text "How do I count to nine in French?"
5153
```
54+
If you are running into OOM, try reducing `kv_cache_free_gpu_memory_fraction`
5255

5356
You should see an output as follows
5457
```
@@ -70,17 +73,17 @@ That's it! You can now count to nine in French. Just remember that the numbers o
7073

7174
```
7275
mkdir model_store
73-
torch-model-archiver --model-name llama3-8b --version 1.0 --handler trt_llm_handler.py --config-file model-config.yaml --archive-format no-archive --export-path model_store -f
74-
mv model model_store/llama3-8b/.
75-
mv llama-3-8b-engine model_store/llama3-8b/.
76+
torch-model-archiver --model-name llama3.1-8b --version 1.0 --handler trt_llm_handler --config-file model-config.yaml --archive-format no-archive --export-path model_store -f
77+
mv model model_store/llama3.1-8b/.
78+
mv llama-3.1-8b-engine model_store/llama3.1-8b/.
7679
```
7780

7881
## Start TorchServe
7982
```
80-
torchserve --start --ncs --model-store model_store --models llama3-8b --disable-token-auth
83+
torchserve --start --ncs --model-store model_store --models llama3.1-8b --disable-token-auth
8184
```
8285

8386
## Run Inference
8487
```
85-
python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 -m llama3-8b --prompt-text "@prompt.json" --prompt-json
88+
python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 -m llama3.1-8b --prompt-text "@prompt.json" --prompt-json
8689
```

‎examples/large_models/trt_llm/llama/model-config.yaml

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ deviceType: "gpu"
77
asyncCommunication: true
88

99
handler:
10-
tokenizer_dir: "model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/"
11-
trt_llm_engine_config:
12-
engine_dir: "llama-3-8b-engine"
10+
tokenizer_dir: "model/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f/"
11+
engine_dir: "llama-3.1-8b-engine"
12+
kv_cache_config:
13+
free_gpu_memory_fraction: 0.1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
{"prompt": "How is the climate in San Francisco?",
22
"temperature":0.5,
3-
"max_new_tokens": 200}
3+
"max_tokens": 400,
4+
"streaming": true}

‎examples/large_models/trt_llm/llama/trt_llm_handler.py

-118
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Llama TensorRT-LLM Engine + LoRA model integration with TorchServe
2+
3+
[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) provides users with an option to build TensorRT engines for LLMs that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.
4+
5+
## Pre-requisites
6+
7+
- TRT-LLM requires Python 3.10
8+
- TRT-LLM works well with python venv (vs conda)
9+
This example is tested with CUDA 12.1
10+
Once TorchServe is installed, install TensorRT-LLM using the following.
11+
12+
```
13+
pip install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com
14+
pip install transformers>=4.44.2
15+
python -c "import tensorrt_llm"
16+
```
17+
shows
18+
```
19+
[TensorRT-LLM] TensorRT-LLM version: 0.13.0.dev2024090300
20+
```
21+
22+
## Download Base model & LoRA adapter from Hugging Face
23+
```
24+
huggingface-cli login
25+
# or using an environment variable
26+
huggingface-cli login --token $HUGGINGFACE_TOKEN
27+
```
28+
```
29+
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --use_auth_token True
30+
python ../../utils/Download_model.py --model_path model --model_name llama-duo/llama3.1-8b-summarize-gpt4o-128k --use_auth_token True
31+
```
32+
33+
## Create TensorRT-LLM Engine
34+
Clone TensorRT-LLM which will be used to create the TensorRT-LLM Engine
35+
36+
```
37+
git clone https://github.com/NVIDIA/TensorRT-LLM.git
38+
```
39+
40+
Compile the model into a TensorRT engine with model weights and a model definition written in the TensorRT-LLM Python API.
41+
42+
```
43+
python TensorRT-LLM/examples/llama/convert_checkpoint.py --model_dir model/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f --output_dir ./tllm_checkpoint_1gpu_bf16 --dtype bfloat16
44+
```
45+
46+
```
47+
trtllm-build --checkpoint_dir tllm_checkpoint_1gpu_bf16 --gemm_plugin bfloat16 --gpt_attention_plugin bfloat16 --output_dir ./llama-3.1-8b-engine-lora --max_batch_size 4 --lora_dir model/models--llama-duo--llama3.1-8b-summarize-gpt4o-128k/snapshots/4ba83353f24fa38946625c8cc49bf21c80a22825 --lora_plugin bfloat16
48+
```
49+
If you have enough GPU memory, you can try increasing the `max_batch_size`
50+
51+
You can test if TensorRT-LLM Engine has been compiled correctly by running the following
52+
```
53+
python TensorRT-LLM/examples/run.py --engine_dir ./llama-3.1-8b-engine-lora --max_output_len 100 --tokenizer_dir model/models--llama-duo--llama3.1-8b-summarize-gpt4o-128k/snapshots/4ba83353f24fa38946625c8cc49bf21c80a22825 --input_text "Amanda: I baked cookies. Do you want some?\nJerry: Sure \nAmanda: I will bring you tomorrow :-)\n\nSummarize the dialog:" --lora_dir model/models--llama-duo--llama3.1-8b-summarize-gpt4o-128k/snapshots/4ba83353f24fa38946625c8cc49bf21c80a22825 --kv_cache_free_gpu_memory_fraction 0.3 --use_py_session
54+
```
55+
If you are running into OOM, try reducing `kv_cache_free_gpu_memory_fraction`
56+
57+
You should see an output as follows
58+
```
59+
Input [Text 0]: "<|begin_of_text|>Amanda: I baked cookies. Do you want some?\nJerry: Sure \nAmanda: I will bring you tomorrow :-)\n\nSummarize the dialog:"
60+
Output [Text 0 Beam 0]: " Amanda offered Jerry cookies and said she would bring them to him tomorrow.
61+
Amanda offered Jerry cookies and said she would bring them to him tomorrow.
62+
The dialogue is between Amanda and Jerry. Amanda offers Jerry cookies and says she will bring them to him tomorrow. The dialogue is a simple exchange between two people, with no complex plot or themes. The tone is casual and friendly. The dialogue is a good example of a short, everyday conversation.
63+
The dialogue is a good example of a short,"
64+
```
65+
66+
## Create model archive
67+
68+
```
69+
mkdir model_store
70+
torch-model-archiver --model-name llama3.1-8b --version 1.0 --handler trt_llm_handler --config-file model-config.yaml --archive-format no-archive --export-path model_store -f
71+
mv model model_store/llama3.1-8b/.
72+
mv llama-3.1-8b-engine-lora model_store/llama3.1-8b/.
73+
```
74+
75+
## Start TorchServe
76+
```
77+
torchserve --start --ncs --model-store model_store --models llama3.1-8b --disable-token-auth
78+
```
79+
80+
## Run Inference
81+
```
82+
python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 -m llama3.1-8b --prompt-text "@prompt.json" --prompt-json
83+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# TorchServe frontend parameters
2+
minWorkers: 1
3+
maxWorkers: 1
4+
maxBatchDelay: 100
5+
responseTimeout: 1200
6+
deviceType: "gpu"
7+
asyncCommunication: true
8+
9+
handler:
10+
tokenizer_dir: "model/models--llama-duo--llama3.1-8b-summarize-gpt4o-128k/snapshots/4ba83353f24fa38946625c8cc49bf21c80a22825"
11+
engine_dir: "llama-3.1-8b-engine-lora"
12+
kv_cache_config:
13+
free_gpu_memory_fraction: 0.1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{"prompt": "Amanda: I baked cookies. Do you want some?\nJerry: Sure \nAmanda: I will bring you tomorrow :-)\n\nSummarize the dialog:",
2+
"temperature":0.0,
3+
"max_new_tokens": 100,
4+
"streaming": true}

‎model-archiver/model_archiver/model_packaging_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"image_segmenter": "vision",
3636
"dali_image_classifier": "vision",
3737
"vllm_handler": "text",
38+
"trt_llm_handler": "text",
3839
}
3940

4041
MODEL_SERVER_VERSION = "1.0"

‎requirements/trt_llm.txt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
--pre --extra-index-url https://pypi.nvidia.com
2+
tensorrt_llm
3+
transformers>=4.44.2

‎ts/llm_launcher.py

+146-34
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import argparse
22
import contextlib
3+
import os
34
import shutil
5+
import subprocess
46
from pathlib import Path
57
from signal import pause
68

@@ -10,14 +12,55 @@
1012
from model_archiver.model_packaging import generate_model_archive
1113

1214
from ts.launcher import start, stop
15+
from ts.utils.hf_utils import download_model
16+
17+
18+
def create_tensorrt_llm_engine(
19+
model_store, model_name, dtype, snapshot_path, max_batch_size
20+
):
21+
if not Path("/tmp/TensorRT-LLM").exists():
22+
subprocess.run(
23+
[
24+
"git",
25+
"clone",
26+
"https://github.com/NVIDIA/TensorRT-LLM.git",
27+
"-b",
28+
"v0.12.0",
29+
"/tmp/TensorRT-LLM",
30+
]
31+
)
32+
if not Path(f"{model_store}/{model_name}/tllm_checkpoint_1gpu_bf16").exists():
33+
subprocess.run(
34+
[
35+
"python",
36+
"/tmp/TensorRT-LLM/examples/llama/convert_checkpoint.py",
37+
"--model_dir",
38+
snapshot_path,
39+
"--output_dir",
40+
f"{model_store}/{model_name}/tllm_checkpoint_1gpu_bf16",
41+
"--dtype",
42+
dtype,
43+
]
44+
)
45+
if not Path(f"{model_store}/{model_name}/{model_name}-engine").exists():
46+
subprocess.run(
47+
[
48+
"trtllm-build",
49+
"--checkpoint_dir",
50+
f"{model_store}/{model_name}/tllm_checkpoint_1gpu_bf16",
51+
"--gemm_plugin",
52+
dtype,
53+
"--gpt_attention_plugin",
54+
dtype,
55+
"--max_batch_size",
56+
f"{max_batch_size}",
57+
"--output_dir",
58+
f"{model_store}/{model_name}/{model_name}-engine",
59+
]
60+
)
1361

1462

15-
def get_model_config(args):
16-
download_dir = getattr(args, "vllm_engine.download_dir")
17-
download_dir = (
18-
Path(download_dir).resolve().as_posix() if download_dir else download_dir
19-
)
20-
63+
def get_model_config(args, model_snapshot_path=None):
2164
model_config = {
2265
"minWorkers": 1,
2366
"maxWorkers": 1,
@@ -26,74 +69,116 @@ def get_model_config(args):
2669
"responseTimeout": 1200,
2770
"deviceType": "gpu",
2871
"asyncCommunication": True,
29-
"parallelLevel": torch.cuda.device_count() if torch.cuda.is_available else 1,
30-
"handler": {
31-
"model_path": args.model_id,
32-
"vllm_engine_config": {
33-
"max_num_seqs": getattr(args, "vllm_engine.max_num_seqs"),
34-
"max_model_len": getattr(args, "vllm_engine.max_model_len"),
35-
"download_dir": download_dir,
36-
"tensor_parallel_size": torch.cuda.device_count()
72+
}
73+
74+
if args.engine == "vllm":
75+
download_dir = getattr(args, "vllm_engine.download_dir")
76+
download_dir = (
77+
Path(download_dir).resolve().as_posix() if download_dir else download_dir
78+
)
79+
80+
model_config.update(
81+
{
82+
"parallelLevel": torch.cuda.device_count()
3783
if torch.cuda.is_available
3884
else 1,
39-
},
40-
},
41-
}
85+
"handler": {
86+
"model_path": args.model_id,
87+
"vllm_engine_config": {
88+
"max_num_seqs": getattr(args, "vllm_engine.max_num_seqs"),
89+
"max_model_len": getattr(args, "vllm_engine.max_model_len"),
90+
"download_dir": download_dir,
91+
"tensor_parallel_size": torch.cuda.device_count()
92+
if torch.cuda.is_available
93+
else 1,
94+
},
95+
},
96+
}
97+
)
4298

43-
if hasattr(args, "lora_adapter_ids"):
44-
raise NotImplementedError("Lora setting needs to be implemented")
45-
lora_adapter_ids = args.lora_adapter_ids.split(";")
99+
if hasattr(args, "lora_adapter_ids"):
100+
raise NotImplementedError("Lora setting needs to be implemented")
101+
lora_adapter_ids = args.lora_adapter_ids.split(";")
46102

47-
model_config["handler"]["vllm_engine_config"].update(
103+
model_config["handler"]["vllm_engine_config"].update(
104+
{
105+
"enable_lora": True,
106+
}
107+
)
108+
109+
elif args.engine == "trt_llm":
110+
model_config.update(
48111
{
49-
"enable_lora": True,
112+
"handler": {
113+
"tokenizer_dir": os.path.join(os.getcwd(), model_snapshot_path),
114+
"engine_dir": f"{args.model_name}-engine",
115+
"kv_cache_config": {
116+
"free_gpu_memory_fraction": getattr(
117+
args, "trt_llm_engine.kv_cache_free_gpu_memory_fraction"
118+
),
119+
},
120+
}
50121
}
51122
)
123+
else:
124+
raise RuntimeError("Unsupported LLM Engine")
52125

53126
return model_config
54127

55128

56129
@contextlib.contextmanager
57-
def create_mar_file(args):
58-
model_store_path = Path(args.model_store)
59-
model_store_path.mkdir(parents=True, exist_ok=True)
60-
61-
mar_file_path = model_store_path / args.model_name
130+
def create_mar_file(args, model_snapshot_path=None):
131+
mar_file_path = Path(args.model_store) / args.model_name
62132

63133
model_config_yaml = Path(args.model_store) / "model-config.yaml"
64134
with model_config_yaml.open("w") as f:
65-
yaml.dump(get_model_config(args), f)
135+
yaml.dump(get_model_config(args, model_snapshot_path), f)
66136

67137
config = ModelArchiverConfig(
68138
model_name=args.model_name,
69139
version="1.0",
70-
handler="vllm_handler",
140+
handler=f"{args.engine}_handler",
71141
serialized_file=None,
72142
export_path=args.model_store,
73143
requirements_file=None,
74144
runtime="python",
75-
force=False,
145+
force=True,
76146
config_file=model_config_yaml.as_posix(),
77147
archive_format="no-archive",
78148
)
79149

80-
generate_model_archive(config)
150+
if not mar_file_path.exists():
151+
generate_model_archive(config)
81152

82153
model_config_yaml.unlink()
83154

84155
assert mar_file_path.exists()
85156

86157
yield mar_file_path.as_posix()
87158

88-
shutil.rmtree(mar_file_path)
159+
if args.engine == "vllm":
160+
shutil.rmtree(mar_file_path)
89161

90162

91163
def main(args):
92164
"""
93165
Register the model in torchserve
94166
"""
95167

96-
with create_mar_file(args):
168+
model_store_path = Path(args.model_store)
169+
model_store_path.mkdir(parents=True, exist_ok=True)
170+
if args.engine == "trt_llm":
171+
model_snapshot_path = download_model(args.model_id)
172+
173+
with create_mar_file(args, model_snapshot_path):
174+
if args.engine == "trt_llm":
175+
create_tensorrt_llm_engine(
176+
args.model_store,
177+
args.model_name,
178+
args.dtype,
179+
model_snapshot_path,
180+
getattr(args, "trt_llm_engine.max_batch_size"),
181+
)
97182
try:
98183
start(
99184
model_store=args.model_store,
@@ -129,7 +214,7 @@ def main(args):
129214
parser.add_argument(
130215
"--model_id",
131216
type=str,
132-
default="meta-llama/Meta-Llama-3-8B-Instruct",
217+
default="meta-llama/Meta-Llama-3.1-8B-Instruct",
133218
help="Model id",
134219
)
135220

@@ -160,6 +245,33 @@ def main(args):
160245
help="Cache dir",
161246
)
162247

248+
parser.add_argument(
249+
"--engine",
250+
type=str,
251+
default="vllm",
252+
help="LLM engine",
253+
)
254+
255+
parser.add_argument(
256+
"--dtype",
257+
type=str,
258+
default="bfloat16",
259+
help="Data type",
260+
)
261+
262+
parser.add_argument(
263+
"--trt_llm_engine.max_batch_size",
264+
type=int,
265+
default=4,
266+
help="Max batch size",
267+
)
268+
269+
parser.add_argument(
270+
"--trt_llm_engine.kv_cache_free_gpu_memory_fraction",
271+
type=int,
272+
default=0.1,
273+
help="KV Cache free gpu memory fraction",
274+
)
163275
args = parser.parse_args()
164276

165277
main(args)

‎ts/torch_handler/trt_llm_handler.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import json
2+
import logging
3+
import time
4+
5+
from tensorrt_llm.hlapi import LLM, KvCacheConfig, SamplingParams
6+
from transformers import AutoTokenizer
7+
8+
from ts.handler_utils.utils import send_intermediate_predict_response
9+
from ts.torch_handler.base_handler import BaseHandler
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class TRTLLMHandler(BaseHandler):
15+
def __init__(self):
16+
super().__init__()
17+
18+
self.trt_llm_engine = None
19+
self.tokenizer = None
20+
self.model = None
21+
self.model_dir = None
22+
self.initialized = False
23+
24+
def initialize(self, ctx):
25+
self.model_dir = ctx.system_properties.get("model_dir")
26+
27+
engine_dir = ctx.model_yaml_config.get("handler").get("engine_dir")
28+
kv_cache_cfg = ctx.model_yaml_config.get("handler").get("kv_cache_config", {})
29+
30+
tokenizer_dir = ctx.model_yaml_config.get("handler").get("tokenizer_dir")
31+
self.tokenizer = AutoTokenizer.from_pretrained(
32+
tokenizer_dir,
33+
legacy=False,
34+
padding_side="left",
35+
truncation_side="left",
36+
trust_remote_code=True,
37+
use_fast=True,
38+
)
39+
40+
if self.tokenizer.pad_token_id is None:
41+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
42+
43+
kv_cache_config = KvCacheConfig(**kv_cache_cfg)
44+
45+
self.trt_llm_engine = LLM(
46+
model=engine_dir, tokenizer=self.tokenizer, kv_cache_config=kv_cache_config
47+
)
48+
self.initialized = True
49+
50+
async def handle(self, data, context):
51+
start_time = time.time()
52+
53+
metrics = context.metrics
54+
55+
data_preprocess = await self.preprocess(data)
56+
output = await self.inference(data_preprocess, context)
57+
output = await self.postprocess(output)
58+
59+
stop_time = time.time()
60+
metrics.add_time(
61+
"HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms"
62+
)
63+
return output
64+
65+
async def preprocess(self, requests):
66+
assert len(requests) == 1, "Expecting batch_size = 1"
67+
req_data = requests[0]
68+
data = req_data.get("data") or req_data.get("body")
69+
if isinstance(data, (bytes, bytearray)):
70+
data = data.decode("utf-8")
71+
return data
72+
73+
async def inference(self, data, context):
74+
generate_kwargs = {
75+
"end_id": self.tokenizer.eos_token_id,
76+
"pad_id": self.tokenizer.pad_token_id,
77+
}
78+
prompt = data.get("prompt")
79+
streaming = data.get("streaming", False)
80+
del data["prompt"]
81+
if "streaming" in data:
82+
del data["streaming"]
83+
generate_kwargs.update(data)
84+
sampling_params = SamplingParams(**generate_kwargs)
85+
86+
outputs = self.trt_llm_engine.generate_async(
87+
prompt, streaming=streaming, sampling_params=sampling_params
88+
)
89+
90+
async for output in outputs:
91+
output_text, output_ids = (
92+
output.outputs[0].text,
93+
output.outputs[0].token_ids,
94+
)
95+
if not streaming:
96+
return [output_text]
97+
else:
98+
output_text = self.tokenizer.decode([output_ids[-1]])
99+
send_intermediate_predict_response(
100+
[json.dumps({"text": output_text})],
101+
context.request_ids,
102+
"Result",
103+
200,
104+
context,
105+
)
106+
return [""]
107+
108+
async def postprocess(self, outputs):
109+
return outputs

‎ts/utils/hf_utils.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from huggingface_hub import snapshot_download
2+
3+
4+
def download_model(
5+
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
6+
revision="main",
7+
model_path=".cache",
8+
use_auth_token=True,
9+
):
10+
# Only download pytorch checkpoint files
11+
allow_patterns = [
12+
"*.json",
13+
"*.pt",
14+
"*.bin",
15+
"*.txt",
16+
"*.model",
17+
"*.pth",
18+
"*.safetensors",
19+
"original/*",
20+
]
21+
22+
snapshot_path = snapshot_download(
23+
repo_id=model_id,
24+
revision=revision,
25+
allow_patterns=allow_patterns,
26+
cache_dir=model_path,
27+
use_auth_token=use_auth_token,
28+
)
29+
30+
return snapshot_path

‎ts_scripts/spellcheck_conf/wordlist.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1299,3 +1299,5 @@ torchaudio
12991299
ln
13001300
OpenAI
13011301
openai
1302+
kv
1303+
OOM

0 commit comments

Comments
 (0)
Please sign in to comment.