Skip to content

Commit 5f3df71

Browse files
mresonamannandan
andauthored
Asynchronous worker communication and vllm integration (#3146)
* Added dummy async comm worker thread * First version of async worker in frontend running * [WIP]Running async worker but requests get corrupted if parallel * First version running with thread feeding + async predict * shorten vllm test time * Added AsyncVLLMEngine * Extend vllm test with multiple possible prompts * Batch size =1 and remove stream in test * Switched vllm examples to async comm and added llama3 example * Fix typo * Corrected java file formatting * Cleanup and silent chatty debug message * Added multi-gpu support to vllm examples * fix java format * Remove debugging messages * Fix async comm worker test * Added cl_socket to fixture * Added multi worker note to vllm example readme * Disable tests * Enable async worker comm test * Debug CI * Fix python version <= 3.9 issue in async worker * Renamed async worker test * Update frontend/server/src/main/java/org/pytorch/serve/wlm/AsyncBatchAggregator.java Remove job from jobs_in_backend on error Co-authored-by: Naman Nandan <[email protected]> * Unskip vllm example test * Clean up async worker code * Safely remove jobs from jobs_in_backend * Let worker die if one of the threads in async service dies * Add description of parallelLevel and parallelType=custom to docs/large_model_inference.md * Added description of parallelLevel to model-archiver readme.md * fix typo + added words * Fix skip condition for vllm example test --------- Co-authored-by: Naman Nandan <[email protected]>
1 parent 4c96e6f commit 5f3df71

28 files changed

+1267
-186
lines changed

docs/large_model_inference.md

+27-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
This document explain how Torchserve supports large model serving, here large model refers to the models that are not able to fit into one gpu so they need be split in multiple partitions over multiple gpus.
44
This page is split into the following sections:
55
- [How it works](#how-it-works)
6+
- [Large Model Inference with vLLM](#pippy-pytorch-native-solution-for-large-model-inference)
67
- [Large Model Inference with PiPPy](#pippy-pytorch-native-solution-for-large-model-inference)
78
- [Large Model Inference with Deep Speed](#deepspeed)
89
- [Deep Speed MII](#deepspeed-mii)
@@ -11,13 +12,36 @@ This page is split into the following sections:
1112

1213
## How it works?
1314

14-
During deployment a worker of a large model, TorchServe utilizes [torchrun](https://pytorch.org/docs/stable/elastic/run.html) to set up the distributed environment for model parallel processing. TorchServe has the capability to support multiple workers for a large model. By default, TorchServe uses a round-robin algorithm to assign GPUs to a worker on a host. In case of large models inference GPUs assigned to each worker is automatically calculated based on number of GPUs specified in the model_config.yaml. CUDA_VISIBLE_DEVICES is set based this number.
15+
For GPU inference of smaller models TorchServe executes a single process per worker which gets assigned a single GPU.
16+
For large model inference the model needs to be split over multiple GPUs.
17+
There are different modes to achieve this split which usually include pipeline parallel (PP), tensor parallel or a combination of these.
18+
Which mode is selected and how the split is implemented depends on the implementation in the utilized framework.
19+
TorchServe allows users to utilize any framework for their model deployment and tries to accommodate the needs of the frameworks through flexible configurations.
20+
Some frameworks require to execute a separate process for each of the GPUs (PiPPy, Deep Speed) while others require a single process which get assigned all GPUs (vLLM).
21+
In case multiple processes are required TorchServe utilizes [torchrun](https://pytorch.org/docs/stable/elastic/run.html) to set up the distributed environment for the worker.
22+
During the setup `torchrun` will start a new process for each GPU assigned to the worker.
23+
If torchrun is utilized or not depends on the parameter parallelType which can be set in the `model-config.yaml` to one of the following options:
1524

16-
For instance, suppose there are eight GPUs on a node and one worker needs 4 GPUs (ie, nproc-per-node=4) on a node. In this case, TorchServe would assign CUDA_VISIBLE_DEVICES="0,1,2,3" to worker1 and CUDA_VISIBLE_DEVICES="4,5,6,7" to worker2.
25+
* `pp` - for pipeline parallel
26+
* `tp` - for tensor parallel
27+
* `pptp` - for pipeline + tensor parallel
28+
* `custom`
1729

18-
In addition to this default behavior, TorchServe provides the flexibility for users to specify GPUs for a worker. For instance, if the user sets "deviceIds: [2,3,4,5]" in the [model config YAML file](https://github.com/pytorch/serve/blob/5ee02e4f050c9b349025d87405b246e970ee710b/model-archiver/README.md?plain=1#L164), and nproc-per-node is set to 2, then TorchServe would assign CUDA_VISIBLE_DEVICES="2,3" to worker1 and CUDA_VISIBLE_DEVICES="4,5" to worker2.
30+
The first three options setup the environment using torchrun while the "custom" option leaves the way of parallelization to the user and assigned the GPUs assigned to a worker to a single process.
31+
The number of assigned GPUs is determined either by the number of processes started by torchrun i.e. configured through nproc-per-node OR the parameter parallelLevel.
32+
Meaning that the parameter parallelLevel should NOT be set if nproc-per-node is set and vice versa.
33+
34+
By default, TorchServe uses a round-robin algorithm to assign GPUs to a worker on a host.
35+
In case of large models inference GPUs assigned to each worker is automatically calculated based on the number of GPUs specified in the model_config.yaml.
36+
CUDA_VISIBLE_DEVICES is set based this number.
37+
38+
For instance, suppose there are eight GPUs on a node and one worker needs 4 GPUs (ie, nproc-per-node=4 OR parallelLevel=4) on a node.
39+
In this case, TorchServe would assign CUDA_VISIBLE_DEVICES="0,1,2,3" to worker1 and CUDA_VISIBLE_DEVICES="4,5,6,7" to worker2.
40+
41+
In addition to this default behavior, TorchServe provides the flexibility for users to specify GPUs for a worker. For instance, if the user sets "deviceIds: [2,3,4,5]" in the [model config YAML file](https://github.com/pytorch/serve/blob/5ee02e4f050c9b349025d87405b246e970ee710b/model-archiver/README.md?plain=1#L164), and nproc-per-node (OR parallelLevel) is set to 2, then TorchServe would assign CUDA_VISIBLE_DEVICES="2,3" to worker1 and CUDA_VISIBLE_DEVICES="4,5" to worker2.
1942

2043
Using Pippy integration as an example, the image below illustrates the internals of the TorchServe large model inference.
44+
For an example using vLLM see [this example](../examples/large_models/vllm/).
2145

2246
![ts-lmi-internal](https://raw.githubusercontent.com/pytorch/serve/master/docs/images/ts-lmi-internal.png)
2347

examples/large_models/vllm/Readme.md

+31-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
# Example showing inference with vLLM
22

33
This folder contains multiple demonstrations showcasing the integration of [vLLM Engine](https://github.com/vllm-project/vllm) with TorchServe, running inference with continuous batching.
4-
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/)
4+
vLLM achieves high throughput using PagedAttention. More details can be found [here](https://vllm.ai/).
5+
The vLLM integration uses our new asynchronous worker communication mode which decoupled communication between frontend and backend from running the actual inference.
6+
By using this new feature TorchServe is capable to feed incoming requests into the vLLM engine while asynchronously running the engine in the backend.
7+
As long as a single request is inside the engine it will continue to run and asynchronously stream out the results until the request is finished.
8+
New requests are added to the engine in a continuous fashion similar to the continuous batching mode shown in other examples.
9+
For all examples distributed inference can be enabled by following the instruction [here](./Readme.md#distributed-inference)
510

6-
- demo1: [Mistral](mistral)
7-
- demo2: [lora](lora)
11+
- demo1: [Meta-Llama3](llama3)
12+
- demo2: [Mistral](mistral)
13+
- demo3: [lora](lora)
814

915
### Supported vLLM Configuration
1016
* LLMEngine configuration:
@@ -13,3 +19,25 @@ vLLM [EngineArgs](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a24255
1319

1420
* Sampling parameters for text generation:
1521
vLLM [SamplingParams](https://github.com/vllm-project/vllm/blob/258a2c58d08fc7a242556120877a89404861fbce/vllm/sampling_params.py#L27) is defined in the JSON format, for example, [prompt.json](lora/prompt.json).
22+
23+
### Distributed Inference
24+
All examples can be easily distributed over multiple GPUs by enabling tensor parallelism in vLLM.
25+
To enable distributed inference the following additions need to made to the model-config.yaml of the examples where 4 is the number of desired GPUs to use for the inference:
26+
27+
```yaml
28+
# TorchServe frontend parameters
29+
...
30+
parallelType: "custom"
31+
parallelLevel: 4
32+
33+
handler:
34+
...
35+
vllm_engine_config:
36+
...
37+
tensor_parallel_size: 4
38+
```
39+
40+
### Multi-worker Note:
41+
While this example in theory works with multiple workers it would distribute the incoming requests in a round robin fashion which might lead to non optimal worker/hardware utilization.
42+
It is therefore advised to only use a single worker per engine and utilize tensor parallelism to distribute the model over multiple GPUs as described in the previous section.
43+
This will result in better hardware utilization and inference performance.

examples/large_models/vllm/base_vllm_handler.py

+51-60
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import json
12
import logging
23
import pathlib
4+
import time
35

4-
from vllm import EngineArgs, LLMEngine, SamplingParams
6+
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
57
from vllm.lora.request import LoRARequest
68

9+
from ts.handler_utils.utils import send_intermediate_predict_response
710
from ts.torch_handler.base_handler import BaseHandler
811

912
logger = logging.getLogger(__name__)
@@ -21,59 +24,64 @@ def __init__(self):
2124
self.initialized = False
2225

2326
def initialize(self, ctx):
24-
ctx.cache = {}
25-
2627
self.model_dir = ctx.system_properties.get("model_dir")
2728
vllm_engine_config = self._get_vllm_engine_config(
2829
ctx.model_yaml_config.get("handler", {})
2930
)
3031
self.adapters = ctx.model_yaml_config.get("handler", {}).get("adapters", {})
31-
self.vllm_engine = LLMEngine.from_engine_args(vllm_engine_config)
32+
33+
self.vllm_engine = AsyncLLMEngine.from_engine_args(vllm_engine_config)
3234
self.initialized = True
3335

34-
def preprocess(self, requests):
35-
for req_id, req_data in zip(self.context.request_ids.values(), requests):
36-
if req_id not in self.context.cache:
37-
data = req_data.get("data") or req_data.get("body")
38-
if isinstance(data, (bytes, bytearray)):
39-
data = data.decode("utf-8")
40-
41-
prompt = data.get("prompt")
42-
sampling_params = self._get_sampling_params(req_data)
43-
lora_request = self._get_lora_request(req_data)
44-
self.context.cache[req_id] = {
45-
"text_len": 0,
46-
"stopping_criteria": self._create_stopping_criteria(req_id),
47-
}
48-
self.vllm_engine.add_request(
49-
req_id, prompt, sampling_params, lora_request=lora_request
50-
)
36+
async def handle(self, data, context):
37+
start_time = time.time()
5138

52-
return requests
39+
metrics = context.metrics
5340

54-
def inference(self, input_batch):
55-
inference_outputs = self.vllm_engine.step()
56-
results = {}
41+
data_preprocess = await self.preprocess(data)
42+
output = await self.inference(data_preprocess, context)
43+
output = await self.postprocess(output)
5744

58-
for output in inference_outputs:
59-
req_id = output.request_id
60-
results[req_id] = {
61-
"text": output.outputs[0].text[
62-
self.context.cache[req_id]["text_len"] :
63-
],
45+
stop_time = time.time()
46+
metrics.add_time(
47+
"HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms"
48+
)
49+
return output
50+
51+
async def preprocess(self, requests):
52+
input_batch = []
53+
assert len(requests) == 1, "Expecting batch_size = 1"
54+
for req_data in requests:
55+
data = req_data.get("data") or req_data.get("body")
56+
if isinstance(data, (bytes, bytearray)):
57+
data = data.decode("utf-8")
58+
59+
prompt = data.get("prompt")
60+
sampling_params = self._get_sampling_params(data)
61+
lora_request = self._get_lora_request(data)
62+
input_batch += [(prompt, sampling_params, lora_request)]
63+
return input_batch
64+
65+
async def inference(self, input_batch, context):
66+
logger.debug(f"Inputs: {input_batch[0]}")
67+
prompt, params, lora = input_batch[0]
68+
generator = self.vllm_engine.generate(
69+
prompt, params, context.request_ids[0], lora
70+
)
71+
text_len = 0
72+
async for output in generator:
73+
result = {
74+
"text": output.outputs[0].text[text_len:],
6475
"tokens": output.outputs[0].token_ids[-1],
65-
"finished": output.finished,
6676
}
67-
self.context.cache[req_id]["text_len"] = len(output.outputs[0].text)
68-
69-
return [results[i] for i in self.context.request_ids.values()]
70-
71-
def postprocess(self, inference_outputs):
72-
self.context.stopping_criteria = [
73-
self.context.cache[req_id]["stopping_criteria"]
74-
for req_id in self.context.request_ids.values()
75-
]
77+
text_len = len(output.outputs[0].text)
78+
if not output.finished:
79+
send_intermediate_predict_response(
80+
[json.dumps(result)], context.request_ids, "Result", 200, context
81+
)
82+
return [json.dumps(result)]
7683

84+
async def postprocess(self, inference_outputs):
7785
return inference_outputs
7886

7987
def _get_vllm_engine_config(self, handler_config: dict):
@@ -85,8 +93,8 @@ def _get_vllm_engine_config(self, handler_config: dict):
8593
len(model_path) > 0
8694
), "please define model in vllm_engine_config or model_path in handler"
8795
model = str(pathlib.Path(self.model_dir).joinpath(model_path))
88-
logger.info(f"EngineArgs model={model}")
89-
vllm_engine_config = EngineArgs(model=model)
96+
logger.debug(f"EngineArgs model: {model}")
97+
vllm_engine_config = AsyncEngineArgs(model=model)
9098
self._set_attr_value(vllm_engine_config, vllm_engine_params)
9199
return vllm_engine_config
92100

@@ -104,27 +112,10 @@ def _get_lora_request(self, req_data: dict):
104112
assert len(adapter_path) > 0, f"{adapter_name} misses adapter path"
105113
lora_id = self.lora_ids.setdefault(adapter_name, len(self.lora_ids) + 1)
106114
adapter_path = str(pathlib.Path(self.model_dir).joinpath(adapter_path))
107-
logger.info(f"adapter_path=${adapter_path}")
108115
return LoRARequest(adapter_name, lora_id, adapter_path)
109116

110117
return None
111118

112-
def _clean_up(self, req_id):
113-
del self.context.cache[req_id]
114-
115-
def _create_stopping_criteria(self, req_id):
116-
class StoppingCriteria(object):
117-
def __init__(self, outer, req_id):
118-
self.req_id = req_id
119-
self.outer = outer
120-
121-
def __call__(self, res):
122-
if res["finished"]:
123-
self.outer._clean_up(self.req_id)
124-
return res["finished"]
125-
126-
return StoppingCriteria(outer=self, req_id=req_id)
127-
128119
def _set_attr_value(self, obj, config: dict):
129120
items = vars(obj)
130121
for k, v in config.items():
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Example showing inference with vLLM on LoRA model
2+
3+
This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `meta-llama/Meta-Llama-3-8B-Instruct` with continuous batching.
4+
This examples supports distributed inference by following [this instruction](../Readme.md#distributed-inference)
5+
6+
### Step 1: Download Model from HuggingFace
7+
8+
Login with a HuggingFace account
9+
```
10+
huggingface-cli login
11+
# or using an environment variable
12+
huggingface-cli login --token $HUGGINGFACE_TOKEN
13+
```
14+
15+
```bash
16+
python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-8B-Instruct --use_auth_token True
17+
```
18+
19+
### Step 2: Generate model artifacts
20+
21+
Add the downloaded path to "model_path:" in `model-config.yaml` and run the following.
22+
23+
```bash
24+
torch-model-archiver --model-name llama3-8b --version 1.0 --handler ../base_vllm_handler.py --config-file model-config.yaml -r ../requirements.txt --archive-format no-archive
25+
mv model llama3-8b
26+
```
27+
28+
### Step 3: Add the model artifacts to model store
29+
30+
```bash
31+
mkdir model_store
32+
mv llama3-8b model_store
33+
```
34+
35+
### Step 4: Start torchserve
36+
37+
```bash
38+
torchserve --start --ncs --ts-config ../config.properties --model-store model_store --models llama3-8b
39+
```
40+
41+
### Step 5: Run inference
42+
43+
```bash
44+
python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 --prompt-text "@prompt.json" --prompt-json
45+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# TorchServe frontend parameters
2+
minWorkers: 1
3+
maxWorkers: 1
4+
maxBatchDelay: 100
5+
responseTimeout: 1200
6+
deviceType: "gpu"
7+
asyncCommunication: true
8+
9+
handler:
10+
model_path: "model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/"
11+
vllm_engine_config:
12+
max_num_seqs: 16
13+
max_model_len: 250
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"prompt": "A robot may not injure a human being",
3+
"max_new_tokens": 50,
4+
"temperature": 0.8,
5+
"logprobs": 1,
6+
"prompt_logprobs": 1,
7+
"max_tokens": 128,
8+
"adapter": "adapter_1"
9+
}

examples/large_models/vllm/lora/Readme.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Example showing inference with vLLM on LoRA model
22

33
This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `Llama-2-7b-hf` + LoRA model `llama-2-7b-sql-lora-test` with continuous batching.
4+
This examples supports distributed inference by following [this instruction](../Readme.md#distributed-inference)
45

56
### Step 1: Download Model from HuggingFace
67

examples/large_models/vllm/lora/model-config.yaml

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# TorchServe frontend parameters
22
minWorkers: 1
33
maxWorkers: 1
4-
batchSize: 16
54
maxBatchDelay: 100
65
responseTimeout: 1200
76
deviceType: "gpu"
8-
continuousBatching: true
7+
asyncCommunication: true
98

109
handler:
11-
model_path: "model/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9"
10+
model_path: "model/models--meta-llama--Llama-2-7b-chat-hf/snapshots/f5db02db724555f92da89c216ac04704f23d4590/"
1211
vllm_engine_config:
1312
enable_lora: true
1413
max_loras: 4

examples/large_models/vllm/mistral/Readme.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Example showing inference with vLLM on Mistral model
22

33
This is an example showing how to integrate [vLLM](https://github.com/vllm-project/vllm) with TorchServe and run inference on model `mistralai/Mistral-7B-v0.1` with continuous batching.
4+
This examples supports distributed inference by following [this instruction](../Readme.md#distributed-inference)
45

56
### Step 1: Download Model from HuggingFace
67

examples/large_models/vllm/mistral/model-config.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# TorchServe frontend parameters
22
minWorkers: 1
33
maxWorkers: 1
4-
batchSize: 16
54
maxBatchDelay: 100
65
responseTimeout: 1200
76
deviceType: "gpu"
8-
continuousBatching: true
7+
asyncCommunication: true
98

109
handler:
1110
model_path: "model/models--mistralai--Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24"

0 commit comments

Comments
 (0)