Skip to content

Commit d60ddb0

Browse files
agunapalmsaroufim
andauthored
TorchServe quickstart chatbot example (#3003)
* TorchServe quickstart chatbot example * Added more details in Readme * lint failure * code cleanup * review comments --------- Co-authored-by: Mark Saroufim <[email protected]>
1 parent bf0ee4b commit d60ddb0

9 files changed

+565
-0
lines changed

examples/LLM/llama2/chat_app/Readme.md

+27
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,33 @@ We are using [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) in
99
You can run this example on your laptop to understand how to use TorchServe
1010

1111

12+
## Quick Start Guide
13+
14+
To get started with TorchServe, you need to run the following
15+
16+
```
17+
# 1: Set HF Token as Env variable
18+
export HUGGINGFACE_TOKEN=<Token> # get this from your HuggingFace account
19+
20+
# 2: Build TorchServe Image for Serving llama2-7b model with 4-bit quantization
21+
./examples/llm/llama2/chat_app/docker/build_image.sh meta-llama/Llama-2-7b-chat-hf
22+
23+
# 3: Launch the streamlit app for server & client
24+
docker run --rm -it --platform linux/amd64 -p 127.0.0.1:8080:8080 -p 127.0.0.1:8081:8081 -p 127.0.0.1:8082:8082 -p 127.0.0.1:8084:8084 -p 127.0.0.1:8085:8085 -v <model-store>:/home/model-server/model-store pytorch/torchserve:meta-llama---Llama-2-7b-chat-hf
25+
```
26+
In step 3, `<model-store>` is a location where you want the model to be downloaded
27+
28+
### What to expect
29+
This launches two streamlit apps
30+
1. TorchServe Server app to start/stop TorchServe, load model, scale up/down workers, configure dynamic batch_size ( Currently llama-cpp-python doesn't support batch_size > 1)
31+
- Since this app is targeted for Apple M1/M2 laptops, we load a 4-bit quantized version of llama2 using llama-cpp-python.
32+
2. Client chat app where you can chat with the model . There is a slider to send concurrent requests to the model. The current app doesn't have a good mechanism to show multiple responses in parallel. You can notice streaming response for the first request followed by a complete response for the next request.
33+
34+
Currently, this launches llama2-7b model with 4-bit quantization running on CPU.
35+
36+
To make use of M1/M2 GPU, you can follow the below guide to do a standalone TorchServe installation.
37+
38+
1239
## Architecture
1340

1441
![Chatbot Architecture](./screenshots/architecture.png)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
ARG BASE_IMAGE=pytorch/torchserve:latest-gpu
2+
3+
FROM $BASE_IMAGE as server
4+
ARG BASE_IMAGE
5+
ARG EXAMPLE_DIR
6+
ARG MODEL_NAME
7+
ARG HUGGINGFACE_TOKEN
8+
9+
USER root
10+
11+
ENV MODEL_NAME=$MODEL_NAME
12+
13+
RUN --mount=type=cache,id=apt-dev,target=/var/cache/apt \
14+
apt-get update && \
15+
apt-get install libopenmpi-dev git -y
16+
17+
COPY $EXAMPLE_DIR/requirements.txt /home/model-server/chat_bot/requirements.txt
18+
RUN pip install -r /home/model-server/chat_bot/requirements.txt && huggingface-cli login --token $HUGGINGFACE_TOKEN
19+
20+
COPY $EXAMPLE_DIR /home/model-server/chat_bot
21+
COPY $EXAMPLE_DIR/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh
22+
COPY $EXAMPLE_DIR/config.properties /home/model-server/config.properties
23+
24+
WORKDIR /home/model-server/chat_bot
25+
RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh \
26+
&& chown -R model-server /home/model-server
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/bin/bash
2+
3+
# Check if there are enough arguments
4+
if [ "$#" -eq 0 ] || [ "$#" -gt 1 ]; then
5+
echo "Usage: $0 <HF Model>"
6+
exit 1
7+
fi
8+
9+
MODEL_NAME=$(echo "$1" | sed 's/\//---/g')
10+
echo "Model: " $MODEL_NAME
11+
12+
BASE_IMAGE="pytorch/torchserve:latest-cpu"
13+
14+
DOCKER_TAG="pytorch/torchserve:${MODEL_NAME}"
15+
16+
# Get relative path of example dir
17+
EXAMPLE_DIR=$(dirname "$(readlink -f "$0")")
18+
ROOT_DIR=${EXAMPLE_DIR}/../../../../..
19+
ROOT_DIR=$(realpath "$ROOT_DIR")
20+
EXAMPLE_DIR=$(echo "$EXAMPLE_DIR" | sed "s|$ROOT_DIR|./|")
21+
22+
# Build docker image for the application
23+
DOCKER_BUILDKIT=1 docker buildx build --platform=linux/amd64 --file ${EXAMPLE_DIR}/Dockerfile --build-arg BASE_IMAGE="${BASE_IMAGE}" --build-arg EXAMPLE_DIR="${EXAMPLE_DIR}" --build-arg MODEL_NAME="${MODEL_NAME}" --build-arg HUGGINGFACE_TOKEN -t "${DOCKER_TAG}" .
24+
25+
echo "Run the following command to start the chat bot"
26+
echo ""
27+
echo docker run --rm -it --platform linux/amd64 -p 127.0.0.1:8080:8080 -p 127.0.0.1:8081:8081 -p 127.0.0.1:8082:8082 -p 127.0.0.1:8084:8084 -p 127.0.0.1:8085:8085 -v $(pwd)/model_store_1:/home/model-server/model-store $DOCKER_TAG
28+
echo ""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import json
2+
import os
3+
from concurrent.futures import ThreadPoolExecutor
4+
5+
import requests
6+
import streamlit as st
7+
8+
MODEL_NAME = os.environ["MODEL_NAME"]
9+
10+
# App title
11+
st.set_page_config(page_title="TorchServe Chatbot")
12+
13+
with st.sidebar:
14+
st.title("TorchServe Chatbot")
15+
16+
st.session_state.model_loaded = False
17+
try:
18+
res = requests.get(url="http://localhost:8080/ping")
19+
res = requests.get(url=f"http://localhost:8081/models/{MODEL_NAME}")
20+
status = "NOT READY"
21+
if res.status_code == 200:
22+
status = json.loads(res.text)[0]["workers"][0]["status"]
23+
24+
if status == "READY":
25+
st.session_state.model_loaded = True
26+
st.success("Proceed to entering your prompt message!", icon="👉")
27+
else:
28+
st.warning("Model not loaded in TorchServe", icon="⚠️")
29+
30+
except requests.ConnectionError:
31+
st.warning("TorchServe is not up. Try again", icon="⚠️")
32+
33+
if st.session_state.model_loaded:
34+
st.success(f"Model loaded: {MODEL_NAME}!", icon="👉")
35+
36+
st.subheader("Model parameters")
37+
temperature = st.sidebar.slider(
38+
"temperature", min_value=0.1, max_value=1.0, value=0.5, step=0.1
39+
)
40+
top_p = st.sidebar.slider(
41+
"top_p", min_value=0.1, max_value=1.0, value=0.5, step=0.1
42+
)
43+
max_new_tokens = st.sidebar.slider(
44+
"max_new_tokens", min_value=48, max_value=512, value=50, step=4
45+
)
46+
concurrent_requests = st.sidebar.select_slider(
47+
"concurrent_requests", options=[2**j for j in range(0, 8)]
48+
)
49+
50+
# Store LLM generated responses
51+
if "messages" not in st.session_state.keys():
52+
st.session_state.messages = [
53+
{"role": "assistant", "content": "How may I assist you today?"}
54+
]
55+
56+
# Display or clear chat messages
57+
for message in st.session_state.messages:
58+
with st.chat_message(message["role"]):
59+
st.write(message["content"])
60+
61+
62+
def clear_chat_history():
63+
st.session_state.messages = [
64+
{"role": "assistant", "content": "How may I assist you today?"}
65+
]
66+
67+
68+
st.sidebar.button("Clear Chat History", on_click=clear_chat_history)
69+
70+
71+
def generate_model_response(prompt_input, executor):
72+
string_dialogue = (
73+
"Question: What are the names of the planets in the solar system? Answer: "
74+
)
75+
headers = {"Content-type": "application/json", "Accept": "text/plain"}
76+
url = f"http://127.0.0.1:8080/predictions/{MODEL_NAME}"
77+
data = json.dumps(
78+
{
79+
"prompt": prompt_input,
80+
"params": {
81+
"max_new_tokens": max_new_tokens,
82+
"top_p": top_p,
83+
"temperature": temperature,
84+
},
85+
}
86+
)
87+
res = [
88+
executor.submit(requests.post, url=url, data=data, headers=headers, stream=True)
89+
for i in range(concurrent_requests)
90+
]
91+
92+
return res, max_new_tokens
93+
94+
95+
# User-provided prompt
96+
if prompt := st.chat_input():
97+
st.session_state.messages.append({"role": "user", "content": prompt})
98+
with st.chat_message("user"):
99+
st.write(prompt)
100+
101+
# Generate a new response if last message is not from assistant
102+
if st.session_state.messages[-1]["role"] != "assistant":
103+
with st.chat_message("assistant"):
104+
with st.spinner("Thinking..."):
105+
with ThreadPoolExecutor() as executor:
106+
futures, max_tokens = generate_model_response(prompt, executor)
107+
placeholder = st.empty()
108+
full_response = ""
109+
count = 0
110+
for future in futures:
111+
response = future.result()
112+
for chunk in response.iter_content(chunk_size=None):
113+
if chunk:
114+
data = chunk.decode("utf-8")
115+
full_response += data
116+
placeholder.markdown(full_response)
117+
message = {"role": "assistant", "content": full_response}
118+
st.session_state.messages.append(message)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
metrics_mode=prometheus
2+
model_metrics_auto_detect=true
3+
inference_address=http://0.0.0.0:8080
4+
management_address=http://0.0.0.0:8081
5+
metrics_address=http://0.0.0.0:8082
6+
number_of_netty_threads=32
7+
job_queue_size=1000
8+
model_store=/home/model-server/model-store
9+
workflow_store=/home/model-server/wf-store
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#!/bin/bash
2+
set -e
3+
4+
export LLAMA2_Q4_MODEL=/home/model-server/model-store/$MODEL_NAME/model/ggml-model-q4_0.gguf
5+
6+
7+
create_model_cfg_yaml() {
8+
# Define the YAML content with a placeholder for the model name
9+
yaml_content="# TorchServe frontend parameters\nminWorkers: 1\nmaxWorkers: 1\nresponseTimeout: 1200\n#deviceType: \"gpu\"\n#deviceIds: [0,1]\n#torchrun:\n# nproc-per-node: 1\n\nhandler:\n model_name: \"${2}\"\n manual_seed: 40"
10+
11+
# Create the YAML file with the specified model name
12+
echo -e "$yaml_content" > "model-config-${1}.yaml"
13+
}
14+
15+
create_model_archive() {
16+
MODEL_NAME=$1
17+
MODEL_CFG=$2
18+
echo "Create model archive for ${MODEL_NAME} if it doesn't already exist"
19+
if [ -d "/home/model-server/model-store/$MODEL_NAME" ]; then
20+
echo "Model archive for $MODEL_NAME exists."
21+
fi
22+
if [ -d "/home/model-server/model-store/$MODEL_NAME/model" ]; then
23+
echo "Model already download"
24+
mv /home/model-server/model-store/$MODEL_NAME/model /home/model-server/model-store/
25+
else
26+
echo "Model needs to be downloaded"
27+
fi
28+
torch-model-archiver --model-name "$MODEL_NAME" --version 1.0 --handler llama_cpp_handler.py --config-file $MODEL_CFG -r requirements.txt --archive-format no-archive --export-path /home/model-server/model-store -f
29+
if [ -d "/home/model-server/model-store/model" ]; then
30+
mv /home/model-server/model-store/model /home/model-server/model-store/$MODEL_NAME/
31+
fi
32+
}
33+
34+
download_model() {
35+
MODEL_NAME=$1
36+
HF_MODEL_NAME=$2
37+
if [ -d "/home/model-server/model-store/$MODEL_NAME/model" ]; then
38+
echo "Model $HF_MODEL_NAME already downloaded"
39+
else
40+
echo "Downloading model $HF_MODEL_NAME"
41+
python Download_model.py --model_path /home/model-server/model-store/$MODEL_NAME/model --model_name $HF_MODEL_NAME
42+
fi
43+
}
44+
45+
quantize_model() {
46+
if [ ! -f "$LLAMA2_Q4_MODEL" ]; then
47+
tmp_model_name=$(echo "$MODEL_NAME" | sed 's/---/--/g')
48+
directory_path=/home/model-server/model-store/$MODEL_NAME/model/models--$tmp_model_name/snapshots/
49+
HF_MODEL_SNAPSHOT=$(find $directory_path -type d -mindepth 1)
50+
echo "Cleaning up previous build of llama-cpp"
51+
git clone https://github.com/ggerganov/llama.cpp.git build
52+
cd build
53+
make
54+
python -m pip install -r requirements.txt
55+
56+
echo "Convert the 7B model to ggml FP16 format"
57+
python convert.py $HF_MODEL_SNAPSHOT --outfile ggml-model-f16.gguf
58+
59+
echo "Quantize the model to 4-bits (using q4_0 method)"
60+
./quantize ggml-model-f16.gguf $LLAMA2_Q4_MODEL q4_0
61+
62+
cd ..
63+
echo "Saved quantized model weights to $LLAMA2_Q4_MODEL"
64+
fi
65+
}
66+
67+
HF_MODEL_NAME=$(echo "$MODEL_NAME" | sed 's/---/\//g')
68+
if [[ "$1" = "serve" ]]; then
69+
shift 1
70+
create_model_cfg_yaml $MODEL_NAME $HF_MODEL_NAME
71+
create_model_archive $MODEL_NAME "model-config-$MODEL_NAME.yaml"
72+
download_model $MODEL_NAME $HF_MODEL_NAME
73+
quantize_model
74+
streamlit run torchserve_server_app.py --server.port 8084 &
75+
streamlit run client_app.py --server.port 8085
76+
else
77+
eval "$@"
78+
fi
79+
80+
# prevent docker exit
81+
tail -f /dev/null
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import logging
2+
import os
3+
from abc import ABC
4+
5+
import torch
6+
from llama_cpp import Llama
7+
8+
from ts.protocol.otf_message_handler import send_intermediate_predict_response
9+
from ts.torch_handler.base_handler import BaseHandler
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class LlamaCppHandler(BaseHandler, ABC):
15+
def __init__(self):
16+
super(LlamaCppHandler, self).__init__()
17+
self.initialized = False
18+
19+
def initialize(self, ctx):
20+
"""In this initialize function, the HF large model is loaded and
21+
partitioned using DeepSpeed.
22+
Args:
23+
ctx (context): It is a JSON Object containing information
24+
pertaining to the model artifacts parameters.
25+
"""
26+
model_path = os.environ["LLAMA2_Q4_MODEL"]
27+
model_name = ctx.model_yaml_config["handler"]["model_name"]
28+
seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
29+
torch.manual_seed(seed)
30+
31+
self.model = Llama(model_path=model_path)
32+
logger.info(f"Loaded {model_name} model successfully")
33+
34+
def preprocess(self, data):
35+
assert (
36+
len(data) == 1
37+
), "llama-cpp-python is currently only supported with batch_size=1"
38+
for row in data:
39+
item = row.get("body")
40+
return item
41+
42+
def inference(self, data):
43+
params = data["params"]
44+
tokens = self.model.tokenize(bytes(data["prompt"], "utf-8"))
45+
generation_kwargs = dict(
46+
tokens=tokens,
47+
temp=params["temperature"],
48+
top_p=params["top_p"],
49+
)
50+
count = 0
51+
for token in self.model.generate(**generation_kwargs):
52+
if count >= params["max_new_tokens"]:
53+
break
54+
55+
count += 1
56+
new_text = self.model.detokenize([token])
57+
send_intermediate_predict_response(
58+
[new_text],
59+
self.context.request_ids,
60+
"Intermediate Prediction success",
61+
200,
62+
self.context,
63+
)
64+
return [""]
65+
66+
def postprocess(self, output):
67+
return output
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
transformers
2+
llama-cpp-python
3+
streamlit>=1.26.0
4+
requests_futures

0 commit comments

Comments
 (0)