Skip to content

Commit e1c31e1

Browse files
authored
stateful inference (pytorch#2513)
* stateful inference-core layer * add grpc layer * add google rpc submodule * fmt * update sequence batch img * update sequence batch img * fmt * delete used file * fmt * fix log and update doc * update log * fmt * make BatchAggregator as base * fix conflict * fix conflict * add SequenceBatchAggregator * update ci for submodule * refactor * fmt * fmt * fix lint * code refactor * update readme * update readme * fmt * fmt * test workflow * revert test * revert test response * fmt * fmt * update readme * allow number ofjobGroup is larger than batchsize * fmt * fix typo * add stateful test data * fmt * fmt * fmt * fmt * set default maxNumSequence * fmt * fmt * revert back config.properties * fmt
1 parent d0f8905 commit e1c31e1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1747
-270
lines changed

.github/workflows/benchmark_nightly.yml

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ jobs:
3636
java-version: '17'
3737
- name: Checkout TorchServe
3838
uses: actions/checkout@v3
39+
with:
40+
submodules: recursive
3941
- name: Install dependencies
4042
run: |
4143
sudo apt-get update -y

.github/workflows/ci_cpu.yml

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ jobs:
3535
java-version: '17'
3636
- name: Checkout TorchServe
3737
uses: actions/checkout@v3
38+
with:
39+
submodules: recursive
3840
- name: Install dependencies
3941
run: |
4042
python ts_scripts/install_dependencies.py --environment=dev

.github/workflows/ci_gpu.yml

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ jobs:
3939
java-version: '17'
4040
- name: Checkout TorchServe
4141
uses: actions/checkout@v3
42+
with:
43+
submodules: recursive
4244
- name: Install dependencies
4345
run: |
4446
python ts_scripts/install_dependencies.py --environment=dev --cuda=cu121

.github/workflows/codeql.yml

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ jobs:
3434
steps:
3535
- name: Checkout repository
3636
uses: actions/checkout@v3
37+
with:
38+
submodules: recursive
3739

3840
- name: Setup Python 3.8
3941
uses: actions/setup-python@v4

.github/workflows/docker-ci.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ jobs:
1717
python-version: ["3.8", "3.9", "3.10"]
1818
steps:
1919
- uses: actions/checkout@v3
20+
with:
21+
submodules: recursive
2022

2123
- name: Test build_image.sh script with custom tagging and gpu flag
2224
working-directory: docker

.github/workflows/docker-nightly-build.yml

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ jobs:
2222
architecture: x64
2323
- name: Checkout TorchServe
2424
uses: actions/checkout@v3
25+
with:
26+
submodules: recursive
2527
- name: Login to Docker
2628
env:
2729
DOCKER_PASSWORD: ${{secrets.DOCKER_PASSWORD}}

.github/workflows/regression_tests_cpu.yml

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ jobs:
3434
java-version: '17'
3535
- name: Checkout TorchServe
3636
uses: actions/checkout@v3
37+
with:
38+
submodules: recursive
3739
- name: Install dependencies
3840
run: |
3941
python ts_scripts/install_dependencies.py --environment=dev

.github/workflows/regression_tests_cpu_binaries.yml

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ jobs:
2121
binaries: ["pypi", "conda"]
2222
steps:
2323
- uses: actions/checkout@v3
24+
with:
25+
submodules: recursive
2426
- name: Setup conda with Python ${{ matrix.python-version }}
2527
uses: s-weigand/setup-conda@v1
2628
with:

.github/workflows/regression_tests_docker.yml

+2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ jobs:
2929
docker system prune --all --volumes -f
3030
- name: Checkout TorchServe
3131
uses: actions/checkout@v3
32+
with:
33+
submodules: recursive
3234
- name: Branch name
3335
run: |
3436
echo $GITHUB_REF_NAME

.github/workflows/regression_tests_gpu.yml

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ jobs:
4242
java-version: '17'
4343
- name: Checkout TorchServe
4444
uses: actions/checkout@v3
45+
with:
46+
submodules: recursive
4547
- name: Install dependencies
4648
run: |
4749
python ts_scripts/install_dependencies.py --environment=dev --cuda=cu121

.github/workflows/regression_tests_gpu_binaries.yml

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ jobs:
2828
ls -la ./
2929
- name: Checkout TorchServe
3030
uses: actions/checkout@v3
31+
with:
32+
submodules: recursive
3133
- uses: conda-incubator/setup-miniconda@v2
3234
with:
3335
miniconda-version: "latest"

.github/workflows/torchserve-nightly-build.yml

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ jobs:
1414
- run: conda install -y conda-build anaconda-client
1515
- name: Checkout TorchServe
1616
uses: actions/checkout@v3
17+
with:
18+
submodules: recursive
1719
- name: Install dependencies
1820
run: |
1921
python ts_scripts/install_dependencies.py --environment=dev

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "third_party/google/rpc"]
2+
path = third_party/google/rpc
3+
url = https://github.com/googleapis/googleapis.git

docs/grpc_api.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ Run following commands to Register, run inference and unregister, densenet161 mo
3333
```bash
3434
git clone https://github.com/pytorch/serve
3535
cd serve
36+
git submodule init
3637
```
3738

3839
- Install gRPC python dependencies
3940

4041
```bash
41-
pip install -U grpcio protobuf grpcio-tools
42+
pip install -U grpcio protobuf grpcio-tools googleapis-common-protos
4243
```
4344

4445
- Start torchServe
@@ -51,7 +52,7 @@ torchserve --start --model-store models/
5152
- Generate python gRPC client stub using the proto files
5253

5354
```bash
54-
python -m grpc_tools.protoc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
55+
python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
5556
```
5657

5758
- Register densenet161 model
@@ -95,4 +96,4 @@ def handle(data, context):
9596
for i in range (3):
9697
send_intermediate_predict_response(["intermediate_response"], context.request_ids, "Intermediate Prediction success", 200, context)
9798
return ["hello world "]
98-
```
99+
```

docs/images/stateful_batch.jpg

43.7 KB
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import logging
2+
from abc import ABC
3+
4+
import torch
5+
import transformers
6+
from transformers import AutoModelForCausalLM, AutoTokenizer
7+
8+
from ts.context import Context
9+
from ts.torch_handler.base_handler import BaseHandler
10+
11+
logger = logging.getLogger(__name__)
12+
logger.info("Transformers version %s", transformers.__version__)
13+
14+
15+
class LlamaHandler(BaseHandler, ABC):
16+
"""
17+
Transformers handler class for sequence, token classification and question answering.
18+
"""
19+
20+
def __init__(self):
21+
super(LlamaHandler, self).__init__()
22+
self.max_length = None
23+
self.max_new_tokens = None
24+
self.tokenizer = None
25+
self.initialized = False
26+
27+
def initialize(self, ctx: Context):
28+
"""In this initialize function, the HF large model is loaded and
29+
partitioned using DeepSpeed.
30+
Args:
31+
ctx (context): It is a JSON Object containing information
32+
pertaining to the model artifacts parameters.
33+
"""
34+
model_dir = ctx.system_properties.get("model_dir")
35+
self.max_length = int(ctx.model_yaml_config["handler"]["max_length"])
36+
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"])
37+
model_name = ctx.model_yaml_config["handler"]["model_name"]
38+
model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}'
39+
seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
40+
torch.manual_seed(seed)
41+
42+
logger.info("Model %s loading tokenizer", ctx.model_name)
43+
self.model = AutoModelForCausalLM.from_pretrained(
44+
model_path,
45+
device_map="balanced",
46+
low_cpu_mem_usage=True,
47+
torch_dtype=torch.float16,
48+
load_in_8bit=True,
49+
trust_remote_code=True,
50+
)
51+
if ctx.model_yaml_config["handler"]["fast_kernels"]:
52+
from optimum.bettertransformer import BetterTransformer
53+
54+
try:
55+
self.model = BetterTransformer.transform(self.model)
56+
except RuntimeError as error:
57+
logger.warning(
58+
"HuggingFace Optimum is not supporting this model,for the list of supported models, please refer to this doc,https://huggingface.co/docs/optimum/bettertransformer/overview"
59+
)
60+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
61+
62+
logger.info("Model %s loaded successfully", ctx.model_name)
63+
self.initialized = True
64+
65+
def preprocess(self, requests):
66+
"""
67+
Basic text preprocessing, based on the user's choice of application mode.
68+
Args:
69+
requests (list): A list of dictionaries with a "data" or "body" field, each
70+
containing the input text to be processed.
71+
Returns:
72+
tuple: A tuple with two tensors: the batch of input ids and the batch of
73+
attention masks.
74+
"""
75+
input_texts = [data.get("data") or data.get("body") for data in requests]
76+
input_ids_batch, attention_mask_batch = [], []
77+
for input_text in input_texts:
78+
input_ids, attention_mask = self.encode_input_text(input_text)
79+
input_ids_batch.append(input_ids)
80+
attention_mask_batch.append(attention_mask)
81+
input_ids_batch = torch.cat(input_ids_batch, dim=0).to(self.model.device)
82+
attention_mask_batch = torch.cat(attention_mask_batch, dim=0).to(self.device)
83+
return input_ids_batch, attention_mask_batch
84+
85+
def encode_input_text(self, input_text):
86+
"""
87+
Encodes a single input text using the tokenizer.
88+
Args:
89+
input_text (str): The input text to be encoded.
90+
Returns:
91+
tuple: A tuple with two tensors: the encoded input ids and the attention mask.
92+
"""
93+
if isinstance(input_text, (bytes, bytearray)):
94+
input_text = input_text.decode("utf-8")
95+
logger.info("Received text: '%s'", input_text)
96+
inputs = self.tokenizer.encode_plus(
97+
input_text,
98+
max_length=self.max_length,
99+
padding=False,
100+
add_special_tokens=True,
101+
return_tensors="pt",
102+
truncation=True,
103+
)
104+
input_ids = inputs["input_ids"]
105+
attention_mask = inputs["attention_mask"]
106+
return input_ids, attention_mask
107+
108+
def inference(self, input_batch):
109+
"""
110+
Predicts the class (or classes) of the received text using the serialized transformers
111+
checkpoint.
112+
Args:
113+
input_batch (tuple): A tuple with two tensors: the batch of input ids and the batch
114+
of attention masks, as returned by the preprocess function.
115+
Returns:
116+
list: A list of strings with the predicted values for each input text in the batch.
117+
"""
118+
input_ids_batch, attention_mask_batch = input_batch
119+
input_ids_batch = input_ids_batch.to(self.device)
120+
outputs = self.model.generate(
121+
input_ids_batch,
122+
attention_mask=attention_mask_batch,
123+
max_length=self.max_new_tokens,
124+
)
125+
126+
inferences = self.tokenizer.batch_decode(
127+
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
128+
)
129+
130+
logger.info("Generated text: %s", inferences)
131+
return inferences
132+
133+
def postprocess(self, inference_output):
134+
"""Post Process Function converts the predicted response into Torchserve readable format.
135+
Args:
136+
inference_output (list): It contains the predicted response of the input text.
137+
Returns:
138+
(list): Returns a list of the Predictions and Explanations.
139+
"""
140+
return inference_output

0 commit comments

Comments
 (0)